diff options
Diffstat (limited to 'extools/mr_stripe_rule_reduce.cc')
-rw-r--r-- | extools/mr_stripe_rule_reduce.cc | 150 |
1 files changed, 39 insertions, 111 deletions
diff --git a/extools/mr_stripe_rule_reduce.cc b/extools/mr_stripe_rule_reduce.cc index 3298a801..8332a106 100644 --- a/extools/mr_stripe_rule_reduce.cc +++ b/extools/mr_stripe_rule_reduce.cc @@ -8,11 +8,11 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> +#include "striped_grammar.h" #include "tdict.h" #include "sentence_pair.h" #include "fdict.h" #include "extract.h" -#include "striped_grammar.h" using namespace std; using namespace std::tr1; @@ -22,13 +22,6 @@ static const size_t MAX_LINE_LENGTH = 64000000; bool use_hadoop_counters = false; -namespace { - inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } - - inline void SkipWhitespace(const char* buf, int* ptr) { - while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } - } -} void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() @@ -50,8 +43,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } -typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics; - void PlusEquals(const ID2RuleStatistics& v, ID2RuleStatistics* self) { for (ID2RuleStatistics::const_iterator it = v.begin(); it != v.end(); ++it) { RuleStatistics& dest = (*self)[it->first]; @@ -62,79 +53,6 @@ void PlusEquals(const ID2RuleStatistics& v, ID2RuleStatistics* self) { } } -int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) { - static const WordID kDIV = TD::Convert("|||"); - int ptr = sstart; - while(ptr < end) { - while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; } - int start = ptr; - while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } - if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; } - const WordID w = TD::Convert(string(buf, start, ptr - start)); - if (w == kDIV) return ptr; - p->push_back(w); - } - assert(p->size() > 0); - return ptr; -} - -void ParseLine(const char* buf, vector<WordID>* cur_key, ID2RuleStatistics* counts) { - static const WordID kDIV = TD::Convert("|||"); - counts->clear(); - int ptr = 0; - while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; } - if (buf[ptr] != '\t') { - cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl; - exit(1); - } - cur_key->clear(); - // key is: "[X] ||| word word word" - int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key); - if (buf[tmpp] != '\t') { - cur_key->push_back(kDIV); - ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key); - } - ++ptr; - int start = ptr; - int end = ptr; - int state = 0; // 0=reading label, 1=reading count - vector<WordID> name; - while(buf[ptr] != 0) { - while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; } - if (buf[ptr] == '|') { - ++ptr; - if (buf[ptr] == '|') { - ++ptr; - if (buf[ptr] == '|') { - ++ptr; - end = ptr - 3; - while (end > start && IsWhitespace(buf[end-1])) { --end; } - if (start == end) { - cerr << "Got empty token!\n LINE=" << buf << endl; - exit(1); - } - switch (state) { - case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; - case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; - default: cerr << "Can't happen\n"; abort(); - } - SkipWhitespace(buf, &ptr); - start = ptr; - } - } - } - } - end=ptr; - while (end > start && IsWhitespace(buf[end-1])) { --end; } - if (end > start) { - switch (state) { - case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; - case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; - default: cerr << "Can't happen\n"; abort(); - } - } -} - void WriteKeyValue(const vector<WordID>& key, const ID2RuleStatistics& val) { cout << TD::GetString(key) << '\t'; bool needdiv = false; @@ -201,44 +119,54 @@ void WriteWithInversions(const vector<WordID>& key, const ID2RuleStatistics& val } } +struct Reducer { + Reducer(bool phrase_marginals, bool bidir) : pm_(phrase_marginals), bidir_(bidir) {} + + void ProcessLine(const vector<WordID>& key, const ID2RuleStatistics& rules) { + if (cur_key_ != key) { + if (cur_key_.size() > 0) Emit(); + acc_.clear(); + cur_key_ = key; + } + PlusEquals(rules, &acc_); + } + + ~Reducer() { + Emit(); + } + + void Emit() { + if (pm_) + DoPhraseMarginals(cur_key_, bidir_, &acc_); + if (bidir_) + WriteWithInversions(cur_key_, acc_); + else + WriteKeyValue(cur_key_, acc_); + } + + const bool pm_; + const bool bidir_; + vector<WordID> cur_key_; + ID2RuleStatistics acc_; +}; + +void cb(const vector<WordID>& key, const ID2RuleStatistics& contexts, void* red) { + static_cast<Reducer*>(red)->ProcessLine(key, contexts); +} + + int main(int argc, char** argv) { po::variables_map conf; InitCommandLine(argc, argv, &conf); char* buf = new char[MAX_LINE_LENGTH]; - ID2RuleStatistics acc, cur_counts; vector<WordID> key, cur_key; int line = 0; use_hadoop_counters = conf.count("use_hadoop_counters") > 0; const bool phrase_marginals = conf.count("phrase_marginals") > 0; const bool bidir = conf.count("bidir") > 0; - while(cin) { - ++line; - cin.getline(buf, MAX_LINE_LENGTH); - if (buf[0] == 0) continue; - ParseLine(buf, &cur_key, &cur_counts); - if (cur_key != key) { - if (key.size() > 0) { - if (phrase_marginals) - DoPhraseMarginals(key, bidir, &acc); - if (bidir) - WriteWithInversions(key, acc); - else - WriteKeyValue(key, acc); - acc.clear(); - } - key = cur_key; - } - PlusEquals(cur_counts, &acc); - } - if (key.size() > 0) { - if (phrase_marginals) - DoPhraseMarginals(key, bidir, &acc); - if (bidir) - WriteWithInversions(key, acc); - else - WriteKeyValue(key, acc); - } + Reducer reducer(phrase_marginals, bidir); + StripedGrammarLexer::ReadContexts(&cin, cb, &reducer); return 0; } |