diff options
Diffstat (limited to 'decoder/ff_wordalign.cc')
-rw-r--r-- | decoder/ff_wordalign.cc | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc index a00b2c76..f07eda02 100644 --- a/decoder/ff_wordalign.cc +++ b/decoder/ff_wordalign.cc @@ -1,5 +1,6 @@ #include "ff_wordalign.h" +#include <sstream> #include <string> #include <cmath> @@ -126,6 +127,72 @@ void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } +// state: POS of src word used, number of trg words generated +SourcePOSBigram::SourcePOSBigram(const std::string& param) : + FeatureFunction(sizeof(WordID) + sizeof(int)) { + cerr << "Reading source POS tags from " << param << endl; + ReadFile rf(param); + istream& in = *rf.stream(); + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + vector<WordID> v; + TD::ConvertSentence(line, &v); + pos_.push_back(v); + } + cerr << " (" << pos_.size() << " lines)\n"; +} + +void SourcePOSBigram::FireFeature(WordID left, + WordID right, + SparseVector<double>* features) const { + int& fid = fmap_[left][right]; + if (!fid) { + ostringstream os; + os << "SP:"; + if (left < 0) { os << "BOS"; } else { os << TD::Convert(left); } + os << '_'; + if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); } + fid = FD::Convert(os.str()); + if (fid == 0) fid = -1; + } + if (fid < 0) return; + features->set_value(fid, 1.0); +} + +void SourcePOSBigram::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const { + WordID& out_context = *static_cast<WordID*>(context); + int& out_word_count = *(static_cast<int*>(context) + 1); + const int arity = edge.Arity(); + if (arity == 0) { + assert(smeta.GetSentenceID() < pos_.size()); + const vector<WordID>& pos_sent = pos_[smeta.GetSentenceID()]; + assert(edge.i_ < pos_sent.size()); + out_context = pos_sent[edge.i_]; + out_word_count = edge.rule_->EWords(); + assert(out_word_count == 1); // this is only defined for lex translation! + // revisit this if you want to translate into null words + } else if (arity == 2) { + WordID left = *static_cast<const WordID*>(ant_contexts[0]); + WordID right = *static_cast<const WordID*>(ant_contexts[1]); + int left_wc = *(static_cast<const int*>(ant_contexts[0]) + 1); + int right_wc = *(static_cast<const int*>(ant_contexts[0]) + 1); + if (left_wc == 1 && right_wc == 1) + FireFeature(-1, left, features); + FireFeature(left, right, features); + out_word_count = left_wc + right_wc; + if (out_word_count == smeta.GetSourceLength()) + FireFeature(right, -1, features); + out_context = right; + } +} + AlignerResults::AlignerResults(const std::string& param) : cur_sent_(-1), cur_grid_(NULL) { |