From b9266b068a37dc46f8de813c59ffbef2e4b89280 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 23 Jun 2012 15:54:21 -0400 Subject: clean up tagger a bit --- decoder/ff_tagger.cc | 17 ++++++++++++++--- decoder/tagger.cc | 1 + 2 files changed, 15 insertions(+), 3 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc index 019315a2..fd9210fa 100644 --- a/decoder/ff_tagger.cc +++ b/decoder/ff_tagger.cc @@ -8,6 +8,17 @@ using namespace std; +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + Tagger_BigramIndicator::Tagger_BigramIndicator(const std::string& param) : FeatureFunction(sizeof(WordID)) { no_uni_ = (LowercaseString(param) == "no_uni"); @@ -28,7 +39,7 @@ void Tagger_BigramIndicator::FireFeature(const WordID& left, os << '_'; if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); } } - fid = FD::Convert(os.str()); + fid = FD::Convert(Escape(os.str())); } features->set_value(fid, 1.0); } @@ -90,7 +101,7 @@ void LexicalPairIndicator::FireFeature(WordID src, if (!fid) { ostringstream os; os << name_ << ':' << TD::Convert(src) << ':' << TD::Convert(trg); - fid = FD::Convert(os.str()); + fid = FD::Convert(Escape(os.str())); } features->set_value(fid, 1.0); } @@ -127,7 +138,7 @@ void OutputIndicator::FireFeature(WordID trg, if (escape.count(trg)) trg = escape[trg]; ostringstream os; os << "T:" << TD::Convert(trg); - fid = FD::Convert(os.str()); + fid = FD::Convert(Escape(os.str())); } features->set_value(fid, 1.0); } diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 54890e85..63e855c8 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -54,6 +54,7 @@ struct TaggerImpl { const int new_node_id = forest->AddNode(kXCAT)->id_; for (int k = 0; k < tagset_.size(); ++k) { TRulePtr rule(TRule::CreateLexicalRule(src, tagset_[k])); + rule->lhs_ = kXCAT; Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); edge->i_ = i; edge->j_ = i+1; -- cgit v1.2.3 From 5584ad22d3d3fae7ea195fe6cd7cf56cd0f35ba0 Mon Sep 17 00:00:00 2001 From: Waleed Ammar Date: Mon, 2 Jul 2012 16:51:21 -0400 Subject: (1) allow prefixes and separators used in feature instantiations to be configured. (2) minor refactoring. --- decoder/ff_ngrams.cc | 85 ++++++++++++++++++++++++++++++++++++++++++++-------- decoder/ff_ngrams.h | 2 +- 2 files changed, 73 insertions(+), 14 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc index d6d79f5e..9c13fdbb 100644 --- a/decoder/ff_ngrams.cc +++ b/decoder/ff_ngrams.cc @@ -48,6 +48,9 @@ struct State { namespace { string Escape(const string& x) { + if (x.find('=') == string::npos && x.find(';') == string::npos) { + return x; + } string y = x; for (int i = 0; i < y.size(); ++i) { if (y[i] == '=') y[i]='_'; @@ -57,10 +60,17 @@ namespace { } } -static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order) { +static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector& prefixes, string& target_separator) { vector const& argv=SplitOnWhitespace(in); *explicit_markers = false; *order = 3; + prefixes.push_back("NOT-USED"); + prefixes.push_back("U:"); // default unigram prefix + prefixes.push_back("B:"); // default bigram prefix + prefixes.push_back("T:"); // ...etc + prefixes.push_back("4:"); // ...etc + prefixes.push_back("5:"); // max allowed! + target_separator = "_"; #define LMSPEC_NEXTARG if (i==argv.end()) { \ cerr << "Missing argument for "<<*last<<". "; goto usage; \ } else { ++i; } @@ -73,6 +83,30 @@ static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order) case 'x': *explicit_markers = true; break; + case 'U': + LMSPEC_NEXTARG; + prefixes[1] = *i; + break; + case 'B': + LMSPEC_NEXTARG; + prefixes[2] = *i; + break; + case 'T': + LMSPEC_NEXTARG; + prefixes[3] = *i; + break; + case '4': + LMSPEC_NEXTARG; + prefixes[4] = *i; + break; + case '5': + LMSPEC_NEXTARG; + prefixes[5] = *i; + break; + case 'S': + LMSPEC_NEXTARG; + target_separator = *i; + break; case 'o': LMSPEC_NEXTARG; *order=atoi((*i).c_str()); break; @@ -86,7 +120,29 @@ static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order) } return true; usage: - cerr << "NgramFeatures is incorrect!\n"; + cerr << "Wrong parameters for NgramFeatures.\n\n" + + << "NgramFeatures Usage: \n" + << " feature_function=NgramFeatures filename.lm [-x] [-o ] \n" + << " [-U ] [-B ][-T ]\n" + << " [-4 <4-gram-prefix>] [-5 <5-gram-prefix>] [-S ]\n\n" + + << "Defaults: \n" + << " = 3\n" + << " = U:\n" + << " = B:\n" + << " = T:\n" + << " <4-gram-prefix> = 4:\n" + << " <5-gram-prefix> = 5:\n" + << " = _\n" + << " -x (i.e. explicit sos/eos markers) is turned off\n\n" + + << "Example configuration: \n" + << " feature_function=NgramFeatures -o 3 -T tri: -S |\n\n" + + << "Example feature instantiation: \n" + << " tri:a|b|c \n\n"; + return false; } @@ -158,16 +214,12 @@ class NgramDetectorImpl { int& fid = ft->fids[curword]; ++n; if (!fid) { - const char* code="_UBT456789"; // prefix code (unigram, bigram, etc.) ostringstream os; - os << code[n] << ':'; + os << prefixes_[n]; for (int i = n-1; i >= 0; --i) { - os << (i != n-1 ? "_" : ""); + os << (i != n-1 ? target_separator_ : ""); const string& tok = TD::Convert(buf[i]); - if (tok.find('=') == string::npos) - os << tok; - else - os << Escape(tok); + os << Escape(tok); } fid = FD::Convert(os.str()); } @@ -297,7 +349,8 @@ class NgramDetectorImpl { } public: - explicit NgramDetectorImpl(bool explicit_markers, unsigned order) : + explicit NgramDetectorImpl(bool explicit_markers, unsigned order, + vector& prefixes, string& target_separator) : kCDEC_UNK(TD::Convert("")) , add_sos_eos_(!explicit_markers) { order_ = order; @@ -305,6 +358,8 @@ class NgramDetectorImpl { unscored_size_offset_ = (order_ - 1) * sizeof(WordID); is_complete_offset_ = unscored_size_offset_ + 1; unscored_words_offset_ = is_complete_offset_ + 1; + prefixes_ = prefixes; + target_separator_ = target_separator; // special handling of beginning / ending sentence markers dummy_state_ = new char[state_size_]; @@ -340,6 +395,8 @@ class NgramDetectorImpl { char* dummy_state_; vector dummy_ants_; TRulePtr dummy_rule_; + vector prefixes_; + string target_separator_; struct FidTree { map fids; map levels; @@ -348,11 +405,13 @@ class NgramDetectorImpl { }; NgramDetector::NgramDetector(const string& param) { - string filename, mapfile, featname; + string filename, mapfile, featname, target_separator; + vector prefixes; bool explicit_markers = false; unsigned order = 3; - ParseArgs(param, &explicit_markers, &order); - pimpl_ = new NgramDetectorImpl(explicit_markers, order); + ParseArgs(param, &explicit_markers, &order, prefixes, target_separator); + pimpl_ = new NgramDetectorImpl(explicit_markers, order, prefixes, + target_separator); SetStateSize(pimpl_->ReserveStateSize()); } diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h index 82f61b33..064dbb49 100644 --- a/decoder/ff_ngrams.h +++ b/decoder/ff_ngrams.h @@ -10,7 +10,7 @@ struct NgramDetectorImpl; class NgramDetector : public FeatureFunction { public: - // param = "filename.lm [-o n]" + // param = "filename.lm [-o ] [-U ] [-B ] [-T ] [-4 <4-gram-prefix>] [-5 <5-gram-prefix>] [-S ] NgramDetector(const std::string& param); ~NgramDetector(); virtual void FinalTraversalFeatures(const void* context, -- cgit v1.2.3 From 09d085e16f31b36b1177b313ee54414eb29d9082 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 20 Jul 2012 00:46:10 -0400 Subject: enable programmatic setting of grammar objects, cleanup of grammar handling --- decoder/decoder.cc | 10 +++----- decoder/decoder.h | 6 +++-- decoder/scfg_translator.cc | 63 ++++++++++++++++++---------------------------- decoder/translator.h | 4 +-- dtrain/dtrain.cc | 4 +-- python/src/_cdec.pyx | 2 +- 6 files changed, 38 insertions(+), 51 deletions(-) (limited to 'decoder') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 333f0fb6..ad4e9e07 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -743,16 +743,14 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) { } vector& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); } const vector& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); } -void Decoder::SetSupplementalGrammar(const std::string& grammar_string) { - assert(pimpl_->translator->GetDecoderType() == "SCFG"); - static_cast(*pimpl_->translator).SetSupplementalGrammar(grammar_string); +void Decoder::AddSupplementalGrammar(GrammarPtr gp) { + static_cast(*pimpl_->translator).AddSupplementalGrammar(gp); } -void Decoder::SetSentenceGrammarFromString(const std::string& grammar_str) { +void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string) { assert(pimpl_->translator->GetDecoderType() == "SCFG"); - static_cast(*pimpl_->translator).SetSentenceGrammarFromString(grammar_str); + static_cast(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string); } - bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { string buf = input; NgramCache::Clear(); // clear ngram cache for remote LM (if used) diff --git a/decoder/decoder.h b/decoder/decoder.h index 6b2f7b16..bef2ff5e 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -37,6 +37,8 @@ struct DecoderObserver { virtual void NotifyDecodingComplete(const SentenceMetadata& smeta); }; +struct Grammar; // TODO once the decoder interface is cleaned up, + // this should be somewhere else struct Decoder { Decoder(int argc, char** argv); Decoder(std::istream* config_file); @@ -54,8 +56,8 @@ struct Decoder { // add grammar rules (currently only supported by SCFG decoders) // that will be used on subsequent calls to Decode. rules should be in standard // text format. This function does NOT read from a file. - void SetSupplementalGrammar(const std::string& grammar); - void SetSentenceGrammarFromString(const std::string& grammar_str); + void AddSupplementalGrammar(boost::shared_ptr gp); + void AddSupplementalGrammarFromString(const std::string& grammar_string); private: boost::program_options::variables_map conf; boost::shared_ptr pimpl_; diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 185f979a..aaa6c40b 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -20,7 +20,6 @@ #define reverse_foreach BOOST_REVERSE_FOREACH using namespace std; -static bool usingSentenceGrammar = false; static bool printGrammarsUsed = false; struct SCFGTranslatorImpl { @@ -91,31 +90,31 @@ struct SCFGTranslatorImpl { bool show_tree_structure_; unsigned int ctf_iterations_; vector grammars; - GrammarPtr sup_grammar_; + set sup_grammars_; - struct Equals { Equals(const GrammarPtr& v) : v_(v) {} - bool operator()(const GrammarPtr& x) const { return x == v_; } const GrammarPtr& v_; }; + struct ContainedIn { + ContainedIn(const set& gs) : gs_(gs) {} + bool operator()(const GrammarPtr& x) const { return gs_.find(x) != gs_.end(); } + const set& gs_; + }; - void SetSupplementalGrammar(const std::string& grammar_string) { - grammars.erase(remove_if(grammars.begin(), grammars.end(), Equals(sup_grammar_)), grammars.end()); + void AddSupplementalGrammarFromString(const std::string& grammar_string) { + grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end()); istringstream in(grammar_string); - sup_grammar_.reset(new TextGrammar(&in)); - grammars.push_back(sup_grammar_); + TextGrammar* sent_grammar = new TextGrammar(&in); + sent_grammar->SetMaxSpan(max_span_limit); + sent_grammar->SetGrammarName("SupFromString"); + AddSupplementalGrammar(GrammarPtr(sent_grammar)); } - struct NameEquals { NameEquals(const string name) : name_(name) {} - bool operator()(const GrammarPtr& x) const { return x->GetGrammarName() == name_; } const string name_; }; + void AddSupplementalGrammar(GrammarPtr gp) { + sup_grammars_.insert(gp); + grammars.push_back(gp); + } - void SetSentenceGrammarFromString(const std::string& grammar_str) { - assert(grammar_str != ""); - if (!SILENT) cerr << "Setting sentence grammar" << endl; - usingSentenceGrammar = true; - istringstream in(grammar_str); - TextGrammar* sent_grammar = new TextGrammar(&in); - sent_grammar->SetMaxSpan(max_span_limit); - sent_grammar->SetGrammarName("__psg"); - grammars.erase(remove_if(grammars.begin(), grammars.end(), NameEquals("__psg")), grammars.end()); - grammars.push_back(GrammarPtr(sent_grammar)); + void RemoveSupplementalGrammars() { + grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end()); + sup_grammars_.clear(); } bool Translate(const string& input, @@ -301,34 +300,22 @@ Check for grammar pointer in the sentence markup, for use with sentence specific void SCFGTranslator::ProcessMarkupHintsImpl(const map& kv) { map::const_iterator it = kv.find("grammar"); - - if (it == kv.end()) { - usingSentenceGrammar= false; - return; - } - //Create sentence specific grammar from specified file name and load grammar into list of grammars - usingSentenceGrammar = true; TextGrammar* sentGrammar = new TextGrammar(it->second); sentGrammar->SetMaxSpan(pimpl_->max_span_limit); sentGrammar->SetGrammarName(it->second); - pimpl_->grammars.push_back(GrammarPtr(sentGrammar)); - + pimpl_->AddSupplementalGrammar(GrammarPtr(sentGrammar)); } -void SCFGTranslator::SetSupplementalGrammar(const std::string& grammar) { - pimpl_->SetSupplementalGrammar(grammar); +void SCFGTranslator::AddSupplementalGrammarFromString(const std::string& grammar) { + pimpl_->AddSupplementalGrammarFromString(grammar); } -void SCFGTranslator::SetSentenceGrammarFromString(const std::string& grammar_str) { - pimpl_->SetSentenceGrammarFromString(grammar_str); +void SCFGTranslator::AddSupplementalGrammar(GrammarPtr grammar) { + pimpl_->AddSupplementalGrammar(grammar); } void SCFGTranslator::SentenceCompleteImpl() { - - if(usingSentenceGrammar) // Drop the last sentence grammar from the list of grammars - { - pimpl_->grammars.pop_back(); - } + pimpl_->RemoveSupplementalGrammars(); } std::string SCFGTranslator::GetDecoderType() const { diff --git a/decoder/translator.h b/decoder/translator.h index cfd3b08a..fc2bb760 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -58,8 +58,8 @@ class SCFGTranslatorImpl; class SCFGTranslator : public Translator { public: SCFGTranslator(const boost::program_options::variables_map& conf); - void SetSupplementalGrammar(const std::string& grammar); - void SetSentenceGrammarFromString(const std::string& grammar); + void AddSupplementalGrammar(GrammarPtr gp); + void AddSupplementalGrammarFromString(const std::string& grammar); virtual std::string GetDecoderType() const; protected: bool TranslateImpl(const std::string& src, diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index b3e62914..b7a4bb6f 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -364,7 +364,7 @@ main(int argc, char** argv) boost::replace_all(in, "\t", "\n"); in += "\n"; grammar_buf_out << in << DTRAIN_GRAMMAR_DELIM << " " << in_split[0] << endl; - decoder.SetSentenceGrammarFromString(in); + decoder.AddSupplementalGrammarFromString(in); src_str_buf.push_back(in_split[1]); // decode observer->SetRef(ref_ids); @@ -378,7 +378,7 @@ main(int argc, char** argv) if (boost::starts_with(rule, DTRAIN_GRAMMAR_DELIM)) break; grammar_str += rule + "\n"; } - decoder.SetSentenceGrammarFromString(grammar_str); + decoder.AddSupplementalGrammarFromString(grammar_str); // decode observer->SetRef(ref_ids_buf[ii]); decoder.Decode(src_str_buf[ii], observer); diff --git a/python/src/_cdec.pyx b/python/src/_cdec.pyx index 879e8b72..164d6570 100644 --- a/python/src/_cdec.pyx +++ b/python/src/_cdec.pyx @@ -98,7 +98,7 @@ cdef class Decoder: else: raise TypeError('Cannot translate input type %s' % type(sentence)) if grammar: - self.dec.SetSentenceGrammarFromString(string( grammar)) + self.dec.AddSupplementalGrammarFromString(string( grammar)) cdef decoder.BasicObserver observer = decoder.BasicObserver() self.dec.Decode(string(inp), &observer) if observer.hypergraph == NULL: -- cgit v1.2.3 From c4c9c2febd5af552ecddc215758e32b88013fbc7 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 20 Jul 2012 01:06:50 -0400 Subject: fix --- decoder/scfg_translator.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'decoder') diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index aaa6c40b..a978cfc2 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -299,11 +299,12 @@ Check for grammar pointer in the sentence markup, for use with sentence specific */ void SCFGTranslator::ProcessMarkupHintsImpl(const map& kv) { map::const_iterator it = kv.find("grammar"); - - TextGrammar* sentGrammar = new TextGrammar(it->second); - sentGrammar->SetMaxSpan(pimpl_->max_span_limit); - sentGrammar->SetGrammarName(it->second); - pimpl_->AddSupplementalGrammar(GrammarPtr(sentGrammar)); + if (it != kv.end()) { + TextGrammar* sentGrammar = new TextGrammar(it->second); + sentGrammar->SetMaxSpan(pimpl_->max_span_limit); + sentGrammar->SetGrammarName(it->second); + pimpl_->AddSupplementalGrammar(GrammarPtr(sentGrammar)); + } } void SCFGTranslator::AddSupplementalGrammarFromString(const std::string& grammar) { -- cgit v1.2.3 From 16d0e2a34df3e32e8992aeda3ad2de7a6e525f14 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 28 Jul 2012 18:35:11 -0400 Subject: slight improvement to the way NTs are handled in the FST-CFG composition algorithm, new rescoring mode --- decoder/Makefile.am | 2 + decoder/decoder.cc | 6 ++- decoder/earley_composer.cc | 38 +++++++++++++++--- decoder/hg.cc | 63 ------------------------------ decoder/hg.h | 17 ++------ decoder/hg_remove_eps.cc | 91 +++++++++++++++++++++++++++++++++++++++++++ decoder/hg_remove_eps.h | 13 +++++++ decoder/rescore_translator.cc | 57 +++++++++++++++++++++++++++ decoder/translator.h | 13 +++++++ decoder/trule.cc | 4 +- 10 files changed, 218 insertions(+), 86 deletions(-) create mode 100644 decoder/hg_remove_eps.cc create mode 100644 decoder/hg_remove_eps.h create mode 100644 decoder/rescore_translator.cc (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 00d01e53..0a792549 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -37,9 +37,11 @@ libcdec_a_SOURCES = \ fst_translator.cc \ csplit.cc \ translator.cc \ + rescore_translator.cc \ scfg_translator.cc \ hg.cc \ hg_io.cc \ + hg_remove_eps.cc \ decoder.cc \ hg_intersect.cc \ hg_sampler.cc \ diff --git a/decoder/decoder.cc b/decoder/decoder.cc index ad4e9e07..a6f7b1ce 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -527,8 +527,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } formalism = LowercaseString(str("formalism",conf)); - if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n"; + if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } @@ -675,6 +675,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream translator.reset(new LexicalTrans(conf)); else if (formalism == "lexalign") translator.reset(new LexicalAlign(conf)); + else if (formalism == "rescore") + translator.reset(new RescoreTranslator(conf)); else if (formalism == "tagger") translator.reset(new Tagger(conf)); else diff --git a/decoder/earley_composer.cc b/decoder/earley_composer.cc index d265d954..efce70a6 100644 --- a/decoder/earley_composer.cc +++ b/decoder/earley_composer.cc @@ -16,6 +16,7 @@ #include "sparse_vector.h" #include "tdict.h" #include "hg.h" +#include "hg_remove_eps.h" using namespace std; using namespace std::tr1; @@ -48,6 +49,27 @@ static void InitializeConstants() { } //////////////////////////////////////////////////////////// +TRulePtr CreateBinaryRule(int lhs, int rhs1, int rhs2) { + TRule* r = new TRule(*kX1X2); + r->lhs_ = lhs; + r->f_[0] = rhs1; + r->f_[1] = rhs2; + return TRulePtr(r); +} + +TRulePtr CreateUnaryRule(int lhs, int rhs1) { + TRule* r = new TRule(*kX1); + r->lhs_ = lhs; + r->f_[0] = rhs1; + return TRulePtr(r); +} + +TRulePtr CreateEpsilonRule(int lhs) { + TRule* r = new TRule(*kEPSRule); + r->lhs_ = lhs; + return TRulePtr(r); +} + class EGrammarNode { friend bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest); friend void AddGrammarRule(const string& r, map* g); @@ -356,7 +378,7 @@ class EarleyComposerImpl { } if (goal_node) { forest->PruneUnreachable(goal_node->id_); - forest->EpsilonRemove(kEPS); + RemoveEpsilons(forest, kEPS); } FreeAll(); return goal_node; @@ -557,24 +579,30 @@ class EarleyComposerImpl { } Hypergraph::Node*& head_node = edge2node[edge]; if (!head_node) - head_node = hg->AddNode(kPHRASE); + head_node = hg->AddNode(edge->cat); if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_0_ && edge->IsPassive()) { assert(goal_node == NULL || goal_node == head_node); goal_node = head_node; } + int rhs1 = 0; + int rhs2 = 0; Hypergraph::TailNodeVector tail; SparseVector extra; if (edge->IsCreatedByPredict()) { // extra.set_value(FD::Convert("predict"), 1); } else if (edge->IsCreatedByScan()) { tail.push_back(edge2node[edge->active_parent]->id_); + rhs1 = edge->active_parent->cat; if (tps) { tail.push_back(tps->id_); + rhs2 = kPHRASE; } //extra.set_value(FD::Convert("scan"), 1); } else if (edge->IsCreatedByComplete()) { tail.push_back(edge2node[edge->active_parent]->id_); + rhs1 = edge->active_parent->cat; tail.push_back(edge2node[edge->passive_parent]->id_); + rhs2 = edge->passive_parent->cat; //extra.set_value(FD::Convert("complete"), 1); } else { assert(!"unexpected edge type!"); @@ -592,11 +620,11 @@ class EarleyComposerImpl { #endif Hypergraph::Edge* hg_edge = NULL; if (tail.size() == 0) { - hg_edge = hg->AddEdge(kEPSRule, tail); + hg_edge = hg->AddEdge(CreateEpsilonRule(edge->cat), tail); } else if (tail.size() == 1) { - hg_edge = hg->AddEdge(kX1, tail); + hg_edge = hg->AddEdge(CreateUnaryRule(edge->cat, rhs1), tail); } else if (tail.size() == 2) { - hg_edge = hg->AddEdge(kX1X2, tail); + hg_edge = hg->AddEdge(CreateBinaryRule(edge->cat, rhs1, rhs2), tail); } if (edge->features) hg_edge->feature_values_ += *edge->features; diff --git a/decoder/hg.cc b/decoder/hg.cc index dd272221..7240a8ab 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -605,69 +605,6 @@ void Hypergraph::TopologicallySortNodesAndEdges(int goal_index, #endif } -TRulePtr Hypergraph::kEPSRule; -TRulePtr Hypergraph::kUnaryRule; - -void Hypergraph::EpsilonRemove(WordID eps) { - if (!kEPSRule) { - kEPSRule.reset(new TRule("[X] ||| ||| ")); - kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); - } - vector kill(edges_.size(), false); - for (unsigned i = 0; i < edges_.size(); ++i) { - const Edge& edge = edges_[i]; - if (edge.tail_nodes_.empty() && - edge.rule_->f_.size() == 1 && - edge.rule_->f_[0] == eps) { - kill[i] = true; - if (!edge.feature_values_.empty()) { - Node& node = nodes_[edge.head_node_]; - if (node.in_edges_.size() != 1) { - cerr << "[WARNING] edge with features going into non-empty node - can't promote\n"; - // this *probably* means that there are multiple derivations of the - // same sequence via different paths through the input forest - // this needs to be investigated and fixed - } else { - for (unsigned j = 0; j < node.out_edges_.size(); ++j) - edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_; - // cerr << "PROMOTED " << edge.feature_values_ << endl; - } - } - } - } - bool created_eps = false; - PruneEdges(kill); - for (unsigned i = 0; i < nodes_.size(); ++i) { - const Node& node = nodes_[i]; - if (node.in_edges_.empty()) { - for (unsigned j = 0; j < node.out_edges_.size(); ++j) { - Edge& edge = edges_[node.out_edges_[j]]; - if (edge.rule_->Arity() == 2) { - assert(edge.rule_->f_.size() == 2); - assert(edge.rule_->e_.size() == 2); - edge.rule_ = kUnaryRule; - unsigned cur = node.id_; - int t = -1; - assert(edge.tail_nodes_.size() == 2); - for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; } - assert(t != -1); - edge.tail_nodes_.resize(1); - edge.tail_nodes_[0] = t; - } else { - edge.rule_ = kEPSRule; - edge.rule_->f_[0] = eps; - edge.rule_->e_[0] = eps; - edge.tail_nodes_.clear(); - created_eps = true; - } - } - } - } - vector k2(edges_.size(), false); - PruneEdges(k2); - if (created_eps) EpsilonRemove(eps); -} - struct EdgeWeightSorter { const Hypergraph& hg; EdgeWeightSorter(const Hypergraph& h) : hg(h) {} diff --git a/decoder/hg.h b/decoder/hg.h index 91d25f01..591e98ce 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -148,7 +148,7 @@ public: void show(std::ostream &o,unsigned mask=SPAN|RULE) const { o<<'{'; if (mask&CATEGORY) - o<GetLHS()); + o<< '[' << TD::Convert(-rule_->GetLHS()) << ']'; if (mask&PREV_SPAN) o<<'<'<'; if (mask&SPAN) @@ -156,9 +156,9 @@ public: if (mask&PROB) o<<" p="<AsString(mask&RULE_LHS); + o<<' '<AsString(mask&RULE_LHS); if (USE_INFO_EDGE) { std::string const& i=info(); if (mask&&!i.empty()) o << " |||"< @@ -535,9 +527,6 @@ public: private: Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {} - - static TRulePtr kEPSRule; - static TRulePtr kUnaryRule; }; diff --git a/decoder/hg_remove_eps.cc b/decoder/hg_remove_eps.cc new file mode 100644 index 00000000..050c4876 --- /dev/null +++ b/decoder/hg_remove_eps.cc @@ -0,0 +1,91 @@ +#include "hg_remove_eps.h" + +#include + +#include "trule.h" +#include "hg.h" + +using namespace std; + +namespace { + TRulePtr kEPSRule; + TRulePtr kUnaryRule; + + TRulePtr CreateUnaryRule(int lhs, int rhs) { + if (!kUnaryRule) kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); + TRule* r = new TRule(*kUnaryRule); + assert(lhs < 0); + assert(rhs < 0); + r->lhs_ = lhs; + r->f_[0] = rhs; + return TRulePtr(r); + } + + TRulePtr CreateEpsilonRule(int lhs, WordID eps) { + if (!kEPSRule) kEPSRule.reset(new TRule("[X] ||| ||| ")); + TRule* r = new TRule(*kEPSRule); + r->lhs_ = lhs; + assert(lhs < 0); + assert(eps > 0); + r->e_[0] = eps; + r->f_[0] = eps; + return TRulePtr(r); + } +} + +void RemoveEpsilons(Hypergraph* g, WordID eps) { + vector kill(g->edges_.size(), false); + for (unsigned i = 0; i < g->edges_.size(); ++i) { + const Hypergraph::Edge& edge = g->edges_[i]; + if (edge.tail_nodes_.empty() && + edge.rule_->f_.size() == 1 && + edge.rule_->f_[0] == eps) { + kill[i] = true; + if (!edge.feature_values_.empty()) { + Hypergraph::Node& node = g->nodes_[edge.head_node_]; + if (node.in_edges_.size() != 1) { + cerr << "[WARNING] edge with features going into non-empty node - can't promote\n"; + // this *probably* means that there are multiple derivations of the + // same sequence via different paths through the input forest + // this needs to be investigated and fixed + } else { + for (unsigned j = 0; j < node.out_edges_.size(); ++j) + g->edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_; + // cerr << "PROMOTED " << edge.feature_values_ << endl; + } + } + } + } + bool created_eps = false; + g->PruneEdges(kill); + for (unsigned i = 0; i < g->nodes_.size(); ++i) { + const Hypergraph::Node& node = g->nodes_[i]; + if (node.in_edges_.empty()) { + for (unsigned j = 0; j < node.out_edges_.size(); ++j) { + Hypergraph::Edge& edge = g->edges_[node.out_edges_[j]]; + const int lhs = edge.rule_->lhs_; + if (edge.rule_->Arity() == 2) { + assert(edge.rule_->f_.size() == 2); + assert(edge.rule_->e_.size() == 2); + unsigned cur = node.id_; + int t = -1; + assert(edge.tail_nodes_.size() == 2); + int rhs = 0; + for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; rhs = edge.rule_->f_[i]; } + assert(t != -1); + edge.tail_nodes_.resize(1); + edge.tail_nodes_[0] = t; + edge.rule_ = CreateUnaryRule(lhs, rhs); + } else { + edge.rule_ = CreateEpsilonRule(lhs, eps); + edge.tail_nodes_.clear(); + created_eps = true; + } + } + } + } + vector k2(g->edges_.size(), false); + g->PruneEdges(k2); + if (created_eps) RemoveEpsilons(g, eps); +} + diff --git a/decoder/hg_remove_eps.h b/decoder/hg_remove_eps.h new file mode 100644 index 00000000..82f06039 --- /dev/null +++ b/decoder/hg_remove_eps.h @@ -0,0 +1,13 @@ +#ifndef _HG_REMOVE_EPS_H_ +#define _HG_REMOVE_EPS_H_ + +#include "wordid.h" +class Hypergraph; + +// This is not a complete implementation of the general algorithm for +// doing this. It makes a few weird assumptions, for example, that +// if some nonterminal X rewrites as eps, then that is the only thing +// that it rewrites as. This needs to be fixed for the general case! +void RemoveEpsilons(Hypergraph* g, WordID eps); + +#endif diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc new file mode 100644 index 00000000..5c417393 --- /dev/null +++ b/decoder/rescore_translator.cc @@ -0,0 +1,57 @@ +#include "translator.h" + +#include +#include + +#include "sentence_metadata.h" +#include "hg.h" +#include "hg_io.h" +#include "tdict.h" + +using namespace std; + +struct RescoreTranslatorImpl { + RescoreTranslatorImpl(const boost::program_options::variables_map& conf) : + goal_sym(conf["goal"].as()), + kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")), + kGOAL(TD::Convert("Goal") * -1) { + } + + bool Translate(const string& input, + const vector& weights, + Hypergraph* forest) { + if (input.find("{\"rules\"") == 0) { + istringstream is(input); + Hypergraph src_cfg_hg; + if (!HypergraphIO::ReadFromJSON(&is, forest)) { + cerr << "Parse error while reading HG from JSON.\n"; + abort(); + } + } else { + cerr << "Can only read HG input from JSON: use training/grammar_convert\n"; + abort(); + } + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(kGOAL); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + forest->Reweight(weights); + return true; + } + + const string goal_sym; + const TRulePtr kGOAL_RULE; + const WordID kGOAL; +}; + +RescoreTranslator::RescoreTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new RescoreTranslatorImpl(conf)) {} + +bool RescoreTranslator::TranslateImpl(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* minus_lm_forest) { + smeta->SetSourceLength(0); // don't know how to compute this + return pimpl_->Translate(input, weights, minus_lm_forest); +} + diff --git a/decoder/translator.h b/decoder/translator.h index fc2bb760..c0800e84 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -85,4 +85,17 @@ class FSTTranslator : public Translator { boost::shared_ptr pimpl_; }; +class RescoreTranslatorImpl; +class RescoreTranslator : public Translator { + public: + RescoreTranslator(const boost::program_options::variables_map& conf); + private: + bool TranslateImpl(const std::string& src, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest); + private: + boost::shared_ptr pimpl_; +}; + #endif diff --git a/decoder/trule.cc b/decoder/trule.cc index 187a003d..896f9f3d 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -237,9 +237,9 @@ void TRule::ComputeArity() { string TRule::AsString(bool verbose) const { ostringstream os; int idx = 0; - if (lhs_ && verbose) { + if (lhs_) { os << '[' << TD::Convert(lhs_ * -1) << "] |||"; - } + } else { os << "NOLHS |||"; } for (unsigned i = 0; i < f_.size(); ++i) { const WordID& w = f_[i]; if (w < 0) { -- cgit v1.2.3 From c7d1b04980d9d90458625a7f8e92985c7409a78d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 29 Jul 2012 21:04:14 -0400 Subject: fix grammar converter to remove edges that cannot exist in any valid derivation --- decoder/hg_io.cc | 1 + decoder/inside_outside.h | 4 ---- decoder/rescore_translator.cc | 1 + training/grammar_convert.cc | 27 +++++++++++++++++++++++++++ 4 files changed, 29 insertions(+), 4 deletions(-) (limited to 'decoder') diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index bfb2fb80..8bd40387 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -261,6 +261,7 @@ static void WriteRule(const TRule& r, ostream* out) { } bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) { + if (hg.empty()) { *out << "{}\n"; return true; } map rid; ostream& o = *out; rid[NULL] = 0; diff --git a/decoder/inside_outside.h b/decoder/inside_outside.h index bb7f9fcc..f73a1d3f 100644 --- a/decoder/inside_outside.h +++ b/decoder/inside_outside.h @@ -41,10 +41,6 @@ WeightType Inside(const Hypergraph& hg, WeightType* const cur_node_inside_score = &inside_score[i]; Hypergraph::EdgesVector const& in=hg.nodes_[i].in_edges_; const unsigned num_in_edges = in.size(); - if (num_in_edges == 0) { - *cur_node_inside_score = WeightType(1); //FIXME: why not call weight(edge) instead? - continue; - } for (unsigned j = 0; j < num_in_edges; ++j) { const Hypergraph::Edge& edge = hg.edges_[in[j]]; WeightType score = weight(edge); diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc index 5c417393..10192f7a 100644 --- a/decoder/rescore_translator.cc +++ b/decoder/rescore_translator.cc @@ -20,6 +20,7 @@ struct RescoreTranslatorImpl { bool Translate(const string& input, const vector& weights, Hypergraph* forest) { + if (input == "{}") return false; if (input.find("{\"rules\"") == 0) { istringstream is(input); Hypergraph src_cfg_hg; diff --git a/training/grammar_convert.cc b/training/grammar_convert.cc index bf8abb26..607a7cb9 100644 --- a/training/grammar_convert.cc +++ b/training/grammar_convert.cc @@ -9,6 +9,7 @@ #include #include +#include "inside_outside.h" #include "tdict.h" #include "filelib.h" #include "hg.h" @@ -69,6 +70,32 @@ void FilterAndCheckCorrectness(int goal, Hypergraph* hg) { if (hg->nodes_.size() != old_size) { cerr << "Warning! During sorting " << (old_size - hg->nodes_.size()) << " disappeared!\n"; } + vector inside; // inside score at each node + double p = Inside(*hg, &inside); + if (!p) { + cerr << "Warning! Grammar defines the empty language!\n"; + hg->clear(); + return; + } + vector prune(hg->edges_.size(), false); + int bad_edges = 0; + for (unsigned i = 0; i < hg->edges_.size(); ++i) { + Hypergraph::Edge& edge = hg->edges_[i]; + bool bad = false; + for (unsigned j = 0; j < edge.tail_nodes_.size(); ++j) { + if (!inside[edge.tail_nodes_[j]]) { + bad = true; + ++bad_edges; + } + } + prune[i] = bad; + } + cerr << "Removing " << bad_edges << " bad edges from the grammar.\n"; + for (unsigned i = 0; i < hg->edges_.size(); ++i) { + if (prune[i]) + cerr << " " << hg->edges_[i].rule_->AsString() << endl; + } + hg->PruneEdges(prune); } void CreateEdge(const TRulePtr& r, const Hypergraph::TailNodeVector& tail, Hypergraph::Node* head_node, Hypergraph* hg) { -- cgit v1.2.3