summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2009-12-19 14:32:28 -0500
committerChris Dyer <redpony@gmail.com>2009-12-19 14:32:28 -0500
commit27db9d8c05188f64c17d61c394d3dafe8b8e93d8 (patch)
tree688930b6e95b6801ffe7d722f33a4f56712ecd21 /decoder
parent39b9c1e0aaec81492d81e541daf7703ba8c517ff (diff)
cool new alignment feature
Diffstat (limited to 'decoder')
-rw-r--r--decoder/cdec_ff.cc1
-rw-r--r--decoder/ff_wordalign.cc67
-rw-r--r--decoder/ff_wordalign.h20
-rw-r--r--decoder/lexcrf.cc6
4 files changed, 91 insertions, 3 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index bb2c9d34..437de428 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -15,6 +15,7 @@ void register_feature_functions() {
global_ff_registry->Register("SourceWordPenalty", new FFFactory<SourceWordPenalty>);
global_ff_registry->Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>);
global_ff_registry->Register("MarkovJump", new FFFactory<MarkovJump>);
+ global_ff_registry->Register("SourcePOSBigram", new FFFactory<SourcePOSBigram>);
global_ff_registry->Register("BlunsomSynchronousParseHack", new FFFactory<BlunsomSynchronousParseHack>);
global_ff_registry->Register("AlignerResults", new FFFactory<AlignerResults>);
global_ff_registry->Register("CSplit_BasicFeatures", new FFFactory<BasicCSplitFeatures>);
diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc
index a00b2c76..f07eda02 100644
--- a/decoder/ff_wordalign.cc
+++ b/decoder/ff_wordalign.cc
@@ -1,5 +1,6 @@
#include "ff_wordalign.h"
+#include <sstream>
#include <string>
#include <cmath>
@@ -126,6 +127,72 @@ void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
}
+// state: POS of src word used, number of trg words generated
+SourcePOSBigram::SourcePOSBigram(const std::string& param) :
+ FeatureFunction(sizeof(WordID) + sizeof(int)) {
+ cerr << "Reading source POS tags from " << param << endl;
+ ReadFile rf(param);
+ istream& in = *rf.stream();
+ while(in) {
+ string line;
+ getline(in, line);
+ if (line.empty()) continue;
+ vector<WordID> v;
+ TD::ConvertSentence(line, &v);
+ pos_.push_back(v);
+ }
+ cerr << " (" << pos_.size() << " lines)\n";
+}
+
+void SourcePOSBigram::FireFeature(WordID left,
+ WordID right,
+ SparseVector<double>* features) const {
+ int& fid = fmap_[left][right];
+ if (!fid) {
+ ostringstream os;
+ os << "SP:";
+ if (left < 0) { os << "BOS"; } else { os << TD::Convert(left); }
+ os << '_';
+ if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); }
+ fid = FD::Convert(os.str());
+ if (fid == 0) fid = -1;
+ }
+ if (fid < 0) return;
+ features->set_value(fid, 1.0);
+}
+
+void SourcePOSBigram::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ WordID& out_context = *static_cast<WordID*>(context);
+ int& out_word_count = *(static_cast<int*>(context) + 1);
+ const int arity = edge.Arity();
+ if (arity == 0) {
+ assert(smeta.GetSentenceID() < pos_.size());
+ const vector<WordID>& pos_sent = pos_[smeta.GetSentenceID()];
+ assert(edge.i_ < pos_sent.size());
+ out_context = pos_sent[edge.i_];
+ out_word_count = edge.rule_->EWords();
+ assert(out_word_count == 1); // this is only defined for lex translation!
+ // revisit this if you want to translate into null words
+ } else if (arity == 2) {
+ WordID left = *static_cast<const WordID*>(ant_contexts[0]);
+ WordID right = *static_cast<const WordID*>(ant_contexts[1]);
+ int left_wc = *(static_cast<const int*>(ant_contexts[0]) + 1);
+ int right_wc = *(static_cast<const int*>(ant_contexts[0]) + 1);
+ if (left_wc == 1 && right_wc == 1)
+ FireFeature(-1, left, features);
+ FireFeature(left, right, features);
+ out_word_count = left_wc + right_wc;
+ if (out_word_count == smeta.GetSourceLength())
+ FireFeature(right, -1, features);
+ out_context = right;
+ }
+}
+
AlignerResults::AlignerResults(const std::string& param) :
cur_sent_(-1),
cur_grid_(NULL) {
diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h
index 4a8b59c7..554dd23e 100644
--- a/decoder/ff_wordalign.h
+++ b/decoder/ff_wordalign.h
@@ -38,6 +38,26 @@ class MarkovJump : public FeatureFunction {
std::string template_;
};
+typedef std::map<WordID, int> Class2FID;
+typedef std::map<WordID, Class2FID> Class2Class2FID;
+class SourcePOSBigram : public FeatureFunction {
+ public:
+ SourcePOSBigram(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ private:
+ void FireFeature(WordID src,
+ WordID trg,
+ SparseVector<double>* features) const;
+ mutable Class2Class2FID fmap_;
+ std::vector<std::vector<WordID> > pos_;
+};
+
class AlignerResults : public FeatureFunction {
public:
AlignerResults(const std::string& param);
diff --git a/decoder/lexcrf.cc b/decoder/lexcrf.cc
index 9f96de9f..b80d055c 100644
--- a/decoder/lexcrf.cc
+++ b/decoder/lexcrf.cc
@@ -46,7 +46,7 @@ struct LexicalCRFImpl {
// hack to tell the feature function system how big the sentence pair is
const int f_start = (use_null ? -1 : 0);
int prev_node_id = -1;
- for (int i = 0; i < e_len; ++i) { // for each word in the *ref*
+ for (int i = 0; i < e_len; ++i) { // for each word in the *target*
Hypergraph::Node* node = forest->AddNode(kXCAT);
const int new_node_id = node->id_;
for (int j = f_start; j < f_len; ++j) { // for each word in the source
@@ -73,8 +73,8 @@ struct LexicalCRFImpl {
const int comb_node_id = forest->AddNode(kXCAT)->id_;
Hypergraph::TailNodeVector tail(2, prev_node_id);
tail[1] = new_node_id;
- const int edge_id = forest->AddEdge(kBINARY, tail)->id_;
- forest->ConnectEdgeToHeadNode(edge_id, comb_node_id);
+ Hypergraph::Edge* edge = forest->AddEdge(kBINARY, tail);
+ forest->ConnectEdgeToHeadNode(edge->id_, comb_node_id);
prev_node_id = comb_node_id;
} else {
prev_node_id = new_node_id;