#include <algorithm> #include <vector> #include <boost/foreach.hpp> #include <boost/functional/hash.hpp> #include "hash.h" #include "translator.h" #include "hg.h" #include "grammar.h" #include "bottom_up_parser.h" #include "sentence_metadata.h" #include "stringlib.h" #include "tdict.h" #include "viterbi.h" #include "verbose.h" #define foreach BOOST_FOREACH #define reverse_foreach BOOST_REVERSE_FOREACH using namespace std; static bool printGrammarsUsed = false; struct GlueGrammar : public TextGrammar { // read glue grammar from file explicit GlueGrammar(const std::string& file); GlueGrammar(const std::string& goal_nt, const std::string& default_nt, const unsigned int ctf_level=0); // "S", "X" virtual bool HasRuleForSpan(int i, int j, int distance) const; }; struct PassThroughGrammar : public TextGrammar { PassThroughGrammar(const Lattice& input, const std::string& cat, const unsigned int ctf_level=0); virtual bool HasRuleForSpan(int i, int j, int distance) const; }; GlueGrammar::GlueGrammar(const string& file) : TextGrammar(file) {} static void RefineRule(TRulePtr pt, const unsigned int ctf_level){ for (unsigned int i=0; i<ctf_level; ++i){ TRulePtr r(new TRule(*pt)); pt->fine_rules_.reset(new vector<TRulePtr>); pt->fine_rules_->push_back(r); pt = r; } } GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt, const unsigned int ctf_level) { TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [1]")); AddRule(stop_glue); RefineRule(stop_glue, ctf_level); TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["+ default_nt + ",2] ||| [1] [2] ||| Glue=1")); AddRule(glue); RefineRule(glue, ctf_level); } bool GlueGrammar::HasRuleForSpan(int i, int /* j */, int /* distance */) const { return (i == 0); } PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level) { unordered_set<WordID> ss; for (int i = 0; i < input.size(); ++i) { const vector<LatticeArc>& alts = input[i]; for (int k = 0; k < alts.size(); ++k) { const int j = alts[k].dist2next + i; const string& src = TD::Convert(alts[k].label); if (ss.count(alts[k].label) == 0) { int length = static_cast<int>(log(UTF8StringLen(src)) / log(1.6)) + 1; if (length > 6) length = 6; string len_feat = "PassThrough_0=1"; len_feat[12] += length; TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1 " + len_feat)); pt->a_.push_back(AlignmentPoint(0,0)); AddRule(pt); RefineRule(pt, ctf_level); ss.insert(alts[k].label); } } } } bool PassThroughGrammar::HasRuleForSpan(int, int, int distance) const { return (distance < 4); // TODO this isn't great, but helps with EPS lattices } struct SCFGTranslatorImpl { SCFGTranslatorImpl(const boost::program_options::variables_map& conf) : max_span_limit(conf["scfg_max_span_limit"].as<int>()), add_pass_through_rules(conf.count("add_pass_through_rules")), goal(conf["goal"].as<string>()), default_nt(conf["scfg_default_nt"].as<string>()), use_ctf_(conf.count("coarse_to_fine_beam_prune")) { if(conf.count("grammar")){ vector<string> gfiles = conf["grammar"].as<vector<string> >(); for (unsigned i = 0; i < gfiles.size(); ++i) { if (!SILENT) cerr << "Reading SCFG grammar from " << gfiles[i] << endl; TextGrammar* g = new TextGrammar(gfiles[i]); g->SetMaxSpan(max_span_limit); g->SetGrammarName(gfiles[i]); grammars.push_back(GrammarPtr(g)); } if (!SILENT) cerr << endl; } if (conf.count("scfg_extra_glue_grammar")) { GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>()); g->SetGrammarName("ExtraGlueGrammar"); grammars.push_back(GrammarPtr(g)); if (!SILENT) cerr << "Adding glue grammar from file " << conf["scfg_extra_glue_grammar"].as<string>() << endl; } ctf_iterations_=0; if (use_ctf_){ ctf_alpha_ = conf["coarse_to_fine_beam_prune"].as<double>(); foreach(GrammarPtr& gp, grammars){ ctf_iterations_ = std::max(gp->GetCTFLevels(), ctf_iterations_); } foreach(GrammarPtr& gp, grammars){ if (gp->GetCTFLevels() != ctf_iterations_){ cerr << "Grammar " << gp->GetGrammarName() << " has CTF level " << gp->GetCTFLevels() << " but overall number of CTF levels is " << ctf_iterations_ << endl << "Mixing coarse-to-fine grammars of different granularities is not supported" << endl; abort(); } } show_tree_structure_ = conf.count("show_tree_structure"); ctf_wide_alpha_ = conf["ctf_beam_widen"].as<double>(); ctf_num_widenings_ = conf["ctf_num_widenings"].as<int>(); ctf_exhaustive_ = (conf.count("ctf_no_exhaustive") == 0); assert(ctf_wide_alpha_ > 1.0); cerr << "Using coarse-to-fine pruning with " << ctf_iterations_ << " grammar projection(s) and alpha=" << ctf_alpha_ << endl; cerr << " Coarse beam will be widened " << ctf_num_widenings_ << " times by a factor of " << ctf_wide_alpha_ << " if fine parse fails" << endl; } if (!conf.count("scfg_no_hiero_glue_grammar")){ GlueGrammar* g = new GlueGrammar(goal, default_nt, ctf_iterations_); g->SetGrammarName("GlueGrammar"); grammars.push_back(GrammarPtr(g)); if (!SILENT) cerr << "Adding glue grammar for default nonterminal " << default_nt << " and goal nonterminal " << goal << endl; } } const int max_span_limit; const bool add_pass_through_rules; const string goal; const string default_nt; const bool use_ctf_; double ctf_alpha_; double ctf_wide_alpha_; int ctf_num_widenings_; bool ctf_exhaustive_; bool show_tree_structure_; unsigned int ctf_iterations_; vector<GrammarPtr> grammars; set<GrammarPtr> sup_grammars_; struct ContainedIn { ContainedIn(const set<GrammarPtr>& gs) : gs_(gs) {} bool operator()(const GrammarPtr& x) const { return gs_.find(x) != gs_.end(); } const set<GrammarPtr>& gs_; }; void AddSupplementalGrammarFromString(const std::string& grammar_string) { grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end()); istringstream in(grammar_string); TextGrammar* sent_grammar = new TextGrammar(&in); sent_grammar->SetMaxSpan(max_span_limit); sent_grammar->SetGrammarName("SupFromString"); AddSupplementalGrammar(GrammarPtr(sent_grammar)); } void AddSupplementalGrammar(GrammarPtr gp) { sup_grammars_.insert(gp); grammars.push_back(gp); } void RemoveSupplementalGrammars() { grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end()); sup_grammars_.clear(); } bool Translate(const string& input, SentenceMetadata* smeta, const vector<double>& weights, Hypergraph* forest) { vector<GrammarPtr> glist = grammars; Lattice& lattice = smeta->src_lattice_; LatticeTools::ConvertTextOrPLF(input, &lattice); smeta->SetSourceLength(lattice.size()); if (add_pass_through_rules){ if (!SILENT) cerr << "Adding pass through grammar" << endl; PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_); g->SetGrammarName("PassThrough"); glist.push_back(GrammarPtr(g)); } for (unsigned gi = 0; gi < glist.size(); ++gi) { if(printGrammarsUsed) cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl; } if (!SILENT) cerr << "First pass parse... " << endl; ExhaustiveBottomUpParser parser(goal, glist); if (!parser.Parse(lattice, forest)){ if (!SILENT) cerr << " parse failed." << endl; return false; } else { // if (!SILENT) cerr << " parse succeeded." << endl; } forest->Reweight(weights); if (use_ctf_) { Hypergraph::Node& goal_node = *(forest->nodes_.end()-1); foreach(unsigned edge_id, goal_node.in_edges_) RefineRule(forest->edges_[edge_id].rule_, ctf_iterations_); double alpha = ctf_alpha_; bool found_parse=false; for (int i=-1; i < ctf_num_widenings_; ++i) { cerr << "Coarse-to-fine source parse, alpha=" << alpha << endl; found_parse = true; Hypergraph refined_forest = *forest; for (unsigned j=0; j < ctf_iterations_; ++j) { cerr << viterbi_stats(refined_forest," Coarse forest",true,show_tree_structure_); cerr << " Iteration " << (j+1) << ": Pruning forest... "; refined_forest.BeamPruneInsideOutside(1.0, false, alpha, NULL); cerr << "Refining forest..."; if (RefineForest(&refined_forest)) { cerr << " Refinement succeeded." << endl; refined_forest.Reweight(weights); } else { cerr << " Refinement failed. Widening beam." << endl; found_parse = false; break; } } if (found_parse) { forest->swap(refined_forest); break; } alpha *= ctf_wide_alpha_; } if (!found_parse){ if (ctf_exhaustive_){ cerr << "Last resort: refining coarse forest without pruning..."; for (unsigned j=0; j < ctf_iterations_; ++j) { if (RefineForest(forest)){ cerr << " Refinement succeeded." << endl; forest->Reweight(weights); } else { cerr << " Refinement failed. No parse found for this sentence." << endl; return false; } } } else return false; } } return true; } typedef std::pair<int, WordID> StateSplit; typedef std::pair<StateSplit, int> StateSplitPair; typedef HASH_MAP<StateSplit, int, boost::hash<StateSplit> > Split2Node; typedef HASH_MAP<int, vector<WordID> > Splits; bool RefineForest(Hypergraph* forest) { Hypergraph refined_forest; Split2Node s2n; HASH_MAP_RESERVED(s2n,StateSplit(-1,-1),StateSplit(-2,-2)); Splits splits; HASH_MAP_RESERVED(splits,-1,-2); Hypergraph::Node& coarse_goal_node = *(forest->nodes_.end()-1); bool refined_goal_node = false; foreach(Hypergraph::Node& node, forest->nodes_){ cerr << "."; foreach(int edge_id, node.in_edges_) { Hypergraph::Edge& edge = forest->edges_[edge_id]; std::vector<int> nt_positions; TRulePtr& coarse_rule_ptr = edge.rule_; for(unsigned i=0; i< coarse_rule_ptr->f_.size(); ++i){ if (coarse_rule_ptr->f_[i] < 0) nt_positions.push_back(i); } if (coarse_rule_ptr->fine_rules_ == 0) { cerr << "Parsing with mixed levels of coarse-to-fine granularity is currently unsupported." << endl << "Could not find refinement for: " << coarse_rule_ptr->AsString() << " on edge " << edge_id << " spanning " << edge.i_ << "," << edge.j_ << endl; abort(); } // fine rules apply only if state splits on tail nodes match fine rule nonterminals foreach(TRulePtr& fine_rule_ptr, *(coarse_rule_ptr->fine_rules_)) { Hypergraph::TailNodeVector tail; for (unsigned pos_i=0; pos_i<nt_positions.size(); ++pos_i){ WordID fine_cat = fine_rule_ptr->f_[nt_positions[pos_i]]; Split2Node::iterator it = s2n.find(StateSplit(edge.tail_nodes_[pos_i], fine_cat)); if (it == s2n.end()) break; else tail.push_back(it->second); } if (tail.size() == nt_positions.size()) { WordID cat = fine_rule_ptr->lhs_; Hypergraph::Edge* new_edge = refined_forest.AddEdge(fine_rule_ptr, tail); new_edge->i_ = edge.i_; new_edge->j_ = edge.j_; new_edge->feature_values_ = fine_rule_ptr->GetFeatureValues(); new_edge->feature_values_.set_value(FD::Convert("LatticeCost"), edge.feature_values_.value(FD::Convert("LatticeCost"))); Hypergraph::Node* head_node; Split2Node::iterator it = s2n.find(StateSplit(node.id_, cat)); if (it == s2n.end()){ head_node = refined_forest.AddNode(cat); s2n.insert(StateSplitPair(StateSplit(node.id_, cat), head_node->id_)); splits[node.id_].push_back(cat); if (&node == &coarse_goal_node) refined_goal_node = true; } else head_node = &(refined_forest.nodes_[it->second]); refined_forest.ConnectEdgeToHeadNode(new_edge, head_node); } } } } cerr << endl; forest->swap(refined_forest); if (!refined_goal_node) return false; return true; } void OutputForest(Hypergraph* h) { foreach(Hypergraph::Node& n, h->nodes_){ if (n.in_edges_.size() == 0){ cerr << "<" << TD::Convert(-n.cat_) << ", ?, ?>" << endl; } else { cerr << "<" << TD::Convert(-n.cat_) << ", " << h->edges_[n.in_edges_[0]].i_ << ", " << h->edges_[n.in_edges_[0]].j_ << ">" << endl; } foreach(int edge_id, n.in_edges_){ cerr << " " << h->edges_[edge_id].rule_->AsString() << endl; } } } }; /* Called once from cdec.cc to setup the initial SCFG translation structure backend */ SCFGTranslator::SCFGTranslator(const boost::program_options::variables_map& conf) : pimpl_(new SCFGTranslatorImpl(conf)) {} /* Called for each sentence to perform translation using the SCFG backend */ bool SCFGTranslator::TranslateImpl(const string& input, SentenceMetadata* smeta, const vector<double>& weights, Hypergraph* minus_lm_forest) { return pimpl_->Translate(input, smeta, weights, minus_lm_forest); } /* Check for grammar pointer in the sentence markup, for use with sentence specific grammars */ void SCFGTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) { map<string,string>::const_iterator it = kv.find("grammar"); if (it != kv.end()) { TextGrammar* sentGrammar = new TextGrammar(it->second); sentGrammar->SetMaxSpan(pimpl_->max_span_limit); sentGrammar->SetGrammarName(it->second); pimpl_->AddSupplementalGrammar(GrammarPtr(sentGrammar)); } } void SCFGTranslator::AddSupplementalGrammarFromString(const std::string& grammar) { pimpl_->AddSupplementalGrammarFromString(grammar); } void SCFGTranslator::AddSupplementalGrammar(GrammarPtr grammar) { pimpl_->AddSupplementalGrammar(grammar); } void SCFGTranslator::SentenceCompleteImpl() { pimpl_->RemoveSupplementalGrammars(); } std::string SCFGTranslator::GetDecoderType() const { return "SCFG"; }