summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore4
-rw-r--r--decoder/cdec.cc4
-rw-r--r--decoder/grammar.cc37
-rw-r--r--decoder/grammar.h13
-rw-r--r--decoder/rule_lexer.h2
-rw-r--r--decoder/rule_lexer.l48
-rw-r--r--decoder/scfg_translator.cc242
-rw-r--r--decoder/trule.h8
-rwxr-xr-xextools/coarsen_grammar.pl133
9 files changed, 428 insertions, 63 deletions
diff --git a/.gitignore b/.gitignore
index 1787b77e..7ed566e1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,8 @@
*swp
+*.o
+vest/sentserver
+vest/sentclient
+gi/pyp-topics/src/contexts_lexer.cc
config.guess
config.sub
libtool
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index 09760f6b..b6cc6f66 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -125,6 +125,10 @@ void InitCommandLine(int argc, char** argv, po::variables_map* confp) {
("prelm_density_prune", po::value<double>(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
("density_prune", po::value<double>(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
("prelm_beam_prune", po::value<double>(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)")
+ ("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)")
+ ("ctf_beam_widen", po::value<double>()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found")
+ ("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse")
+ ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails")
("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)")
("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices")
("promise_power",po::value<double>()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams. 0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit)")
diff --git a/decoder/grammar.cc b/decoder/grammar.cc
index 499e79fe..26efaf99 100644
--- a/decoder/grammar.cc
+++ b/decoder/grammar.cc
@@ -81,8 +81,14 @@ const GrammarIter* TextGrammar::GetRoot() const {
return &pimpl_->root_;
}
-void TextGrammar::AddRule(const TRulePtr& rule) {
- if (rule->IsUnary()) {
+void TextGrammar::AddRule(const TRulePtr& rule, const unsigned int ctf_level, const TRulePtr& coarse_rule) {
+ if (ctf_level > 0) {
+ // assume that coarse_rule is already in tree (would be safer to check)
+ if (coarse_rule->fine_rules_ == 0)
+ coarse_rule->fine_rules_.reset(new std::vector<TRulePtr>());
+ coarse_rule->fine_rules_->push_back(rule);
+ ctf_levels_ = std::max(ctf_levels_, ctf_level);
+ } else if (rule->IsUnary()) {
rhs2unaries_[rule->f().front()].push_back(rule);
unaries_.push_back(rule);
} else {
@@ -95,8 +101,8 @@ void TextGrammar::AddRule(const TRulePtr& rule) {
}
}
-static void AddRuleHelper(const TRulePtr& new_rule, void* extra) {
- static_cast<TextGrammar*>(extra)->AddRule(new_rule);
+static void AddRuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) {
+ static_cast<TextGrammar*>(extra)->AddRule(new_rule, ctf_level, coarse_rule);
}
void TextGrammar::ReadFromFile(const string& filename) {
@@ -110,22 +116,29 @@ bool TextGrammar::HasRuleForSpan(int /* i */, int /* j */, int distance) const {
GlueGrammar::GlueGrammar(const string& file) : TextGrammar(file) {}
-GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt) {
- TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]"));
- TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["
- + default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1"));
+void RefineRule(TRulePtr pt, const unsigned int ctf_level){
+ for (unsigned int i=0; i<ctf_level; ++i){
+ TRulePtr r(new TRule(*pt));
+ pt->fine_rules_.reset(new vector<TRulePtr>);
+ pt->fine_rules_->push_back(r);
+ pt = r;
+ }
+}
+GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt, const unsigned int ctf_level) {
+ TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]"));
AddRule(stop_glue);
+ RefineRule(stop_glue, ctf_level);
+ TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["+ default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1"));
AddRule(glue);
- //cerr << "GLUE: " << stop_glue->AsString() << endl;
- //cerr << "GLUE: " << glue->AsString() << endl;
+ RefineRule(glue, ctf_level);
}
bool GlueGrammar::HasRuleForSpan(int i, int /* j */, int /* distance */) const {
return (i == 0);
}
-PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat) :
+PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level) :
has_rule_(input.size() + 1) {
for (int i = 0; i < input.size(); ++i) {
const vector<LatticeArc>& alts = input[i];
@@ -135,7 +148,7 @@ PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat)
const string& src = TD::Convert(alts[k].label);
TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1"));
AddRule(pt);
-// cerr << "PT: " << pt->AsString() << endl;
+ RefineRule(pt, ctf_level);
}
}
}
diff --git a/decoder/grammar.h b/decoder/grammar.h
index 46886d3a..b26eb912 100644
--- a/decoder/grammar.h
+++ b/decoder/grammar.h
@@ -1,6 +1,7 @@
#ifndef GRAMMAR_H_
#define GRAMMAR_H_
+#include <algorithm>
#include <vector>
#include <map>
#include <set>
@@ -26,11 +27,13 @@ struct GrammarIter {
struct Grammar {
typedef std::map<WordID, std::vector<TRulePtr> > Cat2Rules;
static const std::vector<TRulePtr> NO_RULES;
-
+
+ Grammar(): ctf_levels_(0) {}
virtual ~Grammar();
virtual const GrammarIter* GetRoot() const = 0;
virtual bool HasRuleForSpan(int i, int j, int distance) const;
const std::string GetGrammarName(){return grammar_name_;}
+ unsigned int GetCTFLevels(){ return ctf_levels_; }
void SetGrammarName(std::string n) {grammar_name_ = n; }
// cat is the category to be rewritten
inline const std::vector<TRulePtr>& GetAllUnaryRules() const {
@@ -50,6 +53,7 @@ struct Grammar {
Cat2Rules rhs2unaries_; // these must be filled in by subclasses!
std::vector<TRulePtr> unaries_;
std::string grammar_name_;
+ unsigned int ctf_levels_;
};
typedef boost::shared_ptr<Grammar> GrammarPtr;
@@ -61,7 +65,7 @@ struct TextGrammar : public Grammar {
void SetMaxSpan(int m) { max_span_ = m; }
virtual const GrammarIter* GetRoot() const;
- void AddRule(const TRulePtr& rule);
+ void AddRule(const TRulePtr& rule, const unsigned int ctf_level=0, const TRulePtr& coarse_parent=TRulePtr());
void ReadFromFile(const std::string& filename);
virtual bool HasRuleForSpan(int i, int j, int distance) const;
const std::vector<TRulePtr>& GetUnaryRules(const WordID& cat) const;
@@ -75,15 +79,16 @@ struct TextGrammar : public Grammar {
struct GlueGrammar : public TextGrammar {
// read glue grammar from file
explicit GlueGrammar(const std::string& file);
- GlueGrammar(const std::string& goal_nt, const std::string& default_nt); // "S", "X"
+ GlueGrammar(const std::string& goal_nt, const std::string& default_nt, const unsigned int ctf_level=0); // "S", "X"
virtual bool HasRuleForSpan(int i, int j, int distance) const;
};
struct PassThroughGrammar : public TextGrammar {
- PassThroughGrammar(const Lattice& input, const std::string& cat);
+ PassThroughGrammar(const Lattice& input, const std::string& cat, const unsigned int ctf_level=0);
virtual bool HasRuleForSpan(int i, int j, int distance) const;
private:
std::vector<std::set<int> > has_rule_; // index by [i][j]
};
+void RefineRule(TRulePtr pt, const unsigned int ctf_level);
#endif
diff --git a/decoder/rule_lexer.h b/decoder/rule_lexer.h
index e5db4018..976ea02b 100644
--- a/decoder/rule_lexer.h
+++ b/decoder/rule_lexer.h
@@ -6,7 +6,7 @@
#include "trule.h"
struct RuleLexer {
- typedef void (*RuleCallback)(const TRulePtr& new_rule, void* extra);
+ typedef void (*RuleCallback)(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra);
static void ReadRules(std::istream* in, RuleCallback func, void* extra);
};
diff --git a/decoder/rule_lexer.l b/decoder/rule_lexer.l
index e2acd752..0216b119 100644
--- a/decoder/rule_lexer.l
+++ b/decoder/rule_lexer.l
@@ -6,6 +6,7 @@
#include <sstream>
#include <cstring>
#include <cassert>
+#include <stack>
#include "tdict.h"
#include "fdict.h"
#include "trule.h"
@@ -45,7 +46,8 @@ int scfglex_nt_sanity[MAX_ARITY];
int scfglex_src_nts[MAX_ARITY];
float scfglex_nt_size_means[MAX_ARITY];
float scfglex_nt_size_vars[MAX_ARITY];
-
+std::stack<TRulePtr> ctf_rule_stack;
+unsigned int ctf_level = 0;
void sanity_check_trg_symbol(WordID nt, int index) {
if (scfglex_src_nts[index-1] != nt) {
@@ -77,6 +79,34 @@ void scfglex_reset() {
scfglex_trg_rhs_size = 0;
}
+void check_and_update_ctf_stack(const TRulePtr& rp) {
+ if (ctf_level > ctf_rule_stack.size()){
+ std::cerr << "Found rule at projection level " << ctf_level << " but previous rule was at level "
+ << ctf_rule_stack.size()-1 << " (cannot exceed previous level by more than one; line " << lex_line << ")" << std::endl;
+ abort();
+ }
+ while (ctf_rule_stack.size() > ctf_level)
+ ctf_rule_stack.pop();
+ // ensure that rule has the same signature as parent (coarse) rule. Rules may *only*
+ // differ by the rhs nonterminals, not terminals or permutation of nonterminals.
+ if (ctf_rule_stack.size() > 0) {
+ TRulePtr& coarse_rp = ctf_rule_stack.top();
+ if (rp->f_.size() != coarse_rp->f_.size() || rp->e_ != coarse_rp->e_) {
+ std::cerr << "Rule " << (rp->AsString()) << " is not a projection of " <<
+ (coarse_rp->AsString()) << std::endl;
+ abort();
+ }
+ for (int i=0; i<rp->f_.size(); ++i) {
+ if (((rp->f_[i]<0) != (coarse_rp->f_[i]<0)) ||
+ ((rp->f_[i]>0) && (rp->f_[i] != coarse_rp->f_[i]))) {
+ std::cerr << "Rule " << (rp->AsString()) << " is not a projection of " <<
+ (coarse_rp->AsString()) << std::endl;
+ abort();
+ }
+ }
+ }
+}
+
%}
REAL [\-+]?[0-9]+(\.[0-9]*([eE][-+]*[0-9]+)?)?|inf|[\-+]inf
@@ -85,7 +115,9 @@ NT [\-#$A-Z_:=.",\\][\-#$".A-Z+/=_0-9!:@\\]*
%x LHS_END SRC TRG FEATS FEATVAL ALIGNS
%%
-<INITIAL>[ \t] ;
+<INITIAL>[ \t] {
+ ctf_level++;
+ };
<INITIAL>\[{NT}\] {
scfglex_tmp_token.assign(yytext + 1, yyleng - 2);
@@ -182,12 +214,16 @@ NT [\-#$A-Z_:=.",\\][\-#$".A-Z+/=_0-9!:@\\]*
abort();
}
TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity));
- rule_callback(rp, rule_callback_extra);
+ check_and_update_ctf_stack(rp);
+ TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top());
+ rule_callback(rp, ctf_level, coarse_rp, rule_callback_extra);
+ ctf_rule_stack.push(rp);
// std::cerr << rp->AsString() << std::endl;
num_rules++;
- lex_line++;
- if (num_rules % 50000 == 0) { std::cerr << '.' << std::flush; fl = true; }
- if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; }
+ lex_line++;
+ if (num_rules % 50000 == 0) { std::cerr << '.' << std::flush; fl = true; }
+ if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; }
+ ctf_level = 0;
BEGIN(INITIAL);
}
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index 866c2721..32acfd65 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -3,13 +3,20 @@
//TODO: grammar heuristic (min cost of reachable rule set) for binarizations (active edges) if we wish to prune those also
#include "translator.h"
-
+#include <algorithm>
#include <vector>
-
+#include <tr1/unordered_map>
+#include <boost/foreach.hpp>
+#include <boost/functional/hash.hpp>
#include "hg.h"
#include "grammar.h"
#include "bottom_up_parser.h"
#include "sentence_metadata.h"
+#include "tdict.h"
+#include "viterbi.h"
+
+#define foreach BOOST_FOREACH
+#define reverse_foreach BOOST_REVERSE_FOREACH
using namespace std;
static bool usingSentenceGrammar = false;
@@ -20,39 +27,68 @@ struct SCFGTranslatorImpl {
max_span_limit(conf["scfg_max_span_limit"].as<int>()),
add_pass_through_rules(conf.count("add_pass_through_rules")),
goal(conf["goal"].as<string>()),
- default_nt(conf["scfg_default_nt"].as<string>()) {
- if(conf.count("grammar"))
- {
- vector<string> gfiles = conf["grammar"].as<vector<string> >();
- for (int i = 0; i < gfiles.size(); ++i) {
- cerr << "Reading SCFG grammar from " << gfiles[i] << endl;
- TextGrammar* g = new TextGrammar(gfiles[i]);
- g->SetMaxSpan(max_span_limit);
- g->SetGrammarName(gfiles[i]);
- grammars.push_back(GrammarPtr(g));
-
- }
- }
- if (!conf.count("scfg_no_hiero_glue_grammar"))
- {
- GlueGrammar* g = new GlueGrammar(goal, default_nt);
- g->SetGrammarName("GlueGrammar");
- grammars.push_back(GrammarPtr(g));
- cerr << "Adding glue grammar" << endl;
+ default_nt(conf["scfg_default_nt"].as<string>()),
+ use_ctf_(conf.count("coarse_to_fine_beam_prune"))
+ {
+ if(conf.count("grammar")){
+ vector<string> gfiles = conf["grammar"].as<vector<string> >();
+ for (int i = 0; i < gfiles.size(); ++i) {
+ cerr << "Reading SCFG grammar from " << gfiles[i] << endl;
+ TextGrammar* g = new TextGrammar(gfiles[i]);
+ g->SetMaxSpan(max_span_limit);
+ g->SetGrammarName(gfiles[i]);
+ grammars.push_back(GrammarPtr(g));
+ }
+ }
+ cerr << std::endl;
+ if (conf.count("scfg_extra_glue_grammar")) {
+ GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>());
+ g->SetGrammarName("ExtraGlueGrammar");
+ grammars.push_back(GrammarPtr(g));
+ cerr << "Adding glue grammar from file " << conf["scfg_extra_glue_grammar"].as<string>() << endl;
+ }
+ ctf_iterations_=0;
+ if (use_ctf_){
+ ctf_alpha_ = conf["coarse_to_fine_beam_prune"].as<double>();
+ foreach(GrammarPtr& gp, grammars){
+ ctf_iterations_ = std::max(gp->GetCTFLevels(), ctf_iterations_);
}
- if (conf.count("scfg_extra_glue_grammar"))
- {
- GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>());
- g->SetGrammarName("ExtraGlueGrammar");
- grammars.push_back(GrammarPtr(g));
- cerr << "Adding extra glue grammar" << endl;
+ foreach(GrammarPtr& gp, grammars){
+ if (gp->GetCTFLevels() != ctf_iterations_){
+ cerr << "Grammar " << gp->GetGrammarName() << " has CTF level " << gp->GetCTFLevels() <<
+ " but overall number of CTF levels is " << ctf_iterations_ << endl <<
+ "Mixing coarse-to-fine grammars of different granularities is not supported" << endl;
+ abort();
+ }
}
- }
+ show_tree_structure_ = conf.count("show_tree_structure");
+ ctf_wide_alpha_ = conf["ctf_beam_widen"].as<double>();
+ ctf_num_widenings_ = conf["ctf_num_widenings"].as<int>();
+ ctf_exhaustive_ = (conf.count("ctf_no_exhaustive") == 0);
+ assert(ctf_wide_alpha_ > 1.0);
+ cerr << "Using coarse-to-fine pruning with " << ctf_iterations_ << " grammar projection(s) and alpha=" << ctf_alpha_ << endl;
+ cerr << " Coarse beam will be widened " << ctf_num_widenings_ << " times by a factor of " << ctf_wide_alpha_ << " if fine parse fails" << endl;
+ }
+ if (!conf.count("scfg_no_hiero_glue_grammar")){
+ GlueGrammar* g = new GlueGrammar(goal, default_nt, ctf_iterations_);
+ g->SetGrammarName("GlueGrammar");
+ grammars.push_back(GrammarPtr(g));
+ cerr << "Adding glue grammar for default nonterminal " << default_nt <<
+ " and goal nonterminal " << goal << endl;
+ }
+ }
const int max_span_limit;
const bool add_pass_through_rules;
const string goal;
const string default_nt;
+ const bool use_ctf_;
+ double ctf_alpha_;
+ double ctf_wide_alpha_;
+ int ctf_num_widenings_;
+ bool ctf_exhaustive_;
+ bool show_tree_structure_;
+ unsigned int ctf_iterations_;
vector<GrammarPtr> grammars;
bool Translate(const string& input,
@@ -64,29 +100,155 @@ struct SCFGTranslatorImpl {
LatticeTools::ConvertTextOrPLF(input, &lattice);
smeta->SetSourceLength(lattice.size());
if (add_pass_through_rules){
- PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt);
+ cerr << "Adding pass through grammar" << endl;
+ PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_);
g->SetGrammarName("PassThrough");
glist.push_back(GrammarPtr(g));
- cerr << "Adding pass through grammar" << endl;
}
-
-
-
- if(printGrammarsUsed){ //Iterate trough grammars we have for this sentence and list them
- for (int gi = 0; gi < glist.size(); ++gi)
- {
- cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl;
- }
+ for (int gi = 0; gi < glist.size(); ++gi) {
+ if(printGrammarsUsed)
+ cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl;
}
-
+ cerr << "First pass parse... " << endl;
ExhaustiveBottomUpParser parser(goal, glist);
- if (!parser.Parse(lattice, forest))
+ if (!parser.Parse(lattice, forest)){
+ cerr << "parse failed." << endl;
return false;
+ } else {
+ cerr << "parse succeeded." << endl;
+ }
forest->Reweight(weights);
+ if (use_ctf_) {
+ Hypergraph::Node& goal_node = *(forest->nodes_.end()-1);
+ foreach(int edge_id, goal_node.in_edges_)
+ RefineRule(forest->edges_[edge_id].rule_, ctf_iterations_);
+ double alpha = ctf_alpha_;
+ bool found_parse;
+ for (int i=-1; i < ctf_num_widenings_; ++i) {
+ cerr << "Coarse-to-fine source parse, alpha=" << alpha << endl;
+ found_parse = true;
+ Hypergraph refined_forest = *forest;
+ for (int j=0; j < ctf_iterations_; ++j) {
+ cerr << viterbi_stats(refined_forest," Coarse forest",true,show_tree_structure_);
+ cerr << " Iteration " << (j+1) << ": Pruning forest... ";
+ refined_forest.BeamPruneInsideOutside(1.0, false, alpha, NULL);
+ cerr << "Refining forest...";
+ if (RefineForest(&refined_forest)) {
+ cerr << " Refinement succeeded." << endl;
+ refined_forest.Reweight(weights);
+ } else {
+ cerr << " Refinement failed. Widening beam." << endl;
+ found_parse = false;
+ break;
+ }
+ }
+ if (found_parse) {
+ forest->swap(refined_forest);
+ break;
+ }
+ alpha *= ctf_wide_alpha_;
+ }
+ if (!found_parse){
+ if (ctf_exhaustive_){
+ cerr << "Last resort: refining coarse forest without pruning...";
+ for (int j=0; j < ctf_iterations_; ++j) {
+ if (RefineForest(forest)){
+ cerr << " Refinement succeeded." << endl;
+ forest->Reweight(weights);
+ } else {
+ cerr << " Refinement failed. No parse found for this sentence." << endl;
+ return false;
+ }
+ }
+ } else
+ return false;
+ }
+ }
return true;
}
+
+ typedef std::pair<int, WordID> StateSplit;
+ typedef std::pair<StateSplit, int> StateSplitPair;
+ typedef std::tr1::unordered_map<StateSplit, int, boost::hash<StateSplit> > Split2Node;
+ typedef std::tr1::unordered_map<int, vector<WordID> > Splits;
+
+ bool RefineForest(Hypergraph* forest) {
+ Hypergraph refined_forest;
+ Split2Node s2n;
+ Splits splits;
+ Hypergraph::Node& coarse_goal_node = *(forest->nodes_.end()-1);
+ bool refined_goal_node = false;
+ foreach(Hypergraph::Node& node, forest->nodes_){
+ cerr << ".";
+ foreach(int edge_id, node.in_edges_) {
+ Hypergraph::Edge& edge = forest->edges_[edge_id];
+ std::vector<int> nt_positions;
+ TRulePtr& coarse_rule_ptr = edge.rule_;
+ for(int i=0; i< coarse_rule_ptr->f_.size(); ++i){
+ if (coarse_rule_ptr->f_[i] < 0)
+ nt_positions.push_back(i);
+ }
+ if (coarse_rule_ptr->fine_rules_ == 0) {
+ cerr << "Parsing with mixed levels of coarse-to-fine granularity is currently unsupported." <<
+ endl << "Could not find refinement for: " << coarse_rule_ptr->AsString() << " on edge " << edge_id << " spanning " << edge.i_ << "," << edge.j_ << endl;
+ abort();
+ }
+ // fine rules apply only if state splits on tail nodes match fine rule nonterminals
+ foreach(TRulePtr& fine_rule_ptr, *(coarse_rule_ptr->fine_rules_)) {
+ Hypergraph::TailNodeVector tail;
+ for (int pos_i=0; pos_i<nt_positions.size(); ++pos_i){
+ WordID fine_cat = fine_rule_ptr->f_[nt_positions[pos_i]];
+ Split2Node::iterator it =
+ s2n.find(StateSplit(edge.tail_nodes_[pos_i], fine_cat));
+ if (it == s2n.end())
+ break;
+ else
+ tail.push_back(it->second);
+ }
+ if (tail.size() == nt_positions.size()) {
+ WordID cat = fine_rule_ptr->lhs_;
+ Hypergraph::Edge* new_edge = refined_forest.AddEdge(fine_rule_ptr, tail);
+ new_edge->i_ = edge.i_;
+ new_edge->j_ = edge.j_;
+ new_edge->feature_values_ = fine_rule_ptr->GetFeatureValues();
+ new_edge->feature_values_.set_value(FD::Convert("LatticeCost"),
+ edge.feature_values_[FD::Convert("LatticeCost")]);
+ Hypergraph::Node* head_node;
+ Split2Node::iterator it = s2n.find(StateSplit(node.id_, cat));
+ if (it == s2n.end()){
+ head_node = refined_forest.AddNode(cat);
+ s2n.insert(StateSplitPair(StateSplit(node.id_, cat), head_node->id_));
+ splits[node.id_].push_back(cat);
+ if (&node == &coarse_goal_node)
+ refined_goal_node = true;
+ } else
+ head_node = &(refined_forest.nodes_[it->second]);
+ refined_forest.ConnectEdgeToHeadNode(new_edge, head_node);
+ }
+ }
+ }
+ }
+ cerr << endl;
+ forest->swap(refined_forest);
+ if (!refined_goal_node)
+ return false;
+ return true;
+ }
+ void OutputForest(Hypergraph* h) {
+ foreach(Hypergraph::Node& n, h->nodes_){
+ if (n.in_edges_.size() == 0){
+ cerr << "<" << TD::Convert(-n.cat_) << ", ?, ?>" << endl;
+ } else {
+ cerr << "<" << TD::Convert(-n.cat_) << ", " << h->edges_[n.in_edges_[0]].i_ << ", " << h->edges_[n.in_edges_[0]].j_ << ">" << endl;
+ }
+ foreach(int edge_id, n.in_edges_){
+ cerr << " " << h->edges_[edge_id].rule_->AsString() << endl;
+ }
+ }
+ }
};
+
/*
Called once from cdec.cc to setup the initial SCFG translation structure backend
*/
diff --git a/decoder/trule.h b/decoder/trule.h
index defdbeb9..3bc96165 100644
--- a/decoder/trule.h
+++ b/decoder/trule.h
@@ -34,6 +34,9 @@ class TRule {
TRule(const std::vector<WordID>& e, const std::vector<WordID>& f, const WordID& lhs) :
e_(e), f_(f), lhs_(lhs), prev_i(-1), prev_j(-1) {}
+ TRule(const TRule& other) :
+ e_(other.e_), f_(other.f_), lhs_(other.lhs_), scores_(other.scores_), arity_(other.arity_), prev_i(-1), prev_j(-1) {}
+
// deprecated - this will be private soon
explicit TRule(const std::string& text, bool strict = false, bool mono = false) : prev_i(-1), prev_j(-1) {
ReadFromString(text, strict, mono);
@@ -130,6 +133,8 @@ class TRule {
SparseVector<double> scores_;
char arity_;
+
+ // these attributes are application-specific and should probably be refactored
TRulePtr parent_rule_; // usually NULL, except when doing constrained decoding
// this is only used when doing synchronous parsing
@@ -139,6 +144,9 @@ class TRule {
// may be null
boost::shared_ptr<NTSizeSummaryStatistics> nt_size_summary_;
+ // only for coarse-to-fine decoding
+ boost::shared_ptr<std::vector<TRulePtr> > fine_rules_;
+
private:
TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {}
bool SanityCheck() const;
diff --git a/extools/coarsen_grammar.pl b/extools/coarsen_grammar.pl
new file mode 100755
index 00000000..f2dd6689
--- /dev/null
+++ b/extools/coarsen_grammar.pl
@@ -0,0 +1,133 @@
+#!/usr/bin/perl
+
+# dumb grammar coarsener that maps every nonterminal to X (except S).
+
+use strict;
+
+unless (@ARGV > 1){
+ die "Usage: $0 <weight file> <grammar file> [<grammar file> ... <grammar file>] \n";
+}
+my $weight_file = shift @ARGV;
+
+$ENV{"LC_ALL"} = "C";
+local(*GRAMMAR, *OUT_GRAMMAR, *WEIGHTS);
+
+my %weights;
+unless (open(WEIGHTS, $weight_file)) {die "Could not open weight file $weight_file\n" }
+while (<WEIGHTS>){
+ if (/(.+) (.+)$/){
+ $weights{$1} = $2;
+ }
+}
+close(WEIGHTS);
+unless (keys(%weights)){
+ die "Could not find any PhraseModel features in weight file (perhaps you specified the wrong file?)\n\n".
+ "Usage: $0 <weight file> <grammar file> [<grammar file> ... <grammar file>] \n";
+}
+
+sub cleanup_and_die;
+$SIG{INT} = "cleanup_and_die";
+$SIG{TERM} = "cleanup_and_die";
+$SIG{HUP} = "cleanup_and_die";
+
+open(OUT_GRAMMAR, ">grammar.tmp");
+while (my $grammar_file = shift @ARGV){
+ unless (open(GRAMMAR, $grammar_file)) {die "Could not open grammar file $grammar_file\n"}
+ while (<GRAMMAR>){
+ if (/^((.*\|{3}){3})(.*)$/){
+ my $rule = $1;
+ my $rest = $3;
+ my $coarse_rule = $rule;
+ $coarse_rule =~ s/\[X[^\],]*/[X/g;
+ print OUT_GRAMMAR "$coarse_rule $rule $rest\n";
+ } else {
+ die "Unrecognized rule format: $_\n";
+ }
+ }
+ close(GRAMMAR);
+}
+close(OUT_GRAMMAR);
+
+`sort grammar.tmp > grammar.tmp.sorted`;
+sub dump_rules;
+sub compute_score;
+unless (open(GRAMMAR, "grammar.tmp.sorted")){ die "Something went wrong; could not open intermediate file grammar.tmp.sorted\n"};
+my $prev_coarse_rule = "";
+my $best_features = "";
+my $best_score = 0;
+my @rules = ();
+while (<GRAMMAR>){
+ if (/^\s*((\S.*\|{3}\s*){3})((\S.*\|{3}\s*){3})(.*)$/){
+ my $coarse_rule = $1;
+ my $fine_rule = $3;
+ my $features = $5; # This code does not correctly handle rules with other info (e.g. alignments)
+ if ($coarse_rule eq $prev_coarse_rule){
+ my $score = compute_score($features, %weights);
+ if ($score > $best_score){
+ $best_score = $score;
+ $best_features = $features;
+ }
+ } else {
+ dump_rules($prev_coarse_rule, $best_features, @rules);
+ $prev_coarse_rule = $coarse_rule;
+ $best_features = $features;
+ $best_score = compute_score($features, %weights);
+ @rules = ();
+ }
+ push(@rules, "$fine_rule$features\n");
+ } else {
+ die "Something went wrong during grammar projection: $_\n";
+ }
+}
+dump_rules($prev_coarse_rule, $best_features, @rules);
+close(GRAMMAR);
+cleanup();
+
+sub compute_score {
+ my($features, %weights) = @_;
+ my $score = 0;
+ if ($features =~ s/^\s*(\S.*\S)\s*$/$1/) {
+ my @features = split(/\s+/, $features);
+ my $pm=0;
+ for my $feature (@features) {
+ my $feature_name;
+ my $feature_val;
+ if ($feature =~ /(.*)=(.*)/){
+ $feature_name = $1;
+ $feature_val= $2;
+ } else {
+ $feature_name = "PhraseModel_" . $pm;
+ $feature_val= $feature;
+ }
+ $pm++;
+ if ($weights{$feature_name}){
+ $score += $weights{$feature_name} * $feature_val;
+ }
+ }
+ } else {
+ die "Unexpected feature value format: $features\n";
+ }
+ return $score;
+}
+
+sub dump_rules {
+ my($coarse_rule, $coarse_rule_scores, @fine_rules) = @_;
+ unless($coarse_rule){ return; }
+ print "$coarse_rule $coarse_rule_scores\n";
+ for my $rule (@fine_rules){
+ print "\t$rule";
+ }
+}
+
+sub cleanup_and_die {
+ cleanup();
+ die "\n";
+}
+
+sub cleanup {
+ `rm -rf grammar.tmp grammar.tmp.sorted`;
+}
+
+
+
+