summaryrefslogtreecommitdiff
path: root/extools/mr_stripe_rule_reduce.cc
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-19 18:57:02 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-19 18:57:02 +0000
commitbcf8d448430312fcf6270e3ba2e304ac58650312 (patch)
treee4c6c9dd12ec55d2d6b6606e8b5b5b14b5d95c43 /extools/mr_stripe_rule_reduce.cc
parent49e4f80136dd573c8b08c06426724de2d51bb784 (diff)
use lexer instead of handwritten parser
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@319 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'extools/mr_stripe_rule_reduce.cc')
-rw-r--r--extools/mr_stripe_rule_reduce.cc150
1 files changed, 39 insertions, 111 deletions
diff --git a/extools/mr_stripe_rule_reduce.cc b/extools/mr_stripe_rule_reduce.cc
index 3298a801..8332a106 100644
--- a/extools/mr_stripe_rule_reduce.cc
+++ b/extools/mr_stripe_rule_reduce.cc
@@ -8,11 +8,11 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include "striped_grammar.h"
#include "tdict.h"
#include "sentence_pair.h"
#include "fdict.h"
#include "extract.h"
-#include "striped_grammar.h"
using namespace std;
using namespace std::tr1;
@@ -22,13 +22,6 @@ static const size_t MAX_LINE_LENGTH = 64000000;
bool use_hadoop_counters = false;
-namespace {
- inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; }
-
- inline void SkipWhitespace(const char* buf, int* ptr) {
- while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); }
- }
-}
void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
@@ -50,8 +43,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
}
}
-typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics;
-
void PlusEquals(const ID2RuleStatistics& v, ID2RuleStatistics* self) {
for (ID2RuleStatistics::const_iterator it = v.begin(); it != v.end(); ++it) {
RuleStatistics& dest = (*self)[it->first];
@@ -62,79 +53,6 @@ void PlusEquals(const ID2RuleStatistics& v, ID2RuleStatistics* self) {
}
}
-int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) {
- static const WordID kDIV = TD::Convert("|||");
- int ptr = sstart;
- while(ptr < end) {
- while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; }
- int start = ptr;
- while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; }
- if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; }
- const WordID w = TD::Convert(string(buf, start, ptr - start));
- if (w == kDIV) return ptr;
- p->push_back(w);
- }
- assert(p->size() > 0);
- return ptr;
-}
-
-void ParseLine(const char* buf, vector<WordID>* cur_key, ID2RuleStatistics* counts) {
- static const WordID kDIV = TD::Convert("|||");
- counts->clear();
- int ptr = 0;
- while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; }
- if (buf[ptr] != '\t') {
- cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl;
- exit(1);
- }
- cur_key->clear();
- // key is: "[X] ||| word word word"
- int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key);
- if (buf[tmpp] != '\t') {
- cur_key->push_back(kDIV);
- ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key);
- }
- ++ptr;
- int start = ptr;
- int end = ptr;
- int state = 0; // 0=reading label, 1=reading count
- vector<WordID> name;
- while(buf[ptr] != 0) {
- while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; }
- if (buf[ptr] == '|') {
- ++ptr;
- if (buf[ptr] == '|') {
- ++ptr;
- if (buf[ptr] == '|') {
- ++ptr;
- end = ptr - 3;
- while (end > start && IsWhitespace(buf[end-1])) { --end; }
- if (start == end) {
- cerr << "Got empty token!\n LINE=" << buf << endl;
- exit(1);
- }
- switch (state) {
- case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break;
- case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break;
- default: cerr << "Can't happen\n"; abort();
- }
- SkipWhitespace(buf, &ptr);
- start = ptr;
- }
- }
- }
- }
- end=ptr;
- while (end > start && IsWhitespace(buf[end-1])) { --end; }
- if (end > start) {
- switch (state) {
- case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break;
- case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break;
- default: cerr << "Can't happen\n"; abort();
- }
- }
-}
-
void WriteKeyValue(const vector<WordID>& key, const ID2RuleStatistics& val) {
cout << TD::GetString(key) << '\t';
bool needdiv = false;
@@ -201,44 +119,54 @@ void WriteWithInversions(const vector<WordID>& key, const ID2RuleStatistics& val
}
}
+struct Reducer {
+ Reducer(bool phrase_marginals, bool bidir) : pm_(phrase_marginals), bidir_(bidir) {}
+
+ void ProcessLine(const vector<WordID>& key, const ID2RuleStatistics& rules) {
+ if (cur_key_ != key) {
+ if (cur_key_.size() > 0) Emit();
+ acc_.clear();
+ cur_key_ = key;
+ }
+ PlusEquals(rules, &acc_);
+ }
+
+ ~Reducer() {
+ Emit();
+ }
+
+ void Emit() {
+ if (pm_)
+ DoPhraseMarginals(cur_key_, bidir_, &acc_);
+ if (bidir_)
+ WriteWithInversions(cur_key_, acc_);
+ else
+ WriteKeyValue(cur_key_, acc_);
+ }
+
+ const bool pm_;
+ const bool bidir_;
+ vector<WordID> cur_key_;
+ ID2RuleStatistics acc_;
+};
+
+void cb(const vector<WordID>& key, const ID2RuleStatistics& contexts, void* red) {
+ static_cast<Reducer*>(red)->ProcessLine(key, contexts);
+}
+
+
int main(int argc, char** argv) {
po::variables_map conf;
InitCommandLine(argc, argv, &conf);
char* buf = new char[MAX_LINE_LENGTH];
- ID2RuleStatistics acc, cur_counts;
vector<WordID> key, cur_key;
int line = 0;
use_hadoop_counters = conf.count("use_hadoop_counters") > 0;
const bool phrase_marginals = conf.count("phrase_marginals") > 0;
const bool bidir = conf.count("bidir") > 0;
- while(cin) {
- ++line;
- cin.getline(buf, MAX_LINE_LENGTH);
- if (buf[0] == 0) continue;
- ParseLine(buf, &cur_key, &cur_counts);
- if (cur_key != key) {
- if (key.size() > 0) {
- if (phrase_marginals)
- DoPhraseMarginals(key, bidir, &acc);
- if (bidir)
- WriteWithInversions(key, acc);
- else
- WriteKeyValue(key, acc);
- acc.clear();
- }
- key = cur_key;
- }
- PlusEquals(cur_counts, &acc);
- }
- if (key.size() > 0) {
- if (phrase_marginals)
- DoPhraseMarginals(key, bidir, &acc);
- if (bidir)
- WriteWithInversions(key, acc);
- else
- WriteKeyValue(key, acc);
- }
+ Reducer reducer(phrase_marginals, bidir);
+ StripedGrammarLexer::ReadContexts(&cin, cb, &reducer);
return 0;
}