summaryrefslogtreecommitdiff
path: root/src/phrasebased_translator.cc
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2009-12-03 16:33:55 -0500
committerChris Dyer <redpony@gmail.com>2009-12-03 16:33:55 -0500
commit671c21451542e2dd20e45b4033d44d8e8735f87b (patch)
treeb1773b077dd65b826f067a423d26f7942ce4e043 /src/phrasebased_translator.cc
initial check in
Diffstat (limited to 'src/phrasebased_translator.cc')
-rw-r--r--src/phrasebased_translator.cc206
1 files changed, 206 insertions, 0 deletions
diff --git a/src/phrasebased_translator.cc b/src/phrasebased_translator.cc
new file mode 100644
index 00000000..5eb70876
--- /dev/null
+++ b/src/phrasebased_translator.cc
@@ -0,0 +1,206 @@
+#include "phrasebased_translator.h"
+
+#include <queue>
+#include <iostream>
+#include <tr1/unordered_map>
+#include <tr1/unordered_set>
+
+#include <boost/tuple/tuple.hpp>
+#include <boost/functional/hash.hpp>
+
+#include "sentence_metadata.h"
+#include "tdict.h"
+#include "hg.h"
+#include "filelib.h"
+#include "lattice.h"
+#include "phrasetable_fst.h"
+#include "array2d.h"
+
+using namespace std;
+using namespace std::tr1;
+using namespace boost::tuples;
+
+struct Coverage : public vector<bool> {
+ explicit Coverage(int n, bool v = false) : vector<bool>(n, v), first_gap() {}
+ void Cover(int i, int j) {
+ vector<bool>::iterator it = this->begin() + i;
+ vector<bool>::iterator end = this->begin() + j;
+ while (it != end)
+ *it++ = true;
+ if (first_gap == i) {
+ first_gap = j;
+ it = end;
+ while (*it && it != this->end()) {
+ ++it;
+ ++first_gap;
+ }
+ }
+ }
+ bool Collides(int i, int j) const {
+ vector<bool>::const_iterator it = this->begin() + i;
+ vector<bool>::const_iterator end = this->begin() + j;
+ while (it != end)
+ if (*it++) return true;
+ return false;
+ }
+ int GetFirstGap() const { return first_gap; }
+ private:
+ int first_gap;
+};
+struct CoverageHash {
+ size_t operator()(const Coverage& cov) const {
+ return hasher_(static_cast<const vector<bool>&>(cov));
+ }
+ private:
+ boost::hash<vector<bool> > hasher_;
+};
+ostream& operator<<(ostream& os, const Coverage& cov) {
+ os << '[';
+ for (int i = 0; i < cov.size(); ++i)
+ os << (cov[i] ? '*' : '.');
+ return os << " gap=" << cov.GetFirstGap() << ']';
+}
+
+typedef unordered_map<Coverage, int, CoverageHash> CoverageNodeMap;
+typedef unordered_set<Coverage, CoverageHash> UniqueCoverageSet;
+
+struct PhraseBasedTranslatorImpl {
+ PhraseBasedTranslatorImpl(const boost::program_options::variables_map& conf) :
+ add_pass_through_rules(conf.count("add_pass_through_rules")),
+ max_distortion(conf["pb_max_distortion"].as<int>()),
+ kSOURCE_RULE(new TRule("[X] ||| [X,1] ||| [X,1]", true)),
+ kCONCAT_RULE(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]", true)),
+ kNT_TYPE(TD::Convert("X") * -1) {
+ assert(max_distortion >= 0);
+ vector<string> gfiles = conf["grammar"].as<vector<string> >();
+ assert(gfiles.size() == 1);
+ cerr << "Reading phrasetable from " << gfiles.front() << endl;
+ ReadFile in(gfiles.front());
+ fst.reset(LoadTextPhrasetable(in.stream()));
+ }
+
+ struct State {
+ State(const Coverage& c, int _i, int _j, const FSTNode* q) :
+ coverage(c), i(_i), j(_j), fst(q) {}
+ Coverage coverage;
+ int i;
+ int j;
+ const FSTNode* fst;
+ };
+
+ // we keep track of unique coverages that have been extended since it's
+ // possible to "extend" the same coverage twice, e.g. translate "a b c"
+ // with phrases "a" "b" "a b" and "c". There are two ways to cover "a b"
+ void EnqueuePossibleContinuations(const Coverage& coverage, queue<State>* q, UniqueCoverageSet* ucs) {
+ if (ucs->insert(coverage).second) {
+ const int gap = coverage.GetFirstGap();
+ const int end = min(static_cast<int>(coverage.size()), gap + max_distortion + 1);
+ for (int i = gap; i < end; ++i)
+ if (!coverage[i]) q->push(State(coverage, i, i, fst.get()));
+ }
+ }
+
+ bool Translate(const std::string& input,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ Lattice lattice;
+ LatticeTools::ConvertTextOrPLF(input, &lattice);
+ smeta->SetSourceLength(lattice.size());
+ size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion);
+ minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100);
+ if (add_pass_through_rules) {
+ SparseVector<double> feats;
+ feats.set_value(FD::Convert("PassThrough"), 1);
+ for (int i = 0; i < lattice.size(); ++i) {
+ const vector<LatticeArc>& arcs = lattice[i];
+ for (int j = 0; j < arcs.size(); ++j) {
+ fst->AddPassThroughTranslation(arcs[j].label, feats);
+ // TODO handle lattice edge features
+ }
+ }
+ }
+ CoverageNodeMap c;
+ queue<State> q;
+ UniqueCoverageSet ucs;
+ const Coverage empty_cov(lattice.size(), false);
+ const Coverage goal_cov(lattice.size(), true);
+ EnqueuePossibleContinuations(empty_cov, &q, &ucs);
+ c[empty_cov] = 0; // have to handle the left edge specially
+ while(!q.empty()) {
+ const State s = q.front();
+ q.pop();
+ // cerr << "(" << s.i << "," << s.j << " ptr=" << s.fst << ") cov=" << s.coverage << endl;
+ const vector<LatticeArc>& arcs = lattice[s.j];
+ if (s.fst->HasData()) {
+ Coverage new_cov = s.coverage;
+ new_cov.Cover(s.i, s.j);
+ EnqueuePossibleContinuations(new_cov, &q, &ucs);
+ const vector<TRulePtr>& phrases = s.fst->GetTranslations()->GetRules();
+ const int phrase_head_index = minus_lm_forest->AddNode(kNT_TYPE)->id_;
+ for (int i = 0; i < phrases.size(); ++i) {
+ Hypergraph::Edge* edge = minus_lm_forest->AddEdge(phrases[i], Hypergraph::TailNodeVector());
+ edge->feature_values_ = edge->rule_->scores_;
+ minus_lm_forest->ConnectEdgeToHeadNode(edge->id_, phrase_head_index);
+ }
+ CoverageNodeMap::iterator cit = c.find(s.coverage);
+ assert(cit != c.end());
+ const int tail_node_plus1 = cit->second;
+ if (tail_node_plus1 == 0) { // left edge
+ c[new_cov] = phrase_head_index + 1;
+ } else { // not left edge
+ int& head_node_plus1 = c[new_cov];
+ if (!head_node_plus1)
+ head_node_plus1 = minus_lm_forest->AddNode(kNT_TYPE)->id_ + 1;
+ Hypergraph::TailNodeVector tail(2, tail_node_plus1 - 1);
+ tail[1] = phrase_head_index;
+ const int concat_edge = minus_lm_forest->AddEdge(kCONCAT_RULE, tail)->id_;
+ minus_lm_forest->ConnectEdgeToHeadNode(concat_edge, head_node_plus1 - 1);
+ }
+ }
+ if (s.j == lattice.size()) continue;
+ for (int l = 0; l < arcs.size(); ++l) {
+ const LatticeArc& arc = arcs[l];
+
+ const FSTNode* next_fst_state = s.fst->Extend(arc.label);
+ const int next_j = s.j + arc.dist2next;
+ if (next_fst_state &&
+ !s.coverage.Collides(s.i, next_j)) {
+ q.push(State(s.coverage, s.i, next_j, next_fst_state));
+ }
+ }
+ }
+ if (add_pass_through_rules)
+ fst->ClearPassThroughTranslations();
+ int pregoal_plus1 = c[goal_cov];
+ if (pregoal_plus1 > 0) {
+ TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [X,1]"));
+ int goal = minus_lm_forest->AddNode(TD::Convert("Goal") * -1)->id_;
+ int gedge = minus_lm_forest->AddEdge(kGOAL_RULE, Hypergraph::TailNodeVector(1, pregoal_plus1 - 1))->id_;
+ minus_lm_forest->ConnectEdgeToHeadNode(gedge, goal);
+ // they are almost topo, but not quite always
+ minus_lm_forest->TopologicallySortNodesAndEdges(goal);
+ minus_lm_forest->Reweight(weights);
+ return true;
+ } else {
+ return false; // composition failed
+ }
+ }
+
+ const bool add_pass_through_rules;
+ const int max_distortion;
+ TRulePtr kSOURCE_RULE;
+ const TRulePtr kCONCAT_RULE;
+ const WordID kNT_TYPE;
+ boost::shared_ptr<FSTNode> fst;
+};
+
+PhraseBasedTranslator::PhraseBasedTranslator(const boost::program_options::variables_map& conf) :
+ pimpl_(new PhraseBasedTranslatorImpl(conf)) {}
+
+bool PhraseBasedTranslator::Translate(const std::string& input,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ return pimpl_->Translate(input, smeta, weights, minus_lm_forest);
+}