diff options
Diffstat (limited to 'extools')
| -rw-r--r-- | extools/featurize_grammar.cc | 236 | ||||
| -rw-r--r-- | extools/lex_trans_tbl.h | 1 | 
2 files changed, 58 insertions, 179 deletions
| diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc index 547f390a..9a4af4d8 100644 --- a/extools/featurize_grammar.cc +++ b/extools/featurize_grammar.cc @@ -5,21 +5,17 @@  #include <sstream>  #include <string>  #include <map> -#include <set>  #include <vector>  #include <utility>  #include <cstdlib> -#include <fstream>  #include <tr1/unordered_map> -#include <boost/regex.hpp> -#include "suffix_tree.h" +#include "lex_trans_tbl.h"  #include "sparse_vector.h"  #include "sentence_pair.h"  #include "extract.h"  #include "fdict.h"  #include "tdict.h" -#include "lex_trans_tbl.h"  #include "filelib.h"  #include "striped_grammar.h" @@ -29,7 +25,6 @@  #include <boost/program_options.hpp>  #include <boost/program_options/variables_map.hpp> -  using namespace std;  using namespace std::tr1;  using boost::shared_ptr; @@ -38,8 +33,6 @@ namespace po = boost::program_options;  static string aligned_corpus;  static const size_t MAX_LINE_LENGTH = 64000000; -typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics; -  // Data structures for indexing and counting rules  //typedef boost::tuple< WordID, vector<WordID>, vector<WordID> > RuleTuple;  struct RuleTuple { @@ -130,20 +123,6 @@ struct FreqCount {  };  typedef FreqCount<RuleTuple> RuleFreqCount; -bool validate_non_terminal(const std::string& s) -{ -  static const boost::regex r("\\[X\\d+,\\d+\\]|\\[\\d+\\]"); -  return regex_match(s, r); -} - -namespace { -  inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } -  inline void SkipWhitespace(const char* buf, int* ptr) { -    while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } -  } -} - -  class FeatureExtractor;  class FERegistry;  struct FEFactoryBase { @@ -220,78 +199,7 @@ void InitCommandLine(const FERegistry& r, int argc, char** argv, po::variables_m    }  } -int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) { -  static const WordID kDIV = TD::Convert("|||"); -  int ptr = sstart; -  while(ptr < end) { -    while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; } -    int start = ptr; -    while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } -    if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; } -    const WordID w = TD::Convert(string(buf, start, ptr - start)); -    if (w == kDIV) return ptr; -    p->push_back(w); -  } -  assert(p->size() > 0); -  return ptr;   -} - -void ParseLine(const char* buf, vector<WordID>* cur_key, ID2RuleStatistics* counts) { -  static const WordID kDIV = TD::Convert("|||"); -  counts->clear(); -  int ptr = 0; -  while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; } -  if (buf[ptr] != '\t') { -    cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl; -    exit(1); -  } -  cur_key->clear(); -  // key is: "[X] ||| word word word" -  int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key); -  if (buf[tmpp] != '\t') { -    cur_key->push_back(kDIV); -    ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key); -  } -  ++ptr; -  int start = ptr; -  int end = ptr; -  int state = 0; // 0=reading label, 1=reading count -  vector<WordID> name; -  while(buf[ptr] != 0) { -    while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; } -    if (buf[ptr] == '|') { -      ++ptr; -      if (buf[ptr] == '|') { -        ++ptr; -        if (buf[ptr] == '|') { -          ++ptr; -          end = ptr - 3; -          while (end > start && IsWhitespace(buf[end-1])) { --end; } -          if (start == end) { -            cerr << "Got empty token!\n  LINE=" << buf << endl; -            exit(1); -          } -          switch (state) { -            case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; -            case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; -            default: cerr << "Can't happen\n"; abort(); -          } -          SkipWhitespace(buf, &ptr); -          start = ptr; -        } -      } -    } -  } -  end=ptr; -  while (end > start && IsWhitespace(buf[end-1])) { --end; } -  if (end > start) { -    switch (state) { -      case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; -      case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; -      default: cerr << "Can't happen\n"; abort(); -    } -  } -} +static const bool DEBUG = false;  void LexTranslationTable::createTTable(const char* buf){    AnnotatedParallelSentence sent; @@ -431,11 +339,7 @@ struct XFeatures: public FeatureExtractor {      RuleTuple r(-1, src, trg);      map_rule(r);      rule_counts.inc(r, 0); - -    normalise_string(r.source());      source_counts.inc(r.source(), 0); - -    normalise_string(r.target());      target_counts.inc(r.target(), 0);    } @@ -446,11 +350,7 @@ struct XFeatures: public FeatureExtractor {      RuleTuple r(-1, src, trg);      map_rule(r);      rule_counts.inc_if_exists(r, info.counts.value(kCFE)); - -    normalise_string(r.source());      source_counts.inc_if_exists(r.source(), info.counts.value(kCFE)); - -    normalise_string(r.target());      target_counts.inc_if_exists(r.target(), info.counts.value(kCFE));    } @@ -463,11 +363,9 @@ struct XFeatures: public FeatureExtractor {      map_rule(r);      double l_r_freq = log(rule_counts(r)); -    normalise_string(r.target());      result->set_value(fid_xfe, log(target_counts(r.target())) - l_r_freq);      result->set_value(fid_labelledfe, log(target_counts(r.target())) - log(info.counts.value(kCFE))); -    normalise_string(r.source());      result->set_value(fid_xef, log(source_counts(r.source())) - l_r_freq);      result->set_value(fid_labelledef, log(source_counts(r.source())) - log(info.counts.value(kCFE)));    } @@ -475,21 +373,15 @@ struct XFeatures: public FeatureExtractor {    void map_rule(RuleTuple& r) const {      vector<WordID> indexes; int i=0;      for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it) { -      if (validate_non_terminal(TD::Convert(*it))) +      if (*it <= 0)          indexes.push_back(*it);      }      for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it) { -      if (validate_non_terminal(TD::Convert(*it))) +      if (*it <= 0)          *it = indexes.at(i++);      }    } -  void normalise_string(vector<WordID>& r) const { -    vector<WordID> indexes; -    for (vector<WordID>::iterator it = r.begin(); it != r.end(); ++it) -      if (validate_non_terminal(TD::Convert(*it))) *it = -1; -  } -    const int fid_xfe, fid_xef;    const int fid_labelledfe, fid_labelledef;    const int kCFE; @@ -508,10 +400,8 @@ struct LabelledRuleConditionals: public FeatureExtractor {                                     const vector<WordID>& trg) {      RuleTuple r(lhs, src, trg);      rule_counts.inc(r, 0); -    normalise_string(r.source());      source_counts.inc(r.source(), 0); -    normalise_string(r.target());      target_counts.inc(r.target(), 0);    } @@ -521,10 +411,8 @@ struct LabelledRuleConditionals: public FeatureExtractor {                                       const RuleStatistics& info) {      RuleTuple r(lhs, src, trg);      rule_counts.inc_if_exists(r, info.counts.value(kCFE)); -    normalise_string(r.source());      source_counts.inc_if_exists(r.source(), info.counts.value(kCFE)); -    normalise_string(r.target());      target_counts.inc_if_exists(r.target(), info.counts.value(kCFE));    } @@ -535,18 +423,10 @@ struct LabelledRuleConditionals: public FeatureExtractor {                                 SparseVector<float>* result) const {      RuleTuple r(lhs, src, trg);      double l_r_freq = log(rule_counts(r)); -    normalise_string(r.target());      result->set_value(fid_fe, log(target_counts(r.target())) - l_r_freq); -    normalise_string(r.source());      result->set_value(fid_ef, log(source_counts(r.source())) - l_r_freq);    } -  void normalise_string(vector<WordID>& r) const { -    vector<WordID> indexes; -    for (vector<WordID>::iterator it = r.begin(); it != r.end(); ++it) -      if (validate_non_terminal(TD::Convert(*it))) *it = -1; -  } -    const int fid_fe, fid_ef;    const int kCFE;    RuleFreqCount rule_counts; @@ -643,9 +523,9 @@ struct LabellingShape: public FeatureExtractor {    // Replace all terminals with generic -1    void map_rule(RuleTuple& r) const {      for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it)  -      if (!validate_non_terminal(TD::Convert(*it))) *it = -1; +      if (*it <= 0) *it = -1;      for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it)  -      if (!validate_non_terminal(TD::Convert(*it))) *it = -1; +      if (*it <= 0) *it = -1;    }    const int fid_, kCFE; @@ -758,6 +638,51 @@ struct LexProbExtractor : public FeatureExtractor {    mutable LexTranslationTable table;  }; +struct Featurizer { +  Featurizer(const vector<boost::shared_ptr<FeatureExtractor> >& ex) : extractors(ex) { +  } +  void Callback1(WordID lhs, const vector<WordID>& src, const ID2RuleStatistics& trgs) { +    for (ID2RuleStatistics::const_iterator it = trgs.begin(); it != trgs.end(); ++it) { +      for (int i = 0; i < extractors.size(); ++i) +        extractors[i]->ObserveFilteredRule(lhs, src, it->first); +    } +  } +  void Callback2(WordID lhs, const vector<WordID>& src, const ID2RuleStatistics& trgs) { +    for (ID2RuleStatistics::const_iterator it = trgs.begin(); it != trgs.end(); ++it) { +      for (int i = 0; i < extractors.size(); ++i) +        extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second); +    } +  } +  void Callback3(WordID lhs, const vector<WordID>& src, const ID2RuleStatistics& trgs) { +    for (ID2RuleStatistics::const_iterator it = trgs.begin(); it != trgs.end(); ++it) { +      SparseVector<float> feats; +      for (int i = 0; i < extractors.size(); ++i) +        extractors[i]->ExtractFeatures(lhs, src, it->first, it->second, &feats); +      cout << '[' << TD::Convert(-lhs) << "] ||| "; +      WriteNamed(src, &cout); +      cout << " ||| "; +      WriteAnonymous(it->first, &cout); +      cout << " ||| "; +      feats.Write(false, &cout); +      cout << endl; +    } +  } + private: +  vector<boost::shared_ptr<FeatureExtractor> > extractors; +}; + +void cb1(WordID lhs, const vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void* extra) { +  static_cast<Featurizer*>(extra)->Callback1(lhs, src_rhs, rules); +} + +void cb2(WordID lhs, const vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void* extra) { +  static_cast<Featurizer*>(extra)->Callback2(lhs, src_rhs, rules); +} + +void cb3(WordID lhs, const vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void* extra) { +  static_cast<Featurizer*>(extra)->Callback3(lhs, src_rhs, rules); +} +  int main(int argc, char** argv){    FERegistry reg;    reg.Register("LogRuleCount", new FEFactory<LogRuleCount>); @@ -778,65 +703,18 @@ int main(int argc, char** argv){    vector<boost::shared_ptr<FeatureExtractor> > extractors(feats.size());    for (int i = 0; i < feats.size(); ++i)      extractors[i] = reg.Create(feats[i]); +  Featurizer fizer(extractors); -  //score unscored grammar    cerr << "Reading filtered grammar to detect keys..." << endl; -  char* buf = new char[MAX_LINE_LENGTH]; - -  ID2RuleStatistics acc, cur_counts; -  vector<WordID> key, cur_key,temp_key; -  WordID lhs = 0; -  vector<WordID> src; - -  istream& fs1 = *fg1.stream(); -  while(fs1) { -    fs1.getline(buf, MAX_LINE_LENGTH); -    if (buf[0] == 0) continue; -    ParseLine(buf, &cur_key, &cur_counts); -    src.resize(cur_key.size() - 2); -    for (int i = 0; i < src.size(); ++i) src.at(i) = cur_key.at(i+2); - -    lhs = cur_key[0]; -    for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { -      for (int i = 0; i < extractors.size(); ++i) -        extractors[i]->ObserveFilteredRule(lhs, src, it->first); -    } -  } +  StripedGrammarLexer::ReadStripedGrammar(fg1.stream(), cb1, &fizer);    cerr << "Reading unfiltered grammar..." << endl; -  while(cin) { -    cin.getline(buf, MAX_LINE_LENGTH); -    if (buf[0] == 0) continue; -    ParseLine(buf, &cur_key, &cur_counts); -    src.resize(cur_key.size() - 2); -    for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; -    lhs = cur_key[0]; -    for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { -      for (int i = 0; i < extractors.size(); ++i) -        extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second); -    } -  } +  StripedGrammarLexer::ReadStripedGrammar(&cin, cb2, &fizer);    ReadFile fg2(conf["filtered_grammar"].as<string>()); -  istream& fs2 = *fg2.stream();    cerr << "Reading filtered grammar and adding features..." << endl; -  while(fs2) { -    fs2.getline(buf, MAX_LINE_LENGTH); -    if (buf[0] == 0) continue; -    ParseLine(buf, &cur_key, &cur_counts); -    src.resize(cur_key.size() - 2); -    for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; -    lhs = cur_key[0]; - -    //loop over all the Target side phrases that this source aligns to -    for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { -      SparseVector<float> feats; -      for (int i = 0; i < extractors.size(); ++i) -        extractors[i]->ExtractFeatures(lhs, src, it->first, it->second, &feats); -      cout << TD::Convert(lhs) << " ||| " << TD::GetString(src) << " ||| " << TD::GetString(it->first) << " ||| "; -      feats.Write(false, &cout); -      cout << endl; -    } -  } +  StripedGrammarLexer::ReadStripedGrammar(fg2.stream(), cb3, &fizer); + +  return 0;  } diff --git a/extools/lex_trans_tbl.h b/extools/lex_trans_tbl.h index 81d6ccc9..161b4a0d 100644 --- a/extools/lex_trans_tbl.h +++ b/extools/lex_trans_tbl.h @@ -8,6 +8,7 @@  #ifndef LEX_TRANS_TBL_H_  #define LEX_TRANS_TBL_H_ +#include "wordid.h"  #include <map>  class LexTranslationTable | 
