diff options
Diffstat (limited to 'decoder/rescore_translator.cc')
-rw-r--r-- | decoder/rescore_translator.cc | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc new file mode 100644 index 00000000..5c417393 --- /dev/null +++ b/decoder/rescore_translator.cc @@ -0,0 +1,57 @@ +#include "translator.h" + +#include <sstream> +#include <boost/shared_ptr.hpp> + +#include "sentence_metadata.h" +#include "hg.h" +#include "hg_io.h" +#include "tdict.h" + +using namespace std; + +struct RescoreTranslatorImpl { + RescoreTranslatorImpl(const boost::program_options::variables_map& conf) : + goal_sym(conf["goal"].as<string>()), + kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")), + kGOAL(TD::Convert("Goal") * -1) { + } + + bool Translate(const string& input, + const vector<double>& weights, + Hypergraph* forest) { + if (input.find("{\"rules\"") == 0) { + istringstream is(input); + Hypergraph src_cfg_hg; + if (!HypergraphIO::ReadFromJSON(&is, forest)) { + cerr << "Parse error while reading HG from JSON.\n"; + abort(); + } + } else { + cerr << "Can only read HG input from JSON: use training/grammar_convert\n"; + abort(); + } + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(kGOAL); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + forest->Reweight(weights); + return true; + } + + const string goal_sym; + const TRulePtr kGOAL_RULE; + const WordID kGOAL; +}; + +RescoreTranslator::RescoreTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new RescoreTranslatorImpl(conf)) {} + +bool RescoreTranslator::TranslateImpl(const string& input, + SentenceMetadata* smeta, + const vector<double>& weights, + Hypergraph* minus_lm_forest) { + smeta->SetSourceLength(0); // don't know how to compute this + return pimpl_->Translate(input, weights, minus_lm_forest); +} + |