summaryrefslogtreecommitdiff
path: root/decoder/scfg_translator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/scfg_translator.cc')
-rw-r--r--decoder/scfg_translator.cc242
1 files changed, 202 insertions, 40 deletions
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index 866c2721..32acfd65 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -3,13 +3,20 @@
//TODO: grammar heuristic (min cost of reachable rule set) for binarizations (active edges) if we wish to prune those also
#include "translator.h"
-
+#include <algorithm>
#include <vector>
-
+#include <tr1/unordered_map>
+#include <boost/foreach.hpp>
+#include <boost/functional/hash.hpp>
#include "hg.h"
#include "grammar.h"
#include "bottom_up_parser.h"
#include "sentence_metadata.h"
+#include "tdict.h"
+#include "viterbi.h"
+
+#define foreach BOOST_FOREACH
+#define reverse_foreach BOOST_REVERSE_FOREACH
using namespace std;
static bool usingSentenceGrammar = false;
@@ -20,39 +27,68 @@ struct SCFGTranslatorImpl {
max_span_limit(conf["scfg_max_span_limit"].as<int>()),
add_pass_through_rules(conf.count("add_pass_through_rules")),
goal(conf["goal"].as<string>()),
- default_nt(conf["scfg_default_nt"].as<string>()) {
- if(conf.count("grammar"))
- {
- vector<string> gfiles = conf["grammar"].as<vector<string> >();
- for (int i = 0; i < gfiles.size(); ++i) {
- cerr << "Reading SCFG grammar from " << gfiles[i] << endl;
- TextGrammar* g = new TextGrammar(gfiles[i]);
- g->SetMaxSpan(max_span_limit);
- g->SetGrammarName(gfiles[i]);
- grammars.push_back(GrammarPtr(g));
-
- }
- }
- if (!conf.count("scfg_no_hiero_glue_grammar"))
- {
- GlueGrammar* g = new GlueGrammar(goal, default_nt);
- g->SetGrammarName("GlueGrammar");
- grammars.push_back(GrammarPtr(g));
- cerr << "Adding glue grammar" << endl;
+ default_nt(conf["scfg_default_nt"].as<string>()),
+ use_ctf_(conf.count("coarse_to_fine_beam_prune"))
+ {
+ if(conf.count("grammar")){
+ vector<string> gfiles = conf["grammar"].as<vector<string> >();
+ for (int i = 0; i < gfiles.size(); ++i) {
+ cerr << "Reading SCFG grammar from " << gfiles[i] << endl;
+ TextGrammar* g = new TextGrammar(gfiles[i]);
+ g->SetMaxSpan(max_span_limit);
+ g->SetGrammarName(gfiles[i]);
+ grammars.push_back(GrammarPtr(g));
+ }
+ }
+ cerr << std::endl;
+ if (conf.count("scfg_extra_glue_grammar")) {
+ GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>());
+ g->SetGrammarName("ExtraGlueGrammar");
+ grammars.push_back(GrammarPtr(g));
+ cerr << "Adding glue grammar from file " << conf["scfg_extra_glue_grammar"].as<string>() << endl;
+ }
+ ctf_iterations_=0;
+ if (use_ctf_){
+ ctf_alpha_ = conf["coarse_to_fine_beam_prune"].as<double>();
+ foreach(GrammarPtr& gp, grammars){
+ ctf_iterations_ = std::max(gp->GetCTFLevels(), ctf_iterations_);
}
- if (conf.count("scfg_extra_glue_grammar"))
- {
- GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>());
- g->SetGrammarName("ExtraGlueGrammar");
- grammars.push_back(GrammarPtr(g));
- cerr << "Adding extra glue grammar" << endl;
+ foreach(GrammarPtr& gp, grammars){
+ if (gp->GetCTFLevels() != ctf_iterations_){
+ cerr << "Grammar " << gp->GetGrammarName() << " has CTF level " << gp->GetCTFLevels() <<
+ " but overall number of CTF levels is " << ctf_iterations_ << endl <<
+ "Mixing coarse-to-fine grammars of different granularities is not supported" << endl;
+ abort();
+ }
}
- }
+ show_tree_structure_ = conf.count("show_tree_structure");
+ ctf_wide_alpha_ = conf["ctf_beam_widen"].as<double>();
+ ctf_num_widenings_ = conf["ctf_num_widenings"].as<int>();
+ ctf_exhaustive_ = (conf.count("ctf_no_exhaustive") == 0);
+ assert(ctf_wide_alpha_ > 1.0);
+ cerr << "Using coarse-to-fine pruning with " << ctf_iterations_ << " grammar projection(s) and alpha=" << ctf_alpha_ << endl;
+ cerr << " Coarse beam will be widened " << ctf_num_widenings_ << " times by a factor of " << ctf_wide_alpha_ << " if fine parse fails" << endl;
+ }
+ if (!conf.count("scfg_no_hiero_glue_grammar")){
+ GlueGrammar* g = new GlueGrammar(goal, default_nt, ctf_iterations_);
+ g->SetGrammarName("GlueGrammar");
+ grammars.push_back(GrammarPtr(g));
+ cerr << "Adding glue grammar for default nonterminal " << default_nt <<
+ " and goal nonterminal " << goal << endl;
+ }
+ }
const int max_span_limit;
const bool add_pass_through_rules;
const string goal;
const string default_nt;
+ const bool use_ctf_;
+ double ctf_alpha_;
+ double ctf_wide_alpha_;
+ int ctf_num_widenings_;
+ bool ctf_exhaustive_;
+ bool show_tree_structure_;
+ unsigned int ctf_iterations_;
vector<GrammarPtr> grammars;
bool Translate(const string& input,
@@ -64,29 +100,155 @@ struct SCFGTranslatorImpl {
LatticeTools::ConvertTextOrPLF(input, &lattice);
smeta->SetSourceLength(lattice.size());
if (add_pass_through_rules){
- PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt);
+ cerr << "Adding pass through grammar" << endl;
+ PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_);
g->SetGrammarName("PassThrough");
glist.push_back(GrammarPtr(g));
- cerr << "Adding pass through grammar" << endl;
}
-
-
-
- if(printGrammarsUsed){ //Iterate trough grammars we have for this sentence and list them
- for (int gi = 0; gi < glist.size(); ++gi)
- {
- cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl;
- }
+ for (int gi = 0; gi < glist.size(); ++gi) {
+ if(printGrammarsUsed)
+ cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl;
}
-
+ cerr << "First pass parse... " << endl;
ExhaustiveBottomUpParser parser(goal, glist);
- if (!parser.Parse(lattice, forest))
+ if (!parser.Parse(lattice, forest)){
+ cerr << "parse failed." << endl;
return false;
+ } else {
+ cerr << "parse succeeded." << endl;
+ }
forest->Reweight(weights);
+ if (use_ctf_) {
+ Hypergraph::Node& goal_node = *(forest->nodes_.end()-1);
+ foreach(int edge_id, goal_node.in_edges_)
+ RefineRule(forest->edges_[edge_id].rule_, ctf_iterations_);
+ double alpha = ctf_alpha_;
+ bool found_parse;
+ for (int i=-1; i < ctf_num_widenings_; ++i) {
+ cerr << "Coarse-to-fine source parse, alpha=" << alpha << endl;
+ found_parse = true;
+ Hypergraph refined_forest = *forest;
+ for (int j=0; j < ctf_iterations_; ++j) {
+ cerr << viterbi_stats(refined_forest," Coarse forest",true,show_tree_structure_);
+ cerr << " Iteration " << (j+1) << ": Pruning forest... ";
+ refined_forest.BeamPruneInsideOutside(1.0, false, alpha, NULL);
+ cerr << "Refining forest...";
+ if (RefineForest(&refined_forest)) {
+ cerr << " Refinement succeeded." << endl;
+ refined_forest.Reweight(weights);
+ } else {
+ cerr << " Refinement failed. Widening beam." << endl;
+ found_parse = false;
+ break;
+ }
+ }
+ if (found_parse) {
+ forest->swap(refined_forest);
+ break;
+ }
+ alpha *= ctf_wide_alpha_;
+ }
+ if (!found_parse){
+ if (ctf_exhaustive_){
+ cerr << "Last resort: refining coarse forest without pruning...";
+ for (int j=0; j < ctf_iterations_; ++j) {
+ if (RefineForest(forest)){
+ cerr << " Refinement succeeded." << endl;
+ forest->Reweight(weights);
+ } else {
+ cerr << " Refinement failed. No parse found for this sentence." << endl;
+ return false;
+ }
+ }
+ } else
+ return false;
+ }
+ }
return true;
}
+
+ typedef std::pair<int, WordID> StateSplit;
+ typedef std::pair<StateSplit, int> StateSplitPair;
+ typedef std::tr1::unordered_map<StateSplit, int, boost::hash<StateSplit> > Split2Node;
+ typedef std::tr1::unordered_map<int, vector<WordID> > Splits;
+
+ bool RefineForest(Hypergraph* forest) {
+ Hypergraph refined_forest;
+ Split2Node s2n;
+ Splits splits;
+ Hypergraph::Node& coarse_goal_node = *(forest->nodes_.end()-1);
+ bool refined_goal_node = false;
+ foreach(Hypergraph::Node& node, forest->nodes_){
+ cerr << ".";
+ foreach(int edge_id, node.in_edges_) {
+ Hypergraph::Edge& edge = forest->edges_[edge_id];
+ std::vector<int> nt_positions;
+ TRulePtr& coarse_rule_ptr = edge.rule_;
+ for(int i=0; i< coarse_rule_ptr->f_.size(); ++i){
+ if (coarse_rule_ptr->f_[i] < 0)
+ nt_positions.push_back(i);
+ }
+ if (coarse_rule_ptr->fine_rules_ == 0) {
+ cerr << "Parsing with mixed levels of coarse-to-fine granularity is currently unsupported." <<
+ endl << "Could not find refinement for: " << coarse_rule_ptr->AsString() << " on edge " << edge_id << " spanning " << edge.i_ << "," << edge.j_ << endl;
+ abort();
+ }
+ // fine rules apply only if state splits on tail nodes match fine rule nonterminals
+ foreach(TRulePtr& fine_rule_ptr, *(coarse_rule_ptr->fine_rules_)) {
+ Hypergraph::TailNodeVector tail;
+ for (int pos_i=0; pos_i<nt_positions.size(); ++pos_i){
+ WordID fine_cat = fine_rule_ptr->f_[nt_positions[pos_i]];
+ Split2Node::iterator it =
+ s2n.find(StateSplit(edge.tail_nodes_[pos_i], fine_cat));
+ if (it == s2n.end())
+ break;
+ else
+ tail.push_back(it->second);
+ }
+ if (tail.size() == nt_positions.size()) {
+ WordID cat = fine_rule_ptr->lhs_;
+ Hypergraph::Edge* new_edge = refined_forest.AddEdge(fine_rule_ptr, tail);
+ new_edge->i_ = edge.i_;
+ new_edge->j_ = edge.j_;
+ new_edge->feature_values_ = fine_rule_ptr->GetFeatureValues();
+ new_edge->feature_values_.set_value(FD::Convert("LatticeCost"),
+ edge.feature_values_[FD::Convert("LatticeCost")]);
+ Hypergraph::Node* head_node;
+ Split2Node::iterator it = s2n.find(StateSplit(node.id_, cat));
+ if (it == s2n.end()){
+ head_node = refined_forest.AddNode(cat);
+ s2n.insert(StateSplitPair(StateSplit(node.id_, cat), head_node->id_));
+ splits[node.id_].push_back(cat);
+ if (&node == &coarse_goal_node)
+ refined_goal_node = true;
+ } else
+ head_node = &(refined_forest.nodes_[it->second]);
+ refined_forest.ConnectEdgeToHeadNode(new_edge, head_node);
+ }
+ }
+ }
+ }
+ cerr << endl;
+ forest->swap(refined_forest);
+ if (!refined_goal_node)
+ return false;
+ return true;
+ }
+ void OutputForest(Hypergraph* h) {
+ foreach(Hypergraph::Node& n, h->nodes_){
+ if (n.in_edges_.size() == 0){
+ cerr << "<" << TD::Convert(-n.cat_) << ", ?, ?>" << endl;
+ } else {
+ cerr << "<" << TD::Convert(-n.cat_) << ", " << h->edges_[n.in_edges_[0]].i_ << ", " << h->edges_[n.in_edges_[0]].j_ << ">" << endl;
+ }
+ foreach(int edge_id, n.in_edges_){
+ cerr << " " << h->edges_[edge_id].rule_->AsString() << endl;
+ }
+ }
+ }
};
+
/*
Called once from cdec.cc to setup the initial SCFG translation structure backend
*/