diff options
Diffstat (limited to 'decoder/fst_translator.cc')
-rw-r--r-- | decoder/fst_translator.cc | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc new file mode 100644 index 00000000..38dbd717 --- /dev/null +++ b/decoder/fst_translator.cc @@ -0,0 +1,91 @@ +#include "translator.h" + +#include <sstream> +#include <boost/shared_ptr.hpp> + +#include "sentence_metadata.h" +#include "filelib.h" +#include "hg.h" +#include "hg_io.h" +#include "earley_composer.h" +#include "phrasetable_fst.h" +#include "tdict.h" + +using namespace std; + +struct FSTTranslatorImpl { + FSTTranslatorImpl(const boost::program_options::variables_map& conf) : + goal_sym(conf["goal"].as<string>()), + kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")), + kGOAL(TD::Convert("Goal") * -1), + add_pass_through_rules(conf.count("add_pass_through_rules")) { + fst.reset(LoadTextPhrasetable(conf["grammar"].as<vector<string> >())); + ec.reset(new EarleyComposer(fst.get())); + } + + bool Translate(const string& input, + const vector<double>& weights, + Hypergraph* forest) { + bool composed = false; + if (input.find("{\"rules\"") == 0) { + istringstream is(input); + Hypergraph src_cfg_hg; + assert(HypergraphIO::ReadFromJSON(&is, &src_cfg_hg)); + if (add_pass_through_rules) { + SparseVector<double> feats; + feats.set_value(FD::Convert("PassThrough"), 1); + for (int i = 0; i < src_cfg_hg.edges_.size(); ++i) { + const vector<WordID>& f = src_cfg_hg.edges_[i].rule_->f_; + for (int j = 0; j < f.size(); ++j) { + if (f[j] > 0) { + fst->AddPassThroughTranslation(f[j], feats); + } + } + } + } + composed = ec->Compose(src_cfg_hg, forest); + } else { + const string dummy_grammar("[" + goal_sym + "] ||| " + input + " ||| TOP=1"); + cerr << " Dummy grammar: " << dummy_grammar << endl; + istringstream is(dummy_grammar); + if (add_pass_through_rules) { + vector<WordID> words; + TD::ConvertSentence(input, &words); + SparseVector<double> feats; + feats.set_value(FD::Convert("PassThrough"), 1); + for (int i = 0; i < words.size(); ++i) + fst->AddPassThroughTranslation(words[i], feats); + } + composed = ec->Compose(&is, forest); + } + if (composed) { + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + forest->Reweight(weights); + } + if (add_pass_through_rules) + fst->ClearPassThroughTranslations(); + return composed; + } + + const string goal_sym; + const TRulePtr kGOAL_RULE; + const WordID kGOAL; + const bool add_pass_through_rules; + boost::shared_ptr<EarleyComposer> ec; + boost::shared_ptr<FSTNode> fst; +}; + +FSTTranslator::FSTTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new FSTTranslatorImpl(conf)) {} + +bool FSTTranslator::TranslateImpl(const string& input, + SentenceMetadata* smeta, + const vector<double>& weights, + Hypergraph* minus_lm_forest) { + smeta->SetSourceLength(0); // don't know how to compute this + return pimpl_->Translate(input, weights, minus_lm_forest); +} + |