From 29d0b2bcd7263010e27224c1344ccc9a2f30b623 Mon Sep 17 00:00:00 2001 From: redpony Date: Tue, 6 Jul 2010 20:24:06 +0000 Subject: more git-svn-id: https://ws10smt.googlecode.com/svn/trunk@164 ec762483-ff6d-05da-a07a-a48fb63a330f --- extools/featurize_grammar.cc | 116 ++++++++++++++++++++++++++++++++----------- 1 file changed, 88 insertions(+), 28 deletions(-) (limited to 'extools') diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc index 4c9821ec..17f59e6e 100644 --- a/extools/featurize_grammar.cc +++ b/extools/featurize_grammar.cc @@ -2,6 +2,7 @@ * Featurize a grammar in striped format */ #include +#include #include #include #include @@ -24,18 +25,86 @@ #include #include + using namespace std; using namespace std::tr1; +using boost::shared_ptr; namespace po = boost::program_options; +static string aligned_corpus; static const size_t MAX_LINE_LENGTH = 64000000; typedef unordered_map, RuleStatistics, boost::hash > > ID2RuleStatistics; -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +namespace { + inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } + inline bool IsBracket(char c){return c == '[' || c == ']';} + inline void SkipWhitespace(const char* buf, int* ptr) { + while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } + } +} + + +class FeatureExtractor; +class FERegistry; +struct FEFactoryBase { + virtual ~FEFactoryBase() {} + virtual boost::shared_ptr Create() const = 0; +}; + + +class FERegistry { + friend class FEFactoryBase; + public: + FERegistry() {} + boost::shared_ptr Create(const std::string& ffname) const { + map >::const_iterator it = reg_.find(ffname); + shared_ptr res; + if (it == reg_.end()) { + cerr << "I don't know how to create feature " << ffname << endl; + } else { + res = it->second->Create(); + } + return res; + } + void DisplayList(ostream* out) const { + bool first = true; + for (map >::const_iterator it = reg_.begin(); + it != reg_.end(); ++it) { + if (first) {first=false;} else {*out << ' ';} + *out << it->first; + } + } + + void Register(const std::string& ffname, FEFactoryBase* factory) { + if (reg_.find(ffname) != reg_.end()) { + cerr << "Duplicate registration of FeatureExtractor with name " << ffname << "!\n"; + exit(1); + } + reg_[ffname].reset(factory); + } + + private: + std::map > reg_; +}; + +template +class FEFactory : public FEFactoryBase { + boost::shared_ptr Create() const { + return boost::shared_ptr(new FE); + } +}; + +void InitCommandLine(const FERegistry& r, int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); + ostringstream feats; + feats << "[multiple] Features to extract ("; + r.DisplayList(&feats); + feats << ")"; opts.add_options() ("filtered_grammar,g", po::value(), "Grammar to add features to") + ("list_features,L", "List extractable features") + ("feature,f", po::value >()->composing(), feats.str().c_str()) ("aligned_corpus,c", po::value(), "Aligned corpus (single line format)") ("help,h", "Print this help message and exit"); po::options_description clo("Command line options"); @@ -45,21 +114,13 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::store(parse_command_line(argc, argv, dcmdline_options), *conf); po::notify(*conf); - if (conf->count("help") || conf->count("aligned_corpus")==0) { - cerr << "\nUsage: featurize_grammar -g FILTERED-GRAMMAR.gz -c ALIGNED_CORPUS.fr-en-al [-options] < UNFILTERED-GRAMMAR\n"; + if (conf->count("help") || conf->count("aligned_corpus")==0 || conf->count("feature") == 0) { + cerr << "\nUsage: featurize_grammar -g FILTERED-GRAMMAR.gz -c ALIGNED_CORPUS.fr-en-al -f Feat1 -f Feat2 ... < UNFILTERED-GRAMMAR\n"; cerr << dcmdline_options << endl; exit(1); } } -namespace { - inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } - inline bool IsBracket(char c){return c == '[' || c == ']';} - inline void SkipWhitespace(const char* buf, int* ptr) { - while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } - } -} - int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector* p) { static const WordID kDIV = TD::Convert("|||"); int ptr = sstart; @@ -231,11 +292,11 @@ struct LogRuleCount : public FeatureExtractor { // this extracts the lexical translation prob features // in BOTH directions. struct LexProbExtractor : public FeatureExtractor { - LexProbExtractor(const std::string& corpus) : + LexProbExtractor() : e2f_(FD::Convert("LexE2F")), f2e_(FD::Convert("LexF2E")) { - ReadFile rf(corpus); + ReadFile rf(aligned_corpus); //create lexical translation table - cerr << "Computing lexical translation probabilities from " << corpus << "..." << endl; + cerr << "Computing lexical translation probabilities from " << aligned_corpus << "..." << endl; char* buf = new char[MAX_LINE_LENGTH]; istream& alignment = *rf.stream(); while(alignment) { @@ -333,16 +394,18 @@ struct LexProbExtractor : public FeatureExtractor { }; int main(int argc, char** argv){ + FERegistry reg; + reg.Register("LogRuleCount", new FEFactory); + reg.Register("LexProb", new FEFactory); po::variables_map conf; - InitCommandLine(argc, argv, &conf); - ifstream alignment (conf["aligned_corpus"].as().c_str()); + InitCommandLine(reg, argc, argv, &conf); + aligned_corpus = conf["aligned_corpus"].as(); // GLOBAL VAR ReadFile fg1(conf["filtered_grammar"].as()); - - // TODO make this list configurable - vector > extractors; - extractors.push_back(boost::shared_ptr(new LogRuleCount)); - extractors.push_back(boost::shared_ptr(new LexProbExtractor(conf["aligned_corpus"].as()))); + vector feats = conf["feature"].as >(); + vector > extractors(feats.size()); + for (int i = 0; i < feats.size(); ++i) + extractors[i] = reg.Create(feats[i]); //score unscored grammar cerr << "Reading filtered grammar to detect keys..." << endl; @@ -353,14 +416,12 @@ int main(int argc, char** argv){ WordID lhs = 0; vector src; -#if 0 istream& fs1 = *fg1.stream(); - int line = 0; while(fs1) { fs1.getline(buf, MAX_LINE_LENGTH); if (buf[0] == 0) continue; ParseLine(buf, &cur_key, &cur_counts); - src.resize(cur_key.size() - 2); + src.resize(cur_key.size() - 4); for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; lhs = cur_key[0]; for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { @@ -374,7 +435,7 @@ int main(int argc, char** argv){ cin.getline(buf, MAX_LINE_LENGTH); if (buf[0] == 0) continue; ParseLine(buf, &cur_key, &cur_counts); - src.resize(cur_key.size() - 2); + src.resize(cur_key.size() - 4); for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; lhs = cur_key[0]; for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { @@ -383,7 +444,6 @@ int main(int argc, char** argv){ extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second); } } -#endif ReadFile fg2(conf["filtered_grammar"].as()); istream& fs2 = *fg2.stream(); @@ -392,7 +452,7 @@ int main(int argc, char** argv){ fs2.getline(buf, MAX_LINE_LENGTH); if (buf[0] == 0) continue; ParseLine(buf, &cur_key, &cur_counts); - src.resize(cur_key.size() - 2); + src.resize(cur_key.size() - 4); for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; lhs = cur_key[0]; @@ -401,7 +461,7 @@ int main(int argc, char** argv){ SparseVector feats; for (int i = 0; i < extractors.size(); ++i) extractors[i]->ExtractFeatures(lhs, src, it->first, it->second, &feats); - cout << TD::GetString(cur_key) << " ||| " << TD::GetString(it->first) << " ||| "; + cout << TD::Convert(lhs) << " ||| " << TD::GetString(src) << " ||| " << TD::GetString(it->first) << " ||| "; feats.Write(false, &cout); cout << endl; } -- cgit v1.2.3