From 0172721855098ca02b207231a654dffa5e4eb1c9 Mon Sep 17 00:00:00 2001 From: redpony Date: Tue, 22 Jun 2010 05:12:27 +0000 Subject: initial checkin git-svn-id: https://ws10smt.googlecode.com/svn/trunk@2 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/fst_translator.cc | 91 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 decoder/fst_translator.cc (limited to 'decoder/fst_translator.cc') diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc new file mode 100644 index 00000000..38dbd717 --- /dev/null +++ b/decoder/fst_translator.cc @@ -0,0 +1,91 @@ +#include "translator.h" + +#include +#include + +#include "sentence_metadata.h" +#include "filelib.h" +#include "hg.h" +#include "hg_io.h" +#include "earley_composer.h" +#include "phrasetable_fst.h" +#include "tdict.h" + +using namespace std; + +struct FSTTranslatorImpl { + FSTTranslatorImpl(const boost::program_options::variables_map& conf) : + goal_sym(conf["goal"].as()), + kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")), + kGOAL(TD::Convert("Goal") * -1), + add_pass_through_rules(conf.count("add_pass_through_rules")) { + fst.reset(LoadTextPhrasetable(conf["grammar"].as >())); + ec.reset(new EarleyComposer(fst.get())); + } + + bool Translate(const string& input, + const vector& weights, + Hypergraph* forest) { + bool composed = false; + if (input.find("{\"rules\"") == 0) { + istringstream is(input); + Hypergraph src_cfg_hg; + assert(HypergraphIO::ReadFromJSON(&is, &src_cfg_hg)); + if (add_pass_through_rules) { + SparseVector feats; + feats.set_value(FD::Convert("PassThrough"), 1); + for (int i = 0; i < src_cfg_hg.edges_.size(); ++i) { + const vector& f = src_cfg_hg.edges_[i].rule_->f_; + for (int j = 0; j < f.size(); ++j) { + if (f[j] > 0) { + fst->AddPassThroughTranslation(f[j], feats); + } + } + } + } + composed = ec->Compose(src_cfg_hg, forest); + } else { + const string dummy_grammar("[" + goal_sym + "] ||| " + input + " ||| TOP=1"); + cerr << " Dummy grammar: " << dummy_grammar << endl; + istringstream is(dummy_grammar); + if (add_pass_through_rules) { + vector words; + TD::ConvertSentence(input, &words); + SparseVector feats; + feats.set_value(FD::Convert("PassThrough"), 1); + for (int i = 0; i < words.size(); ++i) + fst->AddPassThroughTranslation(words[i], feats); + } + composed = ec->Compose(&is, forest); + } + if (composed) { + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + forest->Reweight(weights); + } + if (add_pass_through_rules) + fst->ClearPassThroughTranslations(); + return composed; + } + + const string goal_sym; + const TRulePtr kGOAL_RULE; + const WordID kGOAL; + const bool add_pass_through_rules; + boost::shared_ptr ec; + boost::shared_ptr fst; +}; + +FSTTranslator::FSTTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new FSTTranslatorImpl(conf)) {} + +bool FSTTranslator::TranslateImpl(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* minus_lm_forest) { + smeta->SetSourceLength(0); // don't know how to compute this + return pimpl_->Translate(input, weights, minus_lm_forest); +} + -- cgit v1.2.3