diff options
| author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-22 05:12:27 +0000 | 
|---|---|---|
| committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-22 05:12:27 +0000 | 
| commit | 0172721855098ca02b207231a654dffa5e4eb1c9 (patch) | |
| tree | 8069c3a62e2d72bd64a2cdeee9724b2679c8a56b /extools/mr_stripe_rule_reduce.cc | |
| parent | 37728b8be4d0b3df9da81fdda2198ff55b4b2d91 (diff) | |
initial checkin
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@2 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'extools/mr_stripe_rule_reduce.cc')
| -rw-r--r-- | extools/mr_stripe_rule_reduce.cc | 242 | 
1 files changed, 242 insertions, 0 deletions
diff --git a/extools/mr_stripe_rule_reduce.cc b/extools/mr_stripe_rule_reduce.cc new file mode 100644 index 00000000..eaf1b6d7 --- /dev/null +++ b/extools/mr_stripe_rule_reduce.cc @@ -0,0 +1,242 @@ +#include <iostream> +#include <vector> +#include <utility> +#include <cstdlib> +#include <tr1/unordered_map> + +#include <boost/functional/hash.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "tdict.h" +#include "sentence_pair.h" +#include "fdict.h" +#include "extract.h" + +using namespace std; +using namespace std::tr1; +namespace po = boost::program_options; + +static const size_t MAX_LINE_LENGTH = 64000000; + +bool use_hadoop_counters = false; + +namespace { +  inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } + +  inline void SkipWhitespace(const char* buf, int* ptr) { +    while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } +  } +} +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +  po::options_description opts("Configuration options"); +  opts.add_options() +        ("phrase_marginals,p", "Compute phrase marginals") +	("use_hadoop_counters,C", "Enable this if running inside Hadoop") +        ("bidir,b", "Rules are tagged as being F->E or E->F, invert E rules in output") +        ("help,h", "Print this help message and exit"); +  po::options_description clo("Command line options"); +  po::options_description dcmdline_options; +  dcmdline_options.add(opts); + +  po::store(parse_command_line(argc, argv, dcmdline_options), *conf); +  po::notify(*conf); + +  if (conf->count("help")) { +    cerr << "\nUsage: mr_stripe_rule_reduce [-options]\n"; +    cerr << dcmdline_options << endl; +    exit(1); +  } +} + +typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics; + +void PlusEquals(const ID2RuleStatistics& v, ID2RuleStatistics* self) { +  for (ID2RuleStatistics::const_iterator it = v.begin(); it != v.end(); ++it) { +    RuleStatistics& dest = (*self)[it->first]; +    dest += it->second; +    // TODO - do something smarter about alignments? +    if (dest.aligns.empty() && !it->second.aligns.empty()) +      dest.aligns = it->second.aligns; +  } +} + +int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) { +  static const WordID kDIV = TD::Convert("|||"); +  int ptr = sstart; +  while(ptr < end) { +    while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; } +    int start = ptr; +    while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } +    if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; } +    const WordID w = TD::Convert(string(buf, start, ptr - start)); +    if (w == kDIV) return ptr; +    p->push_back(w); +  } +  return ptr; +} + +void ParseLine(const char* buf, vector<WordID>* cur_key, ID2RuleStatistics* counts) { +  static const WordID kDIV = TD::Convert("|||"); +  counts->clear(); +  int ptr = 0; +  while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; } +  if (buf[ptr] != '\t') { +    cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl; +    exit(1); +  } +  cur_key->clear(); +  // key is: "[X] ||| word word word" +  int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key); +  if (buf[tmpp] != '\t') { +    cur_key->push_back(kDIV); +    ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key); +  } +  ++ptr; +  int start = ptr; +  int end = ptr; +  int state = 0; // 0=reading label, 1=reading count +  vector<WordID> name; +  while(buf[ptr] != 0) { +    while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; } +    if (buf[ptr] == '|') { +      ++ptr; +      if (buf[ptr] == '|') { +        ++ptr; +        if (buf[ptr] == '|') { +          ++ptr; +          end = ptr - 3; +          while (end > start && IsWhitespace(buf[end-1])) { --end; } +          if (start == end) { +            cerr << "Got empty token!\n  LINE=" << buf << endl; +            exit(1); +          } +          switch (state) { +            case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; +            case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; +            default: cerr << "Can't happen\n"; abort(); +          } +          SkipWhitespace(buf, &ptr); +          start = ptr; +        } +      } +    } +  } +  end=ptr; +  while (end > start && IsWhitespace(buf[end-1])) { --end; } +  if (end > start) { +    switch (state) { +      case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; +      case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; +      default: cerr << "Can't happen\n"; abort(); +    } +  } +} + +void WriteKeyValue(const vector<WordID>& key, const ID2RuleStatistics& val) { +  cout << TD::GetString(key) << '\t'; +  bool needdiv = false; +  for (ID2RuleStatistics::const_iterator it = val.begin(); it != val.end(); ++it) { +    if (needdiv) cout << " ||| "; else needdiv = true; +    cout << TD::GetString(it->first) << " ||| " << it->second; +  } +  cout << endl; +  if (use_hadoop_counters) cerr << "reporter:counter:UserCounters,RuleCount," << val.size() << endl; +} + +void DoPhraseMarginals(const vector<WordID>& key, const bool bidir, ID2RuleStatistics* val) { +  static const WordID kF = TD::Convert("F"); +  static const WordID kE = TD::Convert("E"); +  static const int kCF = FD::Convert("CF"); +  static const int kCE = FD::Convert("CE"); +  static const int kCFE = FD::Convert("CFE"); +  assert(key.size() > 0); +  int cur_marginal_id = kCF; +  if (bidir) { +    if (key[0] != kF && key[0] != kE) { +      cerr << "DoPhraseMarginals expects keys to have the from 'F|E [NT] word word word'\n"; +      cerr << "  but got: " << TD::GetString(key) << endl; +      exit(1); +    } +    if (key[0] == kE) cur_marginal_id = kCE; +  } +  double tot = 0; +  for (ID2RuleStatistics::iterator it = val->begin(); it != val->end(); ++it) +    tot += it->second.counts.value(kCFE); +  for (ID2RuleStatistics::iterator it = val->begin(); it != val->end(); ++it) { +    it->second.counts.set_value(cur_marginal_id, tot); + +    // prevent double counting of the joint +    if (cur_marginal_id == kCE) it->second.counts.clear_value(kCFE); +  } +} + +void WriteWithInversions(const vector<WordID>& key, const ID2RuleStatistics& val) { +  static const WordID kE = TD::Convert("E"); +  static const WordID kDIV = TD::Convert("|||"); +  vector<WordID> new_key(key.size() - 1); +  for (int i = 1; i < key.size(); ++i) +    new_key[i - 1] = key[i]; +  const bool do_invert = (key[0] == kE); +  if (!do_invert) { +    WriteKeyValue(new_key, val); +  } else { +    ID2RuleStatistics inv; +    assert(new_key.size() > 2); +    vector<WordID> tk(new_key.size() - 2); +    for (int i = 0; i < tk.size(); ++i) +      tk[i] = new_key[2 + i]; +    RuleStatistics& inv_stats = inv[tk]; +    for (ID2RuleStatistics::const_iterator it = val.begin(); it != val.end(); ++it) { +      inv_stats.counts = it->second.counts; +      vector<WordID> ekey(2 + it->first.size()); +      ekey[0] = key[1]; +      ekey[1] = kDIV; +      for (int i = 0; i < it->first.size(); ++i) +        ekey[2+i] = it->first[i]; +      WriteKeyValue(ekey, inv); +    } +  } +} + +int main(int argc, char** argv) { +  po::variables_map conf; +  InitCommandLine(argc, argv, &conf); + +  char* buf = new char[MAX_LINE_LENGTH]; +  ID2RuleStatistics acc, cur_counts; +  vector<WordID> key, cur_key; +  int line = 0; +  use_hadoop_counters = conf.count("use_hadoop_counters") > 0; +  const bool phrase_marginals = conf.count("phrase_marginals") > 0; +  const bool bidir = conf.count("bidir") > 0; +  while(cin) { +    ++line; +    cin.getline(buf, MAX_LINE_LENGTH); +    if (buf[0] == 0) continue; +    ParseLine(buf, &cur_key, &cur_counts); +    if (cur_key != key) { +      if (key.size() > 0) { +        if (phrase_marginals) +          DoPhraseMarginals(key, bidir, &acc); +        if (bidir) +          WriteWithInversions(key, acc); +        else +          WriteKeyValue(key, acc); +        acc.clear(); +      } +      key = cur_key; +    } +    PlusEquals(cur_counts, &acc); +  } +  if (key.size() > 0) { +    if (phrase_marginals) +      DoPhraseMarginals(key, bidir, &acc); +    if (bidir) +      WriteWithInversions(key, acc); +    else +      WriteKeyValue(key, acc); +  } +  return 0; +} +  | 
