summaryrefslogtreecommitdiff
path: root/decoder/fst_translator.cc
blob: 074de4c99e64fe308375dcac3858483dd612c8f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include "translator.h"

#include <sstream>
#include <boost/shared_ptr.hpp>

#include "sentence_metadata.h"
#include "filelib.h"
#include "hg.h"
#include "hg_io.h"
#include "earley_composer.h"
#include "phrasetable_fst.h"
#include "tdict.h"

using namespace std;

struct FSTTranslatorImpl {
  FSTTranslatorImpl(const boost::program_options::variables_map& conf) :
      goal_sym(conf["goal"].as<string>()),
      kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")),
      kGOAL(TD::Convert("Goal") * -1),
      add_pass_through_rules(conf.count("add_pass_through_rules")) {
    fst.reset(LoadTextPhrasetable(conf["grammar"].as<vector<string> >()));
    ec.reset(new EarleyComposer(fst.get()));
  }

  bool Translate(const string& input,
                 const vector<double>& weights,
                 Hypergraph* forest) {
    bool composed = false;
    if (input.find("{\"rules\"") == 0) {
      istringstream is(input);
      Hypergraph src_cfg_hg;
      if (!HypergraphIO::ReadFromJSON(&is, &src_cfg_hg)) {
        cerr << "Failed to read HG from JSON.\n";
        abort();
      }
      if (add_pass_through_rules) {
        SparseVector<double> feats;
        feats.set_value(FD::Convert("PassThrough"), 1);
        for (int i = 0; i < src_cfg_hg.edges_.size(); ++i) {
          const vector<WordID>& f = src_cfg_hg.edges_[i].rule_->f_;
          for (int j = 0; j < f.size(); ++j) {
            if (f[j] > 0) {
              fst->AddPassThroughTranslation(f[j], feats);
            }
          }
        }
      }
      composed = ec->Compose(src_cfg_hg, forest);
    } else {
      const string dummy_grammar("[" + goal_sym + "] ||| " + input + " ||| TOP=1");
      cerr << "  Dummy grammar: " << dummy_grammar << endl;
      istringstream is(dummy_grammar);
      if (add_pass_through_rules) {
        vector<WordID> words;
        TD::ConvertSentence(input, &words);
        SparseVector<double> feats;
        feats.set_value(FD::Convert("PassThrough"), 1);
        for (int i = 0; i < words.size(); ++i)
          fst->AddPassThroughTranslation(words[i], feats);
      }
      composed = ec->Compose(&is, forest);
    }
    if (composed) {
      Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
      Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1);
      Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
      forest->ConnectEdgeToHeadNode(hg_edge, goal);
      forest->Reweight(weights);
    }
    if (add_pass_through_rules)
      fst->ClearPassThroughTranslations();
    return composed;
  }

  const string goal_sym;
  const TRulePtr kGOAL_RULE;
  const WordID kGOAL;
  const bool add_pass_through_rules;
  boost::shared_ptr<EarleyComposer> ec;
  boost::shared_ptr<FSTNode> fst;
};

FSTTranslator::FSTTranslator(const boost::program_options::variables_map& conf) :
  pimpl_(new FSTTranslatorImpl(conf)) {}

bool FSTTranslator::TranslateImpl(const string& input,
                              SentenceMetadata* smeta,
                              const vector<double>& weights,
                              Hypergraph* minus_lm_forest) {
  smeta->SetSourceLength(0);  // don't know how to compute this
  return pimpl_->Translate(input, weights, minus_lm_forest);
}