diff options
Diffstat (limited to 'decoder/phrasebased_translator.cc')
-rw-r--r-- | decoder/phrasebased_translator.cc | 206 |
1 files changed, 206 insertions, 0 deletions
diff --git a/decoder/phrasebased_translator.cc b/decoder/phrasebased_translator.cc new file mode 100644 index 00000000..726b3f9a --- /dev/null +++ b/decoder/phrasebased_translator.cc @@ -0,0 +1,206 @@ +#include "phrasebased_translator.h" + +#include <queue> +#include <iostream> +#include <tr1/unordered_map> +#include <tr1/unordered_set> + +#include <boost/tuple/tuple.hpp> +#include <boost/functional/hash.hpp> + +#include "sentence_metadata.h" +#include "tdict.h" +#include "hg.h" +#include "filelib.h" +#include "lattice.h" +#include "phrasetable_fst.h" +#include "array2d.h" + +using namespace std; +using namespace std::tr1; +using namespace boost::tuples; + +struct Coverage : public vector<bool> { + explicit Coverage(int n, bool v = false) : vector<bool>(n, v), first_gap() {} + void Cover(int i, int j) { + vector<bool>::iterator it = this->begin() + i; + vector<bool>::iterator end = this->begin() + j; + while (it != end) + *it++ = true; + if (first_gap == i) { + first_gap = j; + it = end; + while (*it && it != this->end()) { + ++it; + ++first_gap; + } + } + } + bool Collides(int i, int j) const { + vector<bool>::const_iterator it = this->begin() + i; + vector<bool>::const_iterator end = this->begin() + j; + while (it != end) + if (*it++) return true; + return false; + } + int GetFirstGap() const { return first_gap; } + private: + int first_gap; +}; +struct CoverageHash { + size_t operator()(const Coverage& cov) const { + return hasher_(static_cast<const vector<bool>&>(cov)); + } + private: + boost::hash<vector<bool> > hasher_; +}; +ostream& operator<<(ostream& os, const Coverage& cov) { + os << '['; + for (int i = 0; i < cov.size(); ++i) + os << (cov[i] ? '*' : '.'); + return os << " gap=" << cov.GetFirstGap() << ']'; +} + +typedef unordered_map<Coverage, int, CoverageHash> CoverageNodeMap; +typedef unordered_set<Coverage, CoverageHash> UniqueCoverageSet; + +struct PhraseBasedTranslatorImpl { + PhraseBasedTranslatorImpl(const boost::program_options::variables_map& conf) : + add_pass_through_rules(conf.count("add_pass_through_rules")), + max_distortion(conf["pb_max_distortion"].as<int>()), + kSOURCE_RULE(new TRule("[X] ||| [X,1] ||| [X,1]", true)), + kCONCAT_RULE(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]", true)), + kNT_TYPE(TD::Convert("X") * -1) { + assert(max_distortion >= 0); + vector<string> gfiles = conf["grammar"].as<vector<string> >(); + assert(gfiles.size() == 1); + cerr << "Reading phrasetable from " << gfiles.front() << endl; + ReadFile in(gfiles.front()); + fst.reset(LoadTextPhrasetable(in.stream())); + } + + struct State { + State(const Coverage& c, int _i, int _j, const FSTNode* q) : + coverage(c), i(_i), j(_j), fst(q) {} + Coverage coverage; + int i; + int j; + const FSTNode* fst; + }; + + // we keep track of unique coverages that have been extended since it's + // possible to "extend" the same coverage twice, e.g. translate "a b c" + // with phrases "a" "b" "a b" and "c". There are two ways to cover "a b" + void EnqueuePossibleContinuations(const Coverage& coverage, queue<State>* q, UniqueCoverageSet* ucs) { + if (ucs->insert(coverage).second) { + const int gap = coverage.GetFirstGap(); + const int end = min(static_cast<int>(coverage.size()), gap + max_distortion + 1); + for (int i = gap; i < end; ++i) + if (!coverage[i]) q->push(State(coverage, i, i, fst.get())); + } + } + + bool Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector<double>& weights, + Hypergraph* minus_lm_forest) { + Lattice lattice; + LatticeTools::ConvertTextOrPLF(input, &lattice); + smeta->SetSourceLength(lattice.size()); + size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion); + minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100); + if (add_pass_through_rules) { + SparseVector<double> feats; + feats.set_value(FD::Convert("PassThrough"), 1); + for (int i = 0; i < lattice.size(); ++i) { + const vector<LatticeArc>& arcs = lattice[i]; + for (int j = 0; j < arcs.size(); ++j) { + fst->AddPassThroughTranslation(arcs[j].label, feats); + // TODO handle lattice edge features + } + } + } + CoverageNodeMap c; + queue<State> q; + UniqueCoverageSet ucs; + const Coverage empty_cov(lattice.size(), false); + const Coverage goal_cov(lattice.size(), true); + EnqueuePossibleContinuations(empty_cov, &q, &ucs); + c[empty_cov] = 0; // have to handle the left edge specially + while(!q.empty()) { + const State s = q.front(); + q.pop(); + // cerr << "(" << s.i << "," << s.j << " ptr=" << s.fst << ") cov=" << s.coverage << endl; + const vector<LatticeArc>& arcs = lattice[s.j]; + if (s.fst->HasData()) { + Coverage new_cov = s.coverage; + new_cov.Cover(s.i, s.j); + EnqueuePossibleContinuations(new_cov, &q, &ucs); + const vector<TRulePtr>& phrases = s.fst->GetTranslations()->GetRules(); + const int phrase_head_index = minus_lm_forest->AddNode(kNT_TYPE)->id_; + for (int i = 0; i < phrases.size(); ++i) { + Hypergraph::Edge* edge = minus_lm_forest->AddEdge(phrases[i], Hypergraph::TailNodeVector()); + edge->feature_values_ = edge->rule_->scores_; + minus_lm_forest->ConnectEdgeToHeadNode(edge->id_, phrase_head_index); + } + CoverageNodeMap::iterator cit = c.find(s.coverage); + assert(cit != c.end()); + const int tail_node_plus1 = cit->second; + if (tail_node_plus1 == 0) { // left edge + c[new_cov] = phrase_head_index + 1; + } else { // not left edge + int& head_node_plus1 = c[new_cov]; + if (!head_node_plus1) + head_node_plus1 = minus_lm_forest->AddNode(kNT_TYPE)->id_ + 1; + Hypergraph::TailNodeVector tail(2, tail_node_plus1 - 1); + tail[1] = phrase_head_index; + const int concat_edge = minus_lm_forest->AddEdge(kCONCAT_RULE, tail)->id_; + minus_lm_forest->ConnectEdgeToHeadNode(concat_edge, head_node_plus1 - 1); + } + } + if (s.j == lattice.size()) continue; + for (int l = 0; l < arcs.size(); ++l) { + const LatticeArc& arc = arcs[l]; + + const FSTNode* next_fst_state = s.fst->Extend(arc.label); + const int next_j = s.j + arc.dist2next; + if (next_fst_state && + !s.coverage.Collides(s.i, next_j)) { + q.push(State(s.coverage, s.i, next_j, next_fst_state)); + } + } + } + if (add_pass_through_rules) + fst->ClearPassThroughTranslations(); + int pregoal_plus1 = c[goal_cov]; + if (pregoal_plus1 > 0) { + TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [X,1]")); + int goal = minus_lm_forest->AddNode(TD::Convert("Goal") * -1)->id_; + int gedge = minus_lm_forest->AddEdge(kGOAL_RULE, Hypergraph::TailNodeVector(1, pregoal_plus1 - 1))->id_; + minus_lm_forest->ConnectEdgeToHeadNode(gedge, goal); + // they are almost topo, but not quite always + minus_lm_forest->TopologicallySortNodesAndEdges(goal); + minus_lm_forest->Reweight(weights); + return true; + } else { + return false; // composition failed + } + } + + const bool add_pass_through_rules; + const int max_distortion; + TRulePtr kSOURCE_RULE; + const TRulePtr kCONCAT_RULE; + const WordID kNT_TYPE; + boost::shared_ptr<FSTNode> fst; +}; + +PhraseBasedTranslator::PhraseBasedTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new PhraseBasedTranslatorImpl(conf)) {} + +bool PhraseBasedTranslator::TranslateImpl(const std::string& input, + SentenceMetadata* smeta, + const std::vector<double>& weights, + Hypergraph* minus_lm_forest) { + return pimpl_->Translate(input, smeta, weights, minus_lm_forest); +} |