summaryrefslogtreecommitdiff
path: root/extools/extractor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'extools/extractor.cc')
-rw-r--r--extools/extractor.cc88
1 files changed, 71 insertions, 17 deletions
diff --git a/extools/extractor.cc b/extools/extractor.cc
index bc27e408..7279f745 100644
--- a/extools/extractor.cc
+++ b/extools/extractor.cc
@@ -23,7 +23,7 @@ using namespace std::tr1;
namespace po = boost::program_options;
static const size_t MAX_LINE_LENGTH = 100000;
-WordID kBOS, kEOS, kDIVIDER, kGAP;
+WordID kBOS, kEOS, kDIVIDER, kGAP, kSPLIT;
int kCOUNT;
void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
@@ -34,6 +34,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("loose", "Use loose phrase extraction heuristic for base phrases")
("base_phrase,B", "Write base phrases")
("base_phrase_spans", "Write base sentences and phrase spans")
+ ("phrase_language", po::value<string>()->default_value("target"), "Extract phrase strings in source, target or both languages")
+ ("context_language", po::value<string>()->default_value("target"), "Extract context strings in source, target or both languages")
("bidir,b", "Extract bidirectional rules (for computing p(f|e) in addition to p(e|f))")
("combiner_size,c", po::value<size_t>()->default_value(800000), "Number of unique items to store in cache before writing rule counts. Set to 1 to disable cache. Set to 0 for no limit.")
("silent", "Write nothing to stderr except errors")
@@ -140,27 +142,69 @@ struct CountCombiner {
void WritePhraseContexts(const AnnotatedParallelSentence& sentence,
const vector<ParallelSpan>& phrases,
const int ctx_size,
+ bool phrase_s, bool phrase_t,
+ bool context_s, bool context_t,
CountCombiner* o) {
- vector<WordID> context(ctx_size * 2 + 1);
- context[ctx_size] = kGAP;
- vector<WordID> key;
- key.reserve(100);
+ vector<WordID> context, context_f;
+ if (context_t)
+ {
+ context.resize(ctx_size * 2 + 1);
+ context[ctx_size] = kGAP;
+ }
+ if (context_s)
+ {
+ context_f.resize(ctx_size * 2 + 1);
+ context_f[ctx_size] = kGAP;
+ }
+ vector<WordID> key, key_f;
+ if (phrase_t) key.reserve(100);
+ if (phrase_s) key_f.reserve(100);
+
for (int it = 0; it < phrases.size(); ++it) {
const ParallelSpan& phrase = phrases[it];
- // TODO, support src keys as well
- key.resize(phrase.j2 - phrase.j1);
- for (int j = phrase.j1; j < phrase.j2; ++j)
- key[j - phrase.j1] = sentence.e[j];
+ key.clear();
+ for (int j = phrase.j1; j < phrase.j2 && phrase_t; ++j)
+ key.push_back(sentence.e[j]);
+
+ if (context_t)
+ {
+ context.resize(ctx_size * 2 + 1);
+ for (int i = 0; i < ctx_size && context_t; ++i) {
+ int epos = phrase.j1 - 1 - i;
+ const WordID left_ctx = (epos < 0) ? kBOS : sentence.e[epos];
+ context[ctx_size - i - 1] = left_ctx;
+ epos = phrase.j2 + i;
+ const WordID right_ctx = (epos >= sentence.e_len) ? kEOS : sentence.e[epos];
+ context[ctx_size + i + 1] = right_ctx;
+ }
+ }
+ else
+ context.clear();
+
+ if (phrase_s)
+ {
+ key_f.clear();
+ for (int i = phrase.i1; i < phrase.i2; ++i)
+ key_f.push_back(sentence.f[i]);
+ if (phrase_t) key.push_back(kSPLIT);
+ copy(key_f.begin(), key_f.end(), back_inserter(key));
+ }
- for (int i = 0; i < ctx_size; ++i) {
- int epos = phrase.j1 - 1 - i;
- const WordID left_ctx = (epos < 0) ? kBOS : sentence.e[epos];
- context[ctx_size - i - 1] = left_ctx;
- epos = phrase.j2 + i;
- const WordID right_ctx = (epos >= sentence.e_len) ? kEOS : sentence.e[epos];
- context[ctx_size + i + 1] = right_ctx;
+ if (context_s)
+ {
+ for (int i = 0; i < ctx_size; ++i) {
+ int fpos = phrase.i1 - 1 - i;
+ const WordID left_ctx = (fpos < 0) ? kBOS : sentence.f[fpos];
+ context_f[ctx_size - i - 1] = left_ctx;
+ fpos = phrase.i2 + i;
+ const WordID right_ctx = (fpos >= sentence.f_len) ? kEOS : sentence.f[fpos];
+ context_f[ctx_size + i + 1] = right_ctx;
+ }
+ if (context_t) context.push_back(kSPLIT);
+ copy(context_f.begin(), context_f.end(), back_inserter(context));
}
+
o->Count(key, context, kCOUNT, vector<pair<short,short> >());
}
}
@@ -298,6 +342,7 @@ int main(int argc, char** argv) {
kDIVIDER = TD::Convert("|||");
kGAP = TD::Convert("<PHRASE>");
kCOUNT = FD::Convert("C");
+ kSPLIT = TD::Convert("<SPLIT>");
WordID default_cat = 0; // 0 means no default- extraction will
// fail if a phrase is extracted without a
@@ -327,10 +372,19 @@ int main(int argc, char** argv) {
const int num_categories = conf["topics"].as<int>();
const bool permit_adjacent_nonterminals = conf.count("permit_adjacent_nonterminals") > 0;
const bool require_aligned_terminal = conf.count("no_required_aligned_terminal") == 0;
+ const string ps = conf["phrase_language"].as<string>();
+ const bool phrase_s = ps == "source" || ps == "both";
+ const bool phrase_t = ps == "target" || ps == "both";
+ const string cs = conf["context_language"].as<string>();
+ const bool context_s = cs == "source" || cs == "both";
+ const bool context_t = cs == "target" || cs == "both";
int line = 0;
CountCombiner cc(conf["combiner_size"].as<size_t>());
HadoopStreamingRuleObserver o(&cc,
conf.count("bidir") > 0);
+
+ assert(phrase_s || phrase_t);
+ assert(context_s || context_t);
if(backoff) {
for (int i=0;i < num_categories;++i)
@@ -356,7 +410,7 @@ int main(int argc, char** argv) {
continue;
}
if (write_phrase_contexts) {
- WritePhraseContexts(sentence, phrases, ctx_size, &cc);
+ WritePhraseContexts(sentence, phrases, ctx_size, phrase_s, phrase_t, context_s, context_t, &cc);
continue;
}
if (write_base_phrases) {