diff options
| -rw-r--r-- | extools/extractor.cc | 88 | 
1 files changed, 71 insertions, 17 deletions
| diff --git a/extools/extractor.cc b/extools/extractor.cc index bc27e408..7279f745 100644 --- a/extools/extractor.cc +++ b/extools/extractor.cc @@ -23,7 +23,7 @@ using namespace std::tr1;  namespace po = boost::program_options;  static const size_t MAX_LINE_LENGTH = 100000; -WordID kBOS, kEOS, kDIVIDER, kGAP; +WordID kBOS, kEOS, kDIVIDER, kGAP, kSPLIT;  int kCOUNT;  void InitCommandLine(int argc, char** argv, po::variables_map* conf) { @@ -34,6 +34,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {          ("loose", "Use loose phrase extraction heuristic for base phrases")          ("base_phrase,B", "Write base phrases")          ("base_phrase_spans", "Write base sentences and phrase spans") +        ("phrase_language", po::value<string>()->default_value("target"), "Extract phrase strings in source, target or both languages") +        ("context_language", po::value<string>()->default_value("target"), "Extract context strings in source, target or both languages")          ("bidir,b", "Extract bidirectional rules (for computing p(f|e) in addition to p(e|f))")          ("combiner_size,c", po::value<size_t>()->default_value(800000), "Number of unique items to store in cache before writing rule counts. Set to 1 to disable cache. Set to 0 for no limit.")          ("silent", "Write nothing to stderr except errors") @@ -140,27 +142,69 @@ struct CountCombiner {  void WritePhraseContexts(const AnnotatedParallelSentence& sentence,                           const vector<ParallelSpan>& phrases,                           const int ctx_size, +                         bool phrase_s, bool phrase_t, +                         bool context_s, bool context_t,                           CountCombiner* o) { -  vector<WordID> context(ctx_size * 2 + 1); -  context[ctx_size] = kGAP; -  vector<WordID> key; -  key.reserve(100); +  vector<WordID> context, context_f; +  if (context_t)  +  { +      context.resize(ctx_size * 2 + 1);  +      context[ctx_size] = kGAP; +  } +  if (context_s)  +  { +      context_f.resize(ctx_size * 2 + 1); +      context_f[ctx_size] = kGAP; +  } +  vector<WordID> key, key_f; +  if (phrase_t) key.reserve(100); +  if (phrase_s) key_f.reserve(100); +    for (int it = 0; it < phrases.size(); ++it) {      const ParallelSpan& phrase = phrases[it]; -    // TODO, support src keys as well -    key.resize(phrase.j2 - phrase.j1); -    for (int j = phrase.j1; j < phrase.j2; ++j) -      key[j - phrase.j1] = sentence.e[j]; +    key.clear(); +    for (int j = phrase.j1; j < phrase.j2 && phrase_t; ++j) +      key.push_back(sentence.e[j]); + +    if (context_t) +    { +        context.resize(ctx_size * 2 + 1);  +        for (int i = 0; i < ctx_size && context_t; ++i) { +          int epos = phrase.j1 - 1 - i; +          const WordID left_ctx = (epos < 0) ? kBOS : sentence.e[epos]; +          context[ctx_size - i - 1] = left_ctx; +          epos = phrase.j2 + i; +          const WordID right_ctx = (epos >= sentence.e_len) ? kEOS : sentence.e[epos]; +          context[ctx_size + i + 1] = right_ctx; +        } +    } +    else +        context.clear(); + +    if (phrase_s) +    { +        key_f.clear(); +        for (int i = phrase.i1; i < phrase.i2; ++i) +          key_f.push_back(sentence.f[i]); +        if (phrase_t) key.push_back(kSPLIT); +        copy(key_f.begin(), key_f.end(), back_inserter(key)); +    } -    for (int i = 0; i < ctx_size; ++i) { -      int epos = phrase.j1 - 1 - i; -      const WordID left_ctx = (epos < 0) ? kBOS : sentence.e[epos]; -      context[ctx_size - i - 1] = left_ctx; -      epos = phrase.j2 + i; -      const WordID right_ctx = (epos >= sentence.e_len) ? kEOS : sentence.e[epos]; -      context[ctx_size + i + 1] = right_ctx; +    if (context_s) +    { +        for (int i = 0; i < ctx_size; ++i) { +          int fpos = phrase.i1 - 1 - i; +          const WordID left_ctx = (fpos < 0) ? kBOS : sentence.f[fpos]; +          context_f[ctx_size - i - 1] = left_ctx; +          fpos = phrase.i2 + i; +          const WordID right_ctx = (fpos >= sentence.f_len) ? kEOS : sentence.f[fpos]; +          context_f[ctx_size + i + 1] = right_ctx; +        } +        if (context_t) context.push_back(kSPLIT); +        copy(context_f.begin(), context_f.end(), back_inserter(context));      } +      o->Count(key, context, kCOUNT, vector<pair<short,short> >());    }  } @@ -298,6 +342,7 @@ int main(int argc, char** argv) {    kDIVIDER = TD::Convert("|||");    kGAP = TD::Convert("<PHRASE>");    kCOUNT = FD::Convert("C"); +  kSPLIT = TD::Convert("<SPLIT>");    WordID default_cat = 0;  // 0 means no default- extraction will                             // fail if a phrase is extracted without a @@ -327,10 +372,19 @@ int main(int argc, char** argv) {    const int num_categories = conf["topics"].as<int>();    const bool permit_adjacent_nonterminals = conf.count("permit_adjacent_nonterminals") > 0;    const bool require_aligned_terminal = conf.count("no_required_aligned_terminal") == 0; +  const string ps = conf["phrase_language"].as<string>(); +  const bool phrase_s = ps == "source" || ps == "both"; +  const bool phrase_t = ps == "target" || ps == "both"; +  const string cs = conf["context_language"].as<string>(); +  const bool context_s = cs == "source" || cs == "both"; +  const bool context_t = cs == "target" || cs == "both";    int line = 0;    CountCombiner cc(conf["combiner_size"].as<size_t>());    HadoopStreamingRuleObserver o(&cc,                                  conf.count("bidir") > 0); + +  assert(phrase_s || phrase_t); +  assert(context_s || context_t);    if(backoff) {      for (int i=0;i < num_categories;++i) @@ -356,7 +410,7 @@ int main(int argc, char** argv) {        continue;      }      if (write_phrase_contexts) { -      WritePhraseContexts(sentence, phrases, ctx_size, &cc); +      WritePhraseContexts(sentence, phrases, ctx_size, phrase_s, phrase_t, context_s, context_t, &cc);        continue;      }      if (write_base_phrases) { | 
