diff options
-rw-r--r-- | extools/Makefile.am | 2 | ||||
-rw-r--r-- | extools/mr_stripe_rule_reduce.cc | 150 | ||||
-rw-r--r-- | extools/sentence_pair.cc | 4 | ||||
-rw-r--r-- | extools/sg_lexer.l | 83 | ||||
-rw-r--r-- | extools/striped_grammar.h | 2 |
5 files changed, 112 insertions, 129 deletions
diff --git a/extools/Makefile.am b/extools/Makefile.am index 807fe7d6..562599a3 100644 --- a/extools/Makefile.am +++ b/extools/Makefile.am @@ -28,7 +28,7 @@ build_lexical_translation_SOURCES = build_lexical_translation.cc extract.cc sent build_lexical_translation_LDADD = $(top_srcdir)/decoder/libcdec.a -lz build_lexical_translation_LDFLAGS = -all-static -mr_stripe_rule_reduce_SOURCES = mr_stripe_rule_reduce.cc extract.cc sentence_pair.cc striped_grammar.cc +mr_stripe_rule_reduce_SOURCES = mr_stripe_rule_reduce.cc extract.cc sentence_pair.cc striped_grammar.cc sg_lexer.cc mr_stripe_rule_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a -lz mr_stripe_rule_reduce_LDFLAGS = -all-static diff --git a/extools/mr_stripe_rule_reduce.cc b/extools/mr_stripe_rule_reduce.cc index 3298a801..8332a106 100644 --- a/extools/mr_stripe_rule_reduce.cc +++ b/extools/mr_stripe_rule_reduce.cc @@ -8,11 +8,11 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> +#include "striped_grammar.h" #include "tdict.h" #include "sentence_pair.h" #include "fdict.h" #include "extract.h" -#include "striped_grammar.h" using namespace std; using namespace std::tr1; @@ -22,13 +22,6 @@ static const size_t MAX_LINE_LENGTH = 64000000; bool use_hadoop_counters = false; -namespace { - inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } - - inline void SkipWhitespace(const char* buf, int* ptr) { - while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } - } -} void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() @@ -50,8 +43,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } -typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics; - void PlusEquals(const ID2RuleStatistics& v, ID2RuleStatistics* self) { for (ID2RuleStatistics::const_iterator it = v.begin(); it != v.end(); ++it) { RuleStatistics& dest = (*self)[it->first]; @@ -62,79 +53,6 @@ void PlusEquals(const ID2RuleStatistics& v, ID2RuleStatistics* self) { } } -int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) { - static const WordID kDIV = TD::Convert("|||"); - int ptr = sstart; - while(ptr < end) { - while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; } - int start = ptr; - while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } - if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; } - const WordID w = TD::Convert(string(buf, start, ptr - start)); - if (w == kDIV) return ptr; - p->push_back(w); - } - assert(p->size() > 0); - return ptr; -} - -void ParseLine(const char* buf, vector<WordID>* cur_key, ID2RuleStatistics* counts) { - static const WordID kDIV = TD::Convert("|||"); - counts->clear(); - int ptr = 0; - while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; } - if (buf[ptr] != '\t') { - cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl; - exit(1); - } - cur_key->clear(); - // key is: "[X] ||| word word word" - int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key); - if (buf[tmpp] != '\t') { - cur_key->push_back(kDIV); - ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key); - } - ++ptr; - int start = ptr; - int end = ptr; - int state = 0; // 0=reading label, 1=reading count - vector<WordID> name; - while(buf[ptr] != 0) { - while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; } - if (buf[ptr] == '|') { - ++ptr; - if (buf[ptr] == '|') { - ++ptr; - if (buf[ptr] == '|') { - ++ptr; - end = ptr - 3; - while (end > start && IsWhitespace(buf[end-1])) { --end; } - if (start == end) { - cerr << "Got empty token!\n LINE=" << buf << endl; - exit(1); - } - switch (state) { - case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; - case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; - default: cerr << "Can't happen\n"; abort(); - } - SkipWhitespace(buf, &ptr); - start = ptr; - } - } - } - } - end=ptr; - while (end > start && IsWhitespace(buf[end-1])) { --end; } - if (end > start) { - switch (state) { - case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; - case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; - default: cerr << "Can't happen\n"; abort(); - } - } -} - void WriteKeyValue(const vector<WordID>& key, const ID2RuleStatistics& val) { cout << TD::GetString(key) << '\t'; bool needdiv = false; @@ -201,44 +119,54 @@ void WriteWithInversions(const vector<WordID>& key, const ID2RuleStatistics& val } } +struct Reducer { + Reducer(bool phrase_marginals, bool bidir) : pm_(phrase_marginals), bidir_(bidir) {} + + void ProcessLine(const vector<WordID>& key, const ID2RuleStatistics& rules) { + if (cur_key_ != key) { + if (cur_key_.size() > 0) Emit(); + acc_.clear(); + cur_key_ = key; + } + PlusEquals(rules, &acc_); + } + + ~Reducer() { + Emit(); + } + + void Emit() { + if (pm_) + DoPhraseMarginals(cur_key_, bidir_, &acc_); + if (bidir_) + WriteWithInversions(cur_key_, acc_); + else + WriteKeyValue(cur_key_, acc_); + } + + const bool pm_; + const bool bidir_; + vector<WordID> cur_key_; + ID2RuleStatistics acc_; +}; + +void cb(const vector<WordID>& key, const ID2RuleStatistics& contexts, void* red) { + static_cast<Reducer*>(red)->ProcessLine(key, contexts); +} + + int main(int argc, char** argv) { po::variables_map conf; InitCommandLine(argc, argv, &conf); char* buf = new char[MAX_LINE_LENGTH]; - ID2RuleStatistics acc, cur_counts; vector<WordID> key, cur_key; int line = 0; use_hadoop_counters = conf.count("use_hadoop_counters") > 0; const bool phrase_marginals = conf.count("phrase_marginals") > 0; const bool bidir = conf.count("bidir") > 0; - while(cin) { - ++line; - cin.getline(buf, MAX_LINE_LENGTH); - if (buf[0] == 0) continue; - ParseLine(buf, &cur_key, &cur_counts); - if (cur_key != key) { - if (key.size() > 0) { - if (phrase_marginals) - DoPhraseMarginals(key, bidir, &acc); - if (bidir) - WriteWithInversions(key, acc); - else - WriteKeyValue(key, acc); - acc.clear(); - } - key = cur_key; - } - PlusEquals(cur_counts, &acc); - } - if (key.size() > 0) { - if (phrase_marginals) - DoPhraseMarginals(key, bidir, &acc); - if (bidir) - WriteWithInversions(key, acc); - else - WriteKeyValue(key, acc); - } + Reducer reducer(phrase_marginals, bidir); + StripedGrammarLexer::ReadContexts(&cin, cb, &reducer); return 0; } diff --git a/extools/sentence_pair.cc b/extools/sentence_pair.cc index b2881737..4cbcc98e 100644 --- a/extools/sentence_pair.cc +++ b/extools/sentence_pair.cc @@ -72,8 +72,8 @@ int AnnotatedParallelSentence::ReadAlignmentPoint(const char* buf, } (*b) = 0; while(ch < end && (c == 0 && (!permit_col || (permit_col && buf[ch] != ':')) || c != 0 && buf[ch] != '-')) { - if (buf[ch] < '0' || buf[ch] > '9') { - cerr << "Alignment point badly formed 4: " << string(buf, start, end-start) << endl << buf << endl; + if ((buf[ch] < '0') || (buf[ch] > '9')) { + cerr << "Alignment point badly formed 4: " << string(buf, start, end-start) << endl << buf << endl << buf[ch] << endl; exit(1); } (*b) *= 10; diff --git a/extools/sg_lexer.l b/extools/sg_lexer.l index f115e5bd..f82e8135 100644 --- a/extools/sg_lexer.l +++ b/extools/sg_lexer.l @@ -12,9 +12,12 @@ #include "striped_grammar.h" int lex_line = 0; +int read_contexts = 0; std::istream* sglex_stream = NULL; StripedGrammarLexer::GrammarCallback grammar_callback = NULL; +StripedGrammarLexer::ContextCallback context_callback = NULL; void* grammar_callback_extra = NULL; +void* context_callback_extra = NULL; #undef YY_INPUT #define YY_INPUT(buf, result, max_size) (result = sglex_stream->read(buf, max_size).gcount()) @@ -83,12 +86,39 @@ ALIGN [0-9]+-[0-9]+ %% <INITIAL>[ ] ; +<INITIAL>[\t] { + if (read_contexts) { + cur_options.clear(); + BEGIN(TRG); + } else { + std::cerr << "Unexpected tab while reading striped grammar\n"; + exit(1); + } + } <INITIAL>\[{NT}\] { - sglex_tmp_token.assign(yytext + 1, yyleng - 2); - sglex_lhs = -TD::Convert(sglex_tmp_token); - // std::cerr << sglex_tmp_token << "\n"; - BEGIN(LHS_END); + if (read_contexts) { + sglex_tmp_token.assign(yytext, yyleng); + sglex_src_rhs[sglex_src_rhs_size] = TD::Convert(sglex_tmp_token); + ++sglex_src_rhs_size; + } else { + sglex_tmp_token.assign(yytext + 1, yyleng - 2); + sglex_lhs = -TD::Convert(sglex_tmp_token); + // std::cerr << sglex_tmp_token << "\n"; + BEGIN(LHS_END); + } + } + +<INITIAL>[^ \t]+ { + if (read_contexts) { + // std::cerr << "Context: " << yytext << std::endl; + sglex_tmp_token.assign(yytext, yyleng); + sglex_src_rhs[sglex_src_rhs_size] = TD::Convert(sglex_tmp_token); + ++sglex_src_rhs_size; + } else { + std::cerr << "Unexpected input: " << yytext << " when NT expected\n"; + exit(1); + } } <SRC>\[{NT}\] { @@ -103,7 +133,8 @@ ALIGN [0-9]+-[0-9]+ sglex_reset(); BEGIN(SRC); } -<INITIAL,LHS_END>. { + +<LHS_END>. { std::cerr << "Line " << lex_line << ": unexpected input in LHS: " << yytext << std::endl; exit(1); } @@ -136,21 +167,27 @@ ALIGN [0-9]+-[0-9]+ //std::cerr << "LHS=" << TD::Convert(-sglex_lhs) << " "; //std::cerr << " src_size: " << sglex_src_rhs_size << std::endl; //std::cerr << " src_arity: " << sglex_src_arity << std::endl; - memset(sglex_nt_sanity, 0, sglex_src_arity * sizeof(int)); cur_options.clear(); + memset(sglex_nt_sanity, 0, sglex_src_arity * sizeof(int)); sglex_trg_rhs_size = 0; BEGIN(TRG); } <TRG>\[[1-9][0-9]?\] { - int index = yytext[yyleng - 2] - '0'; - if (yyleng == 4) { - index += 10 * (yytext[yyleng - 3] - '0'); + if (read_contexts) { + sglex_tmp_token.assign(yytext, yyleng); + sglex_trg_rhs[sglex_trg_rhs_size] = TD::Convert(sglex_tmp_token); + ++sglex_trg_rhs_size; + } else { + int index = yytext[yyleng - 2] - '0'; + if (yyleng == 4) { + index += 10 * (yytext[yyleng - 3] - '0'); + } + ++sglex_trg_arity; + sanity_check_trg_index(index); + sglex_trg_rhs[sglex_trg_rhs_size] = 1 - index; + ++sglex_trg_rhs_size; } - ++sglex_trg_arity; - sanity_check_trg_index(index); - sglex_trg_rhs[sglex_trg_rhs_size] = 1 - index; - ++sglex_trg_rhs_size; } <TRG>\|\|\| { @@ -171,13 +208,18 @@ ALIGN [0-9]+-[0-9]+ <TRG>[ ]+ { ; } <FEATS>\n { - assert(sglex_lhs < 0); assert(sglex_src_rhs_size > 0); cur_src_rhs.resize(sglex_src_rhs_size); for (int i = 0; i < sglex_src_rhs_size; ++i) cur_src_rhs[i] = sglex_src_rhs[i]; - grammar_callback(sglex_lhs, cur_src_rhs, cur_options, grammar_callback_extra); + if (read_contexts) { + context_callback(cur_src_rhs, cur_options, context_callback_extra); + } else { + assert(sglex_lhs < 0); + grammar_callback(sglex_lhs, cur_src_rhs, cur_options, grammar_callback_extra); + } cur_options.clear(); + sglex_reset(); BEGIN(INITIAL); } <FEATS>[ ]+ { ; } @@ -233,6 +275,7 @@ ALIGN [0-9]+-[0-9]+ #include "filelib.h" void StripedGrammarLexer::ReadStripedGrammar(std::istream* in, GrammarCallback func, void* extra) { + read_contexts = 0; lex_line = 1; sglex_stream = in; grammar_callback_extra = extra; @@ -240,3 +283,13 @@ void StripedGrammarLexer::ReadStripedGrammar(std::istream* in, GrammarCallback f yylex(); } +void StripedGrammarLexer::ReadContexts(std::istream* in, ContextCallback func, void* extra) { + read_contexts = 1; + lex_line = 1; + sglex_stream = in; + context_callback_extra = extra; + context_callback = func; + yylex(); +} + + diff --git a/extools/striped_grammar.h b/extools/striped_grammar.h index cdf529d6..bf3aec7d 100644 --- a/extools/striped_grammar.h +++ b/extools/striped_grammar.h @@ -49,6 +49,8 @@ typedef std::tr1::unordered_map<std::vector<WordID>, RuleStatistics, boost::hash struct StripedGrammarLexer { typedef void (*GrammarCallback)(WordID lhs, const std::vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void *extra); static void ReadStripedGrammar(std::istream* in, GrammarCallback func, void* extra); + typedef void (*ContextCallback)(const std::vector<WordID>& phrase, const ID2RuleStatistics& rules, void *extra); + static void ReadContexts(std::istream* in, ContextCallback func, void* extra); }; #endif |