diff options
Diffstat (limited to 'decoder/scfg_translator.cc')
-rw-r--r-- | decoder/scfg_translator.cc | 242 |
1 files changed, 202 insertions, 40 deletions
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 866c2721..32acfd65 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -3,13 +3,20 @@ //TODO: grammar heuristic (min cost of reachable rule set) for binarizations (active edges) if we wish to prune those also #include "translator.h" - +#include <algorithm> #include <vector> - +#include <tr1/unordered_map> +#include <boost/foreach.hpp> +#include <boost/functional/hash.hpp> #include "hg.h" #include "grammar.h" #include "bottom_up_parser.h" #include "sentence_metadata.h" +#include "tdict.h" +#include "viterbi.h" + +#define foreach BOOST_FOREACH +#define reverse_foreach BOOST_REVERSE_FOREACH using namespace std; static bool usingSentenceGrammar = false; @@ -20,39 +27,68 @@ struct SCFGTranslatorImpl { max_span_limit(conf["scfg_max_span_limit"].as<int>()), add_pass_through_rules(conf.count("add_pass_through_rules")), goal(conf["goal"].as<string>()), - default_nt(conf["scfg_default_nt"].as<string>()) { - if(conf.count("grammar")) - { - vector<string> gfiles = conf["grammar"].as<vector<string> >(); - for (int i = 0; i < gfiles.size(); ++i) { - cerr << "Reading SCFG grammar from " << gfiles[i] << endl; - TextGrammar* g = new TextGrammar(gfiles[i]); - g->SetMaxSpan(max_span_limit); - g->SetGrammarName(gfiles[i]); - grammars.push_back(GrammarPtr(g)); - - } - } - if (!conf.count("scfg_no_hiero_glue_grammar")) - { - GlueGrammar* g = new GlueGrammar(goal, default_nt); - g->SetGrammarName("GlueGrammar"); - grammars.push_back(GrammarPtr(g)); - cerr << "Adding glue grammar" << endl; + default_nt(conf["scfg_default_nt"].as<string>()), + use_ctf_(conf.count("coarse_to_fine_beam_prune")) + { + if(conf.count("grammar")){ + vector<string> gfiles = conf["grammar"].as<vector<string> >(); + for (int i = 0; i < gfiles.size(); ++i) { + cerr << "Reading SCFG grammar from " << gfiles[i] << endl; + TextGrammar* g = new TextGrammar(gfiles[i]); + g->SetMaxSpan(max_span_limit); + g->SetGrammarName(gfiles[i]); + grammars.push_back(GrammarPtr(g)); + } + } + cerr << std::endl; + if (conf.count("scfg_extra_glue_grammar")) { + GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>()); + g->SetGrammarName("ExtraGlueGrammar"); + grammars.push_back(GrammarPtr(g)); + cerr << "Adding glue grammar from file " << conf["scfg_extra_glue_grammar"].as<string>() << endl; + } + ctf_iterations_=0; + if (use_ctf_){ + ctf_alpha_ = conf["coarse_to_fine_beam_prune"].as<double>(); + foreach(GrammarPtr& gp, grammars){ + ctf_iterations_ = std::max(gp->GetCTFLevels(), ctf_iterations_); } - if (conf.count("scfg_extra_glue_grammar")) - { - GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>()); - g->SetGrammarName("ExtraGlueGrammar"); - grammars.push_back(GrammarPtr(g)); - cerr << "Adding extra glue grammar" << endl; + foreach(GrammarPtr& gp, grammars){ + if (gp->GetCTFLevels() != ctf_iterations_){ + cerr << "Grammar " << gp->GetGrammarName() << " has CTF level " << gp->GetCTFLevels() << + " but overall number of CTF levels is " << ctf_iterations_ << endl << + "Mixing coarse-to-fine grammars of different granularities is not supported" << endl; + abort(); + } } - } + show_tree_structure_ = conf.count("show_tree_structure"); + ctf_wide_alpha_ = conf["ctf_beam_widen"].as<double>(); + ctf_num_widenings_ = conf["ctf_num_widenings"].as<int>(); + ctf_exhaustive_ = (conf.count("ctf_no_exhaustive") == 0); + assert(ctf_wide_alpha_ > 1.0); + cerr << "Using coarse-to-fine pruning with " << ctf_iterations_ << " grammar projection(s) and alpha=" << ctf_alpha_ << endl; + cerr << " Coarse beam will be widened " << ctf_num_widenings_ << " times by a factor of " << ctf_wide_alpha_ << " if fine parse fails" << endl; + } + if (!conf.count("scfg_no_hiero_glue_grammar")){ + GlueGrammar* g = new GlueGrammar(goal, default_nt, ctf_iterations_); + g->SetGrammarName("GlueGrammar"); + grammars.push_back(GrammarPtr(g)); + cerr << "Adding glue grammar for default nonterminal " << default_nt << + " and goal nonterminal " << goal << endl; + } + } const int max_span_limit; const bool add_pass_through_rules; const string goal; const string default_nt; + const bool use_ctf_; + double ctf_alpha_; + double ctf_wide_alpha_; + int ctf_num_widenings_; + bool ctf_exhaustive_; + bool show_tree_structure_; + unsigned int ctf_iterations_; vector<GrammarPtr> grammars; bool Translate(const string& input, @@ -64,29 +100,155 @@ struct SCFGTranslatorImpl { LatticeTools::ConvertTextOrPLF(input, &lattice); smeta->SetSourceLength(lattice.size()); if (add_pass_through_rules){ - PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt); + cerr << "Adding pass through grammar" << endl; + PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_); g->SetGrammarName("PassThrough"); glist.push_back(GrammarPtr(g)); - cerr << "Adding pass through grammar" << endl; } - - - - if(printGrammarsUsed){ //Iterate trough grammars we have for this sentence and list them - for (int gi = 0; gi < glist.size(); ++gi) - { - cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl; - } + for (int gi = 0; gi < glist.size(); ++gi) { + if(printGrammarsUsed) + cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl; } - + cerr << "First pass parse... " << endl; ExhaustiveBottomUpParser parser(goal, glist); - if (!parser.Parse(lattice, forest)) + if (!parser.Parse(lattice, forest)){ + cerr << "parse failed." << endl; return false; + } else { + cerr << "parse succeeded." << endl; + } forest->Reweight(weights); + if (use_ctf_) { + Hypergraph::Node& goal_node = *(forest->nodes_.end()-1); + foreach(int edge_id, goal_node.in_edges_) + RefineRule(forest->edges_[edge_id].rule_, ctf_iterations_); + double alpha = ctf_alpha_; + bool found_parse; + for (int i=-1; i < ctf_num_widenings_; ++i) { + cerr << "Coarse-to-fine source parse, alpha=" << alpha << endl; + found_parse = true; + Hypergraph refined_forest = *forest; + for (int j=0; j < ctf_iterations_; ++j) { + cerr << viterbi_stats(refined_forest," Coarse forest",true,show_tree_structure_); + cerr << " Iteration " << (j+1) << ": Pruning forest... "; + refined_forest.BeamPruneInsideOutside(1.0, false, alpha, NULL); + cerr << "Refining forest..."; + if (RefineForest(&refined_forest)) { + cerr << " Refinement succeeded." << endl; + refined_forest.Reweight(weights); + } else { + cerr << " Refinement failed. Widening beam." << endl; + found_parse = false; + break; + } + } + if (found_parse) { + forest->swap(refined_forest); + break; + } + alpha *= ctf_wide_alpha_; + } + if (!found_parse){ + if (ctf_exhaustive_){ + cerr << "Last resort: refining coarse forest without pruning..."; + for (int j=0; j < ctf_iterations_; ++j) { + if (RefineForest(forest)){ + cerr << " Refinement succeeded." << endl; + forest->Reweight(weights); + } else { + cerr << " Refinement failed. No parse found for this sentence." << endl; + return false; + } + } + } else + return false; + } + } return true; } + + typedef std::pair<int, WordID> StateSplit; + typedef std::pair<StateSplit, int> StateSplitPair; + typedef std::tr1::unordered_map<StateSplit, int, boost::hash<StateSplit> > Split2Node; + typedef std::tr1::unordered_map<int, vector<WordID> > Splits; + + bool RefineForest(Hypergraph* forest) { + Hypergraph refined_forest; + Split2Node s2n; + Splits splits; + Hypergraph::Node& coarse_goal_node = *(forest->nodes_.end()-1); + bool refined_goal_node = false; + foreach(Hypergraph::Node& node, forest->nodes_){ + cerr << "."; + foreach(int edge_id, node.in_edges_) { + Hypergraph::Edge& edge = forest->edges_[edge_id]; + std::vector<int> nt_positions; + TRulePtr& coarse_rule_ptr = edge.rule_; + for(int i=0; i< coarse_rule_ptr->f_.size(); ++i){ + if (coarse_rule_ptr->f_[i] < 0) + nt_positions.push_back(i); + } + if (coarse_rule_ptr->fine_rules_ == 0) { + cerr << "Parsing with mixed levels of coarse-to-fine granularity is currently unsupported." << + endl << "Could not find refinement for: " << coarse_rule_ptr->AsString() << " on edge " << edge_id << " spanning " << edge.i_ << "," << edge.j_ << endl; + abort(); + } + // fine rules apply only if state splits on tail nodes match fine rule nonterminals + foreach(TRulePtr& fine_rule_ptr, *(coarse_rule_ptr->fine_rules_)) { + Hypergraph::TailNodeVector tail; + for (int pos_i=0; pos_i<nt_positions.size(); ++pos_i){ + WordID fine_cat = fine_rule_ptr->f_[nt_positions[pos_i]]; + Split2Node::iterator it = + s2n.find(StateSplit(edge.tail_nodes_[pos_i], fine_cat)); + if (it == s2n.end()) + break; + else + tail.push_back(it->second); + } + if (tail.size() == nt_positions.size()) { + WordID cat = fine_rule_ptr->lhs_; + Hypergraph::Edge* new_edge = refined_forest.AddEdge(fine_rule_ptr, tail); + new_edge->i_ = edge.i_; + new_edge->j_ = edge.j_; + new_edge->feature_values_ = fine_rule_ptr->GetFeatureValues(); + new_edge->feature_values_.set_value(FD::Convert("LatticeCost"), + edge.feature_values_[FD::Convert("LatticeCost")]); + Hypergraph::Node* head_node; + Split2Node::iterator it = s2n.find(StateSplit(node.id_, cat)); + if (it == s2n.end()){ + head_node = refined_forest.AddNode(cat); + s2n.insert(StateSplitPair(StateSplit(node.id_, cat), head_node->id_)); + splits[node.id_].push_back(cat); + if (&node == &coarse_goal_node) + refined_goal_node = true; + } else + head_node = &(refined_forest.nodes_[it->second]); + refined_forest.ConnectEdgeToHeadNode(new_edge, head_node); + } + } + } + } + cerr << endl; + forest->swap(refined_forest); + if (!refined_goal_node) + return false; + return true; + } + void OutputForest(Hypergraph* h) { + foreach(Hypergraph::Node& n, h->nodes_){ + if (n.in_edges_.size() == 0){ + cerr << "<" << TD::Convert(-n.cat_) << ", ?, ?>" << endl; + } else { + cerr << "<" << TD::Convert(-n.cat_) << ", " << h->edges_[n.in_edges_[0]].i_ << ", " << h->edges_[n.in_edges_[0]].j_ << ">" << endl; + } + foreach(int edge_id, n.in_edges_){ + cerr << " " << h->edges_[edge_id].rule_->AsString() << endl; + } + } + } }; + /* Called once from cdec.cc to setup the initial SCFG translation structure backend */ |