summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authoradam.d.lopez <adam.d.lopez@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-13 03:27:59 +0000
committeradam.d.lopez <adam.d.lopez@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-13 03:27:59 +0000
commit47c75319638866609f669346e15663c5ba43af7f (patch)
tree50bc1867bdb1ce27e679fb8981b1b805c4897cd6 /decoder
parenta2360c873ba8b72744e16752a067276a46d63645 (diff)
cdec now supports coarse-to-fine decoding (for SCFG only).
CTF has several options: -coarse_to_fine_beam_prune=<double> (required to activate CTF) assign an alpha parameter for pruning the coarse foreast -ctf_beam_widen=<double> (optional, defaults to 2.0): ratio to widen coarse pruning beam if fine parse fails. -ctf_num_widenings=<int> (optional, defaults to 2): number of times to widen coarse beam before defaulting to exhaustive source parsing -ctf_no_exhaustive (optional) do not attempt exhaustive parse if CTF fails to find a parse. Additionally, script extools/coarsen_grammar.pl will create a coarse-to-fine grammar (for X?? categories *only*). cdec will read CTF grammars in a format identical to the original, in which refinements of a rule immediately follow the coarse projection, preceded by an additional whitespace character. Not fully tested, but should be backwards compatible. Also not yet integrated into pipelines, but should work on the command line. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@231 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder')
-rw-r--r--decoder/cdec.cc4
-rw-r--r--decoder/grammar.cc37
-rw-r--r--decoder/grammar.h13
-rw-r--r--decoder/rule_lexer.h2
-rw-r--r--decoder/rule_lexer.l48
-rw-r--r--decoder/scfg_translator.cc242
-rw-r--r--decoder/trule.h8
7 files changed, 291 insertions, 63 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index 09760f6b..b6cc6f66 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -125,6 +125,10 @@ void InitCommandLine(int argc, char** argv, po::variables_map* confp) {
("prelm_density_prune", po::value<double>(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
("density_prune", po::value<double>(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
("prelm_beam_prune", po::value<double>(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)")
+ ("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)")
+ ("ctf_beam_widen", po::value<double>()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found")
+ ("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse")
+ ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails")
("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)")
("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices")
("promise_power",po::value<double>()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams. 0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit)")
diff --git a/decoder/grammar.cc b/decoder/grammar.cc
index 499e79fe..26efaf99 100644
--- a/decoder/grammar.cc
+++ b/decoder/grammar.cc
@@ -81,8 +81,14 @@ const GrammarIter* TextGrammar::GetRoot() const {
return &pimpl_->root_;
}
-void TextGrammar::AddRule(const TRulePtr& rule) {
- if (rule->IsUnary()) {
+void TextGrammar::AddRule(const TRulePtr& rule, const unsigned int ctf_level, const TRulePtr& coarse_rule) {
+ if (ctf_level > 0) {
+ // assume that coarse_rule is already in tree (would be safer to check)
+ if (coarse_rule->fine_rules_ == 0)
+ coarse_rule->fine_rules_.reset(new std::vector<TRulePtr>());
+ coarse_rule->fine_rules_->push_back(rule);
+ ctf_levels_ = std::max(ctf_levels_, ctf_level);
+ } else if (rule->IsUnary()) {
rhs2unaries_[rule->f().front()].push_back(rule);
unaries_.push_back(rule);
} else {
@@ -95,8 +101,8 @@ void TextGrammar::AddRule(const TRulePtr& rule) {
}
}
-static void AddRuleHelper(const TRulePtr& new_rule, void* extra) {
- static_cast<TextGrammar*>(extra)->AddRule(new_rule);
+static void AddRuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) {
+ static_cast<TextGrammar*>(extra)->AddRule(new_rule, ctf_level, coarse_rule);
}
void TextGrammar::ReadFromFile(const string& filename) {
@@ -110,22 +116,29 @@ bool TextGrammar::HasRuleForSpan(int /* i */, int /* j */, int distance) const {
GlueGrammar::GlueGrammar(const string& file) : TextGrammar(file) {}
-GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt) {
- TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]"));
- TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["
- + default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1"));
+void RefineRule(TRulePtr pt, const unsigned int ctf_level){
+ for (unsigned int i=0; i<ctf_level; ++i){
+ TRulePtr r(new TRule(*pt));
+ pt->fine_rules_.reset(new vector<TRulePtr>);
+ pt->fine_rules_->push_back(r);
+ pt = r;
+ }
+}
+GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt, const unsigned int ctf_level) {
+ TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]"));
AddRule(stop_glue);
+ RefineRule(stop_glue, ctf_level);
+ TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["+ default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1"));
AddRule(glue);
- //cerr << "GLUE: " << stop_glue->AsString() << endl;
- //cerr << "GLUE: " << glue->AsString() << endl;
+ RefineRule(glue, ctf_level);
}
bool GlueGrammar::HasRuleForSpan(int i, int /* j */, int /* distance */) const {
return (i == 0);
}
-PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat) :
+PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level) :
has_rule_(input.size() + 1) {
for (int i = 0; i < input.size(); ++i) {
const vector<LatticeArc>& alts = input[i];
@@ -135,7 +148,7 @@ PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat)
const string& src = TD::Convert(alts[k].label);
TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1"));
AddRule(pt);
-// cerr << "PT: " << pt->AsString() << endl;
+ RefineRule(pt, ctf_level);
}
}
}
diff --git a/decoder/grammar.h b/decoder/grammar.h
index 46886d3a..b26eb912 100644
--- a/decoder/grammar.h
+++ b/decoder/grammar.h
@@ -1,6 +1,7 @@
#ifndef GRAMMAR_H_
#define GRAMMAR_H_
+#include <algorithm>
#include <vector>
#include <map>
#include <set>
@@ -26,11 +27,13 @@ struct GrammarIter {
struct Grammar {
typedef std::map<WordID, std::vector<TRulePtr> > Cat2Rules;
static const std::vector<TRulePtr> NO_RULES;
-
+
+ Grammar(): ctf_levels_(0) {}
virtual ~Grammar();
virtual const GrammarIter* GetRoot() const = 0;
virtual bool HasRuleForSpan(int i, int j, int distance) const;
const std::string GetGrammarName(){return grammar_name_;}
+ unsigned int GetCTFLevels(){ return ctf_levels_; }
void SetGrammarName(std::string n) {grammar_name_ = n; }
// cat is the category to be rewritten
inline const std::vector<TRulePtr>& GetAllUnaryRules() const {
@@ -50,6 +53,7 @@ struct Grammar {
Cat2Rules rhs2unaries_; // these must be filled in by subclasses!
std::vector<TRulePtr> unaries_;
std::string grammar_name_;
+ unsigned int ctf_levels_;
};
typedef boost::shared_ptr<Grammar> GrammarPtr;
@@ -61,7 +65,7 @@ struct TextGrammar : public Grammar {
void SetMaxSpan(int m) { max_span_ = m; }
virtual const GrammarIter* GetRoot() const;
- void AddRule(const TRulePtr& rule);
+ void AddRule(const TRulePtr& rule, const unsigned int ctf_level=0, const TRulePtr& coarse_parent=TRulePtr());
void ReadFromFile(const std::string& filename);
virtual bool HasRuleForSpan(int i, int j, int distance) const;
const std::vector<TRulePtr>& GetUnaryRules(const WordID& cat) const;
@@ -75,15 +79,16 @@ struct TextGrammar : public Grammar {
struct GlueGrammar : public TextGrammar {
// read glue grammar from file
explicit GlueGrammar(const std::string& file);
- GlueGrammar(const std::string& goal_nt, const std::string& default_nt); // "S", "X"
+ GlueGrammar(const std::string& goal_nt, const std::string& default_nt, const unsigned int ctf_level=0); // "S", "X"
virtual bool HasRuleForSpan(int i, int j, int distance) const;
};
struct PassThroughGrammar : public TextGrammar {
- PassThroughGrammar(const Lattice& input, const std::string& cat);
+ PassThroughGrammar(const Lattice& input, const std::string& cat, const unsigned int ctf_level=0);
virtual bool HasRuleForSpan(int i, int j, int distance) const;
private:
std::vector<std::set<int> > has_rule_; // index by [i][j]
};
+void RefineRule(TRulePtr pt, const unsigned int ctf_level);
#endif
diff --git a/decoder/rule_lexer.h b/decoder/rule_lexer.h
index e5db4018..976ea02b 100644
--- a/decoder/rule_lexer.h
+++ b/decoder/rule_lexer.h
@@ -6,7 +6,7 @@
#include "trule.h"
struct RuleLexer {
- typedef void (*RuleCallback)(const TRulePtr& new_rule, void* extra);
+ typedef void (*RuleCallback)(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra);
static void ReadRules(std::istream* in, RuleCallback func, void* extra);
};
diff --git a/decoder/rule_lexer.l b/decoder/rule_lexer.l
index e2acd752..0216b119 100644
--- a/decoder/rule_lexer.l
+++ b/decoder/rule_lexer.l
@@ -6,6 +6,7 @@
#include <sstream>
#include <cstring>
#include <cassert>
+#include <stack>
#include "tdict.h"
#include "fdict.h"
#include "trule.h"
@@ -45,7 +46,8 @@ int scfglex_nt_sanity[MAX_ARITY];
int scfglex_src_nts[MAX_ARITY];
float scfglex_nt_size_means[MAX_ARITY];
float scfglex_nt_size_vars[MAX_ARITY];
-
+std::stack<TRulePtr> ctf_rule_stack;
+unsigned int ctf_level = 0;
void sanity_check_trg_symbol(WordID nt, int index) {
if (scfglex_src_nts[index-1] != nt) {
@@ -77,6 +79,34 @@ void scfglex_reset() {
scfglex_trg_rhs_size = 0;
}
+void check_and_update_ctf_stack(const TRulePtr& rp) {
+ if (ctf_level > ctf_rule_stack.size()){
+ std::cerr << "Found rule at projection level " << ctf_level << " but previous rule was at level "
+ << ctf_rule_stack.size()-1 << " (cannot exceed previous level by more than one; line " << lex_line << ")" << std::endl;
+ abort();
+ }
+ while (ctf_rule_stack.size() > ctf_level)
+ ctf_rule_stack.pop();
+ // ensure that rule has the same signature as parent (coarse) rule. Rules may *only*
+ // differ by the rhs nonterminals, not terminals or permutation of nonterminals.
+ if (ctf_rule_stack.size() > 0) {
+ TRulePtr& coarse_rp = ctf_rule_stack.top();
+ if (rp->f_.size() != coarse_rp->f_.size() || rp->e_ != coarse_rp->e_) {
+ std::cerr << "Rule " << (rp->AsString()) << " is not a projection of " <<
+ (coarse_rp->AsString()) << std::endl;
+ abort();
+ }
+ for (int i=0; i<rp->f_.size(); ++i) {
+ if (((rp->f_[i]<0) != (coarse_rp->f_[i]<0)) ||
+ ((rp->f_[i]>0) && (rp->f_[i] != coarse_rp->f_[i]))) {
+ std::cerr << "Rule " << (rp->AsString()) << " is not a projection of " <<
+ (coarse_rp->AsString()) << std::endl;
+ abort();
+ }
+ }
+ }
+}
+
%}
REAL [\-+]?[0-9]+(\.[0-9]*([eE][-+]*[0-9]+)?)?|inf|[\-+]inf
@@ -85,7 +115,9 @@ NT [\-#$A-Z_:=.",\\][\-#$".A-Z+/=_0-9!:@\\]*
%x LHS_END SRC TRG FEATS FEATVAL ALIGNS
%%
-<INITIAL>[ \t] ;
+<INITIAL>[ \t] {
+ ctf_level++;
+ };
<INITIAL>\[{NT}\] {
scfglex_tmp_token.assign(yytext + 1, yyleng - 2);
@@ -182,12 +214,16 @@ NT [\-#$A-Z_:=.",\\][\-#$".A-Z+/=_0-9!:@\\]*
abort();
}
TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity));
- rule_callback(rp, rule_callback_extra);
+ check_and_update_ctf_stack(rp);
+ TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top());
+ rule_callback(rp, ctf_level, coarse_rp, rule_callback_extra);
+ ctf_rule_stack.push(rp);
// std::cerr << rp->AsString() << std::endl;
num_rules++;
- lex_line++;
- if (num_rules % 50000 == 0) { std::cerr << '.' << std::flush; fl = true; }
- if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; }
+ lex_line++;
+ if (num_rules % 50000 == 0) { std::cerr << '.' << std::flush; fl = true; }
+ if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; }
+ ctf_level = 0;
BEGIN(INITIAL);
}
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index 866c2721..32acfd65 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -3,13 +3,20 @@
//TODO: grammar heuristic (min cost of reachable rule set) for binarizations (active edges) if we wish to prune those also
#include "translator.h"
-
+#include <algorithm>
#include <vector>
-
+#include <tr1/unordered_map>
+#include <boost/foreach.hpp>
+#include <boost/functional/hash.hpp>
#include "hg.h"
#include "grammar.h"
#include "bottom_up_parser.h"
#include "sentence_metadata.h"
+#include "tdict.h"
+#include "viterbi.h"
+
+#define foreach BOOST_FOREACH
+#define reverse_foreach BOOST_REVERSE_FOREACH
using namespace std;
static bool usingSentenceGrammar = false;
@@ -20,39 +27,68 @@ struct SCFGTranslatorImpl {
max_span_limit(conf["scfg_max_span_limit"].as<int>()),
add_pass_through_rules(conf.count("add_pass_through_rules")),
goal(conf["goal"].as<string>()),
- default_nt(conf["scfg_default_nt"].as<string>()) {
- if(conf.count("grammar"))
- {
- vector<string> gfiles = conf["grammar"].as<vector<string> >();
- for (int i = 0; i < gfiles.size(); ++i) {
- cerr << "Reading SCFG grammar from " << gfiles[i] << endl;
- TextGrammar* g = new TextGrammar(gfiles[i]);
- g->SetMaxSpan(max_span_limit);
- g->SetGrammarName(gfiles[i]);
- grammars.push_back(GrammarPtr(g));
-
- }
- }
- if (!conf.count("scfg_no_hiero_glue_grammar"))
- {
- GlueGrammar* g = new GlueGrammar(goal, default_nt);
- g->SetGrammarName("GlueGrammar");
- grammars.push_back(GrammarPtr(g));
- cerr << "Adding glue grammar" << endl;
+ default_nt(conf["scfg_default_nt"].as<string>()),
+ use_ctf_(conf.count("coarse_to_fine_beam_prune"))
+ {
+ if(conf.count("grammar")){
+ vector<string> gfiles = conf["grammar"].as<vector<string> >();
+ for (int i = 0; i < gfiles.size(); ++i) {
+ cerr << "Reading SCFG grammar from " << gfiles[i] << endl;
+ TextGrammar* g = new TextGrammar(gfiles[i]);
+ g->SetMaxSpan(max_span_limit);
+ g->SetGrammarName(gfiles[i]);
+ grammars.push_back(GrammarPtr(g));
+ }
+ }
+ cerr << std::endl;
+ if (conf.count("scfg_extra_glue_grammar")) {
+ GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>());
+ g->SetGrammarName("ExtraGlueGrammar");
+ grammars.push_back(GrammarPtr(g));
+ cerr << "Adding glue grammar from file " << conf["scfg_extra_glue_grammar"].as<string>() << endl;
+ }
+ ctf_iterations_=0;
+ if (use_ctf_){
+ ctf_alpha_ = conf["coarse_to_fine_beam_prune"].as<double>();
+ foreach(GrammarPtr& gp, grammars){
+ ctf_iterations_ = std::max(gp->GetCTFLevels(), ctf_iterations_);
}
- if (conf.count("scfg_extra_glue_grammar"))
- {
- GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>());
- g->SetGrammarName("ExtraGlueGrammar");
- grammars.push_back(GrammarPtr(g));
- cerr << "Adding extra glue grammar" << endl;
+ foreach(GrammarPtr& gp, grammars){
+ if (gp->GetCTFLevels() != ctf_iterations_){
+ cerr << "Grammar " << gp->GetGrammarName() << " has CTF level " << gp->GetCTFLevels() <<
+ " but overall number of CTF levels is " << ctf_iterations_ << endl <<
+ "Mixing coarse-to-fine grammars of different granularities is not supported" << endl;
+ abort();
+ }
}
- }
+ show_tree_structure_ = conf.count("show_tree_structure");
+ ctf_wide_alpha_ = conf["ctf_beam_widen"].as<double>();
+ ctf_num_widenings_ = conf["ctf_num_widenings"].as<int>();
+ ctf_exhaustive_ = (conf.count("ctf_no_exhaustive") == 0);
+ assert(ctf_wide_alpha_ > 1.0);
+ cerr << "Using coarse-to-fine pruning with " << ctf_iterations_ << " grammar projection(s) and alpha=" << ctf_alpha_ << endl;
+ cerr << " Coarse beam will be widened " << ctf_num_widenings_ << " times by a factor of " << ctf_wide_alpha_ << " if fine parse fails" << endl;
+ }
+ if (!conf.count("scfg_no_hiero_glue_grammar")){
+ GlueGrammar* g = new GlueGrammar(goal, default_nt, ctf_iterations_);
+ g->SetGrammarName("GlueGrammar");
+ grammars.push_back(GrammarPtr(g));
+ cerr << "Adding glue grammar for default nonterminal " << default_nt <<
+ " and goal nonterminal " << goal << endl;
+ }
+ }
const int max_span_limit;
const bool add_pass_through_rules;
const string goal;
const string default_nt;
+ const bool use_ctf_;
+ double ctf_alpha_;
+ double ctf_wide_alpha_;
+ int ctf_num_widenings_;
+ bool ctf_exhaustive_;
+ bool show_tree_structure_;
+ unsigned int ctf_iterations_;
vector<GrammarPtr> grammars;
bool Translate(const string& input,
@@ -64,29 +100,155 @@ struct SCFGTranslatorImpl {
LatticeTools::ConvertTextOrPLF(input, &lattice);
smeta->SetSourceLength(lattice.size());
if (add_pass_through_rules){
- PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt);
+ cerr << "Adding pass through grammar" << endl;
+ PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_);
g->SetGrammarName("PassThrough");
glist.push_back(GrammarPtr(g));
- cerr << "Adding pass through grammar" << endl;
}
-
-
-
- if(printGrammarsUsed){ //Iterate trough grammars we have for this sentence and list them
- for (int gi = 0; gi < glist.size(); ++gi)
- {
- cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl;
- }
+ for (int gi = 0; gi < glist.size(); ++gi) {
+ if(printGrammarsUsed)
+ cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl;
}
-
+ cerr << "First pass parse... " << endl;
ExhaustiveBottomUpParser parser(goal, glist);
- if (!parser.Parse(lattice, forest))
+ if (!parser.Parse(lattice, forest)){
+ cerr << "parse failed." << endl;
return false;
+ } else {
+ cerr << "parse succeeded." << endl;
+ }
forest->Reweight(weights);
+ if (use_ctf_) {
+ Hypergraph::Node& goal_node = *(forest->nodes_.end()-1);
+ foreach(int edge_id, goal_node.in_edges_)
+ RefineRule(forest->edges_[edge_id].rule_, ctf_iterations_);
+ double alpha = ctf_alpha_;
+ bool found_parse;
+ for (int i=-1; i < ctf_num_widenings_; ++i) {
+ cerr << "Coarse-to-fine source parse, alpha=" << alpha << endl;
+ found_parse = true;
+ Hypergraph refined_forest = *forest;
+ for (int j=0; j < ctf_iterations_; ++j) {
+ cerr << viterbi_stats(refined_forest," Coarse forest",true,show_tree_structure_);
+ cerr << " Iteration " << (j+1) << ": Pruning forest... ";
+ refined_forest.BeamPruneInsideOutside(1.0, false, alpha, NULL);
+ cerr << "Refining forest...";
+ if (RefineForest(&refined_forest)) {
+ cerr << " Refinement succeeded." << endl;
+ refined_forest.Reweight(weights);
+ } else {
+ cerr << " Refinement failed. Widening beam." << endl;
+ found_parse = false;
+ break;
+ }
+ }
+ if (found_parse) {
+ forest->swap(refined_forest);
+ break;
+ }
+ alpha *= ctf_wide_alpha_;
+ }
+ if (!found_parse){
+ if (ctf_exhaustive_){
+ cerr << "Last resort: refining coarse forest without pruning...";
+ for (int j=0; j < ctf_iterations_; ++j) {
+ if (RefineForest(forest)){
+ cerr << " Refinement succeeded." << endl;
+ forest->Reweight(weights);
+ } else {
+ cerr << " Refinement failed. No parse found for this sentence." << endl;
+ return false;
+ }
+ }
+ } else
+ return false;
+ }
+ }
return true;
}
+
+ typedef std::pair<int, WordID> StateSplit;
+ typedef std::pair<StateSplit, int> StateSplitPair;
+ typedef std::tr1::unordered_map<StateSplit, int, boost::hash<StateSplit> > Split2Node;
+ typedef std::tr1::unordered_map<int, vector<WordID> > Splits;
+
+ bool RefineForest(Hypergraph* forest) {
+ Hypergraph refined_forest;
+ Split2Node s2n;
+ Splits splits;
+ Hypergraph::Node& coarse_goal_node = *(forest->nodes_.end()-1);
+ bool refined_goal_node = false;
+ foreach(Hypergraph::Node& node, forest->nodes_){
+ cerr << ".";
+ foreach(int edge_id, node.in_edges_) {
+ Hypergraph::Edge& edge = forest->edges_[edge_id];
+ std::vector<int> nt_positions;
+ TRulePtr& coarse_rule_ptr = edge.rule_;
+ for(int i=0; i< coarse_rule_ptr->f_.size(); ++i){
+ if (coarse_rule_ptr->f_[i] < 0)
+ nt_positions.push_back(i);
+ }
+ if (coarse_rule_ptr->fine_rules_ == 0) {
+ cerr << "Parsing with mixed levels of coarse-to-fine granularity is currently unsupported." <<
+ endl << "Could not find refinement for: " << coarse_rule_ptr->AsString() << " on edge " << edge_id << " spanning " << edge.i_ << "," << edge.j_ << endl;
+ abort();
+ }
+ // fine rules apply only if state splits on tail nodes match fine rule nonterminals
+ foreach(TRulePtr& fine_rule_ptr, *(coarse_rule_ptr->fine_rules_)) {
+ Hypergraph::TailNodeVector tail;
+ for (int pos_i=0; pos_i<nt_positions.size(); ++pos_i){
+ WordID fine_cat = fine_rule_ptr->f_[nt_positions[pos_i]];
+ Split2Node::iterator it =
+ s2n.find(StateSplit(edge.tail_nodes_[pos_i], fine_cat));
+ if (it == s2n.end())
+ break;
+ else
+ tail.push_back(it->second);
+ }
+ if (tail.size() == nt_positions.size()) {
+ WordID cat = fine_rule_ptr->lhs_;
+ Hypergraph::Edge* new_edge = refined_forest.AddEdge(fine_rule_ptr, tail);
+ new_edge->i_ = edge.i_;
+ new_edge->j_ = edge.j_;
+ new_edge->feature_values_ = fine_rule_ptr->GetFeatureValues();
+ new_edge->feature_values_.set_value(FD::Convert("LatticeCost"),
+ edge.feature_values_[FD::Convert("LatticeCost")]);
+ Hypergraph::Node* head_node;
+ Split2Node::iterator it = s2n.find(StateSplit(node.id_, cat));
+ if (it == s2n.end()){
+ head_node = refined_forest.AddNode(cat);
+ s2n.insert(StateSplitPair(StateSplit(node.id_, cat), head_node->id_));
+ splits[node.id_].push_back(cat);
+ if (&node == &coarse_goal_node)
+ refined_goal_node = true;
+ } else
+ head_node = &(refined_forest.nodes_[it->second]);
+ refined_forest.ConnectEdgeToHeadNode(new_edge, head_node);
+ }
+ }
+ }
+ }
+ cerr << endl;
+ forest->swap(refined_forest);
+ if (!refined_goal_node)
+ return false;
+ return true;
+ }
+ void OutputForest(Hypergraph* h) {
+ foreach(Hypergraph::Node& n, h->nodes_){
+ if (n.in_edges_.size() == 0){
+ cerr << "<" << TD::Convert(-n.cat_) << ", ?, ?>" << endl;
+ } else {
+ cerr << "<" << TD::Convert(-n.cat_) << ", " << h->edges_[n.in_edges_[0]].i_ << ", " << h->edges_[n.in_edges_[0]].j_ << ">" << endl;
+ }
+ foreach(int edge_id, n.in_edges_){
+ cerr << " " << h->edges_[edge_id].rule_->AsString() << endl;
+ }
+ }
+ }
};
+
/*
Called once from cdec.cc to setup the initial SCFG translation structure backend
*/
diff --git a/decoder/trule.h b/decoder/trule.h
index defdbeb9..3bc96165 100644
--- a/decoder/trule.h
+++ b/decoder/trule.h
@@ -34,6 +34,9 @@ class TRule {
TRule(const std::vector<WordID>& e, const std::vector<WordID>& f, const WordID& lhs) :
e_(e), f_(f), lhs_(lhs), prev_i(-1), prev_j(-1) {}
+ TRule(const TRule& other) :
+ e_(other.e_), f_(other.f_), lhs_(other.lhs_), scores_(other.scores_), arity_(other.arity_), prev_i(-1), prev_j(-1) {}
+
// deprecated - this will be private soon
explicit TRule(const std::string& text, bool strict = false, bool mono = false) : prev_i(-1), prev_j(-1) {
ReadFromString(text, strict, mono);
@@ -130,6 +133,8 @@ class TRule {
SparseVector<double> scores_;
char arity_;
+
+ // these attributes are application-specific and should probably be refactored
TRulePtr parent_rule_; // usually NULL, except when doing constrained decoding
// this is only used when doing synchronous parsing
@@ -139,6 +144,9 @@ class TRule {
// may be null
boost::shared_ptr<NTSizeSummaryStatistics> nt_size_summary_;
+ // only for coarse-to-fine decoding
+ boost::shared_ptr<std::vector<TRulePtr> > fine_rules_;
+
private:
TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {}
bool SanityCheck() const;