diff options
-rw-r--r-- | .gitignore | 4 | ||||
-rw-r--r-- | decoder/cdec.cc | 4 | ||||
-rw-r--r-- | decoder/grammar.cc | 37 | ||||
-rw-r--r-- | decoder/grammar.h | 13 | ||||
-rw-r--r-- | decoder/rule_lexer.h | 2 | ||||
-rw-r--r-- | decoder/rule_lexer.l | 48 | ||||
-rw-r--r-- | decoder/scfg_translator.cc | 242 | ||||
-rw-r--r-- | decoder/trule.h | 8 | ||||
-rwxr-xr-x | extools/coarsen_grammar.pl | 133 |
9 files changed, 428 insertions, 63 deletions
@@ -1,4 +1,8 @@ *swp +*.o +vest/sentserver +vest/sentclient +gi/pyp-topics/src/contexts_lexer.cc config.guess config.sub libtool diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 09760f6b..b6cc6f66 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -125,6 +125,10 @@ void InitCommandLine(int argc, char** argv, po::variables_map* confp) { ("prelm_density_prune", po::value<double>(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") ("density_prune", po::value<double>(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") ("prelm_beam_prune", po::value<double>(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)") + ("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") + ("ctf_beam_widen", po::value<double>()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") + ("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse") + ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails") ("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)") ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") ("promise_power",po::value<double>()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams. 0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit)") diff --git a/decoder/grammar.cc b/decoder/grammar.cc index 499e79fe..26efaf99 100644 --- a/decoder/grammar.cc +++ b/decoder/grammar.cc @@ -81,8 +81,14 @@ const GrammarIter* TextGrammar::GetRoot() const { return &pimpl_->root_; } -void TextGrammar::AddRule(const TRulePtr& rule) { - if (rule->IsUnary()) { +void TextGrammar::AddRule(const TRulePtr& rule, const unsigned int ctf_level, const TRulePtr& coarse_rule) { + if (ctf_level > 0) { + // assume that coarse_rule is already in tree (would be safer to check) + if (coarse_rule->fine_rules_ == 0) + coarse_rule->fine_rules_.reset(new std::vector<TRulePtr>()); + coarse_rule->fine_rules_->push_back(rule); + ctf_levels_ = std::max(ctf_levels_, ctf_level); + } else if (rule->IsUnary()) { rhs2unaries_[rule->f().front()].push_back(rule); unaries_.push_back(rule); } else { @@ -95,8 +101,8 @@ void TextGrammar::AddRule(const TRulePtr& rule) { } } -static void AddRuleHelper(const TRulePtr& new_rule, void* extra) { - static_cast<TextGrammar*>(extra)->AddRule(new_rule); +static void AddRuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) { + static_cast<TextGrammar*>(extra)->AddRule(new_rule, ctf_level, coarse_rule); } void TextGrammar::ReadFromFile(const string& filename) { @@ -110,22 +116,29 @@ bool TextGrammar::HasRuleForSpan(int /* i */, int /* j */, int distance) const { GlueGrammar::GlueGrammar(const string& file) : TextGrammar(file) {} -GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt) { - TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]")); - TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] [" - + default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1")); +void RefineRule(TRulePtr pt, const unsigned int ctf_level){ + for (unsigned int i=0; i<ctf_level; ++i){ + TRulePtr r(new TRule(*pt)); + pt->fine_rules_.reset(new vector<TRulePtr>); + pt->fine_rules_->push_back(r); + pt = r; + } +} +GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt, const unsigned int ctf_level) { + TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]")); AddRule(stop_glue); + RefineRule(stop_glue, ctf_level); + TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["+ default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1")); AddRule(glue); - //cerr << "GLUE: " << stop_glue->AsString() << endl; - //cerr << "GLUE: " << glue->AsString() << endl; + RefineRule(glue, ctf_level); } bool GlueGrammar::HasRuleForSpan(int i, int /* j */, int /* distance */) const { return (i == 0); } -PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat) : +PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level) : has_rule_(input.size() + 1) { for (int i = 0; i < input.size(); ++i) { const vector<LatticeArc>& alts = input[i]; @@ -135,7 +148,7 @@ PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat) const string& src = TD::Convert(alts[k].label); TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1")); AddRule(pt); -// cerr << "PT: " << pt->AsString() << endl; + RefineRule(pt, ctf_level); } } } diff --git a/decoder/grammar.h b/decoder/grammar.h index 46886d3a..b26eb912 100644 --- a/decoder/grammar.h +++ b/decoder/grammar.h @@ -1,6 +1,7 @@ #ifndef GRAMMAR_H_ #define GRAMMAR_H_ +#include <algorithm> #include <vector> #include <map> #include <set> @@ -26,11 +27,13 @@ struct GrammarIter { struct Grammar { typedef std::map<WordID, std::vector<TRulePtr> > Cat2Rules; static const std::vector<TRulePtr> NO_RULES; - + + Grammar(): ctf_levels_(0) {} virtual ~Grammar(); virtual const GrammarIter* GetRoot() const = 0; virtual bool HasRuleForSpan(int i, int j, int distance) const; const std::string GetGrammarName(){return grammar_name_;} + unsigned int GetCTFLevels(){ return ctf_levels_; } void SetGrammarName(std::string n) {grammar_name_ = n; } // cat is the category to be rewritten inline const std::vector<TRulePtr>& GetAllUnaryRules() const { @@ -50,6 +53,7 @@ struct Grammar { Cat2Rules rhs2unaries_; // these must be filled in by subclasses! std::vector<TRulePtr> unaries_; std::string grammar_name_; + unsigned int ctf_levels_; }; typedef boost::shared_ptr<Grammar> GrammarPtr; @@ -61,7 +65,7 @@ struct TextGrammar : public Grammar { void SetMaxSpan(int m) { max_span_ = m; } virtual const GrammarIter* GetRoot() const; - void AddRule(const TRulePtr& rule); + void AddRule(const TRulePtr& rule, const unsigned int ctf_level=0, const TRulePtr& coarse_parent=TRulePtr()); void ReadFromFile(const std::string& filename); virtual bool HasRuleForSpan(int i, int j, int distance) const; const std::vector<TRulePtr>& GetUnaryRules(const WordID& cat) const; @@ -75,15 +79,16 @@ struct TextGrammar : public Grammar { struct GlueGrammar : public TextGrammar { // read glue grammar from file explicit GlueGrammar(const std::string& file); - GlueGrammar(const std::string& goal_nt, const std::string& default_nt); // "S", "X" + GlueGrammar(const std::string& goal_nt, const std::string& default_nt, const unsigned int ctf_level=0); // "S", "X" virtual bool HasRuleForSpan(int i, int j, int distance) const; }; struct PassThroughGrammar : public TextGrammar { - PassThroughGrammar(const Lattice& input, const std::string& cat); + PassThroughGrammar(const Lattice& input, const std::string& cat, const unsigned int ctf_level=0); virtual bool HasRuleForSpan(int i, int j, int distance) const; private: std::vector<std::set<int> > has_rule_; // index by [i][j] }; +void RefineRule(TRulePtr pt, const unsigned int ctf_level); #endif diff --git a/decoder/rule_lexer.h b/decoder/rule_lexer.h index e5db4018..976ea02b 100644 --- a/decoder/rule_lexer.h +++ b/decoder/rule_lexer.h @@ -6,7 +6,7 @@ #include "trule.h" struct RuleLexer { - typedef void (*RuleCallback)(const TRulePtr& new_rule, void* extra); + typedef void (*RuleCallback)(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra); static void ReadRules(std::istream* in, RuleCallback func, void* extra); }; diff --git a/decoder/rule_lexer.l b/decoder/rule_lexer.l index e2acd752..0216b119 100644 --- a/decoder/rule_lexer.l +++ b/decoder/rule_lexer.l @@ -6,6 +6,7 @@ #include <sstream> #include <cstring> #include <cassert> +#include <stack> #include "tdict.h" #include "fdict.h" #include "trule.h" @@ -45,7 +46,8 @@ int scfglex_nt_sanity[MAX_ARITY]; int scfglex_src_nts[MAX_ARITY]; float scfglex_nt_size_means[MAX_ARITY]; float scfglex_nt_size_vars[MAX_ARITY]; - +std::stack<TRulePtr> ctf_rule_stack; +unsigned int ctf_level = 0; void sanity_check_trg_symbol(WordID nt, int index) { if (scfglex_src_nts[index-1] != nt) { @@ -77,6 +79,34 @@ void scfglex_reset() { scfglex_trg_rhs_size = 0; } +void check_and_update_ctf_stack(const TRulePtr& rp) { + if (ctf_level > ctf_rule_stack.size()){ + std::cerr << "Found rule at projection level " << ctf_level << " but previous rule was at level " + << ctf_rule_stack.size()-1 << " (cannot exceed previous level by more than one; line " << lex_line << ")" << std::endl; + abort(); + } + while (ctf_rule_stack.size() > ctf_level) + ctf_rule_stack.pop(); + // ensure that rule has the same signature as parent (coarse) rule. Rules may *only* + // differ by the rhs nonterminals, not terminals or permutation of nonterminals. + if (ctf_rule_stack.size() > 0) { + TRulePtr& coarse_rp = ctf_rule_stack.top(); + if (rp->f_.size() != coarse_rp->f_.size() || rp->e_ != coarse_rp->e_) { + std::cerr << "Rule " << (rp->AsString()) << " is not a projection of " << + (coarse_rp->AsString()) << std::endl; + abort(); + } + for (int i=0; i<rp->f_.size(); ++i) { + if (((rp->f_[i]<0) != (coarse_rp->f_[i]<0)) || + ((rp->f_[i]>0) && (rp->f_[i] != coarse_rp->f_[i]))) { + std::cerr << "Rule " << (rp->AsString()) << " is not a projection of " << + (coarse_rp->AsString()) << std::endl; + abort(); + } + } + } +} + %} REAL [\-+]?[0-9]+(\.[0-9]*([eE][-+]*[0-9]+)?)?|inf|[\-+]inf @@ -85,7 +115,9 @@ NT [\-#$A-Z_:=.",\\][\-#$".A-Z+/=_0-9!:@\\]* %x LHS_END SRC TRG FEATS FEATVAL ALIGNS %% -<INITIAL>[ \t] ; +<INITIAL>[ \t] { + ctf_level++; + }; <INITIAL>\[{NT}\] { scfglex_tmp_token.assign(yytext + 1, yyleng - 2); @@ -182,12 +214,16 @@ NT [\-#$A-Z_:=.",\\][\-#$".A-Z+/=_0-9!:@\\]* abort(); } TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity)); - rule_callback(rp, rule_callback_extra); + check_and_update_ctf_stack(rp); + TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top()); + rule_callback(rp, ctf_level, coarse_rp, rule_callback_extra); + ctf_rule_stack.push(rp); // std::cerr << rp->AsString() << std::endl; num_rules++; - lex_line++; - if (num_rules % 50000 == 0) { std::cerr << '.' << std::flush; fl = true; } - if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; } + lex_line++; + if (num_rules % 50000 == 0) { std::cerr << '.' << std::flush; fl = true; } + if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; } + ctf_level = 0; BEGIN(INITIAL); } diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 866c2721..32acfd65 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -3,13 +3,20 @@ //TODO: grammar heuristic (min cost of reachable rule set) for binarizations (active edges) if we wish to prune those also #include "translator.h" - +#include <algorithm> #include <vector> - +#include <tr1/unordered_map> +#include <boost/foreach.hpp> +#include <boost/functional/hash.hpp> #include "hg.h" #include "grammar.h" #include "bottom_up_parser.h" #include "sentence_metadata.h" +#include "tdict.h" +#include "viterbi.h" + +#define foreach BOOST_FOREACH +#define reverse_foreach BOOST_REVERSE_FOREACH using namespace std; static bool usingSentenceGrammar = false; @@ -20,39 +27,68 @@ struct SCFGTranslatorImpl { max_span_limit(conf["scfg_max_span_limit"].as<int>()), add_pass_through_rules(conf.count("add_pass_through_rules")), goal(conf["goal"].as<string>()), - default_nt(conf["scfg_default_nt"].as<string>()) { - if(conf.count("grammar")) - { - vector<string> gfiles = conf["grammar"].as<vector<string> >(); - for (int i = 0; i < gfiles.size(); ++i) { - cerr << "Reading SCFG grammar from " << gfiles[i] << endl; - TextGrammar* g = new TextGrammar(gfiles[i]); - g->SetMaxSpan(max_span_limit); - g->SetGrammarName(gfiles[i]); - grammars.push_back(GrammarPtr(g)); - - } - } - if (!conf.count("scfg_no_hiero_glue_grammar")) - { - GlueGrammar* g = new GlueGrammar(goal, default_nt); - g->SetGrammarName("GlueGrammar"); - grammars.push_back(GrammarPtr(g)); - cerr << "Adding glue grammar" << endl; + default_nt(conf["scfg_default_nt"].as<string>()), + use_ctf_(conf.count("coarse_to_fine_beam_prune")) + { + if(conf.count("grammar")){ + vector<string> gfiles = conf["grammar"].as<vector<string> >(); + for (int i = 0; i < gfiles.size(); ++i) { + cerr << "Reading SCFG grammar from " << gfiles[i] << endl; + TextGrammar* g = new TextGrammar(gfiles[i]); + g->SetMaxSpan(max_span_limit); + g->SetGrammarName(gfiles[i]); + grammars.push_back(GrammarPtr(g)); + } + } + cerr << std::endl; + if (conf.count("scfg_extra_glue_grammar")) { + GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>()); + g->SetGrammarName("ExtraGlueGrammar"); + grammars.push_back(GrammarPtr(g)); + cerr << "Adding glue grammar from file " << conf["scfg_extra_glue_grammar"].as<string>() << endl; + } + ctf_iterations_=0; + if (use_ctf_){ + ctf_alpha_ = conf["coarse_to_fine_beam_prune"].as<double>(); + foreach(GrammarPtr& gp, grammars){ + ctf_iterations_ = std::max(gp->GetCTFLevels(), ctf_iterations_); } - if (conf.count("scfg_extra_glue_grammar")) - { - GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>()); - g->SetGrammarName("ExtraGlueGrammar"); - grammars.push_back(GrammarPtr(g)); - cerr << "Adding extra glue grammar" << endl; + foreach(GrammarPtr& gp, grammars){ + if (gp->GetCTFLevels() != ctf_iterations_){ + cerr << "Grammar " << gp->GetGrammarName() << " has CTF level " << gp->GetCTFLevels() << + " but overall number of CTF levels is " << ctf_iterations_ << endl << + "Mixing coarse-to-fine grammars of different granularities is not supported" << endl; + abort(); + } } - } + show_tree_structure_ = conf.count("show_tree_structure"); + ctf_wide_alpha_ = conf["ctf_beam_widen"].as<double>(); + ctf_num_widenings_ = conf["ctf_num_widenings"].as<int>(); + ctf_exhaustive_ = (conf.count("ctf_no_exhaustive") == 0); + assert(ctf_wide_alpha_ > 1.0); + cerr << "Using coarse-to-fine pruning with " << ctf_iterations_ << " grammar projection(s) and alpha=" << ctf_alpha_ << endl; + cerr << " Coarse beam will be widened " << ctf_num_widenings_ << " times by a factor of " << ctf_wide_alpha_ << " if fine parse fails" << endl; + } + if (!conf.count("scfg_no_hiero_glue_grammar")){ + GlueGrammar* g = new GlueGrammar(goal, default_nt, ctf_iterations_); + g->SetGrammarName("GlueGrammar"); + grammars.push_back(GrammarPtr(g)); + cerr << "Adding glue grammar for default nonterminal " << default_nt << + " and goal nonterminal " << goal << endl; + } + } const int max_span_limit; const bool add_pass_through_rules; const string goal; const string default_nt; + const bool use_ctf_; + double ctf_alpha_; + double ctf_wide_alpha_; + int ctf_num_widenings_; + bool ctf_exhaustive_; + bool show_tree_structure_; + unsigned int ctf_iterations_; vector<GrammarPtr> grammars; bool Translate(const string& input, @@ -64,29 +100,155 @@ struct SCFGTranslatorImpl { LatticeTools::ConvertTextOrPLF(input, &lattice); smeta->SetSourceLength(lattice.size()); if (add_pass_through_rules){ - PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt); + cerr << "Adding pass through grammar" << endl; + PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_); g->SetGrammarName("PassThrough"); glist.push_back(GrammarPtr(g)); - cerr << "Adding pass through grammar" << endl; } - - - - if(printGrammarsUsed){ //Iterate trough grammars we have for this sentence and list them - for (int gi = 0; gi < glist.size(); ++gi) - { - cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl; - } + for (int gi = 0; gi < glist.size(); ++gi) { + if(printGrammarsUsed) + cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl; } - + cerr << "First pass parse... " << endl; ExhaustiveBottomUpParser parser(goal, glist); - if (!parser.Parse(lattice, forest)) + if (!parser.Parse(lattice, forest)){ + cerr << "parse failed." << endl; return false; + } else { + cerr << "parse succeeded." << endl; + } forest->Reweight(weights); + if (use_ctf_) { + Hypergraph::Node& goal_node = *(forest->nodes_.end()-1); + foreach(int edge_id, goal_node.in_edges_) + RefineRule(forest->edges_[edge_id].rule_, ctf_iterations_); + double alpha = ctf_alpha_; + bool found_parse; + for (int i=-1; i < ctf_num_widenings_; ++i) { + cerr << "Coarse-to-fine source parse, alpha=" << alpha << endl; + found_parse = true; + Hypergraph refined_forest = *forest; + for (int j=0; j < ctf_iterations_; ++j) { + cerr << viterbi_stats(refined_forest," Coarse forest",true,show_tree_structure_); + cerr << " Iteration " << (j+1) << ": Pruning forest... "; + refined_forest.BeamPruneInsideOutside(1.0, false, alpha, NULL); + cerr << "Refining forest..."; + if (RefineForest(&refined_forest)) { + cerr << " Refinement succeeded." << endl; + refined_forest.Reweight(weights); + } else { + cerr << " Refinement failed. Widening beam." << endl; + found_parse = false; + break; + } + } + if (found_parse) { + forest->swap(refined_forest); + break; + } + alpha *= ctf_wide_alpha_; + } + if (!found_parse){ + if (ctf_exhaustive_){ + cerr << "Last resort: refining coarse forest without pruning..."; + for (int j=0; j < ctf_iterations_; ++j) { + if (RefineForest(forest)){ + cerr << " Refinement succeeded." << endl; + forest->Reweight(weights); + } else { + cerr << " Refinement failed. No parse found for this sentence." << endl; + return false; + } + } + } else + return false; + } + } return true; } + + typedef std::pair<int, WordID> StateSplit; + typedef std::pair<StateSplit, int> StateSplitPair; + typedef std::tr1::unordered_map<StateSplit, int, boost::hash<StateSplit> > Split2Node; + typedef std::tr1::unordered_map<int, vector<WordID> > Splits; + + bool RefineForest(Hypergraph* forest) { + Hypergraph refined_forest; + Split2Node s2n; + Splits splits; + Hypergraph::Node& coarse_goal_node = *(forest->nodes_.end()-1); + bool refined_goal_node = false; + foreach(Hypergraph::Node& node, forest->nodes_){ + cerr << "."; + foreach(int edge_id, node.in_edges_) { + Hypergraph::Edge& edge = forest->edges_[edge_id]; + std::vector<int> nt_positions; + TRulePtr& coarse_rule_ptr = edge.rule_; + for(int i=0; i< coarse_rule_ptr->f_.size(); ++i){ + if (coarse_rule_ptr->f_[i] < 0) + nt_positions.push_back(i); + } + if (coarse_rule_ptr->fine_rules_ == 0) { + cerr << "Parsing with mixed levels of coarse-to-fine granularity is currently unsupported." << + endl << "Could not find refinement for: " << coarse_rule_ptr->AsString() << " on edge " << edge_id << " spanning " << edge.i_ << "," << edge.j_ << endl; + abort(); + } + // fine rules apply only if state splits on tail nodes match fine rule nonterminals + foreach(TRulePtr& fine_rule_ptr, *(coarse_rule_ptr->fine_rules_)) { + Hypergraph::TailNodeVector tail; + for (int pos_i=0; pos_i<nt_positions.size(); ++pos_i){ + WordID fine_cat = fine_rule_ptr->f_[nt_positions[pos_i]]; + Split2Node::iterator it = + s2n.find(StateSplit(edge.tail_nodes_[pos_i], fine_cat)); + if (it == s2n.end()) + break; + else + tail.push_back(it->second); + } + if (tail.size() == nt_positions.size()) { + WordID cat = fine_rule_ptr->lhs_; + Hypergraph::Edge* new_edge = refined_forest.AddEdge(fine_rule_ptr, tail); + new_edge->i_ = edge.i_; + new_edge->j_ = edge.j_; + new_edge->feature_values_ = fine_rule_ptr->GetFeatureValues(); + new_edge->feature_values_.set_value(FD::Convert("LatticeCost"), + edge.feature_values_[FD::Convert("LatticeCost")]); + Hypergraph::Node* head_node; + Split2Node::iterator it = s2n.find(StateSplit(node.id_, cat)); + if (it == s2n.end()){ + head_node = refined_forest.AddNode(cat); + s2n.insert(StateSplitPair(StateSplit(node.id_, cat), head_node->id_)); + splits[node.id_].push_back(cat); + if (&node == &coarse_goal_node) + refined_goal_node = true; + } else + head_node = &(refined_forest.nodes_[it->second]); + refined_forest.ConnectEdgeToHeadNode(new_edge, head_node); + } + } + } + } + cerr << endl; + forest->swap(refined_forest); + if (!refined_goal_node) + return false; + return true; + } + void OutputForest(Hypergraph* h) { + foreach(Hypergraph::Node& n, h->nodes_){ + if (n.in_edges_.size() == 0){ + cerr << "<" << TD::Convert(-n.cat_) << ", ?, ?>" << endl; + } else { + cerr << "<" << TD::Convert(-n.cat_) << ", " << h->edges_[n.in_edges_[0]].i_ << ", " << h->edges_[n.in_edges_[0]].j_ << ">" << endl; + } + foreach(int edge_id, n.in_edges_){ + cerr << " " << h->edges_[edge_id].rule_->AsString() << endl; + } + } + } }; + /* Called once from cdec.cc to setup the initial SCFG translation structure backend */ diff --git a/decoder/trule.h b/decoder/trule.h index defdbeb9..3bc96165 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -34,6 +34,9 @@ class TRule { TRule(const std::vector<WordID>& e, const std::vector<WordID>& f, const WordID& lhs) : e_(e), f_(f), lhs_(lhs), prev_i(-1), prev_j(-1) {} + TRule(const TRule& other) : + e_(other.e_), f_(other.f_), lhs_(other.lhs_), scores_(other.scores_), arity_(other.arity_), prev_i(-1), prev_j(-1) {} + // deprecated - this will be private soon explicit TRule(const std::string& text, bool strict = false, bool mono = false) : prev_i(-1), prev_j(-1) { ReadFromString(text, strict, mono); @@ -130,6 +133,8 @@ class TRule { SparseVector<double> scores_; char arity_; + + // these attributes are application-specific and should probably be refactored TRulePtr parent_rule_; // usually NULL, except when doing constrained decoding // this is only used when doing synchronous parsing @@ -139,6 +144,9 @@ class TRule { // may be null boost::shared_ptr<NTSizeSummaryStatistics> nt_size_summary_; + // only for coarse-to-fine decoding + boost::shared_ptr<std::vector<TRulePtr> > fine_rules_; + private: TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {} bool SanityCheck() const; diff --git a/extools/coarsen_grammar.pl b/extools/coarsen_grammar.pl new file mode 100755 index 00000000..f2dd6689 --- /dev/null +++ b/extools/coarsen_grammar.pl @@ -0,0 +1,133 @@ +#!/usr/bin/perl + +# dumb grammar coarsener that maps every nonterminal to X (except S). + +use strict; + +unless (@ARGV > 1){ + die "Usage: $0 <weight file> <grammar file> [<grammar file> ... <grammar file>] \n"; +} +my $weight_file = shift @ARGV; + +$ENV{"LC_ALL"} = "C"; +local(*GRAMMAR, *OUT_GRAMMAR, *WEIGHTS); + +my %weights; +unless (open(WEIGHTS, $weight_file)) {die "Could not open weight file $weight_file\n" } +while (<WEIGHTS>){ + if (/(.+) (.+)$/){ + $weights{$1} = $2; + } +} +close(WEIGHTS); +unless (keys(%weights)){ + die "Could not find any PhraseModel features in weight file (perhaps you specified the wrong file?)\n\n". + "Usage: $0 <weight file> <grammar file> [<grammar file> ... <grammar file>] \n"; +} + +sub cleanup_and_die; +$SIG{INT} = "cleanup_and_die"; +$SIG{TERM} = "cleanup_and_die"; +$SIG{HUP} = "cleanup_and_die"; + +open(OUT_GRAMMAR, ">grammar.tmp"); +while (my $grammar_file = shift @ARGV){ + unless (open(GRAMMAR, $grammar_file)) {die "Could not open grammar file $grammar_file\n"} + while (<GRAMMAR>){ + if (/^((.*\|{3}){3})(.*)$/){ + my $rule = $1; + my $rest = $3; + my $coarse_rule = $rule; + $coarse_rule =~ s/\[X[^\],]*/[X/g; + print OUT_GRAMMAR "$coarse_rule $rule $rest\n"; + } else { + die "Unrecognized rule format: $_\n"; + } + } + close(GRAMMAR); +} +close(OUT_GRAMMAR); + +`sort grammar.tmp > grammar.tmp.sorted`; +sub dump_rules; +sub compute_score; +unless (open(GRAMMAR, "grammar.tmp.sorted")){ die "Something went wrong; could not open intermediate file grammar.tmp.sorted\n"}; +my $prev_coarse_rule = ""; +my $best_features = ""; +my $best_score = 0; +my @rules = (); +while (<GRAMMAR>){ + if (/^\s*((\S.*\|{3}\s*){3})((\S.*\|{3}\s*){3})(.*)$/){ + my $coarse_rule = $1; + my $fine_rule = $3; + my $features = $5; # This code does not correctly handle rules with other info (e.g. alignments) + if ($coarse_rule eq $prev_coarse_rule){ + my $score = compute_score($features, %weights); + if ($score > $best_score){ + $best_score = $score; + $best_features = $features; + } + } else { + dump_rules($prev_coarse_rule, $best_features, @rules); + $prev_coarse_rule = $coarse_rule; + $best_features = $features; + $best_score = compute_score($features, %weights); + @rules = (); + } + push(@rules, "$fine_rule$features\n"); + } else { + die "Something went wrong during grammar projection: $_\n"; + } +} +dump_rules($prev_coarse_rule, $best_features, @rules); +close(GRAMMAR); +cleanup(); + +sub compute_score { + my($features, %weights) = @_; + my $score = 0; + if ($features =~ s/^\s*(\S.*\S)\s*$/$1/) { + my @features = split(/\s+/, $features); + my $pm=0; + for my $feature (@features) { + my $feature_name; + my $feature_val; + if ($feature =~ /(.*)=(.*)/){ + $feature_name = $1; + $feature_val= $2; + } else { + $feature_name = "PhraseModel_" . $pm; + $feature_val= $feature; + } + $pm++; + if ($weights{$feature_name}){ + $score += $weights{$feature_name} * $feature_val; + } + } + } else { + die "Unexpected feature value format: $features\n"; + } + return $score; +} + +sub dump_rules { + my($coarse_rule, $coarse_rule_scores, @fine_rules) = @_; + unless($coarse_rule){ return; } + print "$coarse_rule $coarse_rule_scores\n"; + for my $rule (@fine_rules){ + print "\t$rule"; + } +} + +sub cleanup_and_die { + cleanup(); + die "\n"; +} + +sub cleanup { + `rm -rf grammar.tmp grammar.tmp.sorted`; +} + + + + |