summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/tree2string_translator.cc46
1 files changed, 39 insertions, 7 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index 7b37887e..b5b47d5d 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -5,6 +5,7 @@
#include <unordered_set>
#include <boost/shared_ptr.hpp>
#include <boost/functional/hash.hpp>
+#include "fast_lexical_cast.hpp"
#include "tree_fragment.h"
#include "translator.h"
#include "hg.h"
@@ -23,7 +24,7 @@ struct Tree2StringGrammarNode {
// this needs to be rewritten so it is fast and checks errors well
// use a lexer probably
-void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) {
+static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) {
string line;
while(getline(*in, line)) {
size_t pos = line.find("|||");
@@ -142,10 +143,12 @@ void AddDummyGoalNode(Hypergraph* hg) {
struct Tree2StringTranslatorImpl {
vector<boost::shared_ptr<Tree2StringGrammarNode>> root;
bool add_pass_through_rules;
+ bool has_multiple_states;
unsigned remove_grammars;
Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf,
bool has_multiple_states) :
- add_pass_through_rules(conf.count("add_pass_through_rules")) {
+ add_pass_through_rules(conf.count("add_pass_through_rules")),
+ has_multiple_states(has_multiple_states) {
if (conf.count("grammar")) {
const vector<string> gf = conf["grammar"].as<vector<string>>();
root.resize(gf.size());
@@ -158,6 +161,15 @@ struct Tree2StringTranslatorImpl {
}
}
+ // loads a per-sentence grammar
+ void LoadSupplementalGrammar(const string& gfile) {
+ root.resize(root.size() + 1);
+ root.back().reset(new Tree2StringGrammarNode);
+ ++remove_grammars;
+ ReadFile rf(gfile);
+ ReadTree2StringGrammar(rf.stream(), root.back().get(), has_multiple_states);
+ }
+
void CreatePassThroughRules(const cdec::TreeFragment& tree) {
static const int kFIDlex = FD::Convert("PassThrough_Lexical");
static const int kFIDabs = FD::Convert("PassThrough_Abstract");
@@ -227,7 +239,7 @@ struct Tree2StringTranslatorImpl {
}
void RemoveGrammars() {
- assert(remove_grammars < root.size());
+ assert(remove_grammars <= root.size());
root.resize(root.size() - remove_grammars);
}
@@ -235,7 +247,6 @@ struct Tree2StringTranslatorImpl {
SentenceMetadata* smeta,
const vector<double>& weights,
Hypergraph* minus_lm_forest) {
- remove_grammars = 0;
cdec::TreeFragment input_tree(input, false);
if (add_pass_through_rules) CreatePassThroughRules(input_tree);
Hypergraph hg;
@@ -371,6 +382,30 @@ Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::varia
bool has_multiple_states) :
pimpl_(new Tree2StringTranslatorImpl(conf, has_multiple_states)) {}
+void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) {
+ pimpl_->remove_grammars = 0;
+ if (kv.find("grammar0") != kv.end()) {
+ cerr << "SGML tag grammar0 is not expected (order is: grammar, grammar1, grammar2, ...)\n";
+ abort();
+ }
+ unsigned gc = 0;
+ set<string> loaded;
+ while(true) {
+ string gkey = "grammar";
+ if (gc > 0) gkey += boost::lexical_cast<string>(gc);
+ ++gc;
+ map<string,string>::const_iterator it = kv.find(gkey);
+ if (it == kv.end()) break;
+ const string& gfile = it->second;
+ if (loaded.count(gfile) == 1) {
+ cerr << "Attempting to load " << gfile << " twice!\n";
+ abort();
+ }
+ loaded.insert(gfile);
+ pimpl_->LoadSupplementalGrammar(gfile);
+ }
+}
+
bool Tree2StringTranslator::TranslateImpl(const string& input,
SentenceMetadata* smeta,
const vector<double>& weights,
@@ -378,9 +413,6 @@ bool Tree2StringTranslator::TranslateImpl(const string& input,
return pimpl_->Translate(input, smeta, weights, minus_lm_forest);
}
-void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) {
-}
-
void Tree2StringTranslator::SentenceCompleteImpl() {
pimpl_->RemoveGrammars();
}