From 5845e886c65c90ab57f6bf46690167ba0f657e77 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Wed, 14 Jul 2010 18:48:46 +0000 Subject: Added facility to get source or both language phrases/contexts. Need to fix the scripts first before using this new feature. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@250 ec762483-ff6d-05da-a07a-a48fb63a330f --- extools/extractor.cc | 88 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 17 deletions(-) diff --git a/extools/extractor.cc b/extools/extractor.cc index bc27e408..7279f745 100644 --- a/extools/extractor.cc +++ b/extools/extractor.cc @@ -23,7 +23,7 @@ using namespace std::tr1; namespace po = boost::program_options; static const size_t MAX_LINE_LENGTH = 100000; -WordID kBOS, kEOS, kDIVIDER, kGAP; +WordID kBOS, kEOS, kDIVIDER, kGAP, kSPLIT; int kCOUNT; void InitCommandLine(int argc, char** argv, po::variables_map* conf) { @@ -34,6 +34,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("loose", "Use loose phrase extraction heuristic for base phrases") ("base_phrase,B", "Write base phrases") ("base_phrase_spans", "Write base sentences and phrase spans") + ("phrase_language", po::value()->default_value("target"), "Extract phrase strings in source, target or both languages") + ("context_language", po::value()->default_value("target"), "Extract context strings in source, target or both languages") ("bidir,b", "Extract bidirectional rules (for computing p(f|e) in addition to p(e|f))") ("combiner_size,c", po::value()->default_value(800000), "Number of unique items to store in cache before writing rule counts. Set to 1 to disable cache. Set to 0 for no limit.") ("silent", "Write nothing to stderr except errors") @@ -140,27 +142,69 @@ struct CountCombiner { void WritePhraseContexts(const AnnotatedParallelSentence& sentence, const vector& phrases, const int ctx_size, + bool phrase_s, bool phrase_t, + bool context_s, bool context_t, CountCombiner* o) { - vector context(ctx_size * 2 + 1); - context[ctx_size] = kGAP; - vector key; - key.reserve(100); + vector context, context_f; + if (context_t) + { + context.resize(ctx_size * 2 + 1); + context[ctx_size] = kGAP; + } + if (context_s) + { + context_f.resize(ctx_size * 2 + 1); + context_f[ctx_size] = kGAP; + } + vector key, key_f; + if (phrase_t) key.reserve(100); + if (phrase_s) key_f.reserve(100); + for (int it = 0; it < phrases.size(); ++it) { const ParallelSpan& phrase = phrases[it]; - // TODO, support src keys as well - key.resize(phrase.j2 - phrase.j1); - for (int j = phrase.j1; j < phrase.j2; ++j) - key[j - phrase.j1] = sentence.e[j]; + key.clear(); + for (int j = phrase.j1; j < phrase.j2 && phrase_t; ++j) + key.push_back(sentence.e[j]); + + if (context_t) + { + context.resize(ctx_size * 2 + 1); + for (int i = 0; i < ctx_size && context_t; ++i) { + int epos = phrase.j1 - 1 - i; + const WordID left_ctx = (epos < 0) ? kBOS : sentence.e[epos]; + context[ctx_size - i - 1] = left_ctx; + epos = phrase.j2 + i; + const WordID right_ctx = (epos >= sentence.e_len) ? kEOS : sentence.e[epos]; + context[ctx_size + i + 1] = right_ctx; + } + } + else + context.clear(); + + if (phrase_s) + { + key_f.clear(); + for (int i = phrase.i1; i < phrase.i2; ++i) + key_f.push_back(sentence.f[i]); + if (phrase_t) key.push_back(kSPLIT); + copy(key_f.begin(), key_f.end(), back_inserter(key)); + } - for (int i = 0; i < ctx_size; ++i) { - int epos = phrase.j1 - 1 - i; - const WordID left_ctx = (epos < 0) ? kBOS : sentence.e[epos]; - context[ctx_size - i - 1] = left_ctx; - epos = phrase.j2 + i; - const WordID right_ctx = (epos >= sentence.e_len) ? kEOS : sentence.e[epos]; - context[ctx_size + i + 1] = right_ctx; + if (context_s) + { + for (int i = 0; i < ctx_size; ++i) { + int fpos = phrase.i1 - 1 - i; + const WordID left_ctx = (fpos < 0) ? kBOS : sentence.f[fpos]; + context_f[ctx_size - i - 1] = left_ctx; + fpos = phrase.i2 + i; + const WordID right_ctx = (fpos >= sentence.f_len) ? kEOS : sentence.f[fpos]; + context_f[ctx_size + i + 1] = right_ctx; + } + if (context_t) context.push_back(kSPLIT); + copy(context_f.begin(), context_f.end(), back_inserter(context)); } + o->Count(key, context, kCOUNT, vector >()); } } @@ -298,6 +342,7 @@ int main(int argc, char** argv) { kDIVIDER = TD::Convert("|||"); kGAP = TD::Convert(""); kCOUNT = FD::Convert("C"); + kSPLIT = TD::Convert(""); WordID default_cat = 0; // 0 means no default- extraction will // fail if a phrase is extracted without a @@ -327,10 +372,19 @@ int main(int argc, char** argv) { const int num_categories = conf["topics"].as(); const bool permit_adjacent_nonterminals = conf.count("permit_adjacent_nonterminals") > 0; const bool require_aligned_terminal = conf.count("no_required_aligned_terminal") == 0; + const string ps = conf["phrase_language"].as(); + const bool phrase_s = ps == "source" || ps == "both"; + const bool phrase_t = ps == "target" || ps == "both"; + const string cs = conf["context_language"].as(); + const bool context_s = cs == "source" || cs == "both"; + const bool context_t = cs == "target" || cs == "both"; int line = 0; CountCombiner cc(conf["combiner_size"].as()); HadoopStreamingRuleObserver o(&cc, conf.count("bidir") > 0); + + assert(phrase_s || phrase_t); + assert(context_s || context_t); if(backoff) { for (int i=0;i < num_categories;++i) @@ -356,7 +410,7 @@ int main(int argc, char** argv) { continue; } if (write_phrase_contexts) { - WritePhraseContexts(sentence, phrases, ctx_size, &cc); + WritePhraseContexts(sentence, phrases, ctx_size, phrase_s, phrase_t, context_s, context_t, &cc); continue; } if (write_base_phrases) { -- cgit v1.2.3