From 4d47dbd7da0434de67ac619392d516c678e1f2ca Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 18 Feb 2010 17:06:59 -0500 Subject: add generative word alignment model and primitive EM trainer. Model 1 and HMM are supported, without NULL source words --- decoder/lexalign.cc | 113 ++++++++++++++++++++++++++++++++++++++++++++++++++++ decoder/lexalign.h | 18 +++++++++ decoder/lexcrf.cc | 113 ---------------------------------------------------- decoder/lexcrf.h | 18 --------- decoder/lextrans.cc | 113 ++++++++++++++++++++++++++++++++++++++++++++++++++++ decoder/lextrans.h | 18 +++++++++ 6 files changed, 262 insertions(+), 131 deletions(-) create mode 100644 decoder/lexalign.cc create mode 100644 decoder/lexalign.h delete mode 100644 decoder/lexcrf.cc delete mode 100644 decoder/lexcrf.h create mode 100644 decoder/lextrans.cc create mode 100644 decoder/lextrans.h (limited to 'decoder') diff --git a/decoder/lexalign.cc b/decoder/lexalign.cc new file mode 100644 index 00000000..ee3b5fe0 --- /dev/null +++ b/decoder/lexalign.cc @@ -0,0 +1,113 @@ +#include "lexalign.h" + +#include + +#include "filelib.h" +#include "hg.h" +#include "tdict.h" +#include "grammar.h" +#include "sentence_metadata.h" + +using namespace std; + +struct LexicalAlignImpl { + LexicalAlignImpl(const boost::program_options::variables_map& conf) : + use_null(conf.count("lexcrf_use_null") > 0), + kXCAT(TD::Convert("X")*-1), + kNULL(TD::Convert("")), + kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")), + kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) { + } + + void BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) { + const int e_len = smeta.GetTargetLength(); + assert(e_len > 0); + const Lattice& target = smeta.GetReference(); + const int f_len = lattice.size(); + // hack to tell the feature function system how big the sentence pair is + const int f_start = (use_null ? -1 : 0); + int prev_node_id = -1; + for (int i = 0; i < e_len; ++i) { // for each word in the *target* + const WordID& e_i = target[i][0].label; + Hypergraph::Node* node = forest->AddNode(kXCAT); + const int new_node_id = node->id_; + for (int j = f_start; j < f_len; ++j) { // for each word in the source + const WordID src_sym = (j < 0 ? kNULL : lattice[j][0].label); + TRulePtr& rule = LexRule(src_sym, e_i); + Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); + edge->i_ = j; + edge->j_ = j+1; + edge->prev_i_ = i; + edge->prev_j_ = i+1; + edge->feature_values_ += edge->rule_->GetFeatureValues(); + forest->ConnectEdgeToHeadNode(edge->id_, new_node_id); + } + if (prev_node_id >= 0) { + const int comb_node_id = forest->AddNode(kXCAT)->id_; + Hypergraph::TailNodeVector tail(2, prev_node_id); + tail[1] = new_node_id; + Hypergraph::Edge* edge = forest->AddEdge(kBINARY, tail); + forest->ConnectEdgeToHeadNode(edge->id_, comb_node_id); + prev_node_id = comb_node_id; + } else { + prev_node_id = new_node_id; + } + } + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + } + + inline int LexFeatureId(const WordID& f, const WordID& e) { + map& e2fid = f2e2fid[f]; + map::iterator it = e2fid.find(e); + if (it != e2fid.end()) + return it->second; + int& fid = e2fid[e]; + if (f == 0) { + fid = FD::Convert("Lx__" + FD::Escape(TD::Convert(e))); + } else { + fid = FD::Convert("Lx_" + FD::Escape(TD::Convert(f)) + "_" + FD::Escape(TD::Convert(e))); + } + return fid; + } + + inline TRulePtr& LexRule(const WordID& f, const WordID& e) { + map& e2rule = f2e2rule[f]; + map::iterator it = e2rule.find(e); + if (it != e2rule.end()) + return it->second; + TRulePtr& tr = e2rule[e]; + tr.reset(TRule::CreateLexicalRule(f, e)); + tr->scores_.set_value(LexFeatureId(f, e), 1.0); + return tr; + } + + private: + const bool use_null; + const WordID kXCAT; + const WordID kNULL; + const TRulePtr kBINARY; + const TRulePtr kGOAL_RULE; + map > f2e2rule; + map > f2e2fid; + GrammarPtr grammar; +}; + +LexicalAlign::LexicalAlign(const boost::program_options::variables_map& conf) : + pimpl_(new LexicalAlignImpl(conf)) {} + +bool LexicalAlign::Translate(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* forest) { + Lattice lattice; + LatticeTools::ConvertTextToLattice(input, &lattice); + smeta->SetSourceLength(lattice.size()); + pimpl_->BuildTrellis(lattice, *smeta, forest); + forest->is_linear_chain_ = true; + forest->Reweight(weights); + return true; +} + diff --git a/decoder/lexalign.h b/decoder/lexalign.h new file mode 100644 index 00000000..30c89c57 --- /dev/null +++ b/decoder/lexalign.h @@ -0,0 +1,18 @@ +#ifndef _LEXALIGN_H_ +#define _LEXALIGN_H_ + +#include "translator.h" +#include "lattice.h" + +struct LexicalAlignImpl; +struct LexicalAlign : public Translator { + LexicalAlign(const boost::program_options::variables_map& conf); + bool Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* forest); + private: + boost::shared_ptr pimpl_; +}; + +#endif diff --git a/decoder/lexcrf.cc b/decoder/lexcrf.cc deleted file mode 100644 index b0e03c69..00000000 --- a/decoder/lexcrf.cc +++ /dev/null @@ -1,113 +0,0 @@ -#include "lexcrf.h" - -#include - -#include "filelib.h" -#include "hg.h" -#include "tdict.h" -#include "grammar.h" -#include "sentence_metadata.h" - -using namespace std; - -struct LexicalCRFImpl { - LexicalCRFImpl(const boost::program_options::variables_map& conf) : - use_null(conf.count("lexcrf_use_null") > 0), - kXCAT(TD::Convert("X")*-1), - kNULL(TD::Convert("")), - kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")), - kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) { - vector gfiles = conf["grammar"].as >(); - assert(gfiles.size() == 1); - ReadFile rf(gfiles.front()); - TextGrammar *tg = new TextGrammar; - grammar.reset(tg); - istream* in = rf.stream(); - int lc = 0; - bool flag = false; - while(*in) { - string line; - getline(*in, line); - if (line.empty()) continue; - ++lc; - TRulePtr r(TRule::CreateRulePhrasetable(line)); - tg->AddRule(r); - if (lc % 50000 == 0) { cerr << '.'; flag = true; } - if (lc % 2000000 == 0) { cerr << " [" << lc << "]\n"; flag = false; } - } - if (flag) cerr << endl; - cerr << "Loaded " << lc << " rules\n"; - } - - void BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) { - const int e_len = smeta.GetTargetLength(); - assert(e_len > 0); - const int f_len = lattice.size(); - // hack to tell the feature function system how big the sentence pair is - const int f_start = (use_null ? -1 : 0); - int prev_node_id = -1; - for (int i = 0; i < e_len; ++i) { // for each word in the *target* - Hypergraph::Node* node = forest->AddNode(kXCAT); - const int new_node_id = node->id_; - for (int j = f_start; j < f_len; ++j) { // for each word in the source - const WordID src_sym = (j < 0 ? kNULL : lattice[j][0].label); - const GrammarIter* gi = grammar->GetRoot()->Extend(src_sym); - if (!gi) { - cerr << "No translations found for: " << TD::Convert(src_sym) << "\n"; - abort(); - } - const RuleBin* rb = gi->GetRules(); - assert(rb); - for (int k = 0; k < rb->GetNumRules(); ++k) { - TRulePtr rule = rb->GetIthRule(k); - Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); - edge->i_ = j; - edge->j_ = j+1; - edge->prev_i_ = i; - edge->prev_j_ = i+1; - edge->feature_values_ += edge->rule_->GetFeatureValues(); - forest->ConnectEdgeToHeadNode(edge->id_, new_node_id); - } - } - if (prev_node_id >= 0) { - const int comb_node_id = forest->AddNode(kXCAT)->id_; - Hypergraph::TailNodeVector tail(2, prev_node_id); - tail[1] = new_node_id; - Hypergraph::Edge* edge = forest->AddEdge(kBINARY, tail); - forest->ConnectEdgeToHeadNode(edge->id_, comb_node_id); - prev_node_id = comb_node_id; - } else { - prev_node_id = new_node_id; - } - } - Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); - Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); - Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); - forest->ConnectEdgeToHeadNode(hg_edge, goal); - } - - private: - const bool use_null; - const WordID kXCAT; - const WordID kNULL; - const TRulePtr kBINARY; - const TRulePtr kGOAL_RULE; - GrammarPtr grammar; -}; - -LexicalCRF::LexicalCRF(const boost::program_options::variables_map& conf) : - pimpl_(new LexicalCRFImpl(conf)) {} - -bool LexicalCRF::Translate(const string& input, - SentenceMetadata* smeta, - const vector& weights, - Hypergraph* forest) { - Lattice lattice; - LatticeTools::ConvertTextToLattice(input, &lattice); - smeta->SetSourceLength(lattice.size()); - pimpl_->BuildTrellis(lattice, *smeta, forest); - forest->is_linear_chain_ = true; - forest->Reweight(weights); - return true; -} - diff --git a/decoder/lexcrf.h b/decoder/lexcrf.h deleted file mode 100644 index 99362c81..00000000 --- a/decoder/lexcrf.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef _LEXCRF_H_ -#define _LEXCRF_H_ - -#include "translator.h" -#include "lattice.h" - -struct LexicalCRFImpl; -struct LexicalCRF : public Translator { - LexicalCRF(const boost::program_options::variables_map& conf); - bool Translate(const std::string& input, - SentenceMetadata* smeta, - const std::vector& weights, - Hypergraph* forest); - private: - boost::shared_ptr pimpl_; -}; - -#endif diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc new file mode 100644 index 00000000..b0e03c69 --- /dev/null +++ b/decoder/lextrans.cc @@ -0,0 +1,113 @@ +#include "lexcrf.h" + +#include + +#include "filelib.h" +#include "hg.h" +#include "tdict.h" +#include "grammar.h" +#include "sentence_metadata.h" + +using namespace std; + +struct LexicalCRFImpl { + LexicalCRFImpl(const boost::program_options::variables_map& conf) : + use_null(conf.count("lexcrf_use_null") > 0), + kXCAT(TD::Convert("X")*-1), + kNULL(TD::Convert("")), + kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")), + kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) { + vector gfiles = conf["grammar"].as >(); + assert(gfiles.size() == 1); + ReadFile rf(gfiles.front()); + TextGrammar *tg = new TextGrammar; + grammar.reset(tg); + istream* in = rf.stream(); + int lc = 0; + bool flag = false; + while(*in) { + string line; + getline(*in, line); + if (line.empty()) continue; + ++lc; + TRulePtr r(TRule::CreateRulePhrasetable(line)); + tg->AddRule(r); + if (lc % 50000 == 0) { cerr << '.'; flag = true; } + if (lc % 2000000 == 0) { cerr << " [" << lc << "]\n"; flag = false; } + } + if (flag) cerr << endl; + cerr << "Loaded " << lc << " rules\n"; + } + + void BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) { + const int e_len = smeta.GetTargetLength(); + assert(e_len > 0); + const int f_len = lattice.size(); + // hack to tell the feature function system how big the sentence pair is + const int f_start = (use_null ? -1 : 0); + int prev_node_id = -1; + for (int i = 0; i < e_len; ++i) { // for each word in the *target* + Hypergraph::Node* node = forest->AddNode(kXCAT); + const int new_node_id = node->id_; + for (int j = f_start; j < f_len; ++j) { // for each word in the source + const WordID src_sym = (j < 0 ? kNULL : lattice[j][0].label); + const GrammarIter* gi = grammar->GetRoot()->Extend(src_sym); + if (!gi) { + cerr << "No translations found for: " << TD::Convert(src_sym) << "\n"; + abort(); + } + const RuleBin* rb = gi->GetRules(); + assert(rb); + for (int k = 0; k < rb->GetNumRules(); ++k) { + TRulePtr rule = rb->GetIthRule(k); + Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); + edge->i_ = j; + edge->j_ = j+1; + edge->prev_i_ = i; + edge->prev_j_ = i+1; + edge->feature_values_ += edge->rule_->GetFeatureValues(); + forest->ConnectEdgeToHeadNode(edge->id_, new_node_id); + } + } + if (prev_node_id >= 0) { + const int comb_node_id = forest->AddNode(kXCAT)->id_; + Hypergraph::TailNodeVector tail(2, prev_node_id); + tail[1] = new_node_id; + Hypergraph::Edge* edge = forest->AddEdge(kBINARY, tail); + forest->ConnectEdgeToHeadNode(edge->id_, comb_node_id); + prev_node_id = comb_node_id; + } else { + prev_node_id = new_node_id; + } + } + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + } + + private: + const bool use_null; + const WordID kXCAT; + const WordID kNULL; + const TRulePtr kBINARY; + const TRulePtr kGOAL_RULE; + GrammarPtr grammar; +}; + +LexicalCRF::LexicalCRF(const boost::program_options::variables_map& conf) : + pimpl_(new LexicalCRFImpl(conf)) {} + +bool LexicalCRF::Translate(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* forest) { + Lattice lattice; + LatticeTools::ConvertTextToLattice(input, &lattice); + smeta->SetSourceLength(lattice.size()); + pimpl_->BuildTrellis(lattice, *smeta, forest); + forest->is_linear_chain_ = true; + forest->Reweight(weights); + return true; +} + diff --git a/decoder/lextrans.h b/decoder/lextrans.h new file mode 100644 index 00000000..99362c81 --- /dev/null +++ b/decoder/lextrans.h @@ -0,0 +1,18 @@ +#ifndef _LEXCRF_H_ +#define _LEXCRF_H_ + +#include "translator.h" +#include "lattice.h" + +struct LexicalCRFImpl; +struct LexicalCRF : public Translator { + LexicalCRF(const boost::program_options::variables_map& conf); + bool Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* forest); + private: + boost::shared_ptr pimpl_; +}; + +#endif -- cgit v1.2.3