From 09d085e16f31b36b1177b313ee54414eb29d9082 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 20 Jul 2012 00:46:10 -0400 Subject: enable programmatic setting of grammar objects, cleanup of grammar handling --- decoder/decoder.cc | 10 +++----- decoder/decoder.h | 6 +++-- decoder/scfg_translator.cc | 63 ++++++++++++++++++---------------------------- decoder/translator.h | 4 +-- 4 files changed, 35 insertions(+), 48 deletions(-) (limited to 'decoder') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 333f0fb6..ad4e9e07 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -743,16 +743,14 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) { } vector& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); } const vector& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); } -void Decoder::SetSupplementalGrammar(const std::string& grammar_string) { - assert(pimpl_->translator->GetDecoderType() == "SCFG"); - static_cast(*pimpl_->translator).SetSupplementalGrammar(grammar_string); +void Decoder::AddSupplementalGrammar(GrammarPtr gp) { + static_cast(*pimpl_->translator).AddSupplementalGrammar(gp); } -void Decoder::SetSentenceGrammarFromString(const std::string& grammar_str) { +void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string) { assert(pimpl_->translator->GetDecoderType() == "SCFG"); - static_cast(*pimpl_->translator).SetSentenceGrammarFromString(grammar_str); + static_cast(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string); } - bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { string buf = input; NgramCache::Clear(); // clear ngram cache for remote LM (if used) diff --git a/decoder/decoder.h b/decoder/decoder.h index 6b2f7b16..bef2ff5e 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -37,6 +37,8 @@ struct DecoderObserver { virtual void NotifyDecodingComplete(const SentenceMetadata& smeta); }; +struct Grammar; // TODO once the decoder interface is cleaned up, + // this should be somewhere else struct Decoder { Decoder(int argc, char** argv); Decoder(std::istream* config_file); @@ -54,8 +56,8 @@ struct Decoder { // add grammar rules (currently only supported by SCFG decoders) // that will be used on subsequent calls to Decode. rules should be in standard // text format. This function does NOT read from a file. - void SetSupplementalGrammar(const std::string& grammar); - void SetSentenceGrammarFromString(const std::string& grammar_str); + void AddSupplementalGrammar(boost::shared_ptr gp); + void AddSupplementalGrammarFromString(const std::string& grammar_string); private: boost::program_options::variables_map conf; boost::shared_ptr pimpl_; diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 185f979a..aaa6c40b 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -20,7 +20,6 @@ #define reverse_foreach BOOST_REVERSE_FOREACH using namespace std; -static bool usingSentenceGrammar = false; static bool printGrammarsUsed = false; struct SCFGTranslatorImpl { @@ -91,31 +90,31 @@ struct SCFGTranslatorImpl { bool show_tree_structure_; unsigned int ctf_iterations_; vector grammars; - GrammarPtr sup_grammar_; + set sup_grammars_; - struct Equals { Equals(const GrammarPtr& v) : v_(v) {} - bool operator()(const GrammarPtr& x) const { return x == v_; } const GrammarPtr& v_; }; + struct ContainedIn { + ContainedIn(const set& gs) : gs_(gs) {} + bool operator()(const GrammarPtr& x) const { return gs_.find(x) != gs_.end(); } + const set& gs_; + }; - void SetSupplementalGrammar(const std::string& grammar_string) { - grammars.erase(remove_if(grammars.begin(), grammars.end(), Equals(sup_grammar_)), grammars.end()); + void AddSupplementalGrammarFromString(const std::string& grammar_string) { + grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end()); istringstream in(grammar_string); - sup_grammar_.reset(new TextGrammar(&in)); - grammars.push_back(sup_grammar_); + TextGrammar* sent_grammar = new TextGrammar(&in); + sent_grammar->SetMaxSpan(max_span_limit); + sent_grammar->SetGrammarName("SupFromString"); + AddSupplementalGrammar(GrammarPtr(sent_grammar)); } - struct NameEquals { NameEquals(const string name) : name_(name) {} - bool operator()(const GrammarPtr& x) const { return x->GetGrammarName() == name_; } const string name_; }; + void AddSupplementalGrammar(GrammarPtr gp) { + sup_grammars_.insert(gp); + grammars.push_back(gp); + } - void SetSentenceGrammarFromString(const std::string& grammar_str) { - assert(grammar_str != ""); - if (!SILENT) cerr << "Setting sentence grammar" << endl; - usingSentenceGrammar = true; - istringstream in(grammar_str); - TextGrammar* sent_grammar = new TextGrammar(&in); - sent_grammar->SetMaxSpan(max_span_limit); - sent_grammar->SetGrammarName("__psg"); - grammars.erase(remove_if(grammars.begin(), grammars.end(), NameEquals("__psg")), grammars.end()); - grammars.push_back(GrammarPtr(sent_grammar)); + void RemoveSupplementalGrammars() { + grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end()); + sup_grammars_.clear(); } bool Translate(const string& input, @@ -301,34 +300,22 @@ Check for grammar pointer in the sentence markup, for use with sentence specific void SCFGTranslator::ProcessMarkupHintsImpl(const map& kv) { map::const_iterator it = kv.find("grammar"); - - if (it == kv.end()) { - usingSentenceGrammar= false; - return; - } - //Create sentence specific grammar from specified file name and load grammar into list of grammars - usingSentenceGrammar = true; TextGrammar* sentGrammar = new TextGrammar(it->second); sentGrammar->SetMaxSpan(pimpl_->max_span_limit); sentGrammar->SetGrammarName(it->second); - pimpl_->grammars.push_back(GrammarPtr(sentGrammar)); - + pimpl_->AddSupplementalGrammar(GrammarPtr(sentGrammar)); } -void SCFGTranslator::SetSupplementalGrammar(const std::string& grammar) { - pimpl_->SetSupplementalGrammar(grammar); +void SCFGTranslator::AddSupplementalGrammarFromString(const std::string& grammar) { + pimpl_->AddSupplementalGrammarFromString(grammar); } -void SCFGTranslator::SetSentenceGrammarFromString(const std::string& grammar_str) { - pimpl_->SetSentenceGrammarFromString(grammar_str); +void SCFGTranslator::AddSupplementalGrammar(GrammarPtr grammar) { + pimpl_->AddSupplementalGrammar(grammar); } void SCFGTranslator::SentenceCompleteImpl() { - - if(usingSentenceGrammar) // Drop the last sentence grammar from the list of grammars - { - pimpl_->grammars.pop_back(); - } + pimpl_->RemoveSupplementalGrammars(); } std::string SCFGTranslator::GetDecoderType() const { diff --git a/decoder/translator.h b/decoder/translator.h index cfd3b08a..fc2bb760 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -58,8 +58,8 @@ class SCFGTranslatorImpl; class SCFGTranslator : public Translator { public: SCFGTranslator(const boost::program_options::variables_map& conf); - void SetSupplementalGrammar(const std::string& grammar); - void SetSentenceGrammarFromString(const std::string& grammar); + void AddSupplementalGrammar(GrammarPtr gp); + void AddSupplementalGrammarFromString(const std::string& grammar); virtual std::string GetDecoderType() const; protected: bool TranslateImpl(const std::string& src, -- cgit v1.2.3 From c4c9c2febd5af552ecddc215758e32b88013fbc7 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 20 Jul 2012 01:06:50 -0400 Subject: fix --- decoder/scfg_translator.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'decoder') diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index aaa6c40b..a978cfc2 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -299,11 +299,12 @@ Check for grammar pointer in the sentence markup, for use with sentence specific */ void SCFGTranslator::ProcessMarkupHintsImpl(const map& kv) { map::const_iterator it = kv.find("grammar"); - - TextGrammar* sentGrammar = new TextGrammar(it->second); - sentGrammar->SetMaxSpan(pimpl_->max_span_limit); - sentGrammar->SetGrammarName(it->second); - pimpl_->AddSupplementalGrammar(GrammarPtr(sentGrammar)); + if (it != kv.end()) { + TextGrammar* sentGrammar = new TextGrammar(it->second); + sentGrammar->SetMaxSpan(pimpl_->max_span_limit); + sentGrammar->SetGrammarName(it->second); + pimpl_->AddSupplementalGrammar(GrammarPtr(sentGrammar)); + } } void SCFGTranslator::AddSupplementalGrammarFromString(const std::string& grammar) { -- cgit v1.2.3 From 16d0e2a34df3e32e8992aeda3ad2de7a6e525f14 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 28 Jul 2012 18:35:11 -0400 Subject: slight improvement to the way NTs are handled in the FST-CFG composition algorithm, new rescoring mode --- decoder/Makefile.am | 2 + decoder/decoder.cc | 6 ++- decoder/earley_composer.cc | 38 +++++++++++++++--- decoder/hg.cc | 63 ------------------------------ decoder/hg.h | 17 ++------ decoder/hg_remove_eps.cc | 91 +++++++++++++++++++++++++++++++++++++++++++ decoder/hg_remove_eps.h | 13 +++++++ decoder/rescore_translator.cc | 57 +++++++++++++++++++++++++++ decoder/translator.h | 13 +++++++ decoder/trule.cc | 4 +- 10 files changed, 218 insertions(+), 86 deletions(-) create mode 100644 decoder/hg_remove_eps.cc create mode 100644 decoder/hg_remove_eps.h create mode 100644 decoder/rescore_translator.cc (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 00d01e53..0a792549 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -37,9 +37,11 @@ libcdec_a_SOURCES = \ fst_translator.cc \ csplit.cc \ translator.cc \ + rescore_translator.cc \ scfg_translator.cc \ hg.cc \ hg_io.cc \ + hg_remove_eps.cc \ decoder.cc \ hg_intersect.cc \ hg_sampler.cc \ diff --git a/decoder/decoder.cc b/decoder/decoder.cc index ad4e9e07..a6f7b1ce 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -527,8 +527,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } formalism = LowercaseString(str("formalism",conf)); - if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n"; + if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } @@ -675,6 +675,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream translator.reset(new LexicalTrans(conf)); else if (formalism == "lexalign") translator.reset(new LexicalAlign(conf)); + else if (formalism == "rescore") + translator.reset(new RescoreTranslator(conf)); else if (formalism == "tagger") translator.reset(new Tagger(conf)); else diff --git a/decoder/earley_composer.cc b/decoder/earley_composer.cc index d265d954..efce70a6 100644 --- a/decoder/earley_composer.cc +++ b/decoder/earley_composer.cc @@ -16,6 +16,7 @@ #include "sparse_vector.h" #include "tdict.h" #include "hg.h" +#include "hg_remove_eps.h" using namespace std; using namespace std::tr1; @@ -48,6 +49,27 @@ static void InitializeConstants() { } //////////////////////////////////////////////////////////// +TRulePtr CreateBinaryRule(int lhs, int rhs1, int rhs2) { + TRule* r = new TRule(*kX1X2); + r->lhs_ = lhs; + r->f_[0] = rhs1; + r->f_[1] = rhs2; + return TRulePtr(r); +} + +TRulePtr CreateUnaryRule(int lhs, int rhs1) { + TRule* r = new TRule(*kX1); + r->lhs_ = lhs; + r->f_[0] = rhs1; + return TRulePtr(r); +} + +TRulePtr CreateEpsilonRule(int lhs) { + TRule* r = new TRule(*kEPSRule); + r->lhs_ = lhs; + return TRulePtr(r); +} + class EGrammarNode { friend bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest); friend void AddGrammarRule(const string& r, map* g); @@ -356,7 +378,7 @@ class EarleyComposerImpl { } if (goal_node) { forest->PruneUnreachable(goal_node->id_); - forest->EpsilonRemove(kEPS); + RemoveEpsilons(forest, kEPS); } FreeAll(); return goal_node; @@ -557,24 +579,30 @@ class EarleyComposerImpl { } Hypergraph::Node*& head_node = edge2node[edge]; if (!head_node) - head_node = hg->AddNode(kPHRASE); + head_node = hg->AddNode(edge->cat); if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_0_ && edge->IsPassive()) { assert(goal_node == NULL || goal_node == head_node); goal_node = head_node; } + int rhs1 = 0; + int rhs2 = 0; Hypergraph::TailNodeVector tail; SparseVector extra; if (edge->IsCreatedByPredict()) { // extra.set_value(FD::Convert("predict"), 1); } else if (edge->IsCreatedByScan()) { tail.push_back(edge2node[edge->active_parent]->id_); + rhs1 = edge->active_parent->cat; if (tps) { tail.push_back(tps->id_); + rhs2 = kPHRASE; } //extra.set_value(FD::Convert("scan"), 1); } else if (edge->IsCreatedByComplete()) { tail.push_back(edge2node[edge->active_parent]->id_); + rhs1 = edge->active_parent->cat; tail.push_back(edge2node[edge->passive_parent]->id_); + rhs2 = edge->passive_parent->cat; //extra.set_value(FD::Convert("complete"), 1); } else { assert(!"unexpected edge type!"); @@ -592,11 +620,11 @@ class EarleyComposerImpl { #endif Hypergraph::Edge* hg_edge = NULL; if (tail.size() == 0) { - hg_edge = hg->AddEdge(kEPSRule, tail); + hg_edge = hg->AddEdge(CreateEpsilonRule(edge->cat), tail); } else if (tail.size() == 1) { - hg_edge = hg->AddEdge(kX1, tail); + hg_edge = hg->AddEdge(CreateUnaryRule(edge->cat, rhs1), tail); } else if (tail.size() == 2) { - hg_edge = hg->AddEdge(kX1X2, tail); + hg_edge = hg->AddEdge(CreateBinaryRule(edge->cat, rhs1, rhs2), tail); } if (edge->features) hg_edge->feature_values_ += *edge->features; diff --git a/decoder/hg.cc b/decoder/hg.cc index dd272221..7240a8ab 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -605,69 +605,6 @@ void Hypergraph::TopologicallySortNodesAndEdges(int goal_index, #endif } -TRulePtr Hypergraph::kEPSRule; -TRulePtr Hypergraph::kUnaryRule; - -void Hypergraph::EpsilonRemove(WordID eps) { - if (!kEPSRule) { - kEPSRule.reset(new TRule("[X] ||| ||| ")); - kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); - } - vector kill(edges_.size(), false); - for (unsigned i = 0; i < edges_.size(); ++i) { - const Edge& edge = edges_[i]; - if (edge.tail_nodes_.empty() && - edge.rule_->f_.size() == 1 && - edge.rule_->f_[0] == eps) { - kill[i] = true; - if (!edge.feature_values_.empty()) { - Node& node = nodes_[edge.head_node_]; - if (node.in_edges_.size() != 1) { - cerr << "[WARNING] edge with features going into non-empty node - can't promote\n"; - // this *probably* means that there are multiple derivations of the - // same sequence via different paths through the input forest - // this needs to be investigated and fixed - } else { - for (unsigned j = 0; j < node.out_edges_.size(); ++j) - edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_; - // cerr << "PROMOTED " << edge.feature_values_ << endl; - } - } - } - } - bool created_eps = false; - PruneEdges(kill); - for (unsigned i = 0; i < nodes_.size(); ++i) { - const Node& node = nodes_[i]; - if (node.in_edges_.empty()) { - for (unsigned j = 0; j < node.out_edges_.size(); ++j) { - Edge& edge = edges_[node.out_edges_[j]]; - if (edge.rule_->Arity() == 2) { - assert(edge.rule_->f_.size() == 2); - assert(edge.rule_->e_.size() == 2); - edge.rule_ = kUnaryRule; - unsigned cur = node.id_; - int t = -1; - assert(edge.tail_nodes_.size() == 2); - for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; } - assert(t != -1); - edge.tail_nodes_.resize(1); - edge.tail_nodes_[0] = t; - } else { - edge.rule_ = kEPSRule; - edge.rule_->f_[0] = eps; - edge.rule_->e_[0] = eps; - edge.tail_nodes_.clear(); - created_eps = true; - } - } - } - } - vector k2(edges_.size(), false); - PruneEdges(k2); - if (created_eps) EpsilonRemove(eps); -} - struct EdgeWeightSorter { const Hypergraph& hg; EdgeWeightSorter(const Hypergraph& h) : hg(h) {} diff --git a/decoder/hg.h b/decoder/hg.h index 91d25f01..591e98ce 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -148,7 +148,7 @@ public: void show(std::ostream &o,unsigned mask=SPAN|RULE) const { o<<'{'; if (mask&CATEGORY) - o<GetLHS()); + o<< '[' << TD::Convert(-rule_->GetLHS()) << ']'; if (mask&PREV_SPAN) o<<'<'<'; if (mask&SPAN) @@ -156,9 +156,9 @@ public: if (mask&PROB) o<<" p="<AsString(mask&RULE_LHS); + o<<' '<AsString(mask&RULE_LHS); if (USE_INFO_EDGE) { std::string const& i=info(); if (mask&&!i.empty()) o << " |||"< @@ -535,9 +527,6 @@ public: private: Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {} - - static TRulePtr kEPSRule; - static TRulePtr kUnaryRule; }; diff --git a/decoder/hg_remove_eps.cc b/decoder/hg_remove_eps.cc new file mode 100644 index 00000000..050c4876 --- /dev/null +++ b/decoder/hg_remove_eps.cc @@ -0,0 +1,91 @@ +#include "hg_remove_eps.h" + +#include + +#include "trule.h" +#include "hg.h" + +using namespace std; + +namespace { + TRulePtr kEPSRule; + TRulePtr kUnaryRule; + + TRulePtr CreateUnaryRule(int lhs, int rhs) { + if (!kUnaryRule) kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); + TRule* r = new TRule(*kUnaryRule); + assert(lhs < 0); + assert(rhs < 0); + r->lhs_ = lhs; + r->f_[0] = rhs; + return TRulePtr(r); + } + + TRulePtr CreateEpsilonRule(int lhs, WordID eps) { + if (!kEPSRule) kEPSRule.reset(new TRule("[X] ||| ||| ")); + TRule* r = new TRule(*kEPSRule); + r->lhs_ = lhs; + assert(lhs < 0); + assert(eps > 0); + r->e_[0] = eps; + r->f_[0] = eps; + return TRulePtr(r); + } +} + +void RemoveEpsilons(Hypergraph* g, WordID eps) { + vector kill(g->edges_.size(), false); + for (unsigned i = 0; i < g->edges_.size(); ++i) { + const Hypergraph::Edge& edge = g->edges_[i]; + if (edge.tail_nodes_.empty() && + edge.rule_->f_.size() == 1 && + edge.rule_->f_[0] == eps) { + kill[i] = true; + if (!edge.feature_values_.empty()) { + Hypergraph::Node& node = g->nodes_[edge.head_node_]; + if (node.in_edges_.size() != 1) { + cerr << "[WARNING] edge with features going into non-empty node - can't promote\n"; + // this *probably* means that there are multiple derivations of the + // same sequence via different paths through the input forest + // this needs to be investigated and fixed + } else { + for (unsigned j = 0; j < node.out_edges_.size(); ++j) + g->edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_; + // cerr << "PROMOTED " << edge.feature_values_ << endl; + } + } + } + } + bool created_eps = false; + g->PruneEdges(kill); + for (unsigned i = 0; i < g->nodes_.size(); ++i) { + const Hypergraph::Node& node = g->nodes_[i]; + if (node.in_edges_.empty()) { + for (unsigned j = 0; j < node.out_edges_.size(); ++j) { + Hypergraph::Edge& edge = g->edges_[node.out_edges_[j]]; + const int lhs = edge.rule_->lhs_; + if (edge.rule_->Arity() == 2) { + assert(edge.rule_->f_.size() == 2); + assert(edge.rule_->e_.size() == 2); + unsigned cur = node.id_; + int t = -1; + assert(edge.tail_nodes_.size() == 2); + int rhs = 0; + for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; rhs = edge.rule_->f_[i]; } + assert(t != -1); + edge.tail_nodes_.resize(1); + edge.tail_nodes_[0] = t; + edge.rule_ = CreateUnaryRule(lhs, rhs); + } else { + edge.rule_ = CreateEpsilonRule(lhs, eps); + edge.tail_nodes_.clear(); + created_eps = true; + } + } + } + } + vector k2(g->edges_.size(), false); + g->PruneEdges(k2); + if (created_eps) RemoveEpsilons(g, eps); +} + diff --git a/decoder/hg_remove_eps.h b/decoder/hg_remove_eps.h new file mode 100644 index 00000000..82f06039 --- /dev/null +++ b/decoder/hg_remove_eps.h @@ -0,0 +1,13 @@ +#ifndef _HG_REMOVE_EPS_H_ +#define _HG_REMOVE_EPS_H_ + +#include "wordid.h" +class Hypergraph; + +// This is not a complete implementation of the general algorithm for +// doing this. It makes a few weird assumptions, for example, that +// if some nonterminal X rewrites as eps, then that is the only thing +// that it rewrites as. This needs to be fixed for the general case! +void RemoveEpsilons(Hypergraph* g, WordID eps); + +#endif diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc new file mode 100644 index 00000000..5c417393 --- /dev/null +++ b/decoder/rescore_translator.cc @@ -0,0 +1,57 @@ +#include "translator.h" + +#include +#include + +#include "sentence_metadata.h" +#include "hg.h" +#include "hg_io.h" +#include "tdict.h" + +using namespace std; + +struct RescoreTranslatorImpl { + RescoreTranslatorImpl(const boost::program_options::variables_map& conf) : + goal_sym(conf["goal"].as()), + kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")), + kGOAL(TD::Convert("Goal") * -1) { + } + + bool Translate(const string& input, + const vector& weights, + Hypergraph* forest) { + if (input.find("{\"rules\"") == 0) { + istringstream is(input); + Hypergraph src_cfg_hg; + if (!HypergraphIO::ReadFromJSON(&is, forest)) { + cerr << "Parse error while reading HG from JSON.\n"; + abort(); + } + } else { + cerr << "Can only read HG input from JSON: use training/grammar_convert\n"; + abort(); + } + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(kGOAL); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + forest->Reweight(weights); + return true; + } + + const string goal_sym; + const TRulePtr kGOAL_RULE; + const WordID kGOAL; +}; + +RescoreTranslator::RescoreTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new RescoreTranslatorImpl(conf)) {} + +bool RescoreTranslator::TranslateImpl(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* minus_lm_forest) { + smeta->SetSourceLength(0); // don't know how to compute this + return pimpl_->Translate(input, weights, minus_lm_forest); +} + diff --git a/decoder/translator.h b/decoder/translator.h index fc2bb760..c0800e84 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -85,4 +85,17 @@ class FSTTranslator : public Translator { boost::shared_ptr pimpl_; }; +class RescoreTranslatorImpl; +class RescoreTranslator : public Translator { + public: + RescoreTranslator(const boost::program_options::variables_map& conf); + private: + bool TranslateImpl(const std::string& src, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest); + private: + boost::shared_ptr pimpl_; +}; + #endif diff --git a/decoder/trule.cc b/decoder/trule.cc index 187a003d..896f9f3d 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -237,9 +237,9 @@ void TRule::ComputeArity() { string TRule::AsString(bool verbose) const { ostringstream os; int idx = 0; - if (lhs_ && verbose) { + if (lhs_) { os << '[' << TD::Convert(lhs_ * -1) << "] |||"; - } + } else { os << "NOLHS |||"; } for (unsigned i = 0; i < f_.size(); ++i) { const WordID& w = f_[i]; if (w < 0) { -- cgit v1.2.3 From c7d1b04980d9d90458625a7f8e92985c7409a78d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 29 Jul 2012 21:04:14 -0400 Subject: fix grammar converter to remove edges that cannot exist in any valid derivation --- decoder/hg_io.cc | 1 + decoder/inside_outside.h | 4 ---- decoder/rescore_translator.cc | 1 + training/grammar_convert.cc | 27 +++++++++++++++++++++++++++ 4 files changed, 29 insertions(+), 4 deletions(-) (limited to 'decoder') diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index bfb2fb80..8bd40387 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -261,6 +261,7 @@ static void WriteRule(const TRule& r, ostream* out) { } bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) { + if (hg.empty()) { *out << "{}\n"; return true; } map rid; ostream& o = *out; rid[NULL] = 0; diff --git a/decoder/inside_outside.h b/decoder/inside_outside.h index bb7f9fcc..f73a1d3f 100644 --- a/decoder/inside_outside.h +++ b/decoder/inside_outside.h @@ -41,10 +41,6 @@ WeightType Inside(const Hypergraph& hg, WeightType* const cur_node_inside_score = &inside_score[i]; Hypergraph::EdgesVector const& in=hg.nodes_[i].in_edges_; const unsigned num_in_edges = in.size(); - if (num_in_edges == 0) { - *cur_node_inside_score = WeightType(1); //FIXME: why not call weight(edge) instead? - continue; - } for (unsigned j = 0; j < num_in_edges; ++j) { const Hypergraph::Edge& edge = hg.edges_[in[j]]; WeightType score = weight(edge); diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc index 5c417393..10192f7a 100644 --- a/decoder/rescore_translator.cc +++ b/decoder/rescore_translator.cc @@ -20,6 +20,7 @@ struct RescoreTranslatorImpl { bool Translate(const string& input, const vector& weights, Hypergraph* forest) { + if (input == "{}") return false; if (input.find("{\"rules\"") == 0) { istringstream is(input); Hypergraph src_cfg_hg; diff --git a/training/grammar_convert.cc b/training/grammar_convert.cc index bf8abb26..607a7cb9 100644 --- a/training/grammar_convert.cc +++ b/training/grammar_convert.cc @@ -9,6 +9,7 @@ #include #include +#include "inside_outside.h" #include "tdict.h" #include "filelib.h" #include "hg.h" @@ -69,6 +70,32 @@ void FilterAndCheckCorrectness(int goal, Hypergraph* hg) { if (hg->nodes_.size() != old_size) { cerr << "Warning! During sorting " << (old_size - hg->nodes_.size()) << " disappeared!\n"; } + vector inside; // inside score at each node + double p = Inside(*hg, &inside); + if (!p) { + cerr << "Warning! Grammar defines the empty language!\n"; + hg->clear(); + return; + } + vector prune(hg->edges_.size(), false); + int bad_edges = 0; + for (unsigned i = 0; i < hg->edges_.size(); ++i) { + Hypergraph::Edge& edge = hg->edges_[i]; + bool bad = false; + for (unsigned j = 0; j < edge.tail_nodes_.size(); ++j) { + if (!inside[edge.tail_nodes_[j]]) { + bad = true; + ++bad_edges; + } + } + prune[i] = bad; + } + cerr << "Removing " << bad_edges << " bad edges from the grammar.\n"; + for (unsigned i = 0; i < hg->edges_.size(); ++i) { + if (prune[i]) + cerr << " " << hg->edges_[i].rule_->AsString() << endl; + } + hg->PruneEdges(prune); } void CreateEdge(const TRulePtr& r, const Hypergraph::TailNodeVector& tail, Hypergraph::Node* head_node, Hypergraph* hg) { -- cgit v1.2.3