diff options
author | adam.d.lopez <adam.d.lopez@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-13 03:27:59 +0000 |
---|---|---|
committer | adam.d.lopez <adam.d.lopez@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-13 03:27:59 +0000 |
commit | 47c75319638866609f669346e15663c5ba43af7f (patch) | |
tree | 50bc1867bdb1ce27e679fb8981b1b805c4897cd6 /decoder/scfg_translator.cc | |
parent | a2360c873ba8b72744e16752a067276a46d63645 (diff) |
cdec now supports coarse-to-fine decoding (for SCFG only).
CTF has several options:
-coarse_to_fine_beam_prune=<double> (required to activate CTF)
assign an alpha parameter for pruning the coarse foreast
-ctf_beam_widen=<double> (optional, defaults to 2.0):
ratio to widen coarse pruning beam if fine parse fails.
-ctf_num_widenings=<int> (optional, defaults to 2):
number of times to widen coarse beam before defaulting to exhaustive
source parsing
-ctf_no_exhaustive (optional)
do not attempt exhaustive parse if CTF fails to find a parse.
Additionally, script extools/coarsen_grammar.pl will create a
coarse-to-fine grammar (for X?? categories *only*). cdec will
read CTF grammars in a format identical to the original, in which
refinements of a rule immediately follow the coarse projection,
preceded by an additional whitespace character.
Not fully tested, but should be backwards compatible. Also not
yet integrated into pipelines, but should work on the command line.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@231 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder/scfg_translator.cc')
-rw-r--r-- | decoder/scfg_translator.cc | 242 |
1 files changed, 202 insertions, 40 deletions
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 866c2721..32acfd65 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -3,13 +3,20 @@ //TODO: grammar heuristic (min cost of reachable rule set) for binarizations (active edges) if we wish to prune those also #include "translator.h" - +#include <algorithm> #include <vector> - +#include <tr1/unordered_map> +#include <boost/foreach.hpp> +#include <boost/functional/hash.hpp> #include "hg.h" #include "grammar.h" #include "bottom_up_parser.h" #include "sentence_metadata.h" +#include "tdict.h" +#include "viterbi.h" + +#define foreach BOOST_FOREACH +#define reverse_foreach BOOST_REVERSE_FOREACH using namespace std; static bool usingSentenceGrammar = false; @@ -20,39 +27,68 @@ struct SCFGTranslatorImpl { max_span_limit(conf["scfg_max_span_limit"].as<int>()), add_pass_through_rules(conf.count("add_pass_through_rules")), goal(conf["goal"].as<string>()), - default_nt(conf["scfg_default_nt"].as<string>()) { - if(conf.count("grammar")) - { - vector<string> gfiles = conf["grammar"].as<vector<string> >(); - for (int i = 0; i < gfiles.size(); ++i) { - cerr << "Reading SCFG grammar from " << gfiles[i] << endl; - TextGrammar* g = new TextGrammar(gfiles[i]); - g->SetMaxSpan(max_span_limit); - g->SetGrammarName(gfiles[i]); - grammars.push_back(GrammarPtr(g)); - - } - } - if (!conf.count("scfg_no_hiero_glue_grammar")) - { - GlueGrammar* g = new GlueGrammar(goal, default_nt); - g->SetGrammarName("GlueGrammar"); - grammars.push_back(GrammarPtr(g)); - cerr << "Adding glue grammar" << endl; + default_nt(conf["scfg_default_nt"].as<string>()), + use_ctf_(conf.count("coarse_to_fine_beam_prune")) + { + if(conf.count("grammar")){ + vector<string> gfiles = conf["grammar"].as<vector<string> >(); + for (int i = 0; i < gfiles.size(); ++i) { + cerr << "Reading SCFG grammar from " << gfiles[i] << endl; + TextGrammar* g = new TextGrammar(gfiles[i]); + g->SetMaxSpan(max_span_limit); + g->SetGrammarName(gfiles[i]); + grammars.push_back(GrammarPtr(g)); + } + } + cerr << std::endl; + if (conf.count("scfg_extra_glue_grammar")) { + GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>()); + g->SetGrammarName("ExtraGlueGrammar"); + grammars.push_back(GrammarPtr(g)); + cerr << "Adding glue grammar from file " << conf["scfg_extra_glue_grammar"].as<string>() << endl; + } + ctf_iterations_=0; + if (use_ctf_){ + ctf_alpha_ = conf["coarse_to_fine_beam_prune"].as<double>(); + foreach(GrammarPtr& gp, grammars){ + ctf_iterations_ = std::max(gp->GetCTFLevels(), ctf_iterations_); } - if (conf.count("scfg_extra_glue_grammar")) - { - GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>()); - g->SetGrammarName("ExtraGlueGrammar"); - grammars.push_back(GrammarPtr(g)); - cerr << "Adding extra glue grammar" << endl; + foreach(GrammarPtr& gp, grammars){ + if (gp->GetCTFLevels() != ctf_iterations_){ + cerr << "Grammar " << gp->GetGrammarName() << " has CTF level " << gp->GetCTFLevels() << + " but overall number of CTF levels is " << ctf_iterations_ << endl << + "Mixing coarse-to-fine grammars of different granularities is not supported" << endl; + abort(); + } } - } + show_tree_structure_ = conf.count("show_tree_structure"); + ctf_wide_alpha_ = conf["ctf_beam_widen"].as<double>(); + ctf_num_widenings_ = conf["ctf_num_widenings"].as<int>(); + ctf_exhaustive_ = (conf.count("ctf_no_exhaustive") == 0); + assert(ctf_wide_alpha_ > 1.0); + cerr << "Using coarse-to-fine pruning with " << ctf_iterations_ << " grammar projection(s) and alpha=" << ctf_alpha_ << endl; + cerr << " Coarse beam will be widened " << ctf_num_widenings_ << " times by a factor of " << ctf_wide_alpha_ << " if fine parse fails" << endl; + } + if (!conf.count("scfg_no_hiero_glue_grammar")){ + GlueGrammar* g = new GlueGrammar(goal, default_nt, ctf_iterations_); + g->SetGrammarName("GlueGrammar"); + grammars.push_back(GrammarPtr(g)); + cerr << "Adding glue grammar for default nonterminal " << default_nt << + " and goal nonterminal " << goal << endl; + } + } const int max_span_limit; const bool add_pass_through_rules; const string goal; const string default_nt; + const bool use_ctf_; + double ctf_alpha_; + double ctf_wide_alpha_; + int ctf_num_widenings_; + bool ctf_exhaustive_; + bool show_tree_structure_; + unsigned int ctf_iterations_; vector<GrammarPtr> grammars; bool Translate(const string& input, @@ -64,29 +100,155 @@ struct SCFGTranslatorImpl { LatticeTools::ConvertTextOrPLF(input, &lattice); smeta->SetSourceLength(lattice.size()); if (add_pass_through_rules){ - PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt); + cerr << "Adding pass through grammar" << endl; + PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_); g->SetGrammarName("PassThrough"); glist.push_back(GrammarPtr(g)); - cerr << "Adding pass through grammar" << endl; } - - - - if(printGrammarsUsed){ //Iterate trough grammars we have for this sentence and list them - for (int gi = 0; gi < glist.size(); ++gi) - { - cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl; - } + for (int gi = 0; gi < glist.size(); ++gi) { + if(printGrammarsUsed) + cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl; } - + cerr << "First pass parse... " << endl; ExhaustiveBottomUpParser parser(goal, glist); - if (!parser.Parse(lattice, forest)) + if (!parser.Parse(lattice, forest)){ + cerr << "parse failed." << endl; return false; + } else { + cerr << "parse succeeded." << endl; + } forest->Reweight(weights); + if (use_ctf_) { + Hypergraph::Node& goal_node = *(forest->nodes_.end()-1); + foreach(int edge_id, goal_node.in_edges_) + RefineRule(forest->edges_[edge_id].rule_, ctf_iterations_); + double alpha = ctf_alpha_; + bool found_parse; + for (int i=-1; i < ctf_num_widenings_; ++i) { + cerr << "Coarse-to-fine source parse, alpha=" << alpha << endl; + found_parse = true; + Hypergraph refined_forest = *forest; + for (int j=0; j < ctf_iterations_; ++j) { + cerr << viterbi_stats(refined_forest," Coarse forest",true,show_tree_structure_); + cerr << " Iteration " << (j+1) << ": Pruning forest... "; + refined_forest.BeamPruneInsideOutside(1.0, false, alpha, NULL); + cerr << "Refining forest..."; + if (RefineForest(&refined_forest)) { + cerr << " Refinement succeeded." << endl; + refined_forest.Reweight(weights); + } else { + cerr << " Refinement failed. Widening beam." << endl; + found_parse = false; + break; + } + } + if (found_parse) { + forest->swap(refined_forest); + break; + } + alpha *= ctf_wide_alpha_; + } + if (!found_parse){ + if (ctf_exhaustive_){ + cerr << "Last resort: refining coarse forest without pruning..."; + for (int j=0; j < ctf_iterations_; ++j) { + if (RefineForest(forest)){ + cerr << " Refinement succeeded." << endl; + forest->Reweight(weights); + } else { + cerr << " Refinement failed. No parse found for this sentence." << endl; + return false; + } + } + } else + return false; + } + } return true; } + + typedef std::pair<int, WordID> StateSplit; + typedef std::pair<StateSplit, int> StateSplitPair; + typedef std::tr1::unordered_map<StateSplit, int, boost::hash<StateSplit> > Split2Node; + typedef std::tr1::unordered_map<int, vector<WordID> > Splits; + + bool RefineForest(Hypergraph* forest) { + Hypergraph refined_forest; + Split2Node s2n; + Splits splits; + Hypergraph::Node& coarse_goal_node = *(forest->nodes_.end()-1); + bool refined_goal_node = false; + foreach(Hypergraph::Node& node, forest->nodes_){ + cerr << "."; + foreach(int edge_id, node.in_edges_) { + Hypergraph::Edge& edge = forest->edges_[edge_id]; + std::vector<int> nt_positions; + TRulePtr& coarse_rule_ptr = edge.rule_; + for(int i=0; i< coarse_rule_ptr->f_.size(); ++i){ + if (coarse_rule_ptr->f_[i] < 0) + nt_positions.push_back(i); + } + if (coarse_rule_ptr->fine_rules_ == 0) { + cerr << "Parsing with mixed levels of coarse-to-fine granularity is currently unsupported." << + endl << "Could not find refinement for: " << coarse_rule_ptr->AsString() << " on edge " << edge_id << " spanning " << edge.i_ << "," << edge.j_ << endl; + abort(); + } + // fine rules apply only if state splits on tail nodes match fine rule nonterminals + foreach(TRulePtr& fine_rule_ptr, *(coarse_rule_ptr->fine_rules_)) { + Hypergraph::TailNodeVector tail; + for (int pos_i=0; pos_i<nt_positions.size(); ++pos_i){ + WordID fine_cat = fine_rule_ptr->f_[nt_positions[pos_i]]; + Split2Node::iterator it = + s2n.find(StateSplit(edge.tail_nodes_[pos_i], fine_cat)); + if (it == s2n.end()) + break; + else + tail.push_back(it->second); + } + if (tail.size() == nt_positions.size()) { + WordID cat = fine_rule_ptr->lhs_; + Hypergraph::Edge* new_edge = refined_forest.AddEdge(fine_rule_ptr, tail); + new_edge->i_ = edge.i_; + new_edge->j_ = edge.j_; + new_edge->feature_values_ = fine_rule_ptr->GetFeatureValues(); + new_edge->feature_values_.set_value(FD::Convert("LatticeCost"), + edge.feature_values_[FD::Convert("LatticeCost")]); + Hypergraph::Node* head_node; + Split2Node::iterator it = s2n.find(StateSplit(node.id_, cat)); + if (it == s2n.end()){ + head_node = refined_forest.AddNode(cat); + s2n.insert(StateSplitPair(StateSplit(node.id_, cat), head_node->id_)); + splits[node.id_].push_back(cat); + if (&node == &coarse_goal_node) + refined_goal_node = true; + } else + head_node = &(refined_forest.nodes_[it->second]); + refined_forest.ConnectEdgeToHeadNode(new_edge, head_node); + } + } + } + } + cerr << endl; + forest->swap(refined_forest); + if (!refined_goal_node) + return false; + return true; + } + void OutputForest(Hypergraph* h) { + foreach(Hypergraph::Node& n, h->nodes_){ + if (n.in_edges_.size() == 0){ + cerr << "<" << TD::Convert(-n.cat_) << ", ?, ?>" << endl; + } else { + cerr << "<" << TD::Convert(-n.cat_) << ", " << h->edges_[n.in_edges_[0]].i_ << ", " << h->edges_[n.in_edges_[0]].j_ << ">" << endl; + } + foreach(int edge_id, n.in_edges_){ + cerr << " " << h->edges_[edge_id].rule_->AsString() << endl; + } + } + } }; + /* Called once from cdec.cc to setup the initial SCFG translation structure backend */ |