summaryrefslogtreecommitdiff
path: root/decoder/ff_wordalign.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/ff_wordalign.cc')
-rw-r--r--decoder/ff_wordalign.cc67
1 files changed, 67 insertions, 0 deletions
diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc
index a00b2c76..f07eda02 100644
--- a/decoder/ff_wordalign.cc
+++ b/decoder/ff_wordalign.cc
@@ -1,5 +1,6 @@
#include "ff_wordalign.h"
+#include <sstream>
#include <string>
#include <cmath>
@@ -126,6 +127,72 @@ void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
}
+// state: POS of src word used, number of trg words generated
+SourcePOSBigram::SourcePOSBigram(const std::string& param) :
+ FeatureFunction(sizeof(WordID) + sizeof(int)) {
+ cerr << "Reading source POS tags from " << param << endl;
+ ReadFile rf(param);
+ istream& in = *rf.stream();
+ while(in) {
+ string line;
+ getline(in, line);
+ if (line.empty()) continue;
+ vector<WordID> v;
+ TD::ConvertSentence(line, &v);
+ pos_.push_back(v);
+ }
+ cerr << " (" << pos_.size() << " lines)\n";
+}
+
+void SourcePOSBigram::FireFeature(WordID left,
+ WordID right,
+ SparseVector<double>* features) const {
+ int& fid = fmap_[left][right];
+ if (!fid) {
+ ostringstream os;
+ os << "SP:";
+ if (left < 0) { os << "BOS"; } else { os << TD::Convert(left); }
+ os << '_';
+ if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); }
+ fid = FD::Convert(os.str());
+ if (fid == 0) fid = -1;
+ }
+ if (fid < 0) return;
+ features->set_value(fid, 1.0);
+}
+
+void SourcePOSBigram::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ WordID& out_context = *static_cast<WordID*>(context);
+ int& out_word_count = *(static_cast<int*>(context) + 1);
+ const int arity = edge.Arity();
+ if (arity == 0) {
+ assert(smeta.GetSentenceID() < pos_.size());
+ const vector<WordID>& pos_sent = pos_[smeta.GetSentenceID()];
+ assert(edge.i_ < pos_sent.size());
+ out_context = pos_sent[edge.i_];
+ out_word_count = edge.rule_->EWords();
+ assert(out_word_count == 1); // this is only defined for lex translation!
+ // revisit this if you want to translate into null words
+ } else if (arity == 2) {
+ WordID left = *static_cast<const WordID*>(ant_contexts[0]);
+ WordID right = *static_cast<const WordID*>(ant_contexts[1]);
+ int left_wc = *(static_cast<const int*>(ant_contexts[0]) + 1);
+ int right_wc = *(static_cast<const int*>(ant_contexts[0]) + 1);
+ if (left_wc == 1 && right_wc == 1)
+ FireFeature(-1, left, features);
+ FireFeature(left, right, features);
+ out_word_count = left_wc + right_wc;
+ if (out_word_count == smeta.GetSourceLength())
+ FireFeature(right, -1, features);
+ out_context = right;
+ }
+}
+
AlignerResults::AlignerResults(const std::string& param) :
cur_sent_(-1),
cur_grid_(NULL) {