diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-08-03 07:46:54 -0400 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-08-03 07:46:54 -0400 |
commit | be1ab0a8937f9c5668ea5e6c31b798e87672e55e (patch) | |
tree | a13aad60ab6cced213401bce6a38ac885ba171ba /decoder | |
parent | e5d6f4ae41009c26978ecd62668501af9762b0bc (diff) | |
parent | 9fe0219562e5db25171cce8776381600ff9a5649 (diff) |
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/Makefile.am | 2 | ||||
-rw-r--r-- | decoder/decoder.cc | 16 | ||||
-rw-r--r-- | decoder/decoder.h | 6 | ||||
-rw-r--r-- | decoder/earley_composer.cc | 38 | ||||
-rw-r--r-- | decoder/ff_ngrams.cc | 85 | ||||
-rw-r--r-- | decoder/ff_ngrams.h | 2 | ||||
-rw-r--r-- | decoder/ff_tagger.cc | 17 | ||||
-rw-r--r-- | decoder/hg.cc | 63 | ||||
-rw-r--r-- | decoder/hg.h | 17 | ||||
-rw-r--r-- | decoder/hg_io.cc | 1 | ||||
-rw-r--r-- | decoder/hg_remove_eps.cc | 91 | ||||
-rw-r--r-- | decoder/hg_remove_eps.h | 13 | ||||
-rw-r--r-- | decoder/inside_outside.h | 4 | ||||
-rw-r--r-- | decoder/rescore_translator.cc | 58 | ||||
-rw-r--r-- | decoder/scfg_translator.cc | 70 | ||||
-rw-r--r-- | decoder/tagger.cc | 1 | ||||
-rw-r--r-- | decoder/translator.h | 17 | ||||
-rw-r--r-- | decoder/trule.cc | 4 |
18 files changed, 347 insertions, 158 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 00d01e53..0a792549 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -37,9 +37,11 @@ libcdec_a_SOURCES = \ fst_translator.cc \ csplit.cc \ translator.cc \ + rescore_translator.cc \ scfg_translator.cc \ hg.cc \ hg_io.cc \ + hg_remove_eps.cc \ decoder.cc \ hg_intersect.cc \ hg_sampler.cc \ diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 333f0fb6..a6f7b1ce 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -527,8 +527,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } formalism = LowercaseString(str("formalism",conf)); - if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n"; + if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } @@ -675,6 +675,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream translator.reset(new LexicalTrans(conf)); else if (formalism == "lexalign") translator.reset(new LexicalAlign(conf)); + else if (formalism == "rescore") + translator.reset(new RescoreTranslator(conf)); else if (formalism == "tagger") translator.reset(new Tagger(conf)); else @@ -743,16 +745,14 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) { } vector<weight_t>& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); } const vector<weight_t>& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); } -void Decoder::SetSupplementalGrammar(const std::string& grammar_string) { - assert(pimpl_->translator->GetDecoderType() == "SCFG"); - static_cast<SCFGTranslator&>(*pimpl_->translator).SetSupplementalGrammar(grammar_string); +void Decoder::AddSupplementalGrammar(GrammarPtr gp) { + static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammar(gp); } -void Decoder::SetSentenceGrammarFromString(const std::string& grammar_str) { +void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string) { assert(pimpl_->translator->GetDecoderType() == "SCFG"); - static_cast<SCFGTranslator&>(*pimpl_->translator).SetSentenceGrammarFromString(grammar_str); + static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string); } - bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { string buf = input; NgramCache::Clear(); // clear ngram cache for remote LM (if used) diff --git a/decoder/decoder.h b/decoder/decoder.h index 6b2f7b16..bef2ff5e 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -37,6 +37,8 @@ struct DecoderObserver { virtual void NotifyDecodingComplete(const SentenceMetadata& smeta); }; +struct Grammar; // TODO once the decoder interface is cleaned up, + // this should be somewhere else struct Decoder { Decoder(int argc, char** argv); Decoder(std::istream* config_file); @@ -54,8 +56,8 @@ struct Decoder { // add grammar rules (currently only supported by SCFG decoders) // that will be used on subsequent calls to Decode. rules should be in standard // text format. This function does NOT read from a file. - void SetSupplementalGrammar(const std::string& grammar); - void SetSentenceGrammarFromString(const std::string& grammar_str); + void AddSupplementalGrammar(boost::shared_ptr<Grammar> gp); + void AddSupplementalGrammarFromString(const std::string& grammar_string); private: boost::program_options::variables_map conf; boost::shared_ptr<DecoderImpl> pimpl_; diff --git a/decoder/earley_composer.cc b/decoder/earley_composer.cc index d265d954..efce70a6 100644 --- a/decoder/earley_composer.cc +++ b/decoder/earley_composer.cc @@ -16,6 +16,7 @@ #include "sparse_vector.h" #include "tdict.h" #include "hg.h" +#include "hg_remove_eps.h" using namespace std; using namespace std::tr1; @@ -48,6 +49,27 @@ static void InitializeConstants() { } //////////////////////////////////////////////////////////// +TRulePtr CreateBinaryRule(int lhs, int rhs1, int rhs2) { + TRule* r = new TRule(*kX1X2); + r->lhs_ = lhs; + r->f_[0] = rhs1; + r->f_[1] = rhs2; + return TRulePtr(r); +} + +TRulePtr CreateUnaryRule(int lhs, int rhs1) { + TRule* r = new TRule(*kX1); + r->lhs_ = lhs; + r->f_[0] = rhs1; + return TRulePtr(r); +} + +TRulePtr CreateEpsilonRule(int lhs) { + TRule* r = new TRule(*kEPSRule); + r->lhs_ = lhs; + return TRulePtr(r); +} + class EGrammarNode { friend bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest); friend void AddGrammarRule(const string& r, map<WordID, EGrammarNode>* g); @@ -356,7 +378,7 @@ class EarleyComposerImpl { } if (goal_node) { forest->PruneUnreachable(goal_node->id_); - forest->EpsilonRemove(kEPS); + RemoveEpsilons(forest, kEPS); } FreeAll(); return goal_node; @@ -557,24 +579,30 @@ class EarleyComposerImpl { } Hypergraph::Node*& head_node = edge2node[edge]; if (!head_node) - head_node = hg->AddNode(kPHRASE); + head_node = hg->AddNode(edge->cat); if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_0_ && edge->IsPassive()) { assert(goal_node == NULL || goal_node == head_node); goal_node = head_node; } + int rhs1 = 0; + int rhs2 = 0; Hypergraph::TailNodeVector tail; SparseVector<double> extra; if (edge->IsCreatedByPredict()) { // extra.set_value(FD::Convert("predict"), 1); } else if (edge->IsCreatedByScan()) { tail.push_back(edge2node[edge->active_parent]->id_); + rhs1 = edge->active_parent->cat; if (tps) { tail.push_back(tps->id_); + rhs2 = kPHRASE; } //extra.set_value(FD::Convert("scan"), 1); } else if (edge->IsCreatedByComplete()) { tail.push_back(edge2node[edge->active_parent]->id_); + rhs1 = edge->active_parent->cat; tail.push_back(edge2node[edge->passive_parent]->id_); + rhs2 = edge->passive_parent->cat; //extra.set_value(FD::Convert("complete"), 1); } else { assert(!"unexpected edge type!"); @@ -592,11 +620,11 @@ class EarleyComposerImpl { #endif Hypergraph::Edge* hg_edge = NULL; if (tail.size() == 0) { - hg_edge = hg->AddEdge(kEPSRule, tail); + hg_edge = hg->AddEdge(CreateEpsilonRule(edge->cat), tail); } else if (tail.size() == 1) { - hg_edge = hg->AddEdge(kX1, tail); + hg_edge = hg->AddEdge(CreateUnaryRule(edge->cat, rhs1), tail); } else if (tail.size() == 2) { - hg_edge = hg->AddEdge(kX1X2, tail); + hg_edge = hg->AddEdge(CreateBinaryRule(edge->cat, rhs1, rhs2), tail); } if (edge->features) hg_edge->feature_values_ += *edge->features; diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc index d6d79f5e..9c13fdbb 100644 --- a/decoder/ff_ngrams.cc +++ b/decoder/ff_ngrams.cc @@ -48,6 +48,9 @@ struct State { namespace { string Escape(const string& x) { + if (x.find('=') == string::npos && x.find(';') == string::npos) { + return x; + } string y = x; for (int i = 0; i < y.size(); ++i) { if (y[i] == '=') y[i]='_'; @@ -57,10 +60,17 @@ namespace { } } -static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order) { +static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector<string>& prefixes, string& target_separator) { vector<string> const& argv=SplitOnWhitespace(in); *explicit_markers = false; *order = 3; + prefixes.push_back("NOT-USED"); + prefixes.push_back("U:"); // default unigram prefix + prefixes.push_back("B:"); // default bigram prefix + prefixes.push_back("T:"); // ...etc + prefixes.push_back("4:"); // ...etc + prefixes.push_back("5:"); // max allowed! + target_separator = "_"; #define LMSPEC_NEXTARG if (i==argv.end()) { \ cerr << "Missing argument for "<<*last<<". "; goto usage; \ } else { ++i; } @@ -73,6 +83,30 @@ static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order) case 'x': *explicit_markers = true; break; + case 'U': + LMSPEC_NEXTARG; + prefixes[1] = *i; + break; + case 'B': + LMSPEC_NEXTARG; + prefixes[2] = *i; + break; + case 'T': + LMSPEC_NEXTARG; + prefixes[3] = *i; + break; + case '4': + LMSPEC_NEXTARG; + prefixes[4] = *i; + break; + case '5': + LMSPEC_NEXTARG; + prefixes[5] = *i; + break; + case 'S': + LMSPEC_NEXTARG; + target_separator = *i; + break; case 'o': LMSPEC_NEXTARG; *order=atoi((*i).c_str()); break; @@ -86,7 +120,29 @@ static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order) } return true; usage: - cerr << "NgramFeatures is incorrect!\n"; + cerr << "Wrong parameters for NgramFeatures.\n\n" + + << "NgramFeatures Usage: \n" + << " feature_function=NgramFeatures filename.lm [-x] [-o <order>] \n" + << " [-U <unigram-prefix>] [-B <bigram-prefix>][-T <trigram-prefix>]\n" + << " [-4 <4-gram-prefix>] [-5 <5-gram-prefix>] [-S <separator>]\n\n" + + << "Defaults: \n" + << " <order> = 3\n" + << " <unigram-prefix> = U:\n" + << " <bigram-prefix> = B:\n" + << " <trigram-prefix> = T:\n" + << " <4-gram-prefix> = 4:\n" + << " <5-gram-prefix> = 5:\n" + << " <separator> = _\n" + << " -x (i.e. explicit sos/eos markers) is turned off\n\n" + + << "Example configuration: \n" + << " feature_function=NgramFeatures -o 3 -T tri: -S |\n\n" + + << "Example feature instantiation: \n" + << " tri:a|b|c \n\n"; + return false; } @@ -158,16 +214,12 @@ class NgramDetectorImpl { int& fid = ft->fids[curword]; ++n; if (!fid) { - const char* code="_UBT456789"; // prefix code (unigram, bigram, etc.) ostringstream os; - os << code[n] << ':'; + os << prefixes_[n]; for (int i = n-1; i >= 0; --i) { - os << (i != n-1 ? "_" : ""); + os << (i != n-1 ? target_separator_ : ""); const string& tok = TD::Convert(buf[i]); - if (tok.find('=') == string::npos) - os << tok; - else - os << Escape(tok); + os << Escape(tok); } fid = FD::Convert(os.str()); } @@ -297,7 +349,8 @@ class NgramDetectorImpl { } public: - explicit NgramDetectorImpl(bool explicit_markers, unsigned order) : + explicit NgramDetectorImpl(bool explicit_markers, unsigned order, + vector<string>& prefixes, string& target_separator) : kCDEC_UNK(TD::Convert("<unk>")) , add_sos_eos_(!explicit_markers) { order_ = order; @@ -305,6 +358,8 @@ class NgramDetectorImpl { unscored_size_offset_ = (order_ - 1) * sizeof(WordID); is_complete_offset_ = unscored_size_offset_ + 1; unscored_words_offset_ = is_complete_offset_ + 1; + prefixes_ = prefixes; + target_separator_ = target_separator; // special handling of beginning / ending sentence markers dummy_state_ = new char[state_size_]; @@ -340,6 +395,8 @@ class NgramDetectorImpl { char* dummy_state_; vector<const void*> dummy_ants_; TRulePtr dummy_rule_; + vector<string> prefixes_; + string target_separator_; struct FidTree { map<WordID, int> fids; map<WordID, FidTree> levels; @@ -348,11 +405,13 @@ class NgramDetectorImpl { }; NgramDetector::NgramDetector(const string& param) { - string filename, mapfile, featname; + string filename, mapfile, featname, target_separator; + vector<string> prefixes; bool explicit_markers = false; unsigned order = 3; - ParseArgs(param, &explicit_markers, &order); - pimpl_ = new NgramDetectorImpl(explicit_markers, order); + ParseArgs(param, &explicit_markers, &order, prefixes, target_separator); + pimpl_ = new NgramDetectorImpl(explicit_markers, order, prefixes, + target_separator); SetStateSize(pimpl_->ReserveStateSize()); } diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h index 82f61b33..064dbb49 100644 --- a/decoder/ff_ngrams.h +++ b/decoder/ff_ngrams.h @@ -10,7 +10,7 @@ struct NgramDetectorImpl; class NgramDetector : public FeatureFunction { public: - // param = "filename.lm [-o n]" + // param = "filename.lm [-o <order>] [-U <unigram-prefix>] [-B <bigram-prefix>] [-T <trigram-prefix>] [-4 <4-gram-prefix>] [-5 <5-gram-prefix>] [-S <separator>] NgramDetector(const std::string& param); ~NgramDetector(); virtual void FinalTraversalFeatures(const void* context, diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc index 019315a2..fd9210fa 100644 --- a/decoder/ff_tagger.cc +++ b/decoder/ff_tagger.cc @@ -8,6 +8,17 @@ using namespace std; +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + Tagger_BigramIndicator::Tagger_BigramIndicator(const std::string& param) : FeatureFunction(sizeof(WordID)) { no_uni_ = (LowercaseString(param) == "no_uni"); @@ -28,7 +39,7 @@ void Tagger_BigramIndicator::FireFeature(const WordID& left, os << '_'; if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); } } - fid = FD::Convert(os.str()); + fid = FD::Convert(Escape(os.str())); } features->set_value(fid, 1.0); } @@ -90,7 +101,7 @@ void LexicalPairIndicator::FireFeature(WordID src, if (!fid) { ostringstream os; os << name_ << ':' << TD::Convert(src) << ':' << TD::Convert(trg); - fid = FD::Convert(os.str()); + fid = FD::Convert(Escape(os.str())); } features->set_value(fid, 1.0); } @@ -127,7 +138,7 @@ void OutputIndicator::FireFeature(WordID trg, if (escape.count(trg)) trg = escape[trg]; ostringstream os; os << "T:" << TD::Convert(trg); - fid = FD::Convert(os.str()); + fid = FD::Convert(Escape(os.str())); } features->set_value(fid, 1.0); } diff --git a/decoder/hg.cc b/decoder/hg.cc index dd272221..7240a8ab 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -605,69 +605,6 @@ void Hypergraph::TopologicallySortNodesAndEdges(int goal_index, #endif } -TRulePtr Hypergraph::kEPSRule; -TRulePtr Hypergraph::kUnaryRule; - -void Hypergraph::EpsilonRemove(WordID eps) { - if (!kEPSRule) { - kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>")); - kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); - } - vector<bool> kill(edges_.size(), false); - for (unsigned i = 0; i < edges_.size(); ++i) { - const Edge& edge = edges_[i]; - if (edge.tail_nodes_.empty() && - edge.rule_->f_.size() == 1 && - edge.rule_->f_[0] == eps) { - kill[i] = true; - if (!edge.feature_values_.empty()) { - Node& node = nodes_[edge.head_node_]; - if (node.in_edges_.size() != 1) { - cerr << "[WARNING] <eps> edge with features going into non-empty node - can't promote\n"; - // this *probably* means that there are multiple derivations of the - // same sequence via different paths through the input forest - // this needs to be investigated and fixed - } else { - for (unsigned j = 0; j < node.out_edges_.size(); ++j) - edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_; - // cerr << "PROMOTED " << edge.feature_values_ << endl; - } - } - } - } - bool created_eps = false; - PruneEdges(kill); - for (unsigned i = 0; i < nodes_.size(); ++i) { - const Node& node = nodes_[i]; - if (node.in_edges_.empty()) { - for (unsigned j = 0; j < node.out_edges_.size(); ++j) { - Edge& edge = edges_[node.out_edges_[j]]; - if (edge.rule_->Arity() == 2) { - assert(edge.rule_->f_.size() == 2); - assert(edge.rule_->e_.size() == 2); - edge.rule_ = kUnaryRule; - unsigned cur = node.id_; - int t = -1; - assert(edge.tail_nodes_.size() == 2); - for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; } - assert(t != -1); - edge.tail_nodes_.resize(1); - edge.tail_nodes_[0] = t; - } else { - edge.rule_ = kEPSRule; - edge.rule_->f_[0] = eps; - edge.rule_->e_[0] = eps; - edge.tail_nodes_.clear(); - created_eps = true; - } - } - } - } - vector<bool> k2(edges_.size(), false); - PruneEdges(k2); - if (created_eps) EpsilonRemove(eps); -} - struct EdgeWeightSorter { const Hypergraph& hg; EdgeWeightSorter(const Hypergraph& h) : hg(h) {} diff --git a/decoder/hg.h b/decoder/hg.h index 91d25f01..591e98ce 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -148,7 +148,7 @@ public: void show(std::ostream &o,unsigned mask=SPAN|RULE) const { o<<'{'; if (mask&CATEGORY) - o<<TD::Convert(rule_->GetLHS()); + o<< '[' << TD::Convert(-rule_->GetLHS()) << ']'; if (mask&PREV_SPAN) o<<'<'<<prev_i_<<','<<prev_j_<<'>'; if (mask&SPAN) @@ -156,9 +156,9 @@ public: if (mask&PROB) o<<" p="<<edge_prob_; if (mask&FEATURES) - o<<" "<<feature_values_; + o<<' '<<feature_values_; if (mask&RULE) - o<<rule_->AsString(mask&RULE_LHS); + o<<' '<<rule_->AsString(mask&RULE_LHS); if (USE_INFO_EDGE) { std::string const& i=info(); if (mask&&!i.empty()) o << " |||"<<i; // remember, the initial space is expected as part of i @@ -384,14 +384,6 @@ public: // compute the total number of paths in the forest double NumberOfPaths() const; - // BEWARE. this assumes that the source and target language - // strings are identical and that there are no loops. - // It assumes a bunch of other things about where the - // epsilons will be. It tries to assert failure if you - // break these assumptions, but it may not. - // TODO - make this work - void EpsilonRemove(WordID eps); - // multiple the weights vector by the edge feature vector // (inner product) to set the edge probabilities template <class V> @@ -535,9 +527,6 @@ public: private: Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {} - - static TRulePtr kEPSRule; - static TRulePtr kUnaryRule; }; diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index bfb2fb80..8bd40387 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -261,6 +261,7 @@ static void WriteRule(const TRule& r, ostream* out) { } bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) { + if (hg.empty()) { *out << "{}\n"; return true; } map<const TRule*, int> rid; ostream& o = *out; rid[NULL] = 0; diff --git a/decoder/hg_remove_eps.cc b/decoder/hg_remove_eps.cc new file mode 100644 index 00000000..050c4876 --- /dev/null +++ b/decoder/hg_remove_eps.cc @@ -0,0 +1,91 @@ +#include "hg_remove_eps.h" + +#include <cassert> + +#include "trule.h" +#include "hg.h" + +using namespace std; + +namespace { + TRulePtr kEPSRule; + TRulePtr kUnaryRule; + + TRulePtr CreateUnaryRule(int lhs, int rhs) { + if (!kUnaryRule) kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); + TRule* r = new TRule(*kUnaryRule); + assert(lhs < 0); + assert(rhs < 0); + r->lhs_ = lhs; + r->f_[0] = rhs; + return TRulePtr(r); + } + + TRulePtr CreateEpsilonRule(int lhs, WordID eps) { + if (!kEPSRule) kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>")); + TRule* r = new TRule(*kEPSRule); + r->lhs_ = lhs; + assert(lhs < 0); + assert(eps > 0); + r->e_[0] = eps; + r->f_[0] = eps; + return TRulePtr(r); + } +} + +void RemoveEpsilons(Hypergraph* g, WordID eps) { + vector<bool> kill(g->edges_.size(), false); + for (unsigned i = 0; i < g->edges_.size(); ++i) { + const Hypergraph::Edge& edge = g->edges_[i]; + if (edge.tail_nodes_.empty() && + edge.rule_->f_.size() == 1 && + edge.rule_->f_[0] == eps) { + kill[i] = true; + if (!edge.feature_values_.empty()) { + Hypergraph::Node& node = g->nodes_[edge.head_node_]; + if (node.in_edges_.size() != 1) { + cerr << "[WARNING] <eps> edge with features going into non-empty node - can't promote\n"; + // this *probably* means that there are multiple derivations of the + // same sequence via different paths through the input forest + // this needs to be investigated and fixed + } else { + for (unsigned j = 0; j < node.out_edges_.size(); ++j) + g->edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_; + // cerr << "PROMOTED " << edge.feature_values_ << endl; + } + } + } + } + bool created_eps = false; + g->PruneEdges(kill); + for (unsigned i = 0; i < g->nodes_.size(); ++i) { + const Hypergraph::Node& node = g->nodes_[i]; + if (node.in_edges_.empty()) { + for (unsigned j = 0; j < node.out_edges_.size(); ++j) { + Hypergraph::Edge& edge = g->edges_[node.out_edges_[j]]; + const int lhs = edge.rule_->lhs_; + if (edge.rule_->Arity() == 2) { + assert(edge.rule_->f_.size() == 2); + assert(edge.rule_->e_.size() == 2); + unsigned cur = node.id_; + int t = -1; + assert(edge.tail_nodes_.size() == 2); + int rhs = 0; + for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; rhs = edge.rule_->f_[i]; } + assert(t != -1); + edge.tail_nodes_.resize(1); + edge.tail_nodes_[0] = t; + edge.rule_ = CreateUnaryRule(lhs, rhs); + } else { + edge.rule_ = CreateEpsilonRule(lhs, eps); + edge.tail_nodes_.clear(); + created_eps = true; + } + } + } + } + vector<bool> k2(g->edges_.size(), false); + g->PruneEdges(k2); + if (created_eps) RemoveEpsilons(g, eps); +} + diff --git a/decoder/hg_remove_eps.h b/decoder/hg_remove_eps.h new file mode 100644 index 00000000..82f06039 --- /dev/null +++ b/decoder/hg_remove_eps.h @@ -0,0 +1,13 @@ +#ifndef _HG_REMOVE_EPS_H_ +#define _HG_REMOVE_EPS_H_ + +#include "wordid.h" +class Hypergraph; + +// This is not a complete implementation of the general algorithm for +// doing this. It makes a few weird assumptions, for example, that +// if some nonterminal X rewrites as eps, then that is the only thing +// that it rewrites as. This needs to be fixed for the general case! +void RemoveEpsilons(Hypergraph* g, WordID eps); + +#endif diff --git a/decoder/inside_outside.h b/decoder/inside_outside.h index bb7f9fcc..f73a1d3f 100644 --- a/decoder/inside_outside.h +++ b/decoder/inside_outside.h @@ -41,10 +41,6 @@ WeightType Inside(const Hypergraph& hg, WeightType* const cur_node_inside_score = &inside_score[i]; Hypergraph::EdgesVector const& in=hg.nodes_[i].in_edges_; const unsigned num_in_edges = in.size(); - if (num_in_edges == 0) { - *cur_node_inside_score = WeightType(1); //FIXME: why not call weight(edge) instead? - continue; - } for (unsigned j = 0; j < num_in_edges; ++j) { const Hypergraph::Edge& edge = hg.edges_[in[j]]; WeightType score = weight(edge); diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc new file mode 100644 index 00000000..10192f7a --- /dev/null +++ b/decoder/rescore_translator.cc @@ -0,0 +1,58 @@ +#include "translator.h" + +#include <sstream> +#include <boost/shared_ptr.hpp> + +#include "sentence_metadata.h" +#include "hg.h" +#include "hg_io.h" +#include "tdict.h" + +using namespace std; + +struct RescoreTranslatorImpl { + RescoreTranslatorImpl(const boost::program_options::variables_map& conf) : + goal_sym(conf["goal"].as<string>()), + kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")), + kGOAL(TD::Convert("Goal") * -1) { + } + + bool Translate(const string& input, + const vector<double>& weights, + Hypergraph* forest) { + if (input == "{}") return false; + if (input.find("{\"rules\"") == 0) { + istringstream is(input); + Hypergraph src_cfg_hg; + if (!HypergraphIO::ReadFromJSON(&is, forest)) { + cerr << "Parse error while reading HG from JSON.\n"; + abort(); + } + } else { + cerr << "Can only read HG input from JSON: use training/grammar_convert\n"; + abort(); + } + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(kGOAL); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + forest->Reweight(weights); + return true; + } + + const string goal_sym; + const TRulePtr kGOAL_RULE; + const WordID kGOAL; +}; + +RescoreTranslator::RescoreTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new RescoreTranslatorImpl(conf)) {} + +bool RescoreTranslator::TranslateImpl(const string& input, + SentenceMetadata* smeta, + const vector<double>& weights, + Hypergraph* minus_lm_forest) { + smeta->SetSourceLength(0); // don't know how to compute this + return pimpl_->Translate(input, weights, minus_lm_forest); +} + diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 185f979a..a978cfc2 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -20,7 +20,6 @@ #define reverse_foreach BOOST_REVERSE_FOREACH using namespace std; -static bool usingSentenceGrammar = false; static bool printGrammarsUsed = false; struct SCFGTranslatorImpl { @@ -91,31 +90,31 @@ struct SCFGTranslatorImpl { bool show_tree_structure_; unsigned int ctf_iterations_; vector<GrammarPtr> grammars; - GrammarPtr sup_grammar_; + set<GrammarPtr> sup_grammars_; - struct Equals { Equals(const GrammarPtr& v) : v_(v) {} - bool operator()(const GrammarPtr& x) const { return x == v_; } const GrammarPtr& v_; }; + struct ContainedIn { + ContainedIn(const set<GrammarPtr>& gs) : gs_(gs) {} + bool operator()(const GrammarPtr& x) const { return gs_.find(x) != gs_.end(); } + const set<GrammarPtr>& gs_; + }; - void SetSupplementalGrammar(const std::string& grammar_string) { - grammars.erase(remove_if(grammars.begin(), grammars.end(), Equals(sup_grammar_)), grammars.end()); + void AddSupplementalGrammarFromString(const std::string& grammar_string) { + grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end()); istringstream in(grammar_string); - sup_grammar_.reset(new TextGrammar(&in)); - grammars.push_back(sup_grammar_); + TextGrammar* sent_grammar = new TextGrammar(&in); + sent_grammar->SetMaxSpan(max_span_limit); + sent_grammar->SetGrammarName("SupFromString"); + AddSupplementalGrammar(GrammarPtr(sent_grammar)); } - struct NameEquals { NameEquals(const string name) : name_(name) {} - bool operator()(const GrammarPtr& x) const { return x->GetGrammarName() == name_; } const string name_; }; + void AddSupplementalGrammar(GrammarPtr gp) { + sup_grammars_.insert(gp); + grammars.push_back(gp); + } - void SetSentenceGrammarFromString(const std::string& grammar_str) { - assert(grammar_str != ""); - if (!SILENT) cerr << "Setting sentence grammar" << endl; - usingSentenceGrammar = true; - istringstream in(grammar_str); - TextGrammar* sent_grammar = new TextGrammar(&in); - sent_grammar->SetMaxSpan(max_span_limit); - sent_grammar->SetGrammarName("__psg"); - grammars.erase(remove_if(grammars.begin(), grammars.end(), NameEquals("__psg")), grammars.end()); - grammars.push_back(GrammarPtr(sent_grammar)); + void RemoveSupplementalGrammars() { + grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end()); + sup_grammars_.clear(); } bool Translate(const string& input, @@ -300,35 +299,24 @@ Check for grammar pointer in the sentence markup, for use with sentence specific */ void SCFGTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) { map<string,string>::const_iterator it = kv.find("grammar"); - - - if (it == kv.end()) { - usingSentenceGrammar= false; - return; + if (it != kv.end()) { + TextGrammar* sentGrammar = new TextGrammar(it->second); + sentGrammar->SetMaxSpan(pimpl_->max_span_limit); + sentGrammar->SetGrammarName(it->second); + pimpl_->AddSupplementalGrammar(GrammarPtr(sentGrammar)); } - //Create sentence specific grammar from specified file name and load grammar into list of grammars - usingSentenceGrammar = true; - TextGrammar* sentGrammar = new TextGrammar(it->second); - sentGrammar->SetMaxSpan(pimpl_->max_span_limit); - sentGrammar->SetGrammarName(it->second); - pimpl_->grammars.push_back(GrammarPtr(sentGrammar)); - } -void SCFGTranslator::SetSupplementalGrammar(const std::string& grammar) { - pimpl_->SetSupplementalGrammar(grammar); +void SCFGTranslator::AddSupplementalGrammarFromString(const std::string& grammar) { + pimpl_->AddSupplementalGrammarFromString(grammar); } -void SCFGTranslator::SetSentenceGrammarFromString(const std::string& grammar_str) { - pimpl_->SetSentenceGrammarFromString(grammar_str); +void SCFGTranslator::AddSupplementalGrammar(GrammarPtr grammar) { + pimpl_->AddSupplementalGrammar(grammar); } void SCFGTranslator::SentenceCompleteImpl() { - - if(usingSentenceGrammar) // Drop the last sentence grammar from the list of grammars - { - pimpl_->grammars.pop_back(); - } + pimpl_->RemoveSupplementalGrammars(); } std::string SCFGTranslator::GetDecoderType() const { diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 54890e85..63e855c8 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -54,6 +54,7 @@ struct TaggerImpl { const int new_node_id = forest->AddNode(kXCAT)->id_; for (int k = 0; k < tagset_.size(); ++k) { TRulePtr rule(TRule::CreateLexicalRule(src, tagset_[k])); + rule->lhs_ = kXCAT; Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); edge->i_ = i; edge->j_ = i+1; diff --git a/decoder/translator.h b/decoder/translator.h index cfd3b08a..c0800e84 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -58,8 +58,8 @@ class SCFGTranslatorImpl; class SCFGTranslator : public Translator { public: SCFGTranslator(const boost::program_options::variables_map& conf); - void SetSupplementalGrammar(const std::string& grammar); - void SetSentenceGrammarFromString(const std::string& grammar); + void AddSupplementalGrammar(GrammarPtr gp); + void AddSupplementalGrammarFromString(const std::string& grammar); virtual std::string GetDecoderType() const; protected: bool TranslateImpl(const std::string& src, @@ -85,4 +85,17 @@ class FSTTranslator : public Translator { boost::shared_ptr<FSTTranslatorImpl> pimpl_; }; +class RescoreTranslatorImpl; +class RescoreTranslator : public Translator { + public: + RescoreTranslator(const boost::program_options::variables_map& conf); + private: + bool TranslateImpl(const std::string& src, + SentenceMetadata* smeta, + const std::vector<double>& weights, + Hypergraph* minus_lm_forest); + private: + boost::shared_ptr<RescoreTranslatorImpl> pimpl_; +}; + #endif diff --git a/decoder/trule.cc b/decoder/trule.cc index 187a003d..896f9f3d 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -237,9 +237,9 @@ void TRule::ComputeArity() { string TRule::AsString(bool verbose) const { ostringstream os; int idx = 0; - if (lhs_ && verbose) { + if (lhs_) { os << '[' << TD::Convert(lhs_ * -1) << "] |||"; - } + } else { os << "NOLHS |||"; } for (unsigned i = 0; i < f_.size(); ++i) { const WordID& w = f_[i]; if (w < 0) { |