diff options
Diffstat (limited to 'extools')
| -rw-r--r-- | extools/featurize_grammar.cc | 116 | 
1 files changed, 88 insertions, 28 deletions
| diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc index 4c9821ec..17f59e6e 100644 --- a/extools/featurize_grammar.cc +++ b/extools/featurize_grammar.cc @@ -2,6 +2,7 @@   * Featurize a grammar in striped format   */  #include <iostream> +#include <sstream>  #include <string>  #include <map>  #include <vector> @@ -24,18 +25,86 @@  #include <boost/program_options.hpp>  #include <boost/program_options/variables_map.hpp> +  using namespace std;  using namespace std::tr1; +using boost::shared_ptr;  namespace po = boost::program_options; +static string aligned_corpus;  static const size_t MAX_LINE_LENGTH = 64000000;  typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics; -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +namespace { +  inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } +  inline bool IsBracket(char c){return c == '[' || c == ']';} +  inline void SkipWhitespace(const char* buf, int* ptr) { +    while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } +  } +} + + +class FeatureExtractor; +class FERegistry; +struct FEFactoryBase { +  virtual ~FEFactoryBase() {} +  virtual boost::shared_ptr<FeatureExtractor> Create() const = 0; +}; + + +class FERegistry { +  friend class FEFactoryBase; + public: +  FERegistry() {} +  boost::shared_ptr<FeatureExtractor> Create(const std::string& ffname) const { +    map<string, shared_ptr<FEFactoryBase> >::const_iterator it = reg_.find(ffname); +    shared_ptr<FeatureExtractor> res; +    if (it == reg_.end()) { +      cerr << "I don't know how to create feature " << ffname << endl; +    } else { +      res = it->second->Create(); +    } +    return res; +  } +  void DisplayList(ostream* out) const { +    bool first = true; +    for (map<string, shared_ptr<FEFactoryBase> >::const_iterator it = reg_.begin(); +        it != reg_.end(); ++it) { +      if (first) {first=false;} else {*out << ' ';} +      *out << it->first; +    } +  } + +  void Register(const std::string& ffname, FEFactoryBase* factory) { +    if (reg_.find(ffname) != reg_.end()) { +      cerr << "Duplicate registration of FeatureExtractor with name " << ffname << "!\n"; +      exit(1); +    } +    reg_[ffname].reset(factory); +  } + + private: +  std::map<std::string, boost::shared_ptr<FEFactoryBase> > reg_; +}; + +template<class FE> +class FEFactory : public FEFactoryBase { +  boost::shared_ptr<FeatureExtractor> Create() const { +    return boost::shared_ptr<FeatureExtractor>(new FE); +  } +}; + +void InitCommandLine(const FERegistry& r, int argc, char** argv, po::variables_map* conf) {    po::options_description opts("Configuration options"); +  ostringstream feats; +  feats << "[multiple] Features to extract ("; +  r.DisplayList(&feats); +  feats << ")";    opts.add_options()          ("filtered_grammar,g", po::value<string>(), "Grammar to add features to") +        ("list_features,L", "List extractable features") +        ("feature,f", po::value<vector<string> >()->composing(), feats.str().c_str())          ("aligned_corpus,c", po::value<string>(), "Aligned corpus (single line format)")          ("help,h", "Print this help message and exit");    po::options_description clo("Command line options"); @@ -45,21 +114,13 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {    po::store(parse_command_line(argc, argv, dcmdline_options), *conf);    po::notify(*conf); -  if (conf->count("help") || conf->count("aligned_corpus")==0) { -    cerr << "\nUsage: featurize_grammar -g FILTERED-GRAMMAR.gz -c ALIGNED_CORPUS.fr-en-al [-options] < UNFILTERED-GRAMMAR\n"; +  if (conf->count("help") || conf->count("aligned_corpus")==0 || conf->count("feature") == 0) { +    cerr << "\nUsage: featurize_grammar -g FILTERED-GRAMMAR.gz -c ALIGNED_CORPUS.fr-en-al -f Feat1 -f Feat2 ... < UNFILTERED-GRAMMAR\n";      cerr << dcmdline_options << endl;      exit(1);    }  } -namespace { -  inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } -  inline bool IsBracket(char c){return c == '[' || c == ']';} -  inline void SkipWhitespace(const char* buf, int* ptr) { -    while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } -  } -} -  int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) {    static const WordID kDIV = TD::Convert("|||");    int ptr = sstart; @@ -231,11 +292,11 @@ struct LogRuleCount : public FeatureExtractor {  // this extracts the lexical translation prob features  // in BOTH directions.  struct LexProbExtractor : public FeatureExtractor { -  LexProbExtractor(const std::string& corpus) : +  LexProbExtractor() :        e2f_(FD::Convert("LexE2F")), f2e_(FD::Convert("LexF2E")) { -    ReadFile rf(corpus); +    ReadFile rf(aligned_corpus);      //create lexical translation table -    cerr << "Computing lexical translation probabilities from " << corpus << "..." << endl; +    cerr << "Computing lexical translation probabilities from " << aligned_corpus << "..." << endl;      char* buf = new char[MAX_LINE_LENGTH];      istream& alignment = *rf.stream();      while(alignment) { @@ -333,16 +394,18 @@ struct LexProbExtractor : public FeatureExtractor {  };  int main(int argc, char** argv){ +  FERegistry reg; +  reg.Register("LogRuleCount", new FEFactory<LogRuleCount>); +  reg.Register("LexProb", new FEFactory<LexProbExtractor>);    po::variables_map conf; -  InitCommandLine(argc, argv, &conf); -  ifstream alignment (conf["aligned_corpus"].as<string>().c_str()); +  InitCommandLine(reg, argc, argv, &conf); +  aligned_corpus = conf["aligned_corpus"].as<string>();  // GLOBAL VAR    ReadFile fg1(conf["filtered_grammar"].as<string>()); - -  // TODO make this list configurable -  vector<boost::shared_ptr<FeatureExtractor> > extractors; -  extractors.push_back(boost::shared_ptr<FeatureExtractor>(new LogRuleCount)); -  extractors.push_back(boost::shared_ptr<FeatureExtractor>(new LexProbExtractor(conf["aligned_corpus"].as<string>()))); +  vector<string> feats = conf["feature"].as<vector<string> >(); +  vector<boost::shared_ptr<FeatureExtractor> > extractors(feats.size()); +  for (int i = 0; i < feats.size(); ++i) +    extractors[i] = reg.Create(feats[i]);    //score unscored grammar    cerr << "Reading filtered grammar to detect keys..." << endl; @@ -353,14 +416,12 @@ int main(int argc, char** argv){    WordID lhs = 0;    vector<WordID> src; -#if 0    istream& fs1 = *fg1.stream(); -  int line = 0;    while(fs1) {      fs1.getline(buf, MAX_LINE_LENGTH);      if (buf[0] == 0) continue;      ParseLine(buf, &cur_key, &cur_counts); -    src.resize(cur_key.size() - 2); +    src.resize(cur_key.size() - 4);      for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2];      lhs = cur_key[0];      for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { @@ -374,7 +435,7 @@ int main(int argc, char** argv){      cin.getline(buf, MAX_LINE_LENGTH);      if (buf[0] == 0) continue;      ParseLine(buf, &cur_key, &cur_counts); -    src.resize(cur_key.size() - 2); +    src.resize(cur_key.size() - 4);      for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2];      lhs = cur_key[0];      for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { @@ -383,7 +444,6 @@ int main(int argc, char** argv){          extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second);      }    } -#endif    ReadFile fg2(conf["filtered_grammar"].as<string>());    istream& fs2 = *fg2.stream(); @@ -392,7 +452,7 @@ int main(int argc, char** argv){      fs2.getline(buf, MAX_LINE_LENGTH);      if (buf[0] == 0) continue;      ParseLine(buf, &cur_key, &cur_counts); -    src.resize(cur_key.size() - 2); +    src.resize(cur_key.size() - 4);      for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2];      lhs = cur_key[0]; @@ -401,7 +461,7 @@ int main(int argc, char** argv){        SparseVector<float> feats;        for (int i = 0; i < extractors.size(); ++i)          extractors[i]->ExtractFeatures(lhs, src, it->first, it->second, &feats); -      cout << TD::GetString(cur_key) << " ||| " << TD::GetString(it->first) << " ||| "; +      cout << TD::Convert(lhs) << " ||| " << TD::GetString(src) << " ||| " << TD::GetString(it->first) << " ||| ";        feats.Write(false, &cout);        cout << endl;      } | 
