summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extools/Makefile.am2
-rw-r--r--extools/mr_stripe_rule_reduce.cc150
-rw-r--r--extools/sentence_pair.cc4
-rw-r--r--extools/sg_lexer.l83
-rw-r--r--extools/striped_grammar.h2
5 files changed, 112 insertions, 129 deletions
diff --git a/extools/Makefile.am b/extools/Makefile.am
index 807fe7d6..562599a3 100644
--- a/extools/Makefile.am
+++ b/extools/Makefile.am
@@ -28,7 +28,7 @@ build_lexical_translation_SOURCES = build_lexical_translation.cc extract.cc sent
build_lexical_translation_LDADD = $(top_srcdir)/decoder/libcdec.a -lz
build_lexical_translation_LDFLAGS = -all-static
-mr_stripe_rule_reduce_SOURCES = mr_stripe_rule_reduce.cc extract.cc sentence_pair.cc striped_grammar.cc
+mr_stripe_rule_reduce_SOURCES = mr_stripe_rule_reduce.cc extract.cc sentence_pair.cc striped_grammar.cc sg_lexer.cc
mr_stripe_rule_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a -lz
mr_stripe_rule_reduce_LDFLAGS = -all-static
diff --git a/extools/mr_stripe_rule_reduce.cc b/extools/mr_stripe_rule_reduce.cc
index 3298a801..8332a106 100644
--- a/extools/mr_stripe_rule_reduce.cc
+++ b/extools/mr_stripe_rule_reduce.cc
@@ -8,11 +8,11 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include "striped_grammar.h"
#include "tdict.h"
#include "sentence_pair.h"
#include "fdict.h"
#include "extract.h"
-#include "striped_grammar.h"
using namespace std;
using namespace std::tr1;
@@ -22,13 +22,6 @@ static const size_t MAX_LINE_LENGTH = 64000000;
bool use_hadoop_counters = false;
-namespace {
- inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; }
-
- inline void SkipWhitespace(const char* buf, int* ptr) {
- while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); }
- }
-}
void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
@@ -50,8 +43,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
}
}
-typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics;
-
void PlusEquals(const ID2RuleStatistics& v, ID2RuleStatistics* self) {
for (ID2RuleStatistics::const_iterator it = v.begin(); it != v.end(); ++it) {
RuleStatistics& dest = (*self)[it->first];
@@ -62,79 +53,6 @@ void PlusEquals(const ID2RuleStatistics& v, ID2RuleStatistics* self) {
}
}
-int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) {
- static const WordID kDIV = TD::Convert("|||");
- int ptr = sstart;
- while(ptr < end) {
- while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; }
- int start = ptr;
- while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; }
- if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; }
- const WordID w = TD::Convert(string(buf, start, ptr - start));
- if (w == kDIV) return ptr;
- p->push_back(w);
- }
- assert(p->size() > 0);
- return ptr;
-}
-
-void ParseLine(const char* buf, vector<WordID>* cur_key, ID2RuleStatistics* counts) {
- static const WordID kDIV = TD::Convert("|||");
- counts->clear();
- int ptr = 0;
- while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; }
- if (buf[ptr] != '\t') {
- cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl;
- exit(1);
- }
- cur_key->clear();
- // key is: "[X] ||| word word word"
- int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key);
- if (buf[tmpp] != '\t') {
- cur_key->push_back(kDIV);
- ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key);
- }
- ++ptr;
- int start = ptr;
- int end = ptr;
- int state = 0; // 0=reading label, 1=reading count
- vector<WordID> name;
- while(buf[ptr] != 0) {
- while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; }
- if (buf[ptr] == '|') {
- ++ptr;
- if (buf[ptr] == '|') {
- ++ptr;
- if (buf[ptr] == '|') {
- ++ptr;
- end = ptr - 3;
- while (end > start && IsWhitespace(buf[end-1])) { --end; }
- if (start == end) {
- cerr << "Got empty token!\n LINE=" << buf << endl;
- exit(1);
- }
- switch (state) {
- case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break;
- case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break;
- default: cerr << "Can't happen\n"; abort();
- }
- SkipWhitespace(buf, &ptr);
- start = ptr;
- }
- }
- }
- }
- end=ptr;
- while (end > start && IsWhitespace(buf[end-1])) { --end; }
- if (end > start) {
- switch (state) {
- case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break;
- case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break;
- default: cerr << "Can't happen\n"; abort();
- }
- }
-}
-
void WriteKeyValue(const vector<WordID>& key, const ID2RuleStatistics& val) {
cout << TD::GetString(key) << '\t';
bool needdiv = false;
@@ -201,44 +119,54 @@ void WriteWithInversions(const vector<WordID>& key, const ID2RuleStatistics& val
}
}
+struct Reducer {
+ Reducer(bool phrase_marginals, bool bidir) : pm_(phrase_marginals), bidir_(bidir) {}
+
+ void ProcessLine(const vector<WordID>& key, const ID2RuleStatistics& rules) {
+ if (cur_key_ != key) {
+ if (cur_key_.size() > 0) Emit();
+ acc_.clear();
+ cur_key_ = key;
+ }
+ PlusEquals(rules, &acc_);
+ }
+
+ ~Reducer() {
+ Emit();
+ }
+
+ void Emit() {
+ if (pm_)
+ DoPhraseMarginals(cur_key_, bidir_, &acc_);
+ if (bidir_)
+ WriteWithInversions(cur_key_, acc_);
+ else
+ WriteKeyValue(cur_key_, acc_);
+ }
+
+ const bool pm_;
+ const bool bidir_;
+ vector<WordID> cur_key_;
+ ID2RuleStatistics acc_;
+};
+
+void cb(const vector<WordID>& key, const ID2RuleStatistics& contexts, void* red) {
+ static_cast<Reducer*>(red)->ProcessLine(key, contexts);
+}
+
+
int main(int argc, char** argv) {
po::variables_map conf;
InitCommandLine(argc, argv, &conf);
char* buf = new char[MAX_LINE_LENGTH];
- ID2RuleStatistics acc, cur_counts;
vector<WordID> key, cur_key;
int line = 0;
use_hadoop_counters = conf.count("use_hadoop_counters") > 0;
const bool phrase_marginals = conf.count("phrase_marginals") > 0;
const bool bidir = conf.count("bidir") > 0;
- while(cin) {
- ++line;
- cin.getline(buf, MAX_LINE_LENGTH);
- if (buf[0] == 0) continue;
- ParseLine(buf, &cur_key, &cur_counts);
- if (cur_key != key) {
- if (key.size() > 0) {
- if (phrase_marginals)
- DoPhraseMarginals(key, bidir, &acc);
- if (bidir)
- WriteWithInversions(key, acc);
- else
- WriteKeyValue(key, acc);
- acc.clear();
- }
- key = cur_key;
- }
- PlusEquals(cur_counts, &acc);
- }
- if (key.size() > 0) {
- if (phrase_marginals)
- DoPhraseMarginals(key, bidir, &acc);
- if (bidir)
- WriteWithInversions(key, acc);
- else
- WriteKeyValue(key, acc);
- }
+ Reducer reducer(phrase_marginals, bidir);
+ StripedGrammarLexer::ReadContexts(&cin, cb, &reducer);
return 0;
}
diff --git a/extools/sentence_pair.cc b/extools/sentence_pair.cc
index b2881737..4cbcc98e 100644
--- a/extools/sentence_pair.cc
+++ b/extools/sentence_pair.cc
@@ -72,8 +72,8 @@ int AnnotatedParallelSentence::ReadAlignmentPoint(const char* buf,
}
(*b) = 0;
while(ch < end && (c == 0 && (!permit_col || (permit_col && buf[ch] != ':')) || c != 0 && buf[ch] != '-')) {
- if (buf[ch] < '0' || buf[ch] > '9') {
- cerr << "Alignment point badly formed 4: " << string(buf, start, end-start) << endl << buf << endl;
+ if ((buf[ch] < '0') || (buf[ch] > '9')) {
+ cerr << "Alignment point badly formed 4: " << string(buf, start, end-start) << endl << buf << endl << buf[ch] << endl;
exit(1);
}
(*b) *= 10;
diff --git a/extools/sg_lexer.l b/extools/sg_lexer.l
index f115e5bd..f82e8135 100644
--- a/extools/sg_lexer.l
+++ b/extools/sg_lexer.l
@@ -12,9 +12,12 @@
#include "striped_grammar.h"
int lex_line = 0;
+int read_contexts = 0;
std::istream* sglex_stream = NULL;
StripedGrammarLexer::GrammarCallback grammar_callback = NULL;
+StripedGrammarLexer::ContextCallback context_callback = NULL;
void* grammar_callback_extra = NULL;
+void* context_callback_extra = NULL;
#undef YY_INPUT
#define YY_INPUT(buf, result, max_size) (result = sglex_stream->read(buf, max_size).gcount())
@@ -83,12 +86,39 @@ ALIGN [0-9]+-[0-9]+
%%
<INITIAL>[ ] ;
+<INITIAL>[\t] {
+ if (read_contexts) {
+ cur_options.clear();
+ BEGIN(TRG);
+ } else {
+ std::cerr << "Unexpected tab while reading striped grammar\n";
+ exit(1);
+ }
+ }
<INITIAL>\[{NT}\] {
- sglex_tmp_token.assign(yytext + 1, yyleng - 2);
- sglex_lhs = -TD::Convert(sglex_tmp_token);
- // std::cerr << sglex_tmp_token << "\n";
- BEGIN(LHS_END);
+ if (read_contexts) {
+ sglex_tmp_token.assign(yytext, yyleng);
+ sglex_src_rhs[sglex_src_rhs_size] = TD::Convert(sglex_tmp_token);
+ ++sglex_src_rhs_size;
+ } else {
+ sglex_tmp_token.assign(yytext + 1, yyleng - 2);
+ sglex_lhs = -TD::Convert(sglex_tmp_token);
+ // std::cerr << sglex_tmp_token << "\n";
+ BEGIN(LHS_END);
+ }
+ }
+
+<INITIAL>[^ \t]+ {
+ if (read_contexts) {
+ // std::cerr << "Context: " << yytext << std::endl;
+ sglex_tmp_token.assign(yytext, yyleng);
+ sglex_src_rhs[sglex_src_rhs_size] = TD::Convert(sglex_tmp_token);
+ ++sglex_src_rhs_size;
+ } else {
+ std::cerr << "Unexpected input: " << yytext << " when NT expected\n";
+ exit(1);
+ }
}
<SRC>\[{NT}\] {
@@ -103,7 +133,8 @@ ALIGN [0-9]+-[0-9]+
sglex_reset();
BEGIN(SRC);
}
-<INITIAL,LHS_END>. {
+
+<LHS_END>. {
std::cerr << "Line " << lex_line << ": unexpected input in LHS: " << yytext << std::endl;
exit(1);
}
@@ -136,21 +167,27 @@ ALIGN [0-9]+-[0-9]+
//std::cerr << "LHS=" << TD::Convert(-sglex_lhs) << " ";
//std::cerr << " src_size: " << sglex_src_rhs_size << std::endl;
//std::cerr << " src_arity: " << sglex_src_arity << std::endl;
- memset(sglex_nt_sanity, 0, sglex_src_arity * sizeof(int));
cur_options.clear();
+ memset(sglex_nt_sanity, 0, sglex_src_arity * sizeof(int));
sglex_trg_rhs_size = 0;
BEGIN(TRG);
}
<TRG>\[[1-9][0-9]?\] {
- int index = yytext[yyleng - 2] - '0';
- if (yyleng == 4) {
- index += 10 * (yytext[yyleng - 3] - '0');
+ if (read_contexts) {
+ sglex_tmp_token.assign(yytext, yyleng);
+ sglex_trg_rhs[sglex_trg_rhs_size] = TD::Convert(sglex_tmp_token);
+ ++sglex_trg_rhs_size;
+ } else {
+ int index = yytext[yyleng - 2] - '0';
+ if (yyleng == 4) {
+ index += 10 * (yytext[yyleng - 3] - '0');
+ }
+ ++sglex_trg_arity;
+ sanity_check_trg_index(index);
+ sglex_trg_rhs[sglex_trg_rhs_size] = 1 - index;
+ ++sglex_trg_rhs_size;
}
- ++sglex_trg_arity;
- sanity_check_trg_index(index);
- sglex_trg_rhs[sglex_trg_rhs_size] = 1 - index;
- ++sglex_trg_rhs_size;
}
<TRG>\|\|\| {
@@ -171,13 +208,18 @@ ALIGN [0-9]+-[0-9]+
<TRG>[ ]+ { ; }
<FEATS>\n {
- assert(sglex_lhs < 0);
assert(sglex_src_rhs_size > 0);
cur_src_rhs.resize(sglex_src_rhs_size);
for (int i = 0; i < sglex_src_rhs_size; ++i)
cur_src_rhs[i] = sglex_src_rhs[i];
- grammar_callback(sglex_lhs, cur_src_rhs, cur_options, grammar_callback_extra);
+ if (read_contexts) {
+ context_callback(cur_src_rhs, cur_options, context_callback_extra);
+ } else {
+ assert(sglex_lhs < 0);
+ grammar_callback(sglex_lhs, cur_src_rhs, cur_options, grammar_callback_extra);
+ }
cur_options.clear();
+ sglex_reset();
BEGIN(INITIAL);
}
<FEATS>[ ]+ { ; }
@@ -233,6 +275,7 @@ ALIGN [0-9]+-[0-9]+
#include "filelib.h"
void StripedGrammarLexer::ReadStripedGrammar(std::istream* in, GrammarCallback func, void* extra) {
+ read_contexts = 0;
lex_line = 1;
sglex_stream = in;
grammar_callback_extra = extra;
@@ -240,3 +283,13 @@ void StripedGrammarLexer::ReadStripedGrammar(std::istream* in, GrammarCallback f
yylex();
}
+void StripedGrammarLexer::ReadContexts(std::istream* in, ContextCallback func, void* extra) {
+ read_contexts = 1;
+ lex_line = 1;
+ sglex_stream = in;
+ context_callback_extra = extra;
+ context_callback = func;
+ yylex();
+}
+
+
diff --git a/extools/striped_grammar.h b/extools/striped_grammar.h
index cdf529d6..bf3aec7d 100644
--- a/extools/striped_grammar.h
+++ b/extools/striped_grammar.h
@@ -49,6 +49,8 @@ typedef std::tr1::unordered_map<std::vector<WordID>, RuleStatistics, boost::hash
struct StripedGrammarLexer {
typedef void (*GrammarCallback)(WordID lhs, const std::vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void *extra);
static void ReadStripedGrammar(std::istream* in, GrammarCallback func, void* extra);
+ typedef void (*ContextCallback)(const std::vector<WordID>& phrase, const ID2RuleStatistics& rules, void *extra);
+ static void ReadContexts(std::istream* in, ContextCallback func, void* extra);
};
#endif