From 27db9d8c05188f64c17d61c394d3dafe8b8e93d8 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 19 Dec 2009 14:32:28 -0500 Subject: cool new alignment feature --- decoder/cdec_ff.cc | 1 + decoder/ff_wordalign.cc | 67 ++++++++++++++++++++++++++++++++++++++++++++++ decoder/ff_wordalign.h | 20 ++++++++++++++ decoder/lexcrf.cc | 6 ++--- training/cluster-ptrain.pl | 14 ++++++++++ 5 files changed, 105 insertions(+), 3 deletions(-) diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index bb2c9d34..437de428 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -15,6 +15,7 @@ void register_feature_functions() { global_ff_registry->Register("SourceWordPenalty", new FFFactory); global_ff_registry->Register("RelativeSentencePosition", new FFFactory); global_ff_registry->Register("MarkovJump", new FFFactory); + global_ff_registry->Register("SourcePOSBigram", new FFFactory); global_ff_registry->Register("BlunsomSynchronousParseHack", new FFFactory); global_ff_registry->Register("AlignerResults", new FFFactory); global_ff_registry->Register("CSplit_BasicFeatures", new FFFactory); diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc index a00b2c76..f07eda02 100644 --- a/decoder/ff_wordalign.cc +++ b/decoder/ff_wordalign.cc @@ -1,5 +1,6 @@ #include "ff_wordalign.h" +#include #include #include @@ -126,6 +127,72 @@ void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } +// state: POS of src word used, number of trg words generated +SourcePOSBigram::SourcePOSBigram(const std::string& param) : + FeatureFunction(sizeof(WordID) + sizeof(int)) { + cerr << "Reading source POS tags from " << param << endl; + ReadFile rf(param); + istream& in = *rf.stream(); + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + vector v; + TD::ConvertSentence(line, &v); + pos_.push_back(v); + } + cerr << " (" << pos_.size() << " lines)\n"; +} + +void SourcePOSBigram::FireFeature(WordID left, + WordID right, + SparseVector* features) const { + int& fid = fmap_[left][right]; + if (!fid) { + ostringstream os; + os << "SP:"; + if (left < 0) { os << "BOS"; } else { os << TD::Convert(left); } + os << '_'; + if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); } + fid = FD::Convert(os.str()); + if (fid == 0) fid = -1; + } + if (fid < 0) return; + features->set_value(fid, 1.0); +} + +void SourcePOSBigram::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + WordID& out_context = *static_cast(context); + int& out_word_count = *(static_cast(context) + 1); + const int arity = edge.Arity(); + if (arity == 0) { + assert(smeta.GetSentenceID() < pos_.size()); + const vector& pos_sent = pos_[smeta.GetSentenceID()]; + assert(edge.i_ < pos_sent.size()); + out_context = pos_sent[edge.i_]; + out_word_count = edge.rule_->EWords(); + assert(out_word_count == 1); // this is only defined for lex translation! + // revisit this if you want to translate into null words + } else if (arity == 2) { + WordID left = *static_cast(ant_contexts[0]); + WordID right = *static_cast(ant_contexts[1]); + int left_wc = *(static_cast(ant_contexts[0]) + 1); + int right_wc = *(static_cast(ant_contexts[0]) + 1); + if (left_wc == 1 && right_wc == 1) + FireFeature(-1, left, features); + FireFeature(left, right, features); + out_word_count = left_wc + right_wc; + if (out_word_count == smeta.GetSourceLength()) + FireFeature(right, -1, features); + out_context = right; + } +} + AlignerResults::AlignerResults(const std::string& param) : cur_sent_(-1), cur_grid_(NULL) { diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h index 4a8b59c7..554dd23e 100644 --- a/decoder/ff_wordalign.h +++ b/decoder/ff_wordalign.h @@ -38,6 +38,26 @@ class MarkovJump : public FeatureFunction { std::string template_; }; +typedef std::map Class2FID; +typedef std::map Class2Class2FID; +class SourcePOSBigram : public FeatureFunction { + public: + SourcePOSBigram(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + void FireFeature(WordID src, + WordID trg, + SparseVector* features) const; + mutable Class2Class2FID fmap_; + std::vector > pos_; +}; + class AlignerResults : public FeatureFunction { public: AlignerResults(const std::string& param); diff --git a/decoder/lexcrf.cc b/decoder/lexcrf.cc index 9f96de9f..b80d055c 100644 --- a/decoder/lexcrf.cc +++ b/decoder/lexcrf.cc @@ -46,7 +46,7 @@ struct LexicalCRFImpl { // hack to tell the feature function system how big the sentence pair is const int f_start = (use_null ? -1 : 0); int prev_node_id = -1; - for (int i = 0; i < e_len; ++i) { // for each word in the *ref* + for (int i = 0; i < e_len; ++i) { // for each word in the *target* Hypergraph::Node* node = forest->AddNode(kXCAT); const int new_node_id = node->id_; for (int j = f_start; j < f_len; ++j) { // for each word in the source @@ -73,8 +73,8 @@ struct LexicalCRFImpl { const int comb_node_id = forest->AddNode(kXCAT)->id_; Hypergraph::TailNodeVector tail(2, prev_node_id); tail[1] = new_node_id; - const int edge_id = forest->AddEdge(kBINARY, tail)->id_; - forest->ConnectEdgeToHeadNode(edge_id, comb_node_id); + Hypergraph::Edge* edge = forest->AddEdge(kBINARY, tail); + forest->ConnectEdgeToHeadNode(edge->id_, comb_node_id); prev_node_id = comb_node_id; } else { prev_node_id = new_node_id; diff --git a/training/cluster-ptrain.pl b/training/cluster-ptrain.pl index 33aab25d..8944ae34 100755 --- a/training/cluster-ptrain.pl +++ b/training/cluster-ptrain.pl @@ -104,7 +104,21 @@ if ($restart) { } else { `cp $initial_weights $dir/weights.1.gz`; } + open T, "<$training_corpus" or die "Can't read $training_corpus: $!"; + open TO, ">$dir/training.in"; + my $lc = 0; + while() { + chomp; + s/^\s+//; + s/\s+$//; + die "Expected A ||| B in input file" unless / \|\|\| /; + print TO "$_\n"; + $lc++; + } + close T; + close TO; } +$training_corpus = "$dir/training.in"; my $iter_attempts = 1; while ($iter < $max_iteration) { -- cgit v1.2.3