summaryrefslogtreecommitdiff
path: root/src/phrasebased_translator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/phrasebased_translator.cc')
-rw-r--r--src/phrasebased_translator.cc206
1 files changed, 0 insertions, 206 deletions
diff --git a/src/phrasebased_translator.cc b/src/phrasebased_translator.cc
deleted file mode 100644
index 5eb70876..00000000
--- a/src/phrasebased_translator.cc
+++ /dev/null
@@ -1,206 +0,0 @@
-#include "phrasebased_translator.h"
-
-#include <queue>
-#include <iostream>
-#include <tr1/unordered_map>
-#include <tr1/unordered_set>
-
-#include <boost/tuple/tuple.hpp>
-#include <boost/functional/hash.hpp>
-
-#include "sentence_metadata.h"
-#include "tdict.h"
-#include "hg.h"
-#include "filelib.h"
-#include "lattice.h"
-#include "phrasetable_fst.h"
-#include "array2d.h"
-
-using namespace std;
-using namespace std::tr1;
-using namespace boost::tuples;
-
-struct Coverage : public vector<bool> {
- explicit Coverage(int n, bool v = false) : vector<bool>(n, v), first_gap() {}
- void Cover(int i, int j) {
- vector<bool>::iterator it = this->begin() + i;
- vector<bool>::iterator end = this->begin() + j;
- while (it != end)
- *it++ = true;
- if (first_gap == i) {
- first_gap = j;
- it = end;
- while (*it && it != this->end()) {
- ++it;
- ++first_gap;
- }
- }
- }
- bool Collides(int i, int j) const {
- vector<bool>::const_iterator it = this->begin() + i;
- vector<bool>::const_iterator end = this->begin() + j;
- while (it != end)
- if (*it++) return true;
- return false;
- }
- int GetFirstGap() const { return first_gap; }
- private:
- int first_gap;
-};
-struct CoverageHash {
- size_t operator()(const Coverage& cov) const {
- return hasher_(static_cast<const vector<bool>&>(cov));
- }
- private:
- boost::hash<vector<bool> > hasher_;
-};
-ostream& operator<<(ostream& os, const Coverage& cov) {
- os << '[';
- for (int i = 0; i < cov.size(); ++i)
- os << (cov[i] ? '*' : '.');
- return os << " gap=" << cov.GetFirstGap() << ']';
-}
-
-typedef unordered_map<Coverage, int, CoverageHash> CoverageNodeMap;
-typedef unordered_set<Coverage, CoverageHash> UniqueCoverageSet;
-
-struct PhraseBasedTranslatorImpl {
- PhraseBasedTranslatorImpl(const boost::program_options::variables_map& conf) :
- add_pass_through_rules(conf.count("add_pass_through_rules")),
- max_distortion(conf["pb_max_distortion"].as<int>()),
- kSOURCE_RULE(new TRule("[X] ||| [X,1] ||| [X,1]", true)),
- kCONCAT_RULE(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]", true)),
- kNT_TYPE(TD::Convert("X") * -1) {
- assert(max_distortion >= 0);
- vector<string> gfiles = conf["grammar"].as<vector<string> >();
- assert(gfiles.size() == 1);
- cerr << "Reading phrasetable from " << gfiles.front() << endl;
- ReadFile in(gfiles.front());
- fst.reset(LoadTextPhrasetable(in.stream()));
- }
-
- struct State {
- State(const Coverage& c, int _i, int _j, const FSTNode* q) :
- coverage(c), i(_i), j(_j), fst(q) {}
- Coverage coverage;
- int i;
- int j;
- const FSTNode* fst;
- };
-
- // we keep track of unique coverages that have been extended since it's
- // possible to "extend" the same coverage twice, e.g. translate "a b c"
- // with phrases "a" "b" "a b" and "c". There are two ways to cover "a b"
- void EnqueuePossibleContinuations(const Coverage& coverage, queue<State>* q, UniqueCoverageSet* ucs) {
- if (ucs->insert(coverage).second) {
- const int gap = coverage.GetFirstGap();
- const int end = min(static_cast<int>(coverage.size()), gap + max_distortion + 1);
- for (int i = gap; i < end; ++i)
- if (!coverage[i]) q->push(State(coverage, i, i, fst.get()));
- }
- }
-
- bool Translate(const std::string& input,
- SentenceMetadata* smeta,
- const std::vector<double>& weights,
- Hypergraph* minus_lm_forest) {
- Lattice lattice;
- LatticeTools::ConvertTextOrPLF(input, &lattice);
- smeta->SetSourceLength(lattice.size());
- size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion);
- minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100);
- if (add_pass_through_rules) {
- SparseVector<double> feats;
- feats.set_value(FD::Convert("PassThrough"), 1);
- for (int i = 0; i < lattice.size(); ++i) {
- const vector<LatticeArc>& arcs = lattice[i];
- for (int j = 0; j < arcs.size(); ++j) {
- fst->AddPassThroughTranslation(arcs[j].label, feats);
- // TODO handle lattice edge features
- }
- }
- }
- CoverageNodeMap c;
- queue<State> q;
- UniqueCoverageSet ucs;
- const Coverage empty_cov(lattice.size(), false);
- const Coverage goal_cov(lattice.size(), true);
- EnqueuePossibleContinuations(empty_cov, &q, &ucs);
- c[empty_cov] = 0; // have to handle the left edge specially
- while(!q.empty()) {
- const State s = q.front();
- q.pop();
- // cerr << "(" << s.i << "," << s.j << " ptr=" << s.fst << ") cov=" << s.coverage << endl;
- const vector<LatticeArc>& arcs = lattice[s.j];
- if (s.fst->HasData()) {
- Coverage new_cov = s.coverage;
- new_cov.Cover(s.i, s.j);
- EnqueuePossibleContinuations(new_cov, &q, &ucs);
- const vector<TRulePtr>& phrases = s.fst->GetTranslations()->GetRules();
- const int phrase_head_index = minus_lm_forest->AddNode(kNT_TYPE)->id_;
- for (int i = 0; i < phrases.size(); ++i) {
- Hypergraph::Edge* edge = minus_lm_forest->AddEdge(phrases[i], Hypergraph::TailNodeVector());
- edge->feature_values_ = edge->rule_->scores_;
- minus_lm_forest->ConnectEdgeToHeadNode(edge->id_, phrase_head_index);
- }
- CoverageNodeMap::iterator cit = c.find(s.coverage);
- assert(cit != c.end());
- const int tail_node_plus1 = cit->second;
- if (tail_node_plus1 == 0) { // left edge
- c[new_cov] = phrase_head_index + 1;
- } else { // not left edge
- int& head_node_plus1 = c[new_cov];
- if (!head_node_plus1)
- head_node_plus1 = minus_lm_forest->AddNode(kNT_TYPE)->id_ + 1;
- Hypergraph::TailNodeVector tail(2, tail_node_plus1 - 1);
- tail[1] = phrase_head_index;
- const int concat_edge = minus_lm_forest->AddEdge(kCONCAT_RULE, tail)->id_;
- minus_lm_forest->ConnectEdgeToHeadNode(concat_edge, head_node_plus1 - 1);
- }
- }
- if (s.j == lattice.size()) continue;
- for (int l = 0; l < arcs.size(); ++l) {
- const LatticeArc& arc = arcs[l];
-
- const FSTNode* next_fst_state = s.fst->Extend(arc.label);
- const int next_j = s.j + arc.dist2next;
- if (next_fst_state &&
- !s.coverage.Collides(s.i, next_j)) {
- q.push(State(s.coverage, s.i, next_j, next_fst_state));
- }
- }
- }
- if (add_pass_through_rules)
- fst->ClearPassThroughTranslations();
- int pregoal_plus1 = c[goal_cov];
- if (pregoal_plus1 > 0) {
- TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [X,1]"));
- int goal = minus_lm_forest->AddNode(TD::Convert("Goal") * -1)->id_;
- int gedge = minus_lm_forest->AddEdge(kGOAL_RULE, Hypergraph::TailNodeVector(1, pregoal_plus1 - 1))->id_;
- minus_lm_forest->ConnectEdgeToHeadNode(gedge, goal);
- // they are almost topo, but not quite always
- minus_lm_forest->TopologicallySortNodesAndEdges(goal);
- minus_lm_forest->Reweight(weights);
- return true;
- } else {
- return false; // composition failed
- }
- }
-
- const bool add_pass_through_rules;
- const int max_distortion;
- TRulePtr kSOURCE_RULE;
- const TRulePtr kCONCAT_RULE;
- const WordID kNT_TYPE;
- boost::shared_ptr<FSTNode> fst;
-};
-
-PhraseBasedTranslator::PhraseBasedTranslator(const boost::program_options::variables_map& conf) :
- pimpl_(new PhraseBasedTranslatorImpl(conf)) {}
-
-bool PhraseBasedTranslator::Translate(const std::string& input,
- SentenceMetadata* smeta,
- const std::vector<double>& weights,
- Hypergraph* minus_lm_forest) {
- return pimpl_->Translate(input, smeta, weights, minus_lm_forest);
-}