diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/cdec.cc | 4 | ||||
| -rw-r--r-- | decoder/grammar.cc | 37 | ||||
| -rw-r--r-- | decoder/grammar.h | 13 | ||||
| -rw-r--r-- | decoder/rule_lexer.h | 2 | ||||
| -rw-r--r-- | decoder/rule_lexer.l | 48 | ||||
| -rw-r--r-- | decoder/scfg_translator.cc | 242 | ||||
| -rw-r--r-- | decoder/trule.h | 8 | 
7 files changed, 291 insertions, 63 deletions
| diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 09760f6b..b6cc6f66 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -125,6 +125,10 @@ void InitCommandLine(int argc, char** argv, po::variables_map* confp) {      ("prelm_density_prune", po::value<double>(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")      ("density_prune", po::value<double>(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")          ("prelm_beam_prune", po::value<double>(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)") +        ("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") +        ("ctf_beam_widen", po::value<double>()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") +        ("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse") +        ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails")          ("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)")      ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices")      ("promise_power",po::value<double>()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams.  0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit)") diff --git a/decoder/grammar.cc b/decoder/grammar.cc index 499e79fe..26efaf99 100644 --- a/decoder/grammar.cc +++ b/decoder/grammar.cc @@ -81,8 +81,14 @@ const GrammarIter* TextGrammar::GetRoot() const {    return &pimpl_->root_;  } -void TextGrammar::AddRule(const TRulePtr& rule) { -  if (rule->IsUnary()) { +void TextGrammar::AddRule(const TRulePtr& rule, const unsigned int ctf_level, const TRulePtr& coarse_rule) { +  if (ctf_level > 0) { +    // assume that coarse_rule is already in tree (would be safer to check) +    if (coarse_rule->fine_rules_ == 0)  +      coarse_rule->fine_rules_.reset(new std::vector<TRulePtr>()); +    coarse_rule->fine_rules_->push_back(rule); +    ctf_levels_ = std::max(ctf_levels_, ctf_level); +  } else if (rule->IsUnary()) {      rhs2unaries_[rule->f().front()].push_back(rule);      unaries_.push_back(rule);    } else { @@ -95,8 +101,8 @@ void TextGrammar::AddRule(const TRulePtr& rule) {    }  } -static void AddRuleHelper(const TRulePtr& new_rule, void* extra) { -  static_cast<TextGrammar*>(extra)->AddRule(new_rule); +static void AddRuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) { +  static_cast<TextGrammar*>(extra)->AddRule(new_rule, ctf_level, coarse_rule);  }  void TextGrammar::ReadFromFile(const string& filename) { @@ -110,22 +116,29 @@ bool TextGrammar::HasRuleForSpan(int /* i */, int /* j */, int distance) const {  GlueGrammar::GlueGrammar(const string& file) : TextGrammar(file) {} -GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt) { -  TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]")); -  TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] [" -    + default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1")); +void RefineRule(TRulePtr pt, const unsigned int ctf_level){  +  for (unsigned int i=0; i<ctf_level; ++i){ +    TRulePtr r(new TRule(*pt)); +    pt->fine_rules_.reset(new vector<TRulePtr>); +    pt->fine_rules_->push_back(r); +    pt = r; +  } +} +GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt, const unsigned int ctf_level) { +  TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]"));    AddRule(stop_glue); +  RefineRule(stop_glue, ctf_level); +  TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["+ default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1"));    AddRule(glue); -  //cerr << "GLUE: " << stop_glue->AsString() << endl; -  //cerr << "GLUE: " << glue->AsString() << endl; +  RefineRule(glue, ctf_level);  }  bool GlueGrammar::HasRuleForSpan(int i, int /* j */, int /* distance */) const {    return (i == 0);  } -PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat) : +PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level) :      has_rule_(input.size() + 1) {    for (int i = 0; i < input.size(); ++i) {      const vector<LatticeArc>& alts = input[i]; @@ -135,7 +148,7 @@ PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat)        const string& src = TD::Convert(alts[k].label);        TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1"));        AddRule(pt); -//      cerr << "PT: " << pt->AsString() << endl; +      RefineRule(pt, ctf_level);      }    }  } diff --git a/decoder/grammar.h b/decoder/grammar.h index 46886d3a..b26eb912 100644 --- a/decoder/grammar.h +++ b/decoder/grammar.h @@ -1,6 +1,7 @@  #ifndef GRAMMAR_H_  #define GRAMMAR_H_ +#include <algorithm>  #include <vector>  #include <map>  #include <set> @@ -26,11 +27,13 @@ struct GrammarIter {  struct Grammar {    typedef std::map<WordID, std::vector<TRulePtr> > Cat2Rules;    static const std::vector<TRulePtr> NO_RULES; - +   +  Grammar(): ctf_levels_(0) {}    virtual ~Grammar();    virtual const GrammarIter* GetRoot() const = 0;    virtual bool HasRuleForSpan(int i, int j, int distance) const;    const std::string GetGrammarName(){return grammar_name_;} +  unsigned int GetCTFLevels(){ return ctf_levels_; }    void SetGrammarName(std::string n) {grammar_name_ = n; }    // cat is the category to be rewritten    inline const std::vector<TRulePtr>& GetAllUnaryRules() const { @@ -50,6 +53,7 @@ struct Grammar {    Cat2Rules rhs2unaries_;     // these must be filled in by subclasses!    std::vector<TRulePtr> unaries_;    std::string grammar_name_;  +  unsigned int ctf_levels_;  };  typedef boost::shared_ptr<Grammar> GrammarPtr; @@ -61,7 +65,7 @@ struct TextGrammar : public Grammar {    void SetMaxSpan(int m) { max_span_ = m; }    virtual const GrammarIter* GetRoot() const; -  void AddRule(const TRulePtr& rule); +  void AddRule(const TRulePtr& rule, const unsigned int ctf_level=0, const TRulePtr& coarse_parent=TRulePtr());    void ReadFromFile(const std::string& filename);    virtual bool HasRuleForSpan(int i, int j, int distance) const;    const std::vector<TRulePtr>& GetUnaryRules(const WordID& cat) const; @@ -75,15 +79,16 @@ struct TextGrammar : public Grammar {  struct GlueGrammar : public TextGrammar {    // read glue grammar from file    explicit GlueGrammar(const std::string& file); -  GlueGrammar(const std::string& goal_nt, const std::string& default_nt);  // "S", "X" +  GlueGrammar(const std::string& goal_nt, const std::string& default_nt, const unsigned int ctf_level=0);  // "S", "X"    virtual bool HasRuleForSpan(int i, int j, int distance) const;  };  struct PassThroughGrammar : public TextGrammar { -  PassThroughGrammar(const Lattice& input, const std::string& cat); +  PassThroughGrammar(const Lattice& input, const std::string& cat, const unsigned int ctf_level=0);    virtual bool HasRuleForSpan(int i, int j, int distance) const;   private:    std::vector<std::set<int> > has_rule_;  // index by [i][j]  }; +void RefineRule(TRulePtr pt, const unsigned int ctf_level);  #endif diff --git a/decoder/rule_lexer.h b/decoder/rule_lexer.h index e5db4018..976ea02b 100644 --- a/decoder/rule_lexer.h +++ b/decoder/rule_lexer.h @@ -6,7 +6,7 @@  #include "trule.h"  struct RuleLexer { -  typedef void (*RuleCallback)(const TRulePtr& new_rule, void* extra); +  typedef void (*RuleCallback)(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra);    static void ReadRules(std::istream* in, RuleCallback func, void* extra);  }; diff --git a/decoder/rule_lexer.l b/decoder/rule_lexer.l index e2acd752..0216b119 100644 --- a/decoder/rule_lexer.l +++ b/decoder/rule_lexer.l @@ -6,6 +6,7 @@  #include <sstream>  #include <cstring>  #include <cassert> +#include <stack>  #include "tdict.h"  #include "fdict.h"  #include "trule.h" @@ -45,7 +46,8 @@ int scfglex_nt_sanity[MAX_ARITY];  int scfglex_src_nts[MAX_ARITY];  float scfglex_nt_size_means[MAX_ARITY];  float scfglex_nt_size_vars[MAX_ARITY]; - +std::stack<TRulePtr> ctf_rule_stack; +unsigned int ctf_level = 0;  void sanity_check_trg_symbol(WordID nt, int index) {    if (scfglex_src_nts[index-1] != nt) { @@ -77,6 +79,34 @@ void scfglex_reset() {    scfglex_trg_rhs_size = 0;  } +void check_and_update_ctf_stack(const TRulePtr& rp) { +  if (ctf_level > ctf_rule_stack.size()){ +    std::cerr << "Found rule at projection level " << ctf_level << " but previous rule was at level "  +      << ctf_rule_stack.size()-1 << " (cannot exceed previous level by more than one; line " << lex_line << ")" << std::endl; +    abort(); +  } +  while (ctf_rule_stack.size() > ctf_level) +    ctf_rule_stack.pop(); +  // ensure that rule has the same signature as parent (coarse) rule.  Rules may *only* +  // differ by the rhs nonterminals, not terminals or permutation of nonterminals. +  if (ctf_rule_stack.size() > 0) { +    TRulePtr& coarse_rp = ctf_rule_stack.top(); +    if (rp->f_.size() != coarse_rp->f_.size() || rp->e_ != coarse_rp->e_) { +      std::cerr << "Rule " << (rp->AsString()) << " is not a projection of " << +        (coarse_rp->AsString()) << std::endl; +      abort(); +    } +    for (int i=0; i<rp->f_.size(); ++i) { +      if (((rp->f_[i]<0) != (coarse_rp->f_[i]<0)) || +          ((rp->f_[i]>0) && (rp->f_[i] != coarse_rp->f_[i]))) { +        std::cerr << "Rule " << (rp->AsString()) << " is not a projection of " << +          (coarse_rp->AsString()) << std::endl; +        abort(); +      } +    } +  } +} +  %}  REAL [\-+]?[0-9]+(\.[0-9]*([eE][-+]*[0-9]+)?)?|inf|[\-+]inf @@ -85,7 +115,9 @@ NT [\-#$A-Z_:=.",\\][\-#$".A-Z+/=_0-9!:@\\]*  %x LHS_END SRC TRG FEATS FEATVAL ALIGNS  %% -<INITIAL>[ \t]	; +<INITIAL>[ \t]	{  +  ctf_level++;  +  };  <INITIAL>\[{NT}\]   {  		scfglex_tmp_token.assign(yytext + 1, yyleng - 2); @@ -182,12 +214,16 @@ NT [\-#$A-Z_:=.",\\][\-#$".A-Z+/=_0-9!:@\\]*                    abort();                  }  		TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity)); -		rule_callback(rp, rule_callback_extra); +    check_and_update_ctf_stack(rp); +    TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top()); +		rule_callback(rp, ctf_level, coarse_rp, rule_callback_extra); +    ctf_rule_stack.push(rp);  		// std::cerr << rp->AsString() << std::endl;  		num_rules++; -                lex_line++; -                if (num_rules %   50000 == 0) { std::cerr << '.' << std::flush; fl = true; } -                if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; } +    lex_line++; +    if (num_rules %   50000 == 0) { std::cerr << '.' << std::flush; fl = true; } +    if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; } +    ctf_level = 0;  		BEGIN(INITIAL);  		} diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 866c2721..32acfd65 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -3,13 +3,20 @@  //TODO: grammar heuristic (min cost of reachable rule set) for binarizations (active edges) if we wish to prune those also  #include "translator.h" - +#include <algorithm>  #include <vector> - +#include <tr1/unordered_map> +#include <boost/foreach.hpp> +#include <boost/functional/hash.hpp>  #include "hg.h"  #include "grammar.h"  #include "bottom_up_parser.h"  #include "sentence_metadata.h" +#include "tdict.h" +#include "viterbi.h" + +#define foreach         BOOST_FOREACH +#define reverse_foreach BOOST_REVERSE_FOREACH  using namespace std;  static bool usingSentenceGrammar = false; @@ -20,39 +27,68 @@ struct SCFGTranslatorImpl {        max_span_limit(conf["scfg_max_span_limit"].as<int>()),        add_pass_through_rules(conf.count("add_pass_through_rules")),        goal(conf["goal"].as<string>()), -      default_nt(conf["scfg_default_nt"].as<string>()) { -    if(conf.count("grammar")) -      { -	vector<string> gfiles = conf["grammar"].as<vector<string> >(); -	for (int i = 0; i < gfiles.size(); ++i) { -	  cerr << "Reading SCFG grammar from " << gfiles[i] << endl; -	  TextGrammar* g = new TextGrammar(gfiles[i]); -	  g->SetMaxSpan(max_span_limit); -	  g->SetGrammarName(gfiles[i]); -	  grammars.push_back(GrammarPtr(g)); - -	} -      } -    if (!conf.count("scfg_no_hiero_glue_grammar")) -      { -	GlueGrammar* g = new GlueGrammar(goal, default_nt); -	g->SetGrammarName("GlueGrammar"); -	grammars.push_back(GrammarPtr(g)); -	cerr << "Adding glue grammar" << endl; +      default_nt(conf["scfg_default_nt"].as<string>()),  +      use_ctf_(conf.count("coarse_to_fine_beam_prune")) +  { +    if(conf.count("grammar")){ +      vector<string> gfiles = conf["grammar"].as<vector<string> >(); +      for (int i = 0; i < gfiles.size(); ++i) { +    	  cerr << "Reading SCFG grammar from " << gfiles[i] << endl; +    	  TextGrammar* g = new TextGrammar(gfiles[i]); +    	  g->SetMaxSpan(max_span_limit); +    	  g->SetGrammarName(gfiles[i]); +    	  grammars.push_back(GrammarPtr(g)); +	    } +    } +    cerr << std::endl; +    if (conf.count("scfg_extra_glue_grammar")) { +      GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>()); +      g->SetGrammarName("ExtraGlueGrammar");		 +      grammars.push_back(GrammarPtr(g)); +      cerr << "Adding glue grammar from file " << conf["scfg_extra_glue_grammar"].as<string>() << endl; +    } +    ctf_iterations_=0; +    if (use_ctf_){ +      ctf_alpha_ = conf["coarse_to_fine_beam_prune"].as<double>(); +      foreach(GrammarPtr& gp, grammars){ +        ctf_iterations_ = std::max(gp->GetCTFLevels(), ctf_iterations_);        } -    if (conf.count("scfg_extra_glue_grammar")) -      { -	GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>()); -	g->SetGrammarName("ExtraGlueGrammar"); -	grammars.push_back(GrammarPtr(g)); -	cerr << "Adding extra glue grammar" << endl; +      foreach(GrammarPtr& gp, grammars){ +        if (gp->GetCTFLevels() != ctf_iterations_){ +          cerr << "Grammar " << gp->GetGrammarName() << " has CTF level " << gp->GetCTFLevels() << +            " but overall number of CTF levels is " << ctf_iterations_ << endl << +            "Mixing coarse-to-fine grammars of different granularities is not supported" << endl; +          abort(); +        }        } -  } +      show_tree_structure_ = conf.count("show_tree_structure"); +      ctf_wide_alpha_ = conf["ctf_beam_widen"].as<double>(); +      ctf_num_widenings_ = conf["ctf_num_widenings"].as<int>(); +      ctf_exhaustive_ = (conf.count("ctf_no_exhaustive") == 0); +      assert(ctf_wide_alpha_ > 1.0); +      cerr << "Using coarse-to-fine pruning with " << ctf_iterations_ << " grammar projection(s) and alpha=" << ctf_alpha_ << endl; +      cerr << "  Coarse beam will be widened " << ctf_num_widenings_ << " times by a factor of " << ctf_wide_alpha_ << " if fine parse fails" << endl; +    } +    if (!conf.count("scfg_no_hiero_glue_grammar")){  +      GlueGrammar* g = new GlueGrammar(goal, default_nt, ctf_iterations_); +      g->SetGrammarName("GlueGrammar"); +      grammars.push_back(GrammarPtr(g)); +      cerr << "Adding glue grammar for default nonterminal " << default_nt <<  +        " and goal nonterminal " << goal << endl; +    } + }    const int max_span_limit;    const bool add_pass_through_rules;    const string goal;    const string default_nt; +  const bool use_ctf_; +  double ctf_alpha_; +  double ctf_wide_alpha_; +  int ctf_num_widenings_; +  bool ctf_exhaustive_; +  bool show_tree_structure_; +  unsigned int ctf_iterations_;    vector<GrammarPtr> grammars;    bool Translate(const string& input, @@ -64,29 +100,155 @@ struct SCFGTranslatorImpl {      LatticeTools::ConvertTextOrPLF(input, &lattice);      smeta->SetSourceLength(lattice.size());      if (add_pass_through_rules){ -      PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt); +      cerr << "Adding pass through grammar" << endl; +      PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_);        g->SetGrammarName("PassThrough");        glist.push_back(GrammarPtr(g)); -      cerr << "Adding pass through grammar" << endl;      } - - - -    if(printGrammarsUsed){    //Iterate trough grammars we have for this sentence and list them -      for (int gi = 0; gi < glist.size(); ++gi) -	{ -	  cerr << "Using grammar::" << 	 glist[gi]->GetGrammarName() << endl; -	} +    for (int gi = 0; gi < glist.size(); ++gi) { +      if(printGrammarsUsed) +        cerr << "Using grammar::" << glist[gi]->GetGrammarName() << endl;      } - +    cerr << "First pass parse... " << endl;      ExhaustiveBottomUpParser parser(goal, glist); -    if (!parser.Parse(lattice, forest)) +    if (!parser.Parse(lattice, forest)){ +      cerr << "parse failed." << endl;        return false; +    } else { +      cerr << "parse succeeded." << endl; +    }      forest->Reweight(weights); +    if (use_ctf_) { +      Hypergraph::Node& goal_node = *(forest->nodes_.end()-1); +      foreach(int edge_id, goal_node.in_edges_) +        RefineRule(forest->edges_[edge_id].rule_, ctf_iterations_); +      double alpha = ctf_alpha_; +      bool found_parse; +      for (int i=-1; i < ctf_num_widenings_; ++i) { +        cerr << "Coarse-to-fine source parse, alpha=" << alpha << endl;  +        found_parse = true; +        Hypergraph refined_forest = *forest; +        for (int j=0; j < ctf_iterations_; ++j) { +          cerr << viterbi_stats(refined_forest,"  Coarse forest",true,show_tree_structure_); +          cerr << "  Iteration " << (j+1) << ": Pruning forest... "; +          refined_forest.BeamPruneInsideOutside(1.0, false, alpha, NULL); +          cerr << "Refining forest..."; +          if (RefineForest(&refined_forest)) { +            cerr << "  Refinement succeeded." << endl; +            refined_forest.Reweight(weights); +          } else {  +            cerr << "  Refinement failed. Widening beam." << endl; +            found_parse = false; +            break; +          } +        } +        if (found_parse) { +          forest->swap(refined_forest); +          break; +        } +        alpha *= ctf_wide_alpha_; +      } +      if (!found_parse){ +        if (ctf_exhaustive_){ +          cerr << "Last resort: refining coarse forest without pruning..."; +          for (int j=0; j < ctf_iterations_; ++j) { +            if (RefineForest(forest)){  +              cerr << "  Refinement succeeded." << endl; +              forest->Reweight(weights); +            } else { +              cerr << "  Refinement failed.  No parse found for this sentence." << endl; +              return false; +            } +          }  +        } else  +          return false; +      } +    }      return true;    } +    +  typedef std::pair<int, WordID> StateSplit; +  typedef std::pair<StateSplit, int> StateSplitPair; +  typedef std::tr1::unordered_map<StateSplit, int, boost::hash<StateSplit> > Split2Node; +  typedef std::tr1::unordered_map<int, vector<WordID> > Splits; + +  bool RefineForest(Hypergraph* forest) { +    Hypergraph refined_forest; +    Split2Node s2n; +    Splits splits; +    Hypergraph::Node& coarse_goal_node = *(forest->nodes_.end()-1); +    bool refined_goal_node = false; +    foreach(Hypergraph::Node& node, forest->nodes_){ +      cerr << ".";  +      foreach(int edge_id, node.in_edges_) { +        Hypergraph::Edge& edge = forest->edges_[edge_id]; +        std::vector<int> nt_positions; +        TRulePtr& coarse_rule_ptr = edge.rule_; +        for(int i=0; i< coarse_rule_ptr->f_.size(); ++i){ +          if (coarse_rule_ptr->f_[i] < 0)  +            nt_positions.push_back(i);  +        } +        if (coarse_rule_ptr->fine_rules_ == 0) { +          cerr << "Parsing with mixed levels of coarse-to-fine granularity is currently unsupported." <<  +            endl << "Could not find refinement for: " << coarse_rule_ptr->AsString() << " on edge " << edge_id << " spanning " << edge.i_ << "," << edge.j_ << endl; +          abort(); +        } +        // fine rules apply only if state splits on tail nodes match fine rule nonterminals +        foreach(TRulePtr& fine_rule_ptr, *(coarse_rule_ptr->fine_rules_)) { +          Hypergraph::TailNodeVector tail; +          for (int pos_i=0; pos_i<nt_positions.size(); ++pos_i){ +            WordID fine_cat = fine_rule_ptr->f_[nt_positions[pos_i]]; +            Split2Node::iterator it =  +              s2n.find(StateSplit(edge.tail_nodes_[pos_i], fine_cat)); +            if (it == s2n.end())  +              break; +            else  +              tail.push_back(it->second); +          } +          if (tail.size() == nt_positions.size()) { +            WordID cat = fine_rule_ptr->lhs_; +            Hypergraph::Edge* new_edge = refined_forest.AddEdge(fine_rule_ptr, tail);  +            new_edge->i_ = edge.i_; +            new_edge->j_ = edge.j_; +            new_edge->feature_values_ = fine_rule_ptr->GetFeatureValues(); +            new_edge->feature_values_.set_value(FD::Convert("LatticeCost"),  +              edge.feature_values_[FD::Convert("LatticeCost")]); +            Hypergraph::Node* head_node; +            Split2Node::iterator it = s2n.find(StateSplit(node.id_, cat)); +            if (it == s2n.end()){ +              head_node = refined_forest.AddNode(cat); +              s2n.insert(StateSplitPair(StateSplit(node.id_, cat), head_node->id_)); +              splits[node.id_].push_back(cat); +              if (&node == &coarse_goal_node) +                refined_goal_node = true; +            } else  +              head_node = &(refined_forest.nodes_[it->second]); +            refined_forest.ConnectEdgeToHeadNode(new_edge, head_node); +          } +        } +      } +    } +    cerr << endl; +    forest->swap(refined_forest); +    if (!refined_goal_node) +      return false; +    return true; +  } +  void OutputForest(Hypergraph* h) { +    foreach(Hypergraph::Node& n, h->nodes_){ +      if (n.in_edges_.size() == 0){ +        cerr << "<" << TD::Convert(-n.cat_) << ", ?, ?>" << endl; +      } else { +        cerr << "<" << TD::Convert(-n.cat_) << ", " << h->edges_[n.in_edges_[0]].i_ << ", " << h->edges_[n.in_edges_[0]].j_ << ">" << endl; +      } +      foreach(int edge_id, n.in_edges_){ +        cerr << "    " << h->edges_[edge_id].rule_->AsString() << endl; +      } +    } +  }  }; +  /*  Called once from cdec.cc to setup the initial SCFG translation structure backend  */ diff --git a/decoder/trule.h b/decoder/trule.h index defdbeb9..3bc96165 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -34,6 +34,9 @@ class TRule {    TRule(const std::vector<WordID>& e, const std::vector<WordID>& f, const WordID& lhs) :      e_(e), f_(f), lhs_(lhs), prev_i(-1), prev_j(-1) {} +  TRule(const TRule& other) : +    e_(other.e_), f_(other.f_), lhs_(other.lhs_), scores_(other.scores_), arity_(other.arity_), prev_i(-1), prev_j(-1) {} +    // deprecated - this will be private soon    explicit TRule(const std::string& text, bool strict = false, bool mono = false) : prev_i(-1), prev_j(-1) {      ReadFromString(text, strict, mono); @@ -130,6 +133,8 @@ class TRule {    SparseVector<double> scores_;    char arity_; +   +  // these attributes are application-specific and should probably be refactored     TRulePtr parent_rule_;  // usually NULL, except when doing constrained decoding    // this is only used when doing synchronous parsing @@ -139,6 +144,9 @@ class TRule {    // may be null    boost::shared_ptr<NTSizeSummaryStatistics> nt_size_summary_; +  // only for coarse-to-fine decoding +  boost::shared_ptr<std::vector<TRulePtr> > fine_rules_; +   private:    TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {}    bool SanityCheck() const; | 
