diff options
Diffstat (limited to 'extools')
-rw-r--r-- | extools/featurize_grammar.cc | 156 |
1 files changed, 126 insertions, 30 deletions
diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc index b387fe04..27c0dadf 100644 --- a/extools/featurize_grammar.cc +++ b/extools/featurize_grammar.cc @@ -5,6 +5,7 @@ #include <sstream> #include <string> #include <map> +#include <set> #include <vector> #include <utility> #include <cstdlib> @@ -376,7 +377,7 @@ struct LogRuleCount : public FeatureExtractor { SparseVector<float>* result) const { (void) lhs; (void) src; (void) trg; //result->set_value(fid_, log(info.counts.value(kCFE))); - result->set_value(fid_, (info.counts.value(kCFE))); + result->set_value(fid_, log(info.counts.value(kCFE))); if (IsZero(info.counts.value(kCFE))) result->set_value(sfid_, 1); } @@ -385,13 +386,25 @@ struct LogRuleCount : public FeatureExtractor { const int kCFE; }; +struct RulePenalty : public FeatureExtractor { + RulePenalty() : fid_(FD::Convert("RulePenalty")) {} + virtual void ExtractFeatures(const WordID /*lhs*/, + const vector<WordID>& /*src*/, + const vector<WordID>& /*trg*/, + const RuleStatistics& /*info*/, + SparseVector<float>* result) const + { result->set_value(fid_, 1); } + + const int fid_; +}; + struct BackoffRule : public FeatureExtractor { - BackoffRule() : - fid_(FD::Convert("BackoffRule")) {} + BackoffRule() : fid_(FD::Convert("BackoffRule")) {} + virtual void ExtractFeatures(const WordID lhs, const vector<WordID>& src, const vector<WordID>& trg, - const RuleStatistics& info, + const RuleStatistics& /*info*/, SparseVector<float>* result) const { (void) lhs; (void) src; (void) trg; string lhstr = TD::Convert(lhs); @@ -404,6 +417,7 @@ struct BackoffRule : public FeatureExtractor { // The negative log of the condition rule probs // ignoring the identities of the non-terminals. // i.e. the prob Hiero would assign. +// Also extracts Labelled features. struct XFeatures: public FeatureExtractor { XFeatures() : fid_xfe(FD::Convert("XFE")), @@ -425,14 +439,11 @@ struct XFeatures: public FeatureExtractor { target_counts.inc(r.target(), 0); } - // compute statistics over keys, the same lhs-src-trg tuple may be seen - // more than once virtual void ObserveUnfilteredRule(const WordID /*lhs*/, const vector<WordID>& src, const vector<WordID>& trg, const RuleStatistics& info) { RuleTuple r(-1, src, trg); -// cerr << " ObserveUnfilteredRule() in:" << r << " " << hash_value(r) << endl; map_rule(r); rule_counts.inc_if_exists(r, info.counts.value(kCFE)); @@ -441,7 +452,6 @@ struct XFeatures: public FeatureExtractor { normalise_string(r.target()); target_counts.inc_if_exists(r.target(), info.counts.value(kCFE)); -// cerr << " ObserveUnfilteredRule() inc: " << r << " " << hash_value(r) << " " << info.counts.value(kCFE) << " to " << rule_counts(r) << endl; } virtual void ExtractFeatures(const WordID /*lhs*/, @@ -451,18 +461,15 @@ struct XFeatures: public FeatureExtractor { SparseVector<float>* result) const { RuleTuple r(-1, src, trg); map_rule(r); - //result->set_value(fid_fe, rule_counts(r)); double l_r_freq = log(rule_counts(r)); normalise_string(r.target()); result->set_value(fid_xfe, log(target_counts(r.target())) - l_r_freq); result->set_value(fid_labelledfe, log(target_counts(r.target())) - log(info.counts.value(kCFE))); - //result->set_value(fid_labelledfe, target_counts(r.target())); normalise_string(r.source()); result->set_value(fid_xef, log(source_counts(r.source())) - l_r_freq); result->set_value(fid_labelledef, log(source_counts(r.source())) - log(info.counts.value(kCFE))); - //result->set_value(fid_labelledef, source_counts(r.source())); } void map_rule(RuleTuple& r) const { @@ -490,20 +497,17 @@ struct XFeatures: public FeatureExtractor { FreqCount< vector<WordID> > source_counts, target_counts; }; + struct LabelledRuleConditionals: public FeatureExtractor { LabelledRuleConditionals() : - fid_count(FD::Convert("TCount")), - fid_e(FD::Convert("TCountE")), - fid_f(FD::Convert("TCountF")), - fid_fe(FD::Convert("TLabelledFE")), - fid_ef(FD::Convert("TLabelledEF")), + fid_fe(FD::Convert("LabelledFE")), + fid_ef(FD::Convert("LabelledEF")), kCFE(FD::Convert("CFE")) {} virtual void ObserveFilteredRule(const WordID lhs, const vector<WordID>& src, const vector<WordID>& trg) { RuleTuple r(lhs, src, trg); rule_counts.inc(r, 0); - //cerr << " ObservefilteredRule() inc: " << r << " " << hash_value(r) << endl; normalise_string(r.source()); source_counts.inc(r.source(), 0); @@ -511,15 +515,12 @@ struct LabelledRuleConditionals: public FeatureExtractor { target_counts.inc(r.target(), 0); } - // compute statistics over keys, the same lhs-src-trg tuple may be seen - // more than once virtual void ObserveUnfilteredRule(const WordID lhs, const vector<WordID>& src, const vector<WordID>& trg, const RuleStatistics& info) { RuleTuple r(lhs, src, trg); rule_counts.inc_if_exists(r, info.counts.value(kCFE)); - //cerr << " ObserveUnfilteredRule() inc_if_exists: " << r << " " << hash_value(r) << " " << info.counts.value(kCFE) << " to " << rule_counts(r) << endl; normalise_string(r.source()); source_counts.inc_if_exists(r.source(), info.counts.value(kCFE)); @@ -534,19 +535,10 @@ struct LabelledRuleConditionals: public FeatureExtractor { SparseVector<float>* result) const { RuleTuple r(lhs, src, trg); double l_r_freq = log(rule_counts(r)); - //result->set_value(fid_count, rule_counts(r)); - //cerr << " ExtractFeatures() count: " << r << " " << info.counts.value(kCFE) << " | " << rule_counts(r) << endl; - //assert(l_r_freq == log(info.counts.value(kCFE))); - //cerr << " ExtractFeatures() after:" << " " << r.hash << endl; - //cerr << " ExtractFeatures() in:" << r << " " << r_freq << " " << hash_value(r) << endl; - //cerr << " ExtractFeatures() in:" << r << " " << r_freq << endl; normalise_string(r.target()); result->set_value(fid_fe, log(target_counts(r.target())) - l_r_freq); normalise_string(r.source()); result->set_value(fid_ef, log(source_counts(r.source())) - l_r_freq); - - //result->set_value(fid_e, target_counts(r.target())); - //result->set_value(fid_f, source_counts(r.source())); } void normalise_string(vector<WordID>& r) const { @@ -556,12 +548,112 @@ struct LabelledRuleConditionals: public FeatureExtractor { } const int fid_fe, fid_ef; - const int fid_count, fid_e, fid_f; const int kCFE; RuleFreqCount rule_counts; FreqCount< vector<WordID> > source_counts, target_counts; }; +struct LHSProb: public FeatureExtractor { + LHSProb() : fid_(FD::Convert("LHSProb")), kCFE(FD::Convert("CFE")), total_count(0) {} + + virtual void ObserveUnfilteredRule(const WordID lhs, + const vector<WordID>& /*src*/, + const vector<WordID>& /*trg*/, + const RuleStatistics& info) { + int count = info.counts.value(kCFE); + total_count += count; + lhs_counts.inc(lhs, count); + } + + virtual void ExtractFeatures(const WordID lhs, + const vector<WordID>& /*src*/, + const vector<WordID>& /*trg*/, + const RuleStatistics& /*info*/, + SparseVector<float>* result) const { + double lhs_log_prob = log(total_count) - log(lhs_counts(lhs)); + result->set_value(fid_, lhs_log_prob); + } + + const int fid_; + const int kCFE; + int total_count; + FreqCount<WordID> lhs_counts; +}; + +// Proper rule generative probability: p( s,t | lhs) +struct GenerativeProb: public FeatureExtractor { + GenerativeProb() : + fid_(FD::Convert("GenerativeProb")), + kCFE(FD::Convert("CFE")) {} + + virtual void ObserveUnfilteredRule(const WordID lhs, + const vector<WordID>& /*src*/, + const vector<WordID>& /*trg*/, + const RuleStatistics& info) + { lhs_counts.inc(lhs, info.counts.value(kCFE)); } + + virtual void ExtractFeatures(const WordID lhs, + const vector<WordID>& /*src*/, + const vector<WordID>& /*trg*/, + const RuleStatistics& info, + SparseVector<float>* result) const { + double log_prob = log(lhs_counts(lhs)) - log(info.counts.value(kCFE)); + result->set_value(fid_, log_prob); + } + + const int fid_; + const int kCFE; + FreqCount<WordID> lhs_counts; +}; + +// remove terminals from the rules before estimating the conditional prob +struct LabellingShape: public FeatureExtractor { + LabellingShape() : fid_(FD::Convert("LabellingShape")), kCFE(FD::Convert("CFE")) {} + + virtual void ObserveFilteredRule(const WordID /*lhs*/, + const vector<WordID>& src, + const vector<WordID>& trg) { + RuleTuple r(-1, src, trg); + map_rule(r); + rule_counts.inc(r, 0); + source_counts.inc(r.source(), 0); + } + + virtual void ObserveUnfilteredRule(const WordID /*lhs*/, + const vector<WordID>& src, + const vector<WordID>& trg, + const RuleStatistics& info) { + RuleTuple r(-1, src, trg); + map_rule(r); + rule_counts.inc_if_exists(r, info.counts.value(kCFE)); + source_counts.inc_if_exists(r.source(), info.counts.value(kCFE)); + } + + virtual void ExtractFeatures(const WordID /*lhs*/, + const vector<WordID>& src, + const vector<WordID>& trg, + const RuleStatistics& /*info*/, + SparseVector<float>* result) const { + RuleTuple r(-1, src, trg); + map_rule(r); + double l_r_freq = log(rule_counts(r)); + result->set_value(fid_, log(source_counts(r.source())) - l_r_freq); + } + + // Replace all terminals with generic -1 + void map_rule(RuleTuple& r) const { + for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it) + if (!validate_non_terminal(TD::Convert(*it))) *it = -1; + for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it) + if (!validate_non_terminal(TD::Convert(*it))) *it = -1; + } + + const int fid_, kCFE; + RuleFreqCount rule_counts; + FreqCount< vector<WordID> > source_counts; +}; + + // this extracts the lexical translation prob features // in BOTH directions. struct LexProbExtractor : public FeatureExtractor { @@ -673,6 +765,10 @@ int main(int argc, char** argv){ reg.Register("XFeatures", new FEFactory<XFeatures>); reg.Register("LabelledRuleConditionals", new FEFactory<LabelledRuleConditionals>); reg.Register("BackoffRule", new FEFactory<BackoffRule>); + reg.Register("RulePenalty", new FEFactory<RulePenalty>); + reg.Register("LHSProb", new FEFactory<LHSProb>); + reg.Register("LabellingShape", new FEFactory<LabellingShape>); + reg.Register("GenerativeProb", new FEFactory<GenerativeProb>); po::variables_map conf; InitCommandLine(reg, argc, argv, &conf); aligned_corpus = conf["aligned_corpus"].as<string>(); // GLOBAL VAR |