diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/tree2string_translator.cc | 46 | 
1 files changed, 39 insertions, 7 deletions
| diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 7b37887e..b5b47d5d 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -5,6 +5,7 @@  #include <unordered_set>  #include <boost/shared_ptr.hpp>  #include <boost/functional/hash.hpp> +#include "fast_lexical_cast.hpp"  #include "tree_fragment.h"  #include "translator.h"  #include "hg.h" @@ -23,7 +24,7 @@ struct Tree2StringGrammarNode {  // this needs to be rewritten so it is fast and checks errors well  // use a lexer probably -void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { +static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) {    string line;    while(getline(*in, line)) {      size_t pos = line.find("|||"); @@ -142,10 +143,12 @@ void AddDummyGoalNode(Hypergraph* hg) {  struct Tree2StringTranslatorImpl {    vector<boost::shared_ptr<Tree2StringGrammarNode>> root;    bool add_pass_through_rules; +  bool has_multiple_states;    unsigned remove_grammars;    Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf,                              bool has_multiple_states) : -      add_pass_through_rules(conf.count("add_pass_through_rules")) { +      add_pass_through_rules(conf.count("add_pass_through_rules")), +      has_multiple_states(has_multiple_states) {      if (conf.count("grammar")) {        const vector<string> gf = conf["grammar"].as<vector<string>>();        root.resize(gf.size()); @@ -158,6 +161,15 @@ struct Tree2StringTranslatorImpl {      }    } +  // loads a per-sentence grammar +  void LoadSupplementalGrammar(const string& gfile) { +    root.resize(root.size() + 1); +    root.back().reset(new Tree2StringGrammarNode); +    ++remove_grammars; +    ReadFile rf(gfile); +    ReadTree2StringGrammar(rf.stream(), root.back().get(), has_multiple_states); +  } +    void CreatePassThroughRules(const cdec::TreeFragment& tree) {      static const int kFIDlex = FD::Convert("PassThrough_Lexical");      static const int kFIDabs = FD::Convert("PassThrough_Abstract"); @@ -227,7 +239,7 @@ struct Tree2StringTranslatorImpl {    }    void RemoveGrammars() { -    assert(remove_grammars < root.size()); +    assert(remove_grammars <= root.size());      root.resize(root.size() - remove_grammars);    } @@ -235,7 +247,6 @@ struct Tree2StringTranslatorImpl {                   SentenceMetadata* smeta,                   const vector<double>& weights,                   Hypergraph* minus_lm_forest) { -    remove_grammars = 0;      cdec::TreeFragment input_tree(input, false);      if (add_pass_through_rules) CreatePassThroughRules(input_tree);      Hypergraph hg; @@ -371,6 +382,30 @@ Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::varia                                               bool has_multiple_states) :    pimpl_(new Tree2StringTranslatorImpl(conf, has_multiple_states)) {} +void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) { +  pimpl_->remove_grammars = 0; +  if (kv.find("grammar0") != kv.end()) { +    cerr << "SGML tag grammar0 is not expected (order is: grammar, grammar1, grammar2, ...)\n"; +    abort(); +  } +  unsigned gc = 0; +  set<string> loaded; +  while(true) { +    string gkey = "grammar"; +    if (gc > 0) gkey += boost::lexical_cast<string>(gc); +    ++gc; +    map<string,string>::const_iterator it = kv.find(gkey); +    if (it == kv.end()) break; +    const string& gfile = it->second; +    if (loaded.count(gfile) == 1) { +      cerr << "Attempting to load " << gfile << " twice!\n"; +      abort(); +    } +    loaded.insert(gfile); +    pimpl_->LoadSupplementalGrammar(gfile); +  } +} +  bool Tree2StringTranslator::TranslateImpl(const string& input,                                 SentenceMetadata* smeta,                                 const vector<double>& weights, @@ -378,9 +413,6 @@ bool Tree2StringTranslator::TranslateImpl(const string& input,    return pimpl_->Translate(input, smeta, weights, minus_lm_forest);  } -void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) { -} -  void Tree2StringTranslator::SentenceCompleteImpl() {    pimpl_->RemoveGrammars();  } | 
