diff options
| -rw-r--r-- | extools/Makefile.am | 2 | ||||
| -rw-r--r-- | extools/mr_stripe_rule_reduce.cc | 150 | ||||
| -rw-r--r-- | extools/sentence_pair.cc | 4 | ||||
| -rw-r--r-- | extools/sg_lexer.l | 83 | ||||
| -rw-r--r-- | extools/striped_grammar.h | 2 | 
5 files changed, 112 insertions, 129 deletions
| diff --git a/extools/Makefile.am b/extools/Makefile.am index 807fe7d6..562599a3 100644 --- a/extools/Makefile.am +++ b/extools/Makefile.am @@ -28,7 +28,7 @@ build_lexical_translation_SOURCES = build_lexical_translation.cc extract.cc sent  build_lexical_translation_LDADD = $(top_srcdir)/decoder/libcdec.a -lz  build_lexical_translation_LDFLAGS = -all-static -mr_stripe_rule_reduce_SOURCES = mr_stripe_rule_reduce.cc extract.cc sentence_pair.cc striped_grammar.cc +mr_stripe_rule_reduce_SOURCES = mr_stripe_rule_reduce.cc extract.cc sentence_pair.cc striped_grammar.cc sg_lexer.cc  mr_stripe_rule_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a -lz  mr_stripe_rule_reduce_LDFLAGS = -all-static diff --git a/extools/mr_stripe_rule_reduce.cc b/extools/mr_stripe_rule_reduce.cc index 3298a801..8332a106 100644 --- a/extools/mr_stripe_rule_reduce.cc +++ b/extools/mr_stripe_rule_reduce.cc @@ -8,11 +8,11 @@  #include <boost/program_options.hpp>  #include <boost/program_options/variables_map.hpp> +#include "striped_grammar.h"  #include "tdict.h"  #include "sentence_pair.h"  #include "fdict.h"  #include "extract.h" -#include "striped_grammar.h"  using namespace std;  using namespace std::tr1; @@ -22,13 +22,6 @@ static const size_t MAX_LINE_LENGTH = 64000000;  bool use_hadoop_counters = false; -namespace { -  inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } - -  inline void SkipWhitespace(const char* buf, int* ptr) { -    while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } -  } -}  void InitCommandLine(int argc, char** argv, po::variables_map* conf) {    po::options_description opts("Configuration options");    opts.add_options() @@ -50,8 +43,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {    }  } -typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics; -  void PlusEquals(const ID2RuleStatistics& v, ID2RuleStatistics* self) {    for (ID2RuleStatistics::const_iterator it = v.begin(); it != v.end(); ++it) {      RuleStatistics& dest = (*self)[it->first]; @@ -62,79 +53,6 @@ void PlusEquals(const ID2RuleStatistics& v, ID2RuleStatistics* self) {    }  } -int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) { -  static const WordID kDIV = TD::Convert("|||"); -  int ptr = sstart; -  while(ptr < end) { -    while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; } -    int start = ptr; -    while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } -    if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; } -    const WordID w = TD::Convert(string(buf, start, ptr - start)); -    if (w == kDIV) return ptr; -    p->push_back(w); -  } -  assert(p->size() > 0); -  return ptr; -} - -void ParseLine(const char* buf, vector<WordID>* cur_key, ID2RuleStatistics* counts) { -  static const WordID kDIV = TD::Convert("|||"); -  counts->clear(); -  int ptr = 0; -  while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; } -  if (buf[ptr] != '\t') { -    cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl; -    exit(1); -  } -  cur_key->clear(); -  // key is: "[X] ||| word word word" -  int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key); -  if (buf[tmpp] != '\t') { -    cur_key->push_back(kDIV); -    ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key); -  } -  ++ptr; -  int start = ptr; -  int end = ptr; -  int state = 0; // 0=reading label, 1=reading count -  vector<WordID> name; -  while(buf[ptr] != 0) { -    while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; } -    if (buf[ptr] == '|') { -      ++ptr; -      if (buf[ptr] == '|') { -        ++ptr; -        if (buf[ptr] == '|') { -          ++ptr; -          end = ptr - 3; -          while (end > start && IsWhitespace(buf[end-1])) { --end; } -          if (start == end) { -            cerr << "Got empty token!\n  LINE=" << buf << endl; -            exit(1); -          } -          switch (state) { -            case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; -            case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; -            default: cerr << "Can't happen\n"; abort(); -          } -          SkipWhitespace(buf, &ptr); -          start = ptr; -        } -      } -    } -  } -  end=ptr; -  while (end > start && IsWhitespace(buf[end-1])) { --end; } -  if (end > start) { -    switch (state) { -      case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; -      case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; -      default: cerr << "Can't happen\n"; abort(); -    } -  } -} -  void WriteKeyValue(const vector<WordID>& key, const ID2RuleStatistics& val) {    cout << TD::GetString(key) << '\t';    bool needdiv = false; @@ -201,44 +119,54 @@ void WriteWithInversions(const vector<WordID>& key, const ID2RuleStatistics& val    }  } +struct Reducer { +  Reducer(bool phrase_marginals, bool bidir) : pm_(phrase_marginals), bidir_(bidir) {} + +  void ProcessLine(const vector<WordID>& key, const ID2RuleStatistics& rules) { +    if (cur_key_ != key) { +      if (cur_key_.size() > 0) Emit(); +      acc_.clear(); +      cur_key_ = key; +    } +    PlusEquals(rules, &acc_); +  } + +  ~Reducer() { +    Emit(); +  } + +  void Emit() { +    if (pm_) +      DoPhraseMarginals(cur_key_, bidir_, &acc_); +    if (bidir_) +      WriteWithInversions(cur_key_, acc_); +    else +      WriteKeyValue(cur_key_, acc_); +  } + +  const bool pm_; +  const bool bidir_; +  vector<WordID> cur_key_; +  ID2RuleStatistics acc_; +}; + +void cb(const vector<WordID>& key, const ID2RuleStatistics& contexts, void* red) { +  static_cast<Reducer*>(red)->ProcessLine(key, contexts); +} + +  int main(int argc, char** argv) {    po::variables_map conf;    InitCommandLine(argc, argv, &conf);    char* buf = new char[MAX_LINE_LENGTH]; -  ID2RuleStatistics acc, cur_counts;    vector<WordID> key, cur_key;    int line = 0;    use_hadoop_counters = conf.count("use_hadoop_counters") > 0;    const bool phrase_marginals = conf.count("phrase_marginals") > 0;    const bool bidir = conf.count("bidir") > 0; -  while(cin) { -    ++line; -    cin.getline(buf, MAX_LINE_LENGTH); -    if (buf[0] == 0) continue; -    ParseLine(buf, &cur_key, &cur_counts); -    if (cur_key != key) { -      if (key.size() > 0) { -        if (phrase_marginals) -          DoPhraseMarginals(key, bidir, &acc); -        if (bidir) -          WriteWithInversions(key, acc); -        else -          WriteKeyValue(key, acc); -        acc.clear(); -      } -      key = cur_key; -    } -    PlusEquals(cur_counts, &acc); -  } -  if (key.size() > 0) { -    if (phrase_marginals) -      DoPhraseMarginals(key, bidir, &acc); -    if (bidir) -      WriteWithInversions(key, acc); -    else -      WriteKeyValue(key, acc); -  } +  Reducer reducer(phrase_marginals, bidir); +  StripedGrammarLexer::ReadContexts(&cin, cb, &reducer);    return 0;  } diff --git a/extools/sentence_pair.cc b/extools/sentence_pair.cc index b2881737..4cbcc98e 100644 --- a/extools/sentence_pair.cc +++ b/extools/sentence_pair.cc @@ -72,8 +72,8 @@ int AnnotatedParallelSentence::ReadAlignmentPoint(const char* buf,    }    (*b) = 0;    while(ch < end && (c == 0 && (!permit_col || (permit_col && buf[ch] != ':')) || c != 0 && buf[ch] != '-')) { -    if (buf[ch] < '0' || buf[ch] > '9') { -      cerr << "Alignment point badly formed 4: " << string(buf, start, end-start) << endl << buf << endl; +    if ((buf[ch] < '0') || (buf[ch] > '9')) { +      cerr << "Alignment point badly formed 4: " << string(buf, start, end-start) << endl << buf << endl << buf[ch] << endl;        exit(1);      }      (*b) *= 10; diff --git a/extools/sg_lexer.l b/extools/sg_lexer.l index f115e5bd..f82e8135 100644 --- a/extools/sg_lexer.l +++ b/extools/sg_lexer.l @@ -12,9 +12,12 @@  #include "striped_grammar.h"  int lex_line = 0; +int read_contexts = 0;  std::istream* sglex_stream = NULL;  StripedGrammarLexer::GrammarCallback grammar_callback = NULL; +StripedGrammarLexer::ContextCallback context_callback = NULL;  void* grammar_callback_extra = NULL; +void* context_callback_extra = NULL;  #undef YY_INPUT  #define YY_INPUT(buf, result, max_size) (result = sglex_stream->read(buf, max_size).gcount()) @@ -83,12 +86,39 @@ ALIGN [0-9]+-[0-9]+  %%  <INITIAL>[ ]	; +<INITIAL>[\t]	{ +		if (read_contexts) { +			cur_options.clear(); +			BEGIN(TRG); +		} else { +			std::cerr << "Unexpected tab while reading striped grammar\n"; +			exit(1); +		} +		}  <INITIAL>\[{NT}\]   { -		sglex_tmp_token.assign(yytext + 1, yyleng - 2); -		sglex_lhs = -TD::Convert(sglex_tmp_token); -		// std::cerr << sglex_tmp_token << "\n"; -  		BEGIN(LHS_END); +		if (read_contexts) { +			sglex_tmp_token.assign(yytext, yyleng); +			sglex_src_rhs[sglex_src_rhs_size] = TD::Convert(sglex_tmp_token); +			++sglex_src_rhs_size; +		} else { +			sglex_tmp_token.assign(yytext + 1, yyleng - 2); +			sglex_lhs = -TD::Convert(sglex_tmp_token); +			// std::cerr << sglex_tmp_token << "\n"; +  			BEGIN(LHS_END); +			} +		} + +<INITIAL>[^ \t]+ { +		if (read_contexts) { +			// std::cerr << "Context: " << yytext << std::endl; +			sglex_tmp_token.assign(yytext, yyleng); +			sglex_src_rhs[sglex_src_rhs_size] = TD::Convert(sglex_tmp_token); +			++sglex_src_rhs_size; +		} else { +			std::cerr << "Unexpected input: " << yytext << " when NT expected\n"; +			exit(1); +		}  		}  <SRC>\[{NT}\]   { @@ -103,7 +133,8 @@ ALIGN [0-9]+-[0-9]+  		sglex_reset();  		BEGIN(SRC);  		} -<INITIAL,LHS_END>.	{ + +<LHS_END>.	{  		std::cerr << "Line " << lex_line << ": unexpected input in LHS: " << yytext << std::endl;  		exit(1);  		} @@ -136,21 +167,27 @@ ALIGN [0-9]+-[0-9]+  		//std::cerr << "LHS=" << TD::Convert(-sglex_lhs) << " ";  		//std::cerr << "  src_size: " << sglex_src_rhs_size << std::endl;  		//std::cerr << "  src_arity: " << sglex_src_arity << std::endl; -		memset(sglex_nt_sanity, 0, sglex_src_arity * sizeof(int));  		cur_options.clear(); +		memset(sglex_nt_sanity, 0, sglex_src_arity * sizeof(int));  		sglex_trg_rhs_size = 0;  		BEGIN(TRG);  		}  <TRG>\[[1-9][0-9]?\]   { -		int index = yytext[yyleng - 2] - '0'; -		if (yyleng == 4) { -		  index += 10 * (yytext[yyleng - 3] - '0'); +		if (read_contexts) { +			sglex_tmp_token.assign(yytext, yyleng); +			sglex_trg_rhs[sglex_trg_rhs_size] = TD::Convert(sglex_tmp_token); +			++sglex_trg_rhs_size; +		} else { +			int index = yytext[yyleng - 2] - '0'; +			if (yyleng == 4) { +			  index += 10 * (yytext[yyleng - 3] - '0'); +			} +			++sglex_trg_arity; +			sanity_check_trg_index(index); +			sglex_trg_rhs[sglex_trg_rhs_size] = 1 - index; +			++sglex_trg_rhs_size;  		} -		++sglex_trg_arity; -		sanity_check_trg_index(index); -		sglex_trg_rhs[sglex_trg_rhs_size] = 1 - index; -		++sglex_trg_rhs_size;  }  <TRG>\|\|\|	{ @@ -171,13 +208,18 @@ ALIGN [0-9]+-[0-9]+  <TRG>[ ]+	{ ; }  <FEATS>\n	{ -		assert(sglex_lhs < 0);  		assert(sglex_src_rhs_size > 0);  		cur_src_rhs.resize(sglex_src_rhs_size);  		for (int i = 0; i < sglex_src_rhs_size; ++i)  			cur_src_rhs[i] = sglex_src_rhs[i]; -		grammar_callback(sglex_lhs, cur_src_rhs, cur_options, grammar_callback_extra); +		if (read_contexts) { +			context_callback(cur_src_rhs, cur_options, context_callback_extra); +		} else { +			assert(sglex_lhs < 0); +			grammar_callback(sglex_lhs, cur_src_rhs, cur_options, grammar_callback_extra); +		}  		cur_options.clear(); +		sglex_reset();  		BEGIN(INITIAL);  		}  <FEATS>[ ]+	{ ; } @@ -233,6 +275,7 @@ ALIGN [0-9]+-[0-9]+  #include "filelib.h"  void StripedGrammarLexer::ReadStripedGrammar(std::istream* in, GrammarCallback func, void* extra) { +  read_contexts = 0;    lex_line = 1;    sglex_stream = in;    grammar_callback_extra = extra; @@ -240,3 +283,13 @@ void StripedGrammarLexer::ReadStripedGrammar(std::istream* in, GrammarCallback f    yylex();  } +void StripedGrammarLexer::ReadContexts(std::istream* in, ContextCallback func, void* extra) { +  read_contexts = 1; +  lex_line = 1; +  sglex_stream = in; +  context_callback_extra = extra; +  context_callback = func; +  yylex(); +} + + diff --git a/extools/striped_grammar.h b/extools/striped_grammar.h index cdf529d6..bf3aec7d 100644 --- a/extools/striped_grammar.h +++ b/extools/striped_grammar.h @@ -49,6 +49,8 @@ typedef std::tr1::unordered_map<std::vector<WordID>, RuleStatistics, boost::hash  struct StripedGrammarLexer {    typedef void (*GrammarCallback)(WordID lhs, const std::vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void *extra);    static void ReadStripedGrammar(std::istream* in, GrammarCallback func, void* extra); +  typedef void (*ContextCallback)(const std::vector<WordID>& phrase, const ID2RuleStatistics& rules, void *extra); +  static void ReadContexts(std::istream* in, ContextCallback func, void* extra);  };  #endif | 
