diff options
Diffstat (limited to 'extools')
| -rw-r--r-- | extools/featurize_grammar.cc | 259 | 
1 files changed, 252 insertions, 7 deletions
| diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc index 17f59e6e..8be057b0 100644 --- a/extools/featurize_grammar.cc +++ b/extools/featurize_grammar.cc @@ -10,6 +10,7 @@  #include <cstdlib>  #include <fstream>  #include <tr1/unordered_map> +#include <boost/regex.hpp>  #include "suffix_tree.h"  #include "sparse_vector.h" @@ -20,6 +21,7 @@  #include "lex_trans_tbl.h"  #include "filelib.h" +#include <boost/tuple/tuple.hpp>  #include <boost/shared_ptr.hpp>  #include <boost/functional/hash.hpp>  #include <boost/program_options.hpp> @@ -36,6 +38,102 @@ static const size_t MAX_LINE_LENGTH = 64000000;  typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics; +// Data structures for indexing and counting rules +//typedef boost::tuple< WordID, vector<WordID>, vector<WordID> > RuleTuple; +struct RuleTuple { +  RuleTuple(const WordID& lhs, const vector<WordID>& s, const vector<WordID>& t)  +  : m_lhs(lhs), m_source(s), m_target(t) { +    hash_value(); +    m_dirty = false; +  } +   +  size_t hash_value() const { +//    if (m_dirty) { +      size_t hash = 0; +      boost::hash_combine(hash, m_lhs); +      boost::hash_combine(hash, m_source); +      boost::hash_combine(hash, m_target); +//    } +//    m_dirty = false; +    return hash; +  } + +  bool operator==(RuleTuple const& b) const +  { return m_lhs == b.m_lhs && m_source == b.m_source && m_target == b.m_target; } + +  WordID& lhs() { m_dirty=true; return m_lhs; } +  vector<WordID>& source() { m_dirty=true; return m_source; } +  vector<WordID>& target() { m_dirty=true; return m_target; } +  const WordID& lhs() const { return m_lhs; } +  const vector<WordID>& source() const { return m_source; } +  const vector<WordID>& target() const { return m_target; } + +//  mutable size_t m_hash; +private: +  WordID m_lhs; +  vector<WordID> m_source, m_target; +  mutable bool m_dirty; +}; +std::size_t hash_value(RuleTuple const& b) { return b.hash_value(); } +bool operator<(RuleTuple const& l, RuleTuple const& r) { +  if (l.lhs() < r.lhs()) return true; +  else if (l.lhs() == r.lhs()) { +    if (l.source() < r.source()) return true; +    else if (l.source() == r.source()) { +      if (l.target() < r.target()) return true; +    } +  } +  return false; +} + +ostream& operator<<(ostream& o, RuleTuple const& r) { +  o << "(" << r.lhs() << "-->" << "<"; +  for (vector<WordID>::const_iterator it=r.source().begin(); it!=r.source().end(); ++it) +    o << TD::Convert(*it) << " "; +  o << "|||"; +  for (vector<WordID>::const_iterator it=r.target().begin(); it!=r.target().end(); ++it) +    o << " " << TD::Convert(*it); +  o << ">)"; +  return o; +} + +template <typename Key> +struct FreqCount { +  //typedef unordered_map<Key, int, boost::hash<Key> > Counts; +  typedef map<Key, int> Counts; +  Counts counts; + +  int inc(const Key& r, int c=1) { +    pair<typename Counts::iterator,bool> itb  +      = counts.insert(make_pair(r,c)); +    if (!itb.second)  +      itb.first->second += c;  +    return itb.first->second; +  } + +  int inc_if_exists(const Key& r, int c=1) { +    typename Counts::iterator it = counts.find(r); +    if (it != counts.end())  +      it->second += c;  +    return it->second; +  } + +  int count(const Key& r) const { +    typename Counts::const_iterator it = counts.find(r); +    if (it == counts.end()) return 0; +    return it->second; +  } + +  int operator()(const Key& r) const { return count(r); } +}; +typedef FreqCount<RuleTuple> RuleFreqCount; + +bool validate_non_terminal(const std::string& s) +{ +  static const boost::regex r("\\[X\\d+,\\d+\\]|\\[\\d+\\]"); +  return regex_match(s, r); +} +  namespace {    inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; }    inline bool IsBracket(char c){return c == '[' || c == ']';} @@ -280,7 +378,8 @@ struct LogRuleCount : public FeatureExtractor {                                 const RuleStatistics& info,                                 SparseVector<float>* result) const {      (void) lhs; (void) src; (void) trg; -    result->set_value(fid_, log(info.counts.value(kCFE))); +    //result->set_value(fid_, log(info.counts.value(kCFE))); +    result->set_value(fid_, (info.counts.value(kCFE)));      if (IsZero(info.counts.value(kCFE)))        result->set_value(sfid_, 1);    } @@ -289,6 +388,141 @@ struct LogRuleCount : public FeatureExtractor {    const int kCFE;  }; +// The negative log of the condition rule probs  +// ignoring the identities of the  non-terminals.  +// i.e. the prob Hiero would assign. +struct XFeatures: public FeatureExtractor { +  XFeatures() : +    fid_fe(FD::Convert("XFE")), +    fid_ef(FD::Convert("XEF")), +    kCFE(FD::Convert("CFE")) {} +  virtual void ObserveFilteredRule(const WordID /*lhs*/, +                                   const vector<WordID>& src, +                                   const vector<WordID>& trg) { +    RuleTuple r(-1, src, trg); +    map_rule(r); +    rule_counts.inc(r, 0); +    source_counts.inc(r.source(), 0); +    target_counts.inc(r.target(), 0); +  } + +  // compute statistics over keys, the same lhs-src-trg tuple may be seen +  // more than once +  virtual void ObserveUnfilteredRule(const WordID /*lhs*/, +                                     const vector<WordID>& src, +                                     const vector<WordID>& trg, +                                     const RuleStatistics& info) { +    RuleTuple r(-1, src, trg); +//    cerr << "   ObserveUnfilteredRule() in:" << r << " " << hash_value(r) << endl; +    map_rule(r); +    rule_counts.inc_if_exists(r, info.counts.value(kCFE)); +    source_counts.inc_if_exists(r.source(), info.counts.value(kCFE)); +    target_counts.inc_if_exists(r.target(), info.counts.value(kCFE)); +//    cerr << "   ObserveUnfilteredRule() inc: " << r << " " << hash_value(r) << " " << info.counts.value(kCFE) << " to " << rule_counts(r) << endl; +  } + +  virtual void ExtractFeatures(const WordID /*lhs*/, +                               const vector<WordID>& src, +                               const vector<WordID>& trg, +                               const RuleStatistics& /*info*/, +                               SparseVector<float>* result) const { +    RuleTuple r(-1, src, trg); +    map_rule(r); +    //result->set_value(fid_fe, log(target_counts(r.target())) - log(rule_counts(r))); +    //result->set_value(fid_ef, log(source_counts(r.source())) - log(rule_counts(r))); +    result->set_value(fid_ef, target_counts(r.target())); +    result->set_value(fid_fe, rule_counts(r)); +    //result->set_value(fid_fe, (source_counts(r.source()))); +  } + +  void map_rule(RuleTuple& r) const { +    vector<WordID> indexes; int i=0; +    for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it) { +      if (validate_non_terminal(TD::Convert(*it))) +        indexes.push_back(*it); +    } +    for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it) { +      if (validate_non_terminal(TD::Convert(*it))) +        *it = indexes.at(i++); +    } +  } + +  const int fid_fe, fid_ef; +  const int kCFE; +  RuleFreqCount rule_counts; +  FreqCount< vector<WordID> > source_counts, target_counts; +}; + +struct LabelledRuleConditionals: public FeatureExtractor { +  LabelledRuleConditionals() : +    fid_fe(FD::Convert("TLabelledFE")), +    fid_ef(FD::Convert("TLabelledEF")), +    kCFE(FD::Convert("CFE")) {} +  virtual void ObserveFilteredRule(const WordID /*lhs*/, +                                   const vector<WordID>& src, +                                   const vector<WordID>& trg) { +    RuleTuple r(-1, src, trg); +    rule_counts.inc(r, 0); +    cerr << "   ObservefilteredRule() inc: " << r << " " << hash_value(r) << endl; +//    map_rule(r); +    source_counts.inc(r.source(), 0); +    target_counts.inc(r.target(), 0); +  } + +  // compute statistics over keys, the same lhs-src-trg tuple may be seen +  // more than once +  virtual void ObserveUnfilteredRule(const WordID /*lhs*/, +                                     const vector<WordID>& src, +                                     const vector<WordID>& trg, +                                     const RuleStatistics& info) { +    RuleTuple r(-1, src, trg); +    //cerr << "   ObserveUnfilteredRule() in:" << r << " " << hash_value(r) << endl; +    rule_counts.inc_if_exists(r, info.counts.value(kCFE)); +    cerr << "   ObserveUnfilteredRule() inc_if_exists: " << r << " " << hash_value(r) << " " << info.counts.value(kCFE) << " to " << rule_counts(r) << endl; +//    map_rule(r); +    source_counts.inc_if_exists(r.source(), info.counts.value(kCFE)); +    target_counts.inc_if_exists(r.target(), info.counts.value(kCFE)); +  } + +  virtual void ExtractFeatures(const WordID /*lhs*/, +                               const vector<WordID>& src, +                               const vector<WordID>& trg, +                               const RuleStatistics& info, +                               SparseVector<float>* result) const { +    RuleTuple r(-1, src, trg); +    //cerr << "   ExtractFeatures() in:" << " " << r.m_hash << endl; +    int r_freq = rule_counts(r); +    cerr << "   ExtractFeatures() count: " << r << " " << hash_value(r) << " " << info.counts.value(kCFE) << " | " << rule_counts(r) << endl; +    assert(r_freq == info.counts.value(kCFE)); +    //cerr << "   ExtractFeatures() after:" << " " << r.hash << endl; +    //cerr << "   ExtractFeatures() in:" << r << " " << r_freq << " " << hash_value(r) << endl; +    //cerr << "   ExtractFeatures() in:" << r << " " << r_freq << endl; +//    map_rule(r); +    //result->set_value(fid_fe, log(target_counts(r.target())) - log(r_freq)); +    //result->set_value(fid_ef, log(source_counts(r.source())) - log(r_freq)); +    result->set_value(fid_ef, target_counts(r.target())); +    result->set_value(fid_fe, r_freq); +    //result->set_value(fid_fe, (source_counts(r.source()))); +  } + +  void map_rule(RuleTuple& r) const { +    vector<WordID> indexes; int i=0; +    for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it) { +      if (validate_non_terminal(TD::Convert(*it))) +        indexes.push_back(*it); +    } +    for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it) { +      if (validate_non_terminal(TD::Convert(*it))) +        *it = indexes.at(i++); +    } +  } + +  const int fid_fe, fid_ef; +  const int kCFE; +  RuleFreqCount rule_counts; +  FreqCount< vector<WordID> > source_counts, target_counts; +}; +  // this extracts the lexical translation prob features  // in BOTH directions.  struct LexProbExtractor : public FeatureExtractor { @@ -307,7 +541,7 @@ struct LexProbExtractor : public FeatureExtractor {      delete[] buf;    } -  virtual void ExtractFeatures(const WordID lhs, +  virtual void ExtractFeatures(const WordID /*lhs*/,                                 const vector<WordID>& src,                                 const vector<WordID>& trg,                                 const RuleStatistics& info, @@ -397,6 +631,8 @@ int main(int argc, char** argv){    FERegistry reg;    reg.Register("LogRuleCount", new FEFactory<LogRuleCount>);    reg.Register("LexProb", new FEFactory<LexProbExtractor>); +  reg.Register("XFeatures", new FEFactory<XFeatures>); +  reg.Register("LabelledRuleConditionals", new FEFactory<LabelledRuleConditionals>);    po::variables_map conf;    InitCommandLine(reg, argc, argv, &conf);    aligned_corpus = conf["aligned_corpus"].as<string>();  // GLOBAL VAR @@ -421,10 +657,20 @@ int main(int argc, char** argv){      fs1.getline(buf, MAX_LINE_LENGTH);      if (buf[0] == 0) continue;      ParseLine(buf, &cur_key, &cur_counts); -    src.resize(cur_key.size() - 4); -    for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; +    //src.resize(cur_key.size() - 4); +    src.resize(cur_key.size() - 3); +    for (int i = 0; i < src.size(); ++i) src.at(i) = cur_key.at(i+2); + +    cerr << "Key: "; for (vector<WordID>::const_iterator wit=cur_key.begin(); wit!=cur_key.end(); ++wit) cerr << TD::Convert(*wit) << " "; cerr << endl; +      lhs = cur_key[0]; +    cerr << buf << endl;      for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { + +      cerr << "READ: <"; for (vector<WordID>::const_iterator wit=src.begin(); wit!=src.end(); ++wit) cerr << TD::Convert(*wit) << " "; +      cerr << "|||"; for (vector<WordID>::const_iterator wit=it->first.begin(); wit!=it->first.end(); ++wit) cerr << " " << TD::Convert(*wit); +      cerr << ">\n"; +        for (int i = 0; i < extractors.size(); ++i)          extractors[i]->ObserveFilteredRule(lhs, src, it->first);      } @@ -435,11 +681,10 @@ int main(int argc, char** argv){      cin.getline(buf, MAX_LINE_LENGTH);      if (buf[0] == 0) continue;      ParseLine(buf, &cur_key, &cur_counts); -    src.resize(cur_key.size() - 4); +    src.resize(cur_key.size() - 3);      for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2];      lhs = cur_key[0];      for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { -      // TODO set lhs, src, trg        for (int i = 0; i < extractors.size(); ++i)          extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second);      } @@ -452,7 +697,7 @@ int main(int argc, char** argv){      fs2.getline(buf, MAX_LINE_LENGTH);      if (buf[0] == 0) continue;      ParseLine(buf, &cur_key, &cur_counts); -    src.resize(cur_key.size() - 4); +    src.resize(cur_key.size() - 3);      for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2];      lhs = cur_key[0]; | 
