diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-13 16:15:33 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-13 16:15:33 +0000 |
commit | 2eeaa2eb91334bea11d70db1011f1a28ce3bb7d2 (patch) | |
tree | 42a7480c4d601064fa9a5b25a33d59eb12cdf166 | |
parent | bddde964a3686d79e77898def1ca7eb228c7caf1 (diff) |
major speed up using DFA parser
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@235 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r-- | extools/featurize_grammar.cc | 236 | ||||
-rw-r--r-- | extools/lex_trans_tbl.h | 1 |
2 files changed, 58 insertions, 179 deletions
diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc index 547f390a..9a4af4d8 100644 --- a/extools/featurize_grammar.cc +++ b/extools/featurize_grammar.cc @@ -5,21 +5,17 @@ #include <sstream> #include <string> #include <map> -#include <set> #include <vector> #include <utility> #include <cstdlib> -#include <fstream> #include <tr1/unordered_map> -#include <boost/regex.hpp> -#include "suffix_tree.h" +#include "lex_trans_tbl.h" #include "sparse_vector.h" #include "sentence_pair.h" #include "extract.h" #include "fdict.h" #include "tdict.h" -#include "lex_trans_tbl.h" #include "filelib.h" #include "striped_grammar.h" @@ -29,7 +25,6 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> - using namespace std; using namespace std::tr1; using boost::shared_ptr; @@ -38,8 +33,6 @@ namespace po = boost::program_options; static string aligned_corpus; static const size_t MAX_LINE_LENGTH = 64000000; -typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics; - // Data structures for indexing and counting rules //typedef boost::tuple< WordID, vector<WordID>, vector<WordID> > RuleTuple; struct RuleTuple { @@ -130,20 +123,6 @@ struct FreqCount { }; typedef FreqCount<RuleTuple> RuleFreqCount; -bool validate_non_terminal(const std::string& s) -{ - static const boost::regex r("\\[X\\d+,\\d+\\]|\\[\\d+\\]"); - return regex_match(s, r); -} - -namespace { - inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } - inline void SkipWhitespace(const char* buf, int* ptr) { - while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } - } -} - - class FeatureExtractor; class FERegistry; struct FEFactoryBase { @@ -220,78 +199,7 @@ void InitCommandLine(const FERegistry& r, int argc, char** argv, po::variables_m } } -int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) { - static const WordID kDIV = TD::Convert("|||"); - int ptr = sstart; - while(ptr < end) { - while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; } - int start = ptr; - while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } - if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; } - const WordID w = TD::Convert(string(buf, start, ptr - start)); - if (w == kDIV) return ptr; - p->push_back(w); - } - assert(p->size() > 0); - return ptr; -} - -void ParseLine(const char* buf, vector<WordID>* cur_key, ID2RuleStatistics* counts) { - static const WordID kDIV = TD::Convert("|||"); - counts->clear(); - int ptr = 0; - while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; } - if (buf[ptr] != '\t') { - cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl; - exit(1); - } - cur_key->clear(); - // key is: "[X] ||| word word word" - int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key); - if (buf[tmpp] != '\t') { - cur_key->push_back(kDIV); - ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key); - } - ++ptr; - int start = ptr; - int end = ptr; - int state = 0; // 0=reading label, 1=reading count - vector<WordID> name; - while(buf[ptr] != 0) { - while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; } - if (buf[ptr] == '|') { - ++ptr; - if (buf[ptr] == '|') { - ++ptr; - if (buf[ptr] == '|') { - ++ptr; - end = ptr - 3; - while (end > start && IsWhitespace(buf[end-1])) { --end; } - if (start == end) { - cerr << "Got empty token!\n LINE=" << buf << endl; - exit(1); - } - switch (state) { - case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; - case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; - default: cerr << "Can't happen\n"; abort(); - } - SkipWhitespace(buf, &ptr); - start = ptr; - } - } - } - } - end=ptr; - while (end > start && IsWhitespace(buf[end-1])) { --end; } - if (end > start) { - switch (state) { - case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; - case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; - default: cerr << "Can't happen\n"; abort(); - } - } -} +static const bool DEBUG = false; void LexTranslationTable::createTTable(const char* buf){ AnnotatedParallelSentence sent; @@ -431,11 +339,7 @@ struct XFeatures: public FeatureExtractor { RuleTuple r(-1, src, trg); map_rule(r); rule_counts.inc(r, 0); - - normalise_string(r.source()); source_counts.inc(r.source(), 0); - - normalise_string(r.target()); target_counts.inc(r.target(), 0); } @@ -446,11 +350,7 @@ struct XFeatures: public FeatureExtractor { RuleTuple r(-1, src, trg); map_rule(r); rule_counts.inc_if_exists(r, info.counts.value(kCFE)); - - normalise_string(r.source()); source_counts.inc_if_exists(r.source(), info.counts.value(kCFE)); - - normalise_string(r.target()); target_counts.inc_if_exists(r.target(), info.counts.value(kCFE)); } @@ -463,11 +363,9 @@ struct XFeatures: public FeatureExtractor { map_rule(r); double l_r_freq = log(rule_counts(r)); - normalise_string(r.target()); result->set_value(fid_xfe, log(target_counts(r.target())) - l_r_freq); result->set_value(fid_labelledfe, log(target_counts(r.target())) - log(info.counts.value(kCFE))); - normalise_string(r.source()); result->set_value(fid_xef, log(source_counts(r.source())) - l_r_freq); result->set_value(fid_labelledef, log(source_counts(r.source())) - log(info.counts.value(kCFE))); } @@ -475,21 +373,15 @@ struct XFeatures: public FeatureExtractor { void map_rule(RuleTuple& r) const { vector<WordID> indexes; int i=0; for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it) { - if (validate_non_terminal(TD::Convert(*it))) + if (*it <= 0) indexes.push_back(*it); } for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it) { - if (validate_non_terminal(TD::Convert(*it))) + if (*it <= 0) *it = indexes.at(i++); } } - void normalise_string(vector<WordID>& r) const { - vector<WordID> indexes; - for (vector<WordID>::iterator it = r.begin(); it != r.end(); ++it) - if (validate_non_terminal(TD::Convert(*it))) *it = -1; - } - const int fid_xfe, fid_xef; const int fid_labelledfe, fid_labelledef; const int kCFE; @@ -508,10 +400,8 @@ struct LabelledRuleConditionals: public FeatureExtractor { const vector<WordID>& trg) { RuleTuple r(lhs, src, trg); rule_counts.inc(r, 0); - normalise_string(r.source()); source_counts.inc(r.source(), 0); - normalise_string(r.target()); target_counts.inc(r.target(), 0); } @@ -521,10 +411,8 @@ struct LabelledRuleConditionals: public FeatureExtractor { const RuleStatistics& info) { RuleTuple r(lhs, src, trg); rule_counts.inc_if_exists(r, info.counts.value(kCFE)); - normalise_string(r.source()); source_counts.inc_if_exists(r.source(), info.counts.value(kCFE)); - normalise_string(r.target()); target_counts.inc_if_exists(r.target(), info.counts.value(kCFE)); } @@ -535,18 +423,10 @@ struct LabelledRuleConditionals: public FeatureExtractor { SparseVector<float>* result) const { RuleTuple r(lhs, src, trg); double l_r_freq = log(rule_counts(r)); - normalise_string(r.target()); result->set_value(fid_fe, log(target_counts(r.target())) - l_r_freq); - normalise_string(r.source()); result->set_value(fid_ef, log(source_counts(r.source())) - l_r_freq); } - void normalise_string(vector<WordID>& r) const { - vector<WordID> indexes; - for (vector<WordID>::iterator it = r.begin(); it != r.end(); ++it) - if (validate_non_terminal(TD::Convert(*it))) *it = -1; - } - const int fid_fe, fid_ef; const int kCFE; RuleFreqCount rule_counts; @@ -643,9 +523,9 @@ struct LabellingShape: public FeatureExtractor { // Replace all terminals with generic -1 void map_rule(RuleTuple& r) const { for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it) - if (!validate_non_terminal(TD::Convert(*it))) *it = -1; + if (*it <= 0) *it = -1; for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it) - if (!validate_non_terminal(TD::Convert(*it))) *it = -1; + if (*it <= 0) *it = -1; } const int fid_, kCFE; @@ -758,6 +638,51 @@ struct LexProbExtractor : public FeatureExtractor { mutable LexTranslationTable table; }; +struct Featurizer { + Featurizer(const vector<boost::shared_ptr<FeatureExtractor> >& ex) : extractors(ex) { + } + void Callback1(WordID lhs, const vector<WordID>& src, const ID2RuleStatistics& trgs) { + for (ID2RuleStatistics::const_iterator it = trgs.begin(); it != trgs.end(); ++it) { + for (int i = 0; i < extractors.size(); ++i) + extractors[i]->ObserveFilteredRule(lhs, src, it->first); + } + } + void Callback2(WordID lhs, const vector<WordID>& src, const ID2RuleStatistics& trgs) { + for (ID2RuleStatistics::const_iterator it = trgs.begin(); it != trgs.end(); ++it) { + for (int i = 0; i < extractors.size(); ++i) + extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second); + } + } + void Callback3(WordID lhs, const vector<WordID>& src, const ID2RuleStatistics& trgs) { + for (ID2RuleStatistics::const_iterator it = trgs.begin(); it != trgs.end(); ++it) { + SparseVector<float> feats; + for (int i = 0; i < extractors.size(); ++i) + extractors[i]->ExtractFeatures(lhs, src, it->first, it->second, &feats); + cout << '[' << TD::Convert(-lhs) << "] ||| "; + WriteNamed(src, &cout); + cout << " ||| "; + WriteAnonymous(it->first, &cout); + cout << " ||| "; + feats.Write(false, &cout); + cout << endl; + } + } + private: + vector<boost::shared_ptr<FeatureExtractor> > extractors; +}; + +void cb1(WordID lhs, const vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void* extra) { + static_cast<Featurizer*>(extra)->Callback1(lhs, src_rhs, rules); +} + +void cb2(WordID lhs, const vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void* extra) { + static_cast<Featurizer*>(extra)->Callback2(lhs, src_rhs, rules); +} + +void cb3(WordID lhs, const vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void* extra) { + static_cast<Featurizer*>(extra)->Callback3(lhs, src_rhs, rules); +} + int main(int argc, char** argv){ FERegistry reg; reg.Register("LogRuleCount", new FEFactory<LogRuleCount>); @@ -778,65 +703,18 @@ int main(int argc, char** argv){ vector<boost::shared_ptr<FeatureExtractor> > extractors(feats.size()); for (int i = 0; i < feats.size(); ++i) extractors[i] = reg.Create(feats[i]); + Featurizer fizer(extractors); - //score unscored grammar cerr << "Reading filtered grammar to detect keys..." << endl; - char* buf = new char[MAX_LINE_LENGTH]; - - ID2RuleStatistics acc, cur_counts; - vector<WordID> key, cur_key,temp_key; - WordID lhs = 0; - vector<WordID> src; - - istream& fs1 = *fg1.stream(); - while(fs1) { - fs1.getline(buf, MAX_LINE_LENGTH); - if (buf[0] == 0) continue; - ParseLine(buf, &cur_key, &cur_counts); - src.resize(cur_key.size() - 2); - for (int i = 0; i < src.size(); ++i) src.at(i) = cur_key.at(i+2); - - lhs = cur_key[0]; - for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { - for (int i = 0; i < extractors.size(); ++i) - extractors[i]->ObserveFilteredRule(lhs, src, it->first); - } - } + StripedGrammarLexer::ReadStripedGrammar(fg1.stream(), cb1, &fizer); cerr << "Reading unfiltered grammar..." << endl; - while(cin) { - cin.getline(buf, MAX_LINE_LENGTH); - if (buf[0] == 0) continue; - ParseLine(buf, &cur_key, &cur_counts); - src.resize(cur_key.size() - 2); - for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; - lhs = cur_key[0]; - for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { - for (int i = 0; i < extractors.size(); ++i) - extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second); - } - } + StripedGrammarLexer::ReadStripedGrammar(&cin, cb2, &fizer); ReadFile fg2(conf["filtered_grammar"].as<string>()); - istream& fs2 = *fg2.stream(); cerr << "Reading filtered grammar and adding features..." << endl; - while(fs2) { - fs2.getline(buf, MAX_LINE_LENGTH); - if (buf[0] == 0) continue; - ParseLine(buf, &cur_key, &cur_counts); - src.resize(cur_key.size() - 2); - for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; - lhs = cur_key[0]; - - //loop over all the Target side phrases that this source aligns to - for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { - SparseVector<float> feats; - for (int i = 0; i < extractors.size(); ++i) - extractors[i]->ExtractFeatures(lhs, src, it->first, it->second, &feats); - cout << TD::Convert(lhs) << " ||| " << TD::GetString(src) << " ||| " << TD::GetString(it->first) << " ||| "; - feats.Write(false, &cout); - cout << endl; - } - } + StripedGrammarLexer::ReadStripedGrammar(fg2.stream(), cb3, &fizer); + + return 0; } diff --git a/extools/lex_trans_tbl.h b/extools/lex_trans_tbl.h index 81d6ccc9..161b4a0d 100644 --- a/extools/lex_trans_tbl.h +++ b/extools/lex_trans_tbl.h @@ -8,6 +8,7 @@ #ifndef LEX_TRANS_TBL_H_ #define LEX_TRANS_TBL_H_ +#include "wordid.h" #include <map> class LexTranslationTable |