diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/bottom_up_parser.cc | 80 | ||||
| -rw-r--r-- | decoder/cdec_ff.cc | 1 | ||||
| -rw-r--r-- | decoder/ff_ruleshape.cc | 138 | ||||
| -rw-r--r-- | decoder/ff_ruleshape.h | 46 | 
4 files changed, 254 insertions, 11 deletions
| diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index 606b8d7e..8738c8f1 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -45,6 +45,7 @@ class PassiveChart {                   const float lattice_cost);    void ApplyUnaryRules(const int i, const int j); +  void TopoSortUnaries();    const vector<GrammarPtr>& grammars_;    const Lattice& input_; @@ -57,6 +58,7 @@ class PassiveChart {    TRulePtr goal_rule_;    int goal_idx_;             // index of goal node, if found    const int lc_fid_; +  vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars    static WordID kGOAL;       // [Goal]  }; @@ -159,21 +161,78 @@ PassiveChart::PassiveChart(const string& goal,      goal_cat_(TD::Convert(goal) * -1),      goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")),      goal_idx_(-1), -    lc_fid_(FD::Convert("LatticeCost")) { +    lc_fid_(FD::Convert("LatticeCost")), +    unaries_() {    act_chart_.resize(grammars_.size()); -  for (unsigned i = 0; i < grammars_.size(); ++i) +  for (unsigned i = 0; i < grammars_.size(); ++i) {      act_chart_[i] = new ActiveChart(forest, *this); +    const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules(); +    for (unsigned j = 0; j < u.size(); ++j) +      unaries_.push_back(u[j]); +  } +  TopoSortUnaries();    if (!kGOAL) kGOAL = TD::Convert("Goal") * -1;    if (!SILENT) cerr << "  Goal category: [" << goal << ']' << endl;  } +static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) { +  if (mark[node] == 1) { +    cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n"; +    return false; // cycle detected +  } else if (mark[node] == 2) { +    return true; // already been  +  } +  mark[node] = 1; +  const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node); +  if (nit != g.end()) { +    const vector<TRulePtr>& edges = nit->second; +    vector<bool> okay(edges.size(), true); +    for (unsigned i = 0; i < edges.size(); ++i) { +      okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark); +      if (!okay[i]) { +        cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl; +      } +   } +    for (unsigned i = 0; i < edges.size(); ++i) { +      if (okay[i]) u.push_back(edges[i]); +      //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl; +    } +  } +  mark[node] = 2; +  return true; +} + +void PassiveChart::TopoSortUnaries() { +  vector<TRulePtr> u(unaries_.size()); u.clear(); +  map<int, vector<TRulePtr> > g; +  map<int, int> mark; +  //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl; +  mark[goal_cat_] = 2; +  for (unsigned i = 0; i < unaries_.size(); ++i) { +    //cerr << "Adding: " << unaries_[i]->AsString() << endl; +    g[unaries_[i]->f()[0]].push_back(unaries_[i]); +  } +    //m[unaries_[i]->lhs_].push_back(unaries_[i]); +  for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) { +    //cerr << "PROC: " << TD::Convert(-it->first) << endl; +    if (mark[it->first] > 0) { +      //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n"; +    } else { +      TopoSortVisit(it->first, u, g, mark); +    } +  } +  unaries_.clear(); +  for (int i = u.size() - 1; i >= 0; --i) +    unaries_.push_back(u[i]); +} +  void PassiveChart::ApplyRule(const int i,                               const int j,                               const TRulePtr& r,                               const Hypergraph::TailNodeVector& ant_nodes,                               const float lattice_cost) {    Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); -  //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; +  // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;    new_edge->prev_i_ = r->prev_i;    new_edge->prev_j_ = r->prev_j;    new_edge->i_ = i; @@ -215,15 +274,14 @@ void PassiveChart::ApplyRules(const int i,  void PassiveChart::ApplyUnaryRules(const int i, const int j) {    const vector<int>& nodes = chart_(i,j);  // reference is important! -  for (unsigned gi = 0; gi < grammars_.size(); ++gi) { -    if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue; -    for (unsigned di = 0; di < nodes.size(); ++di) { -      const WordID& cat = forest_->nodes_[nodes[di]].cat_; -      const vector<TRulePtr>& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat); -      for (unsigned ri = 0; ri < unaries.size(); ++ri) { -        // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl; +  for (unsigned di = 0; di < nodes.size(); ++di) { +    const WordID& cat = forest_->nodes_[nodes[di]].cat_; +    for (unsigned ri = 0; ri < unaries_.size(); ++ri) { +      //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; +      if (unaries_[ri]->f()[0] == cat) { +        //cerr << "  --MATCH\n";          const Hypergraph::TailNodeVector ant(1, nodes[di]); -        ApplyRule(i, j, unaries[ri], ant, 0);  // may update nodes +        ApplyRule(i, j, unaries_[ri], ant, 0);  // may update nodes        }      }    } diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index b2541722..0411908f 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -58,6 +58,7 @@ void register_feature_functions() {    ff_registry.Register("KLanguageModel", new KLanguageModelFactory());    ff_registry.Register("NonLatinCount", new FFFactory<NonLatinCount>);    ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>); +  ff_registry.Register("RuleShape2", new FFFactory<RuleShapeFeatures2>);    ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>);    ff_registry.Register("LexNullJump", new FFFactory<LexNullJump>);    ff_registry.Register("NewJump", new FFFactory<NewJump>); diff --git a/decoder/ff_ruleshape.cc b/decoder/ff_ruleshape.cc index 7bb548c4..35b41c46 100644 --- a/decoder/ff_ruleshape.cc +++ b/decoder/ff_ruleshape.cc @@ -1,5 +1,8 @@  #include "ff_ruleshape.h" +#include "filelib.h" +#include "stringlib.h" +#include "verbose.h"  #include "trule.h"  #include "hg.h"  #include "fdict.h" @@ -104,3 +107,138 @@ void RuleShapeFeatures::TraversalFeaturesImpl(const SentenceMetadata& /* smeta *    features->set_value(cur->fid_, 1.0);  } +namespace { +void ParseRSArgs(string const& in, string* emapfile, string* fmapfile, unsigned *pfxsize) { +  vector<string> const& argv=SplitOnWhitespace(in); +  *emapfile = ""; +  *fmapfile = ""; +  *pfxsize = 0; +#define RSSPEC_NEXTARG if (i==argv.end()) {            \ +    cerr << "Missing argument for "<<*last<<". "; goto usage; \ +    } else { ++i; } + +  for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { +    string const& s=*i; +    if (s[0]=='-') { +      if (s.size()>2) goto fail; +      switch (s[1]) { +      case 'e': +        if (emapfile->size() > 0) { cerr << "Multiple -e specifications!\n"; abort(); } +        RSSPEC_NEXTARG; *emapfile=*i; +        break; +      case 'f': +        if (fmapfile->size() > 0) { cerr << "Multiple -f specifications!\n"; abort(); } +        RSSPEC_NEXTARG; *fmapfile=*i; +        break; +      case 'p': +        RSSPEC_NEXTARG; *pfxsize=atoi(i->c_str()); +        break; +#undef RSSPEC_NEXTARG +      default: +      fail: +        cerr<<"Unknown RuleShape2 option "<<s<<" ; "; +        goto usage; +      } +    } else { +      cerr << "RuleShape2 bad argument!\n"; +      abort(); +    } +  } +  return; +usage: +  cerr << "Bad parameters for RuleShape2\n"; +  abort(); +} + +inline void AddWordToClassMapping_(vector<WordID>* pv, unsigned f, unsigned t, unsigned pfx_size) { +  if (pfx_size) { +    const string& ts = TD::Convert(t); +    if (pfx_size < ts.size()) +      t = TD::Convert(ts.substr(0, pfx_size)); +  } +  if (f >= pv->size()) +    pv->resize((f + 1) * 1.2); +  (*pv)[f] = t; +} +} + +RuleShapeFeatures2::~RuleShapeFeatures2() {} + +RuleShapeFeatures2::RuleShapeFeatures2(const string& param) : kNT(TD::Convert("NT")), kUNK(TD::Convert("<unk>")) { +  string emap; +  string fmap; +  unsigned pfxsize = 0; +  ParseRSArgs(param, &emap, &fmap, &pfxsize); +  has_src_ = fmap.size(); +  has_trg_ = emap.size(); +  if (has_trg_) LoadWordClasses(emap, pfxsize, &e2class_); +  if (has_src_) LoadWordClasses(fmap, pfxsize, &f2class_); +  if (!has_trg_ && !has_src_) { +    cerr << "RuleShapeFeatures2 requires [-e trg_map.gz] or [-f src_map.gz] or both, and optional [-p pfxsize]\n"; +    abort(); +  } +} + +void RuleShapeFeatures2::LoadWordClasses(const string& file, const unsigned pfx_size, vector<WordID>* pv) { +  ReadFile rf(file); +  istream& in = *rf.stream(); +  string line; +  vector<WordID> dummy; +  int lc = 0; +  if (!SILENT) +    cerr << "  Loading word classes from " << file << " ...\n"; +  AddWordToClassMapping_(pv, TD::Convert("<s>"), TD::Convert("<s>"), 0); +  AddWordToClassMapping_(pv, TD::Convert("</s>"), TD::Convert("</s>"), 0); +  while(getline(in, line)) { +    dummy.clear(); +    TD::ConvertSentence(line, &dummy); +    ++lc; +    if (dummy.size() != 2 && dummy.size() != 3) { +      cerr << "    Class map file expects: CLASS WORD [freq]\n"; +      cerr << "    Format error in " << file << ", line " << lc << ": " << line << endl; +      abort(); +    } +    AddWordToClassMapping_(pv, dummy[1], dummy[0], pfx_size); +  } +  if (!SILENT) +    cerr << "  Loaded word " << lc << " mapping rules.\n"; +} + +void RuleShapeFeatures2::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, +                                               const Hypergraph::Edge& edge, +                                               const vector<const void*>& /* ant_contexts */, +                                               SparseVector<double>* features, +                                               SparseVector<double>* /* estimated_features */, +                                               void* /* context */) const { +  const vector<int>& f = edge.rule_->f(); +  const vector<int>& e = edge.rule_->e(); +  Node* fid = &fidtree_; +  if (has_src_) { +    for (unsigned i = 0; i < f.size(); ++i) +      fid = &fid->next_[MapF(f[i])]; +  } +  if (has_trg_) { +    for (unsigned i = 0; i < e.size(); ++i) +      fid = &fid->next_[MapE(e[i])]; +  } +  if (!fid->fid_) { +    ostringstream os; +    os << "RS:"; +    if (has_src_) { +      for (unsigned i = 0; i < f.size(); ++i) { +        if (i) os << '_'; +        os << TD::Convert(MapF(f[i])); +      } +      if (has_trg_) os << "__"; +    } +    if (has_trg_) { +      for (unsigned i = 0; i < e.size(); ++i) { +        if (i) os << '_'; +        os << TD::Convert(MapE(e[i])); +      } +    } +    fid->fid_ = FD::Convert(os.str()); +  } +  features->set_value(fid->fid_, 1); +} + diff --git a/decoder/ff_ruleshape.h b/decoder/ff_ruleshape.h index 9f20faf3..488cfd84 100644 --- a/decoder/ff_ruleshape.h +++ b/decoder/ff_ruleshape.h @@ -2,6 +2,7 @@  #define _FF_RULESHAPE_H_  #include <vector> +#include <map>  #include "ff.h"  class RuleShapeFeatures : public FeatureFunction { @@ -28,4 +29,49 @@ class RuleShapeFeatures : public FeatureFunction {    }  }; +class RuleShapeFeatures2 : public FeatureFunction { + public: +  ~RuleShapeFeatures2(); +  RuleShapeFeatures2(const std::string& param); + protected: +  virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, +                                     const HG::Edge& edge, +                                     const std::vector<const void*>& ant_contexts, +                                     SparseVector<double>* features, +                                     SparseVector<double>* estimated_features, +                                     void* context) const; + private: +  struct Node { +    int fid_; +    Node() : fid_() {} +    std::map<WordID, Node> next_; +  }; +  mutable Node fidtree_; + +  inline WordID MapE(WordID w) const { +    if (w <= 0) return kNT; +    unsigned res = 0; +    if (w < e2class_.size()) res = e2class_[w]; +    if (!res) res = kUNK; +    return res; +  } + +  inline WordID MapF(WordID w) const { +    if (w <= 0) return kNT; +    unsigned res = 0; +    if (w < f2class_.size()) res = f2class_[w]; +    if (!res) res = kUNK; +    return res; +  } + +  // prfx_size=0 => use full word classes otherwise truncate to specified length +  void LoadWordClasses(const std::string& fname, unsigned pfxsize, std::vector<WordID>* pv); +  const WordID kNT; +  const WordID kUNK; +  std::vector<WordID> e2class_; +  std::vector<WordID> f2class_; +  bool has_src_; +  bool has_trg_; +}; +  #endif | 
