summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/Makefile.am2
-rw-r--r--decoder/decoder.cc16
-rw-r--r--decoder/decoder.h6
-rw-r--r--decoder/earley_composer.cc38
-rw-r--r--decoder/hg.cc63
-rw-r--r--decoder/hg.h17
-rw-r--r--decoder/hg_io.cc1
-rw-r--r--decoder/hg_remove_eps.cc91
-rw-r--r--decoder/hg_remove_eps.h13
-rw-r--r--decoder/inside_outside.h4
-rw-r--r--decoder/rescore_translator.cc58
-rw-r--r--decoder/scfg_translator.cc70
-rw-r--r--decoder/translator.h17
-rw-r--r--decoder/trule.cc4
14 files changed, 259 insertions, 141 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 00d01e53..0a792549 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -37,9 +37,11 @@ libcdec_a_SOURCES = \
fst_translator.cc \
csplit.cc \
translator.cc \
+ rescore_translator.cc \
scfg_translator.cc \
hg.cc \
hg_io.cc \
+ hg_remove_eps.cc \
decoder.cc \
hg_intersect.cc \
hg_sampler.cc \
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 333f0fb6..a6f7b1ce 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -527,8 +527,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
}
formalism = LowercaseString(str("formalism",conf));
- if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") {
- cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n";
+ if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") {
+ cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n";
cerr << dcmdline_options << endl;
exit(1);
}
@@ -675,6 +675,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
translator.reset(new LexicalTrans(conf));
else if (formalism == "lexalign")
translator.reset(new LexicalAlign(conf));
+ else if (formalism == "rescore")
+ translator.reset(new RescoreTranslator(conf));
else if (formalism == "tagger")
translator.reset(new Tagger(conf));
else
@@ -743,16 +745,14 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) {
}
vector<weight_t>& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); }
const vector<weight_t>& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); }
-void Decoder::SetSupplementalGrammar(const std::string& grammar_string) {
- assert(pimpl_->translator->GetDecoderType() == "SCFG");
- static_cast<SCFGTranslator&>(*pimpl_->translator).SetSupplementalGrammar(grammar_string);
+void Decoder::AddSupplementalGrammar(GrammarPtr gp) {
+ static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammar(gp);
}
-void Decoder::SetSentenceGrammarFromString(const std::string& grammar_str) {
+void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string) {
assert(pimpl_->translator->GetDecoderType() == "SCFG");
- static_cast<SCFGTranslator&>(*pimpl_->translator).SetSentenceGrammarFromString(grammar_str);
+ static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string);
}
-
bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
string buf = input;
NgramCache::Clear(); // clear ngram cache for remote LM (if used)
diff --git a/decoder/decoder.h b/decoder/decoder.h
index 6b2f7b16..bef2ff5e 100644
--- a/decoder/decoder.h
+++ b/decoder/decoder.h
@@ -37,6 +37,8 @@ struct DecoderObserver {
virtual void NotifyDecodingComplete(const SentenceMetadata& smeta);
};
+struct Grammar; // TODO once the decoder interface is cleaned up,
+ // this should be somewhere else
struct Decoder {
Decoder(int argc, char** argv);
Decoder(std::istream* config_file);
@@ -54,8 +56,8 @@ struct Decoder {
// add grammar rules (currently only supported by SCFG decoders)
// that will be used on subsequent calls to Decode. rules should be in standard
// text format. This function does NOT read from a file.
- void SetSupplementalGrammar(const std::string& grammar);
- void SetSentenceGrammarFromString(const std::string& grammar_str);
+ void AddSupplementalGrammar(boost::shared_ptr<Grammar> gp);
+ void AddSupplementalGrammarFromString(const std::string& grammar_string);
private:
boost::program_options::variables_map conf;
boost::shared_ptr<DecoderImpl> pimpl_;
diff --git a/decoder/earley_composer.cc b/decoder/earley_composer.cc
index d265d954..efce70a6 100644
--- a/decoder/earley_composer.cc
+++ b/decoder/earley_composer.cc
@@ -16,6 +16,7 @@
#include "sparse_vector.h"
#include "tdict.h"
#include "hg.h"
+#include "hg_remove_eps.h"
using namespace std;
using namespace std::tr1;
@@ -48,6 +49,27 @@ static void InitializeConstants() {
}
////////////////////////////////////////////////////////////
+TRulePtr CreateBinaryRule(int lhs, int rhs1, int rhs2) {
+ TRule* r = new TRule(*kX1X2);
+ r->lhs_ = lhs;
+ r->f_[0] = rhs1;
+ r->f_[1] = rhs2;
+ return TRulePtr(r);
+}
+
+TRulePtr CreateUnaryRule(int lhs, int rhs1) {
+ TRule* r = new TRule(*kX1);
+ r->lhs_ = lhs;
+ r->f_[0] = rhs1;
+ return TRulePtr(r);
+}
+
+TRulePtr CreateEpsilonRule(int lhs) {
+ TRule* r = new TRule(*kEPSRule);
+ r->lhs_ = lhs;
+ return TRulePtr(r);
+}
+
class EGrammarNode {
friend bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest);
friend void AddGrammarRule(const string& r, map<WordID, EGrammarNode>* g);
@@ -356,7 +378,7 @@ class EarleyComposerImpl {
}
if (goal_node) {
forest->PruneUnreachable(goal_node->id_);
- forest->EpsilonRemove(kEPS);
+ RemoveEpsilons(forest, kEPS);
}
FreeAll();
return goal_node;
@@ -557,24 +579,30 @@ class EarleyComposerImpl {
}
Hypergraph::Node*& head_node = edge2node[edge];
if (!head_node)
- head_node = hg->AddNode(kPHRASE);
+ head_node = hg->AddNode(edge->cat);
if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_0_ && edge->IsPassive()) {
assert(goal_node == NULL || goal_node == head_node);
goal_node = head_node;
}
+ int rhs1 = 0;
+ int rhs2 = 0;
Hypergraph::TailNodeVector tail;
SparseVector<double> extra;
if (edge->IsCreatedByPredict()) {
// extra.set_value(FD::Convert("predict"), 1);
} else if (edge->IsCreatedByScan()) {
tail.push_back(edge2node[edge->active_parent]->id_);
+ rhs1 = edge->active_parent->cat;
if (tps) {
tail.push_back(tps->id_);
+ rhs2 = kPHRASE;
}
//extra.set_value(FD::Convert("scan"), 1);
} else if (edge->IsCreatedByComplete()) {
tail.push_back(edge2node[edge->active_parent]->id_);
+ rhs1 = edge->active_parent->cat;
tail.push_back(edge2node[edge->passive_parent]->id_);
+ rhs2 = edge->passive_parent->cat;
//extra.set_value(FD::Convert("complete"), 1);
} else {
assert(!"unexpected edge type!");
@@ -592,11 +620,11 @@ class EarleyComposerImpl {
#endif
Hypergraph::Edge* hg_edge = NULL;
if (tail.size() == 0) {
- hg_edge = hg->AddEdge(kEPSRule, tail);
+ hg_edge = hg->AddEdge(CreateEpsilonRule(edge->cat), tail);
} else if (tail.size() == 1) {
- hg_edge = hg->AddEdge(kX1, tail);
+ hg_edge = hg->AddEdge(CreateUnaryRule(edge->cat, rhs1), tail);
} else if (tail.size() == 2) {
- hg_edge = hg->AddEdge(kX1X2, tail);
+ hg_edge = hg->AddEdge(CreateBinaryRule(edge->cat, rhs1, rhs2), tail);
}
if (edge->features)
hg_edge->feature_values_ += *edge->features;
diff --git a/decoder/hg.cc b/decoder/hg.cc
index dd272221..7240a8ab 100644
--- a/decoder/hg.cc
+++ b/decoder/hg.cc
@@ -605,69 +605,6 @@ void Hypergraph::TopologicallySortNodesAndEdges(int goal_index,
#endif
}
-TRulePtr Hypergraph::kEPSRule;
-TRulePtr Hypergraph::kUnaryRule;
-
-void Hypergraph::EpsilonRemove(WordID eps) {
- if (!kEPSRule) {
- kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>"));
- kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]"));
- }
- vector<bool> kill(edges_.size(), false);
- for (unsigned i = 0; i < edges_.size(); ++i) {
- const Edge& edge = edges_[i];
- if (edge.tail_nodes_.empty() &&
- edge.rule_->f_.size() == 1 &&
- edge.rule_->f_[0] == eps) {
- kill[i] = true;
- if (!edge.feature_values_.empty()) {
- Node& node = nodes_[edge.head_node_];
- if (node.in_edges_.size() != 1) {
- cerr << "[WARNING] <eps> edge with features going into non-empty node - can't promote\n";
- // this *probably* means that there are multiple derivations of the
- // same sequence via different paths through the input forest
- // this needs to be investigated and fixed
- } else {
- for (unsigned j = 0; j < node.out_edges_.size(); ++j)
- edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_;
- // cerr << "PROMOTED " << edge.feature_values_ << endl;
- }
- }
- }
- }
- bool created_eps = false;
- PruneEdges(kill);
- for (unsigned i = 0; i < nodes_.size(); ++i) {
- const Node& node = nodes_[i];
- if (node.in_edges_.empty()) {
- for (unsigned j = 0; j < node.out_edges_.size(); ++j) {
- Edge& edge = edges_[node.out_edges_[j]];
- if (edge.rule_->Arity() == 2) {
- assert(edge.rule_->f_.size() == 2);
- assert(edge.rule_->e_.size() == 2);
- edge.rule_ = kUnaryRule;
- unsigned cur = node.id_;
- int t = -1;
- assert(edge.tail_nodes_.size() == 2);
- for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; }
- assert(t != -1);
- edge.tail_nodes_.resize(1);
- edge.tail_nodes_[0] = t;
- } else {
- edge.rule_ = kEPSRule;
- edge.rule_->f_[0] = eps;
- edge.rule_->e_[0] = eps;
- edge.tail_nodes_.clear();
- created_eps = true;
- }
- }
- }
- }
- vector<bool> k2(edges_.size(), false);
- PruneEdges(k2);
- if (created_eps) EpsilonRemove(eps);
-}
-
struct EdgeWeightSorter {
const Hypergraph& hg;
EdgeWeightSorter(const Hypergraph& h) : hg(h) {}
diff --git a/decoder/hg.h b/decoder/hg.h
index 91d25f01..591e98ce 100644
--- a/decoder/hg.h
+++ b/decoder/hg.h
@@ -148,7 +148,7 @@ public:
void show(std::ostream &o,unsigned mask=SPAN|RULE) const {
o<<'{';
if (mask&CATEGORY)
- o<<TD::Convert(rule_->GetLHS());
+ o<< '[' << TD::Convert(-rule_->GetLHS()) << ']';
if (mask&PREV_SPAN)
o<<'<'<<prev_i_<<','<<prev_j_<<'>';
if (mask&SPAN)
@@ -156,9 +156,9 @@ public:
if (mask&PROB)
o<<" p="<<edge_prob_;
if (mask&FEATURES)
- o<<" "<<feature_values_;
+ o<<' '<<feature_values_;
if (mask&RULE)
- o<<rule_->AsString(mask&RULE_LHS);
+ o<<' '<<rule_->AsString(mask&RULE_LHS);
if (USE_INFO_EDGE) {
std::string const& i=info();
if (mask&&!i.empty()) o << " |||"<<i; // remember, the initial space is expected as part of i
@@ -384,14 +384,6 @@ public:
// compute the total number of paths in the forest
double NumberOfPaths() const;
- // BEWARE. this assumes that the source and target language
- // strings are identical and that there are no loops.
- // It assumes a bunch of other things about where the
- // epsilons will be. It tries to assert failure if you
- // break these assumptions, but it may not.
- // TODO - make this work
- void EpsilonRemove(WordID eps);
-
// multiple the weights vector by the edge feature vector
// (inner product) to set the edge probabilities
template <class V>
@@ -535,9 +527,6 @@ public:
private:
Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {}
-
- static TRulePtr kEPSRule;
- static TRulePtr kUnaryRule;
};
diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc
index bfb2fb80..8bd40387 100644
--- a/decoder/hg_io.cc
+++ b/decoder/hg_io.cc
@@ -261,6 +261,7 @@ static void WriteRule(const TRule& r, ostream* out) {
}
bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) {
+ if (hg.empty()) { *out << "{}\n"; return true; }
map<const TRule*, int> rid;
ostream& o = *out;
rid[NULL] = 0;
diff --git a/decoder/hg_remove_eps.cc b/decoder/hg_remove_eps.cc
new file mode 100644
index 00000000..050c4876
--- /dev/null
+++ b/decoder/hg_remove_eps.cc
@@ -0,0 +1,91 @@
+#include "hg_remove_eps.h"
+
+#include <cassert>
+
+#include "trule.h"
+#include "hg.h"
+
+using namespace std;
+
+namespace {
+ TRulePtr kEPSRule;
+ TRulePtr kUnaryRule;
+
+ TRulePtr CreateUnaryRule(int lhs, int rhs) {
+ if (!kUnaryRule) kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]"));
+ TRule* r = new TRule(*kUnaryRule);
+ assert(lhs < 0);
+ assert(rhs < 0);
+ r->lhs_ = lhs;
+ r->f_[0] = rhs;
+ return TRulePtr(r);
+ }
+
+ TRulePtr CreateEpsilonRule(int lhs, WordID eps) {
+ if (!kEPSRule) kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>"));
+ TRule* r = new TRule(*kEPSRule);
+ r->lhs_ = lhs;
+ assert(lhs < 0);
+ assert(eps > 0);
+ r->e_[0] = eps;
+ r->f_[0] = eps;
+ return TRulePtr(r);
+ }
+}
+
+void RemoveEpsilons(Hypergraph* g, WordID eps) {
+ vector<bool> kill(g->edges_.size(), false);
+ for (unsigned i = 0; i < g->edges_.size(); ++i) {
+ const Hypergraph::Edge& edge = g->edges_[i];
+ if (edge.tail_nodes_.empty() &&
+ edge.rule_->f_.size() == 1 &&
+ edge.rule_->f_[0] == eps) {
+ kill[i] = true;
+ if (!edge.feature_values_.empty()) {
+ Hypergraph::Node& node = g->nodes_[edge.head_node_];
+ if (node.in_edges_.size() != 1) {
+ cerr << "[WARNING] <eps> edge with features going into non-empty node - can't promote\n";
+ // this *probably* means that there are multiple derivations of the
+ // same sequence via different paths through the input forest
+ // this needs to be investigated and fixed
+ } else {
+ for (unsigned j = 0; j < node.out_edges_.size(); ++j)
+ g->edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_;
+ // cerr << "PROMOTED " << edge.feature_values_ << endl;
+ }
+ }
+ }
+ }
+ bool created_eps = false;
+ g->PruneEdges(kill);
+ for (unsigned i = 0; i < g->nodes_.size(); ++i) {
+ const Hypergraph::Node& node = g->nodes_[i];
+ if (node.in_edges_.empty()) {
+ for (unsigned j = 0; j < node.out_edges_.size(); ++j) {
+ Hypergraph::Edge& edge = g->edges_[node.out_edges_[j]];
+ const int lhs = edge.rule_->lhs_;
+ if (edge.rule_->Arity() == 2) {
+ assert(edge.rule_->f_.size() == 2);
+ assert(edge.rule_->e_.size() == 2);
+ unsigned cur = node.id_;
+ int t = -1;
+ assert(edge.tail_nodes_.size() == 2);
+ int rhs = 0;
+ for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; rhs = edge.rule_->f_[i]; }
+ assert(t != -1);
+ edge.tail_nodes_.resize(1);
+ edge.tail_nodes_[0] = t;
+ edge.rule_ = CreateUnaryRule(lhs, rhs);
+ } else {
+ edge.rule_ = CreateEpsilonRule(lhs, eps);
+ edge.tail_nodes_.clear();
+ created_eps = true;
+ }
+ }
+ }
+ }
+ vector<bool> k2(g->edges_.size(), false);
+ g->PruneEdges(k2);
+ if (created_eps) RemoveEpsilons(g, eps);
+}
+
diff --git a/decoder/hg_remove_eps.h b/decoder/hg_remove_eps.h
new file mode 100644
index 00000000..82f06039
--- /dev/null
+++ b/decoder/hg_remove_eps.h
@@ -0,0 +1,13 @@
+#ifndef _HG_REMOVE_EPS_H_
+#define _HG_REMOVE_EPS_H_
+
+#include "wordid.h"
+class Hypergraph;
+
+// This is not a complete implementation of the general algorithm for
+// doing this. It makes a few weird assumptions, for example, that
+// if some nonterminal X rewrites as eps, then that is the only thing
+// that it rewrites as. This needs to be fixed for the general case!
+void RemoveEpsilons(Hypergraph* g, WordID eps);
+
+#endif
diff --git a/decoder/inside_outside.h b/decoder/inside_outside.h
index bb7f9fcc..f73a1d3f 100644
--- a/decoder/inside_outside.h
+++ b/decoder/inside_outside.h
@@ -41,10 +41,6 @@ WeightType Inside(const Hypergraph& hg,
WeightType* const cur_node_inside_score = &inside_score[i];
Hypergraph::EdgesVector const& in=hg.nodes_[i].in_edges_;
const unsigned num_in_edges = in.size();
- if (num_in_edges == 0) {
- *cur_node_inside_score = WeightType(1); //FIXME: why not call weight(edge) instead?
- continue;
- }
for (unsigned j = 0; j < num_in_edges; ++j) {
const Hypergraph::Edge& edge = hg.edges_[in[j]];
WeightType score = weight(edge);
diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc
new file mode 100644
index 00000000..10192f7a
--- /dev/null
+++ b/decoder/rescore_translator.cc
@@ -0,0 +1,58 @@
+#include "translator.h"
+
+#include <sstream>
+#include <boost/shared_ptr.hpp>
+
+#include "sentence_metadata.h"
+#include "hg.h"
+#include "hg_io.h"
+#include "tdict.h"
+
+using namespace std;
+
+struct RescoreTranslatorImpl {
+ RescoreTranslatorImpl(const boost::program_options::variables_map& conf) :
+ goal_sym(conf["goal"].as<string>()),
+ kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")),
+ kGOAL(TD::Convert("Goal") * -1) {
+ }
+
+ bool Translate(const string& input,
+ const vector<double>& weights,
+ Hypergraph* forest) {
+ if (input == "{}") return false;
+ if (input.find("{\"rules\"") == 0) {
+ istringstream is(input);
+ Hypergraph src_cfg_hg;
+ if (!HypergraphIO::ReadFromJSON(&is, forest)) {
+ cerr << "Parse error while reading HG from JSON.\n";
+ abort();
+ }
+ } else {
+ cerr << "Can only read HG input from JSON: use training/grammar_convert\n";
+ abort();
+ }
+ Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
+ Hypergraph::Node* goal = forest->AddNode(kGOAL);
+ Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
+ forest->ConnectEdgeToHeadNode(hg_edge, goal);
+ forest->Reweight(weights);
+ return true;
+ }
+
+ const string goal_sym;
+ const TRulePtr kGOAL_RULE;
+ const WordID kGOAL;
+};
+
+RescoreTranslator::RescoreTranslator(const boost::program_options::variables_map& conf) :
+ pimpl_(new RescoreTranslatorImpl(conf)) {}
+
+bool RescoreTranslator::TranslateImpl(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ smeta->SetSourceLength(0); // don't know how to compute this
+ return pimpl_->Translate(input, weights, minus_lm_forest);
+}
+
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index 185f979a..a978cfc2 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -20,7 +20,6 @@
#define reverse_foreach BOOST_REVERSE_FOREACH
using namespace std;
-static bool usingSentenceGrammar = false;
static bool printGrammarsUsed = false;
struct SCFGTranslatorImpl {
@@ -91,31 +90,31 @@ struct SCFGTranslatorImpl {
bool show_tree_structure_;
unsigned int ctf_iterations_;
vector<GrammarPtr> grammars;
- GrammarPtr sup_grammar_;
+ set<GrammarPtr> sup_grammars_;
- struct Equals { Equals(const GrammarPtr& v) : v_(v) {}
- bool operator()(const GrammarPtr& x) const { return x == v_; } const GrammarPtr& v_; };
+ struct ContainedIn {
+ ContainedIn(const set<GrammarPtr>& gs) : gs_(gs) {}
+ bool operator()(const GrammarPtr& x) const { return gs_.find(x) != gs_.end(); }
+ const set<GrammarPtr>& gs_;
+ };
- void SetSupplementalGrammar(const std::string& grammar_string) {
- grammars.erase(remove_if(grammars.begin(), grammars.end(), Equals(sup_grammar_)), grammars.end());
+ void AddSupplementalGrammarFromString(const std::string& grammar_string) {
+ grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end());
istringstream in(grammar_string);
- sup_grammar_.reset(new TextGrammar(&in));
- grammars.push_back(sup_grammar_);
+ TextGrammar* sent_grammar = new TextGrammar(&in);
+ sent_grammar->SetMaxSpan(max_span_limit);
+ sent_grammar->SetGrammarName("SupFromString");
+ AddSupplementalGrammar(GrammarPtr(sent_grammar));
}
- struct NameEquals { NameEquals(const string name) : name_(name) {}
- bool operator()(const GrammarPtr& x) const { return x->GetGrammarName() == name_; } const string name_; };
+ void AddSupplementalGrammar(GrammarPtr gp) {
+ sup_grammars_.insert(gp);
+ grammars.push_back(gp);
+ }
- void SetSentenceGrammarFromString(const std::string& grammar_str) {
- assert(grammar_str != "");
- if (!SILENT) cerr << "Setting sentence grammar" << endl;
- usingSentenceGrammar = true;
- istringstream in(grammar_str);
- TextGrammar* sent_grammar = new TextGrammar(&in);
- sent_grammar->SetMaxSpan(max_span_limit);
- sent_grammar->SetGrammarName("__psg");
- grammars.erase(remove_if(grammars.begin(), grammars.end(), NameEquals("__psg")), grammars.end());
- grammars.push_back(GrammarPtr(sent_grammar));
+ void RemoveSupplementalGrammars() {
+ grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end());
+ sup_grammars_.clear();
}
bool Translate(const string& input,
@@ -300,35 +299,24 @@ Check for grammar pointer in the sentence markup, for use with sentence specific
*/
void SCFGTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) {
map<string,string>::const_iterator it = kv.find("grammar");
-
-
- if (it == kv.end()) {
- usingSentenceGrammar= false;
- return;
+ if (it != kv.end()) {
+ TextGrammar* sentGrammar = new TextGrammar(it->second);
+ sentGrammar->SetMaxSpan(pimpl_->max_span_limit);
+ sentGrammar->SetGrammarName(it->second);
+ pimpl_->AddSupplementalGrammar(GrammarPtr(sentGrammar));
}
- //Create sentence specific grammar from specified file name and load grammar into list of grammars
- usingSentenceGrammar = true;
- TextGrammar* sentGrammar = new TextGrammar(it->second);
- sentGrammar->SetMaxSpan(pimpl_->max_span_limit);
- sentGrammar->SetGrammarName(it->second);
- pimpl_->grammars.push_back(GrammarPtr(sentGrammar));
-
}
-void SCFGTranslator::SetSupplementalGrammar(const std::string& grammar) {
- pimpl_->SetSupplementalGrammar(grammar);
+void SCFGTranslator::AddSupplementalGrammarFromString(const std::string& grammar) {
+ pimpl_->AddSupplementalGrammarFromString(grammar);
}
-void SCFGTranslator::SetSentenceGrammarFromString(const std::string& grammar_str) {
- pimpl_->SetSentenceGrammarFromString(grammar_str);
+void SCFGTranslator::AddSupplementalGrammar(GrammarPtr grammar) {
+ pimpl_->AddSupplementalGrammar(grammar);
}
void SCFGTranslator::SentenceCompleteImpl() {
-
- if(usingSentenceGrammar) // Drop the last sentence grammar from the list of grammars
- {
- pimpl_->grammars.pop_back();
- }
+ pimpl_->RemoveSupplementalGrammars();
}
std::string SCFGTranslator::GetDecoderType() const {
diff --git a/decoder/translator.h b/decoder/translator.h
index cfd3b08a..c0800e84 100644
--- a/decoder/translator.h
+++ b/decoder/translator.h
@@ -58,8 +58,8 @@ class SCFGTranslatorImpl;
class SCFGTranslator : public Translator {
public:
SCFGTranslator(const boost::program_options::variables_map& conf);
- void SetSupplementalGrammar(const std::string& grammar);
- void SetSentenceGrammarFromString(const std::string& grammar);
+ void AddSupplementalGrammar(GrammarPtr gp);
+ void AddSupplementalGrammarFromString(const std::string& grammar);
virtual std::string GetDecoderType() const;
protected:
bool TranslateImpl(const std::string& src,
@@ -85,4 +85,17 @@ class FSTTranslator : public Translator {
boost::shared_ptr<FSTTranslatorImpl> pimpl_;
};
+class RescoreTranslatorImpl;
+class RescoreTranslator : public Translator {
+ public:
+ RescoreTranslator(const boost::program_options::variables_map& conf);
+ private:
+ bool TranslateImpl(const std::string& src,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest);
+ private:
+ boost::shared_ptr<RescoreTranslatorImpl> pimpl_;
+};
+
#endif
diff --git a/decoder/trule.cc b/decoder/trule.cc
index 187a003d..896f9f3d 100644
--- a/decoder/trule.cc
+++ b/decoder/trule.cc
@@ -237,9 +237,9 @@ void TRule::ComputeArity() {
string TRule::AsString(bool verbose) const {
ostringstream os;
int idx = 0;
- if (lhs_ && verbose) {
+ if (lhs_) {
os << '[' << TD::Convert(lhs_ * -1) << "] |||";
- }
+ } else { os << "NOLHS |||"; }
for (unsigned i = 0; i < f_.size(); ++i) {
const WordID& w = f_[i];
if (w < 0) {