From 2eeaa2eb91334bea11d70db1011f1a28ce3bb7d2 Mon Sep 17 00:00:00 2001 From: redpony Date: Tue, 13 Jul 2010 16:15:33 +0000 Subject: major speed up using DFA parser git-svn-id: https://ws10smt.googlecode.com/svn/trunk@235 ec762483-ff6d-05da-a07a-a48fb63a330f --- extools/featurize_grammar.cc | 236 +++++++++++-------------------------------- extools/lex_trans_tbl.h | 1 + 2 files changed, 58 insertions(+), 179 deletions(-) diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc index 547f390a..9a4af4d8 100644 --- a/extools/featurize_grammar.cc +++ b/extools/featurize_grammar.cc @@ -5,21 +5,17 @@ #include #include #include -#include #include #include #include -#include #include -#include -#include "suffix_tree.h" +#include "lex_trans_tbl.h" #include "sparse_vector.h" #include "sentence_pair.h" #include "extract.h" #include "fdict.h" #include "tdict.h" -#include "lex_trans_tbl.h" #include "filelib.h" #include "striped_grammar.h" @@ -29,7 +25,6 @@ #include #include - using namespace std; using namespace std::tr1; using boost::shared_ptr; @@ -38,8 +33,6 @@ namespace po = boost::program_options; static string aligned_corpus; static const size_t MAX_LINE_LENGTH = 64000000; -typedef unordered_map, RuleStatistics, boost::hash > > ID2RuleStatistics; - // Data structures for indexing and counting rules //typedef boost::tuple< WordID, vector, vector > RuleTuple; struct RuleTuple { @@ -130,20 +123,6 @@ struct FreqCount { }; typedef FreqCount RuleFreqCount; -bool validate_non_terminal(const std::string& s) -{ - static const boost::regex r("\\[X\\d+,\\d+\\]|\\[\\d+\\]"); - return regex_match(s, r); -} - -namespace { - inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } - inline void SkipWhitespace(const char* buf, int* ptr) { - while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } - } -} - - class FeatureExtractor; class FERegistry; struct FEFactoryBase { @@ -220,78 +199,7 @@ void InitCommandLine(const FERegistry& r, int argc, char** argv, po::variables_m } } -int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector* p) { - static const WordID kDIV = TD::Convert("|||"); - int ptr = sstart; - while(ptr < end) { - while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; } - int start = ptr; - while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } - if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; } - const WordID w = TD::Convert(string(buf, start, ptr - start)); - if (w == kDIV) return ptr; - p->push_back(w); - } - assert(p->size() > 0); - return ptr; -} - -void ParseLine(const char* buf, vector* cur_key, ID2RuleStatistics* counts) { - static const WordID kDIV = TD::Convert("|||"); - counts->clear(); - int ptr = 0; - while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; } - if (buf[ptr] != '\t') { - cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl; - exit(1); - } - cur_key->clear(); - // key is: "[X] ||| word word word" - int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key); - if (buf[tmpp] != '\t') { - cur_key->push_back(kDIV); - ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key); - } - ++ptr; - int start = ptr; - int end = ptr; - int state = 0; // 0=reading label, 1=reading count - vector name; - while(buf[ptr] != 0) { - while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; } - if (buf[ptr] == '|') { - ++ptr; - if (buf[ptr] == '|') { - ++ptr; - if (buf[ptr] == '|') { - ++ptr; - end = ptr - 3; - while (end > start && IsWhitespace(buf[end-1])) { --end; } - if (start == end) { - cerr << "Got empty token!\n LINE=" << buf << endl; - exit(1); - } - switch (state) { - case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; - case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; - default: cerr << "Can't happen\n"; abort(); - } - SkipWhitespace(buf, &ptr); - start = ptr; - } - } - } - } - end=ptr; - while (end > start && IsWhitespace(buf[end-1])) { --end; } - if (end > start) { - switch (state) { - case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; - case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; - default: cerr << "Can't happen\n"; abort(); - } - } -} +static const bool DEBUG = false; void LexTranslationTable::createTTable(const char* buf){ AnnotatedParallelSentence sent; @@ -431,11 +339,7 @@ struct XFeatures: public FeatureExtractor { RuleTuple r(-1, src, trg); map_rule(r); rule_counts.inc(r, 0); - - normalise_string(r.source()); source_counts.inc(r.source(), 0); - - normalise_string(r.target()); target_counts.inc(r.target(), 0); } @@ -446,11 +350,7 @@ struct XFeatures: public FeatureExtractor { RuleTuple r(-1, src, trg); map_rule(r); rule_counts.inc_if_exists(r, info.counts.value(kCFE)); - - normalise_string(r.source()); source_counts.inc_if_exists(r.source(), info.counts.value(kCFE)); - - normalise_string(r.target()); target_counts.inc_if_exists(r.target(), info.counts.value(kCFE)); } @@ -463,11 +363,9 @@ struct XFeatures: public FeatureExtractor { map_rule(r); double l_r_freq = log(rule_counts(r)); - normalise_string(r.target()); result->set_value(fid_xfe, log(target_counts(r.target())) - l_r_freq); result->set_value(fid_labelledfe, log(target_counts(r.target())) - log(info.counts.value(kCFE))); - normalise_string(r.source()); result->set_value(fid_xef, log(source_counts(r.source())) - l_r_freq); result->set_value(fid_labelledef, log(source_counts(r.source())) - log(info.counts.value(kCFE))); } @@ -475,21 +373,15 @@ struct XFeatures: public FeatureExtractor { void map_rule(RuleTuple& r) const { vector indexes; int i=0; for (vector::iterator it = r.target().begin(); it != r.target().end(); ++it) { - if (validate_non_terminal(TD::Convert(*it))) + if (*it <= 0) indexes.push_back(*it); } for (vector::iterator it = r.source().begin(); it != r.source().end(); ++it) { - if (validate_non_terminal(TD::Convert(*it))) + if (*it <= 0) *it = indexes.at(i++); } } - void normalise_string(vector& r) const { - vector indexes; - for (vector::iterator it = r.begin(); it != r.end(); ++it) - if (validate_non_terminal(TD::Convert(*it))) *it = -1; - } - const int fid_xfe, fid_xef; const int fid_labelledfe, fid_labelledef; const int kCFE; @@ -508,10 +400,8 @@ struct LabelledRuleConditionals: public FeatureExtractor { const vector& trg) { RuleTuple r(lhs, src, trg); rule_counts.inc(r, 0); - normalise_string(r.source()); source_counts.inc(r.source(), 0); - normalise_string(r.target()); target_counts.inc(r.target(), 0); } @@ -521,10 +411,8 @@ struct LabelledRuleConditionals: public FeatureExtractor { const RuleStatistics& info) { RuleTuple r(lhs, src, trg); rule_counts.inc_if_exists(r, info.counts.value(kCFE)); - normalise_string(r.source()); source_counts.inc_if_exists(r.source(), info.counts.value(kCFE)); - normalise_string(r.target()); target_counts.inc_if_exists(r.target(), info.counts.value(kCFE)); } @@ -535,18 +423,10 @@ struct LabelledRuleConditionals: public FeatureExtractor { SparseVector* result) const { RuleTuple r(lhs, src, trg); double l_r_freq = log(rule_counts(r)); - normalise_string(r.target()); result->set_value(fid_fe, log(target_counts(r.target())) - l_r_freq); - normalise_string(r.source()); result->set_value(fid_ef, log(source_counts(r.source())) - l_r_freq); } - void normalise_string(vector& r) const { - vector indexes; - for (vector::iterator it = r.begin(); it != r.end(); ++it) - if (validate_non_terminal(TD::Convert(*it))) *it = -1; - } - const int fid_fe, fid_ef; const int kCFE; RuleFreqCount rule_counts; @@ -643,9 +523,9 @@ struct LabellingShape: public FeatureExtractor { // Replace all terminals with generic -1 void map_rule(RuleTuple& r) const { for (vector::iterator it = r.target().begin(); it != r.target().end(); ++it) - if (!validate_non_terminal(TD::Convert(*it))) *it = -1; + if (*it <= 0) *it = -1; for (vector::iterator it = r.source().begin(); it != r.source().end(); ++it) - if (!validate_non_terminal(TD::Convert(*it))) *it = -1; + if (*it <= 0) *it = -1; } const int fid_, kCFE; @@ -758,6 +638,51 @@ struct LexProbExtractor : public FeatureExtractor { mutable LexTranslationTable table; }; +struct Featurizer { + Featurizer(const vector >& ex) : extractors(ex) { + } + void Callback1(WordID lhs, const vector& src, const ID2RuleStatistics& trgs) { + for (ID2RuleStatistics::const_iterator it = trgs.begin(); it != trgs.end(); ++it) { + for (int i = 0; i < extractors.size(); ++i) + extractors[i]->ObserveFilteredRule(lhs, src, it->first); + } + } + void Callback2(WordID lhs, const vector& src, const ID2RuleStatistics& trgs) { + for (ID2RuleStatistics::const_iterator it = trgs.begin(); it != trgs.end(); ++it) { + for (int i = 0; i < extractors.size(); ++i) + extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second); + } + } + void Callback3(WordID lhs, const vector& src, const ID2RuleStatistics& trgs) { + for (ID2RuleStatistics::const_iterator it = trgs.begin(); it != trgs.end(); ++it) { + SparseVector feats; + for (int i = 0; i < extractors.size(); ++i) + extractors[i]->ExtractFeatures(lhs, src, it->first, it->second, &feats); + cout << '[' << TD::Convert(-lhs) << "] ||| "; + WriteNamed(src, &cout); + cout << " ||| "; + WriteAnonymous(it->first, &cout); + cout << " ||| "; + feats.Write(false, &cout); + cout << endl; + } + } + private: + vector > extractors; +}; + +void cb1(WordID lhs, const vector& src_rhs, const ID2RuleStatistics& rules, void* extra) { + static_cast(extra)->Callback1(lhs, src_rhs, rules); +} + +void cb2(WordID lhs, const vector& src_rhs, const ID2RuleStatistics& rules, void* extra) { + static_cast(extra)->Callback2(lhs, src_rhs, rules); +} + +void cb3(WordID lhs, const vector& src_rhs, const ID2RuleStatistics& rules, void* extra) { + static_cast(extra)->Callback3(lhs, src_rhs, rules); +} + int main(int argc, char** argv){ FERegistry reg; reg.Register("LogRuleCount", new FEFactory); @@ -778,65 +703,18 @@ int main(int argc, char** argv){ vector > extractors(feats.size()); for (int i = 0; i < feats.size(); ++i) extractors[i] = reg.Create(feats[i]); + Featurizer fizer(extractors); - //score unscored grammar cerr << "Reading filtered grammar to detect keys..." << endl; - char* buf = new char[MAX_LINE_LENGTH]; - - ID2RuleStatistics acc, cur_counts; - vector key, cur_key,temp_key; - WordID lhs = 0; - vector src; - - istream& fs1 = *fg1.stream(); - while(fs1) { - fs1.getline(buf, MAX_LINE_LENGTH); - if (buf[0] == 0) continue; - ParseLine(buf, &cur_key, &cur_counts); - src.resize(cur_key.size() - 2); - for (int i = 0; i < src.size(); ++i) src.at(i) = cur_key.at(i+2); - - lhs = cur_key[0]; - for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { - for (int i = 0; i < extractors.size(); ++i) - extractors[i]->ObserveFilteredRule(lhs, src, it->first); - } - } + StripedGrammarLexer::ReadStripedGrammar(fg1.stream(), cb1, &fizer); cerr << "Reading unfiltered grammar..." << endl; - while(cin) { - cin.getline(buf, MAX_LINE_LENGTH); - if (buf[0] == 0) continue; - ParseLine(buf, &cur_key, &cur_counts); - src.resize(cur_key.size() - 2); - for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; - lhs = cur_key[0]; - for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { - for (int i = 0; i < extractors.size(); ++i) - extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second); - } - } + StripedGrammarLexer::ReadStripedGrammar(&cin, cb2, &fizer); ReadFile fg2(conf["filtered_grammar"].as()); - istream& fs2 = *fg2.stream(); cerr << "Reading filtered grammar and adding features..." << endl; - while(fs2) { - fs2.getline(buf, MAX_LINE_LENGTH); - if (buf[0] == 0) continue; - ParseLine(buf, &cur_key, &cur_counts); - src.resize(cur_key.size() - 2); - for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; - lhs = cur_key[0]; - - //loop over all the Target side phrases that this source aligns to - for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { - SparseVector feats; - for (int i = 0; i < extractors.size(); ++i) - extractors[i]->ExtractFeatures(lhs, src, it->first, it->second, &feats); - cout << TD::Convert(lhs) << " ||| " << TD::GetString(src) << " ||| " << TD::GetString(it->first) << " ||| "; - feats.Write(false, &cout); - cout << endl; - } - } + StripedGrammarLexer::ReadStripedGrammar(fg2.stream(), cb3, &fizer); + + return 0; } diff --git a/extools/lex_trans_tbl.h b/extools/lex_trans_tbl.h index 81d6ccc9..161b4a0d 100644 --- a/extools/lex_trans_tbl.h +++ b/extools/lex_trans_tbl.h @@ -8,6 +8,7 @@ #ifndef LEX_TRANS_TBL_H_ #define LEX_TRANS_TBL_H_ +#include "wordid.h" #include class LexTranslationTable -- cgit v1.2.3