diff options
-rw-r--r-- | decoder/tree2string_translator.cc | 46 |
1 files changed, 39 insertions, 7 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 7b37887e..b5b47d5d 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -5,6 +5,7 @@ #include <unordered_set> #include <boost/shared_ptr.hpp> #include <boost/functional/hash.hpp> +#include "fast_lexical_cast.hpp" #include "tree_fragment.h" #include "translator.h" #include "hg.h" @@ -23,7 +24,7 @@ struct Tree2StringGrammarNode { // this needs to be rewritten so it is fast and checks errors well // use a lexer probably -void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { +static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { string line; while(getline(*in, line)) { size_t pos = line.find("|||"); @@ -142,10 +143,12 @@ void AddDummyGoalNode(Hypergraph* hg) { struct Tree2StringTranslatorImpl { vector<boost::shared_ptr<Tree2StringGrammarNode>> root; bool add_pass_through_rules; + bool has_multiple_states; unsigned remove_grammars; Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf, bool has_multiple_states) : - add_pass_through_rules(conf.count("add_pass_through_rules")) { + add_pass_through_rules(conf.count("add_pass_through_rules")), + has_multiple_states(has_multiple_states) { if (conf.count("grammar")) { const vector<string> gf = conf["grammar"].as<vector<string>>(); root.resize(gf.size()); @@ -158,6 +161,15 @@ struct Tree2StringTranslatorImpl { } } + // loads a per-sentence grammar + void LoadSupplementalGrammar(const string& gfile) { + root.resize(root.size() + 1); + root.back().reset(new Tree2StringGrammarNode); + ++remove_grammars; + ReadFile rf(gfile); + ReadTree2StringGrammar(rf.stream(), root.back().get(), has_multiple_states); + } + void CreatePassThroughRules(const cdec::TreeFragment& tree) { static const int kFIDlex = FD::Convert("PassThrough_Lexical"); static const int kFIDabs = FD::Convert("PassThrough_Abstract"); @@ -227,7 +239,7 @@ struct Tree2StringTranslatorImpl { } void RemoveGrammars() { - assert(remove_grammars < root.size()); + assert(remove_grammars <= root.size()); root.resize(root.size() - remove_grammars); } @@ -235,7 +247,6 @@ struct Tree2StringTranslatorImpl { SentenceMetadata* smeta, const vector<double>& weights, Hypergraph* minus_lm_forest) { - remove_grammars = 0; cdec::TreeFragment input_tree(input, false); if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; @@ -371,6 +382,30 @@ Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::varia bool has_multiple_states) : pimpl_(new Tree2StringTranslatorImpl(conf, has_multiple_states)) {} +void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) { + pimpl_->remove_grammars = 0; + if (kv.find("grammar0") != kv.end()) { + cerr << "SGML tag grammar0 is not expected (order is: grammar, grammar1, grammar2, ...)\n"; + abort(); + } + unsigned gc = 0; + set<string> loaded; + while(true) { + string gkey = "grammar"; + if (gc > 0) gkey += boost::lexical_cast<string>(gc); + ++gc; + map<string,string>::const_iterator it = kv.find(gkey); + if (it == kv.end()) break; + const string& gfile = it->second; + if (loaded.count(gfile) == 1) { + cerr << "Attempting to load " << gfile << " twice!\n"; + abort(); + } + loaded.insert(gfile); + pimpl_->LoadSupplementalGrammar(gfile); + } +} + bool Tree2StringTranslator::TranslateImpl(const string& input, SentenceMetadata* smeta, const vector<double>& weights, @@ -378,9 +413,6 @@ bool Tree2StringTranslator::TranslateImpl(const string& input, return pimpl_->Translate(input, smeta, weights, minus_lm_forest); } -void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) { -} - void Tree2StringTranslator::SentenceCompleteImpl() { pimpl_->RemoveGrammars(); } |