diff options
Diffstat (limited to 'extools/extractor.cc')
-rw-r--r-- | extools/extractor.cc | 88 |
1 files changed, 71 insertions, 17 deletions
diff --git a/extools/extractor.cc b/extools/extractor.cc index bc27e408..7279f745 100644 --- a/extools/extractor.cc +++ b/extools/extractor.cc @@ -23,7 +23,7 @@ using namespace std::tr1; namespace po = boost::program_options; static const size_t MAX_LINE_LENGTH = 100000; -WordID kBOS, kEOS, kDIVIDER, kGAP; +WordID kBOS, kEOS, kDIVIDER, kGAP, kSPLIT; int kCOUNT; void InitCommandLine(int argc, char** argv, po::variables_map* conf) { @@ -34,6 +34,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("loose", "Use loose phrase extraction heuristic for base phrases") ("base_phrase,B", "Write base phrases") ("base_phrase_spans", "Write base sentences and phrase spans") + ("phrase_language", po::value<string>()->default_value("target"), "Extract phrase strings in source, target or both languages") + ("context_language", po::value<string>()->default_value("target"), "Extract context strings in source, target or both languages") ("bidir,b", "Extract bidirectional rules (for computing p(f|e) in addition to p(e|f))") ("combiner_size,c", po::value<size_t>()->default_value(800000), "Number of unique items to store in cache before writing rule counts. Set to 1 to disable cache. Set to 0 for no limit.") ("silent", "Write nothing to stderr except errors") @@ -140,27 +142,69 @@ struct CountCombiner { void WritePhraseContexts(const AnnotatedParallelSentence& sentence, const vector<ParallelSpan>& phrases, const int ctx_size, + bool phrase_s, bool phrase_t, + bool context_s, bool context_t, CountCombiner* o) { - vector<WordID> context(ctx_size * 2 + 1); - context[ctx_size] = kGAP; - vector<WordID> key; - key.reserve(100); + vector<WordID> context, context_f; + if (context_t) + { + context.resize(ctx_size * 2 + 1); + context[ctx_size] = kGAP; + } + if (context_s) + { + context_f.resize(ctx_size * 2 + 1); + context_f[ctx_size] = kGAP; + } + vector<WordID> key, key_f; + if (phrase_t) key.reserve(100); + if (phrase_s) key_f.reserve(100); + for (int it = 0; it < phrases.size(); ++it) { const ParallelSpan& phrase = phrases[it]; - // TODO, support src keys as well - key.resize(phrase.j2 - phrase.j1); - for (int j = phrase.j1; j < phrase.j2; ++j) - key[j - phrase.j1] = sentence.e[j]; + key.clear(); + for (int j = phrase.j1; j < phrase.j2 && phrase_t; ++j) + key.push_back(sentence.e[j]); + + if (context_t) + { + context.resize(ctx_size * 2 + 1); + for (int i = 0; i < ctx_size && context_t; ++i) { + int epos = phrase.j1 - 1 - i; + const WordID left_ctx = (epos < 0) ? kBOS : sentence.e[epos]; + context[ctx_size - i - 1] = left_ctx; + epos = phrase.j2 + i; + const WordID right_ctx = (epos >= sentence.e_len) ? kEOS : sentence.e[epos]; + context[ctx_size + i + 1] = right_ctx; + } + } + else + context.clear(); + + if (phrase_s) + { + key_f.clear(); + for (int i = phrase.i1; i < phrase.i2; ++i) + key_f.push_back(sentence.f[i]); + if (phrase_t) key.push_back(kSPLIT); + copy(key_f.begin(), key_f.end(), back_inserter(key)); + } - for (int i = 0; i < ctx_size; ++i) { - int epos = phrase.j1 - 1 - i; - const WordID left_ctx = (epos < 0) ? kBOS : sentence.e[epos]; - context[ctx_size - i - 1] = left_ctx; - epos = phrase.j2 + i; - const WordID right_ctx = (epos >= sentence.e_len) ? kEOS : sentence.e[epos]; - context[ctx_size + i + 1] = right_ctx; + if (context_s) + { + for (int i = 0; i < ctx_size; ++i) { + int fpos = phrase.i1 - 1 - i; + const WordID left_ctx = (fpos < 0) ? kBOS : sentence.f[fpos]; + context_f[ctx_size - i - 1] = left_ctx; + fpos = phrase.i2 + i; + const WordID right_ctx = (fpos >= sentence.f_len) ? kEOS : sentence.f[fpos]; + context_f[ctx_size + i + 1] = right_ctx; + } + if (context_t) context.push_back(kSPLIT); + copy(context_f.begin(), context_f.end(), back_inserter(context)); } + o->Count(key, context, kCOUNT, vector<pair<short,short> >()); } } @@ -298,6 +342,7 @@ int main(int argc, char** argv) { kDIVIDER = TD::Convert("|||"); kGAP = TD::Convert("<PHRASE>"); kCOUNT = FD::Convert("C"); + kSPLIT = TD::Convert("<SPLIT>"); WordID default_cat = 0; // 0 means no default- extraction will // fail if a phrase is extracted without a @@ -327,10 +372,19 @@ int main(int argc, char** argv) { const int num_categories = conf["topics"].as<int>(); const bool permit_adjacent_nonterminals = conf.count("permit_adjacent_nonterminals") > 0; const bool require_aligned_terminal = conf.count("no_required_aligned_terminal") == 0; + const string ps = conf["phrase_language"].as<string>(); + const bool phrase_s = ps == "source" || ps == "both"; + const bool phrase_t = ps == "target" || ps == "both"; + const string cs = conf["context_language"].as<string>(); + const bool context_s = cs == "source" || cs == "both"; + const bool context_t = cs == "target" || cs == "both"; int line = 0; CountCombiner cc(conf["combiner_size"].as<size_t>()); HadoopStreamingRuleObserver o(&cc, conf.count("bidir") > 0); + + assert(phrase_s || phrase_t); + assert(context_s || context_t); if(backoff) { for (int i=0;i < num_categories;++i) @@ -356,7 +410,7 @@ int main(int argc, char** argv) { continue; } if (write_phrase_contexts) { - WritePhraseContexts(sentence, phrases, ctx_size, &cc); + WritePhraseContexts(sentence, phrases, ctx_size, phrase_s, phrase_t, context_s, context_t, &cc); continue; } if (write_base_phrases) { |