From c04ba5eed5049569d327dfb6162d91bec0a3aec8 Mon Sep 17 00:00:00 2001 From: redpony Date: Tue, 6 Jul 2010 17:45:09 +0000 Subject: featurizer git-svn-id: https://ws10smt.googlecode.com/svn/trunk@154 ec762483-ff6d-05da-a07a-a48fb63a330f --- extools/Makefile.am | 10 + extools/featurize_grammar.cc | 410 ++++++++++++++++++++++++++++++++++++++++ extools/filter_grammar.cc | 235 +++++++++++++++++++++++ extools/filter_score_grammar.cc | 57 ++++-- 4 files changed, 700 insertions(+), 12 deletions(-) create mode 100644 extools/featurize_grammar.cc create mode 100644 extools/filter_grammar.cc diff --git a/extools/Makefile.am b/extools/Makefile.am index bce6c404..fc02f831 100644 --- a/extools/Makefile.am +++ b/extools/Makefile.am @@ -2,6 +2,8 @@ bin_PROGRAMS = \ extractor \ mr_stripe_rule_reduce \ build_lexical_translation \ + filter_grammar \ + featurize_grammar \ filter_score_grammar noinst_PROGRAMS = @@ -10,6 +12,14 @@ filter_score_grammar_SOURCES = filter_score_grammar.cc extract.cc sentence_pair. filter_score_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a -lz filter_score_grammar_LDFLAGS = -all-static +filter_grammar_SOURCES = filter_grammar.cc extract.cc sentence_pair.cc +filter_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a -lz +filter_grammar_LDFLAGS = -all-static + +featurize_grammar_SOURCES = featurize_grammar.cc extract.cc sentence_pair.cc +featurize_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a -lz +featurize_grammar_LDFLAGS = -all-static + build_lexical_translation_SOURCES = build_lexical_translation.cc extract.cc sentence_pair.cc build_lexical_translation_LDADD = $(top_srcdir)/decoder/libcdec.a -lz build_lexical_translation_LDFLAGS = -all-static diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc new file mode 100644 index 00000000..1ca20a4b --- /dev/null +++ b/extools/featurize_grammar.cc @@ -0,0 +1,410 @@ +/* + * Featurize a grammar in striped format + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "suffix_tree.h" +#include "sparse_vector.h" +#include "sentence_pair.h" +#include "extract.h" +#include "fdict.h" +#include "tdict.h" +#include "lex_trans_tbl.h" +#include "filelib.h" + +#include +#include +#include +#include + +using namespace std; +using namespace std::tr1; +namespace po = boost::program_options; + +static const size_t MAX_LINE_LENGTH = 64000000; + +typedef unordered_map, RuleStatistics, boost::hash > > ID2RuleStatistics; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("filtered_grammar,g", po::value(), "Grammar to add features to") + ("aligned_corpus,c", po::value(), "Aligned corpus (single line format)") + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("aligned_corpus")==0) { + cerr << "\nUsage: featurize_grammar -g FILTERED-GRAMMAR.gz -c ALIGNED_CORPUS.fr-en-al [-options] < UNFILTERED-GRAMMAR\n"; + cerr << dcmdline_options << endl; + exit(1); + } +} + +namespace { + inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } + inline bool IsBracket(char c){return c == '[' || c == ']';} + inline void SkipWhitespace(const char* buf, int* ptr) { + while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } + } +} + +int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector* p) { + static const WordID kDIV = TD::Convert("|||"); + int ptr = sstart; + while(ptr < end) { + while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; } + int start = ptr; + while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } + if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; } + const WordID w = TD::Convert(string(buf, start, ptr - start)); + + if((IsBracket(buf[start]) and IsBracket(buf[ptr-1])) or( w == kDIV)) + p->push_back(1 * w); + else { + if (w == kDIV) return ptr; + p->push_back(w); + } + } + return ptr; +} + +void ParseLine(const char* buf, vector* cur_key, ID2RuleStatistics* counts) { + static const WordID kDIV = TD::Convert("|||"); + counts->clear(); + int ptr = 0; + while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; } + if (buf[ptr] != '\t') { + cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl; + exit(1); + } + cur_key->clear(); + // key is: "[X] ||| word word word" + int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key); + cur_key->push_back(kDIV); + ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key); + ++ptr; + int start = ptr; + int end = ptr; + int state = 0; // 0=reading label, 1=reading count + vector name; + while(buf[ptr] != 0) { + while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; } + if (buf[ptr] == '|') { + ++ptr; + if (buf[ptr] == '|') { + ++ptr; + if (buf[ptr] == '|') { + ++ptr; + end = ptr - 3; + while (end > start && IsWhitespace(buf[end-1])) { --end; } + if (start == end) { + cerr << "Got empty token!\n LINE=" << buf << endl; + exit(1); + } + switch (state) { + case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; + case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; + default: cerr << "Can't happen\n"; abort(); + } + SkipWhitespace(buf, &ptr); + start = ptr; + } + } + } + } + end=ptr; + while (end > start && IsWhitespace(buf[end-1])) { --end; } + if (end > start) { + switch (state) { + case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; + case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; + default: cerr << "Can't happen\n"; abort(); + } + } +} + + +void LexTranslationTable::createTTable(const char* buf){ + AnnotatedParallelSentence sent; + sent.ParseInputLine(buf); + + //iterate over the alignment to compute aligned words + + for(int i =0;i (sent.f[i], sent.e[j])]; + ++total_foreign[sent.f[i]]; + ++total_english[sent.e[j]]; + } + } + if (DEBUG) cerr << endl; + } + if (DEBUG) cerr << endl; + + const WordID NULL_ = TD::Convert("NULL"); + //handle unaligned words - align them to null + for (int j =0; j < sent.e_len; j++) { + if (sent.e_aligned[j]) continue; + ++word_translation[pair (NULL_, sent.e[j])]; + ++total_foreign[NULL_]; + ++total_english[sent.e[j]]; + } + + for (int i =0; i < sent.f_len; i++) { + if (sent.f_aligned[i]) continue; + ++word_translation[pair (sent.f[i], NULL_)]; + ++total_english[NULL_]; + ++total_foreign[sent.f[i]]; + } +} + +inline float safenlog(float v) { + if (v == 1.0f) return 0.0f; + float res = -log(v); + if (res > 100.0f) res = 100.0f; + return res; +} + +static bool IsZero(float f) { return (f > 0.999 && f < 1.001); } + +struct FeatureExtractor { + // create any keys necessary + virtual void ObserveFilteredRule(const WordID lhs, + const vector& src, + const vector& trg) {} + + // compute statistics over keys, the same lhs-src-trg tuple may be seen + // more than once + virtual void ObserveUnfilteredRule(const WordID lhs, + const vector& src, + const vector& trg, + const RuleStatistics& info) {} + + // compute features, a unique lhs-src-trg tuple will be seen exactly once + virtual void ExtractFeatures(const WordID lhs, + const vector& src, + const vector& trg, + const RuleStatistics& info, + SparseVector* result) const = 0; + + virtual ~FeatureExtractor() {} +}; + +struct LogRuleCount : public FeatureExtractor { + LogRuleCount() : + fid_(FD::Convert("LogRuleCount")), + sfid_(FD::Convert("SingletonRule")), + kCFE(FD::Convert("CFE")) {} + virtual void ExtractFeatures(const WordID lhs, + const vector& src, + const vector& trg, + const RuleStatistics& info, + SparseVector* result) const { + (void) lhs; (void) src; (void) trg; + result->set_value(fid_, log(info.counts.value(kCFE))); + if (IsZero(info.counts.value(kCFE))) + result->set_value(sfid_, 1); + } + const int fid_; + const int sfid_; + const int kCFE; +}; + +// this extracts the lexical translation prob features +// in BOTH directions. +struct LexProbExtractor : public FeatureExtractor { + LexProbExtractor(const std::string& corpus) : + e2f_(FD::Convert("LexE2F")), f2e_(FD::Convert("LexF2E")) { + ReadFile rf(corpus); + //create lexical translation table + cerr << "Computing lexical translation probabilities from " << corpus << "..." << endl; + char* buf = new char[MAX_LINE_LENGTH]; + istream& alignment = *rf.stream(); + while(alignment) { + alignment.getline(buf, MAX_LINE_LENGTH); + if (buf[0] == 0) continue; + table.createTTable(buf); + } + delete[] buf; + } + + virtual void ExtractFeatures(const WordID lhs, + const vector& src, + const vector& trg, + const RuleStatistics& info, + SparseVector* result) const { + map > foreign_aligned; + map > english_aligned; + + //Loop over all the alignment points to compute lexical translation probability + const vector< pair >& al = info.aligns; + vector< pair >::const_iterator ita; + for (ita = al.begin(); ita != al.end(); ++ita) { + if (DEBUG) { + cerr << "\nA:" << ita->first << "," << ita->second << "::"; + cerr << TD::Convert(src[ita->first]) << "-" << TD::Convert(trg[ita->second]); + } + + //Lookup this alignment probability in the table + int temp = table.word_translation[pair (src[ita->first],trg[ita->second])]; + float f2e=0, e2f=0; + if ( table.total_foreign[src[ita->first]] != 0) + f2e = (float) temp / table.total_foreign[src[ita->first]]; + if ( table.total_english[trg[ita->second]] !=0 ) + e2f = (float) temp / table.total_english[trg[ita->second]]; + if (DEBUG) printf (" %d %E %E\n", temp, f2e, e2f); + + //local counts to keep track of which things haven't been aligned, to later compute their null alignment + if (foreign_aligned.count(src[ita->first])) { + foreign_aligned[ src[ita->first] ].first++; + foreign_aligned[ src[ita->first] ].second += e2f; + } else { + foreign_aligned[ src[ita->first] ] = pair (1,e2f); + } + + if (english_aligned.count( trg[ ita->second] )) { + english_aligned[ trg[ ita->second] ].first++; + english_aligned[ trg[ ita->second] ].second += f2e; + } else { + english_aligned[ trg[ ita->second] ] = pair (1,f2e); + } + } + + float final_lex_f2e=1, final_lex_e2f=1; + static const WordID NULL_ = TD::Convert("NULL"); + + //compute lexical weight P(F|E) and include unaligned foreign words + for(int i=0;i temp_lex_prob = foreign_aligned[src[i]]; + final_lex_e2f *= temp_lex_prob.second / temp_lex_prob.first; + } + else //dealing with null alignment + { + int temp_count = table.word_translation[pair (src[i],NULL_)]; + float temp_e2f = (float) temp_count / table.total_english[NULL_]; + final_lex_e2f *= temp_e2f; + } + + } + + //compute P(E|F) unaligned english words + for(int j=0; j< trg.size(); j++) { + if (!table.total_english.count(trg[j])) continue; + + if (english_aligned.count(trg[j])) + { + pair temp_lex_prob = english_aligned[trg[j]]; + final_lex_f2e *= temp_lex_prob.second / temp_lex_prob.first; + } + else //dealing with null + { + int temp_count = table.word_translation[pair (NULL_,trg[j])]; + float temp_f2e = (float) temp_count / table.total_foreign[NULL_]; + final_lex_f2e *= temp_f2e; + } + } + result->set_value(e2f_, safenlog(final_lex_e2f)); + result->set_value(f2e_, safenlog(final_lex_f2e)); + } + const int e2f_, f2e_; + mutable LexTranslationTable table; +}; + +int main(int argc, char** argv){ + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + ifstream alignment (conf["aligned_corpus"].as().c_str()); + ReadFile fg1(conf["filtered_grammar"].as()); + + istream& fs1 = *fg1.stream(); + + // TODO make this list configurable + vector > extractors; + extractors.push_back(boost::shared_ptr(new LogRuleCount)); + extractors.push_back(boost::shared_ptr(new LexProbExtractor(conf["aligned_corpus"].as()))); + + //score unscored grammar + cerr << "Reading filtered grammar to detect keys..." << endl; + char* buf = new char[MAX_LINE_LENGTH]; + + ID2RuleStatistics acc, cur_counts; + vector key, cur_key,temp_key; + WordID lhs = 0; + vector src; + +#if 0 + int line = 0; + while(fs1) { + fs1.getline(buf, MAX_LINE_LENGTH); + if (buf[0] == 0) continue; + ParseLine(buf, &cur_key, &cur_counts); + src.resize(cur_key.size() - 2); + for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; + lhs = cur_key[0]; + for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { + for (int i = 0; i < extractors.size(); ++i) + extractors[i]->ObserveFilteredRule(lhs, src, it->first); + } + } + + cerr << "Reading unfiltered grammar..." << endl; + while(cin) { + cin.getline(buf, MAX_LINE_LENGTH); + if (buf[0] == 0) continue; + ParseLine(buf, &cur_key, &cur_counts); + src.resize(cur_key.size() - 2); + for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; + lhs = cur_key[0]; + for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { + // TODO set lhs, src, trg + for (int i = 0; i < extractors.size(); ++i) + extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second); + } + } +#endif + + ReadFile fg2(conf["filtered_grammar"].as()); + istream& fs2 = *fg2.stream(); + cerr << "Reading filtered grammar and adding features..." << endl; + while(fs2) { + fs2.getline(buf, MAX_LINE_LENGTH); + if (buf[0] == 0) continue; + ParseLine(buf, &cur_key, &cur_counts); + src.resize(cur_key.size() - 2); + for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; + lhs = cur_key[0]; + + //loop over all the Target side phrases that this source aligns to + for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { + SparseVector feats; + for (int i = 0; i < extractors.size(); ++i) + extractors[i]->ExtractFeatures(lhs, src, it->first, it->second, &feats); + cout << TD::GetString(cur_key) << " ||| " << TD::GetString(it->first) << " ||| "; + feats.Write(false, &cout); + cout << endl; + } + } +} + diff --git a/extools/filter_grammar.cc b/extools/filter_grammar.cc new file mode 100644 index 00000000..a2992f7d --- /dev/null +++ b/extools/filter_grammar.cc @@ -0,0 +1,235 @@ +/* + * Filter a grammar in striped format + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "suffix_tree.h" +#include "sparse_vector.h" +#include "sentence_pair.h" +#include "extract.h" +#include "fdict.h" +#include "tdict.h" +#include "lex_trans_tbl.h" +#include "filelib.h" + +#include +#include +#include +#include + +using namespace std; +using namespace std::tr1; +namespace po = boost::program_options; + +static const size_t MAX_LINE_LENGTH = 64000000; + +typedef unordered_map, RuleStatistics, boost::hash > > ID2RuleStatistics; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("test_set,t", po::value(), "Filter for this test set") + ("top_e_given_f,n", po::value()->default_value(30), "Keep top N rules, according to p(e|f). 0 for all") + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("test_set")==0) { + cerr << "\nUsage: filter_grammar -t TEST-SET.fr [-options] < grammar\n"; + cerr << dcmdline_options << endl; + exit(1); + } +} +namespace { + inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } + inline bool IsBracket(char c){return c == '[' || c == ']';} + inline void SkipWhitespace(const char* buf, int* ptr) { + while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } + } +} + +int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector* p) { + static const WordID kDIV = TD::Convert("|||"); + int ptr = sstart; + while(ptr < end) { + while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; } + int start = ptr; + while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } + if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; } + const WordID w = TD::Convert(string(buf, start, ptr - start)); + + if((IsBracket(buf[start]) and IsBracket(buf[ptr-1])) or( w == kDIV)) + p->push_back(1 * w); + else { + if (w == kDIV) return ptr; + p->push_back(w); + } + } + return ptr; +} + + +void ParseLine(const char* buf, vector* cur_key, ID2RuleStatistics* counts) { + static const WordID kDIV = TD::Convert("|||"); + counts->clear(); + int ptr = 0; + while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; } + if (buf[ptr] != '\t') { + cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl; + exit(1); + } + cur_key->clear(); + // key is: "[X] ||| word word word" + int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key); + cur_key->push_back(kDIV); + ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key); + ++ptr; + int start = ptr; + int end = ptr; + int state = 0; // 0=reading label, 1=reading count + vector name; + while(buf[ptr] != 0) { + while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; } + if (buf[ptr] == '|') { + ++ptr; + if (buf[ptr] == '|') { + ++ptr; + if (buf[ptr] == '|') { + ++ptr; + end = ptr - 3; + while (end > start && IsWhitespace(buf[end-1])) { --end; } + if (start == end) { + cerr << "Got empty token!\n LINE=" << buf << endl; + exit(1); + } + switch (state) { + case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; + case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; + default: cerr << "Can't happen\n"; abort(); + } + SkipWhitespace(buf, &ptr); + start = ptr; + } + } + } + } + end=ptr; + while (end > start && IsWhitespace(buf[end-1])) { --end; } + if (end > start) { + switch (state) { + case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; + case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; + default: cerr << "Can't happen\n"; abort(); + } + } +} + + +struct SourceFilter { + // return true to keep the rule, otherwise false + virtual bool Matches(const vector& key) const = 0; + virtual ~SourceFilter() {} +}; + +struct DumbSuffixTreeFilter : SourceFilter { + DumbSuffixTreeFilter(const string& corpus) : + kDIV(TD::Convert("|||")) { + cerr << "Build suffix tree from test set in " << corpus << endl; + assert(FileExists(corpus)); + ReadFile rfts(corpus); + istream& testSet = *rfts.stream(); + char* buf = new char[MAX_LINE_LENGTH]; + AnnotatedParallelSentence sent; + + /* process the data set to build suffix tree + */ + while(!testSet.eof()) { + testSet.getline(buf, MAX_LINE_LENGTH); + if (buf[0] == 0) continue; + + //hack to read in the test set using AnnotatedParallelSentence + strcat(buf," ||| fake ||| 0-0"); + sent.ParseInputLine(buf); + + //add each successive suffix to the tree + for(int i=0; i& key) const { + const Node* curnode = &root; + const int ks = key.size() - 1; + for(int i=0; i < ks; i++) { + const string& word = TD::Convert(key[i]); + if (key[i] == kDIV || (word[0] == '[' && word[word.size() - 1] == ']')) { // non-terminal + curnode = &root; + } else if (curnode) { + curnode = curnode->Extend(key[i]); + if (!curnode) return false; + } + } + return true; + } + const WordID kDIV; + Node root; +}; + +int main(int argc, char** argv){ + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const int max_options = conf["top_e_given_f"].as();; + istream& unscored_grammar = cin; + + cerr << "Loading test set " << conf["test_set"].as() << "...\n"; + boost::shared_ptr filter; + filter.reset(new DumbSuffixTreeFilter(conf["test_set"].as())); + + cerr << "Filtering...\n"; + //score unscored grammar + char* buf = new char[MAX_LINE_LENGTH]; + + ID2RuleStatistics acc, cur_counts; + vector key, cur_key,temp_key; + int line = 0; + + multimap options; + const int kCOUNT = FD::Convert("CFE"); + while(!unscored_grammar.eof()) + { + ++line; + options.clear(); + unscored_grammar.getline(buf, MAX_LINE_LENGTH); + if (buf[0] == 0) continue; + ParseLine(buf, &cur_key, &cur_counts); + if (!filter || filter->Matches(cur_key)) { + // sort by counts + for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { + options.insert(make_pair(-it->second.counts.value(kCOUNT), it)); + } + int ocount = 0; + cout << TD::GetString(cur_key) << '\t'; + + bool first = true; + for (multimap::iterator it = options.begin(); it != options.end(); ++it) { + if (first) { first = false; } else { cout << " ||| "; } + cout << TD::GetString(it->second->first) << " ||| " << it->second->second; + ++ocount; + if (ocount == max_options) break; + } + cout << endl; + } + } +} + diff --git a/extools/filter_score_grammar.cc b/extools/filter_score_grammar.cc index f34b240d..fe9a2a07 100644 --- a/extools/filter_score_grammar.cc +++ b/extools/filter_score_grammar.cc @@ -37,7 +37,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { opts.add_options() ("test_set,t", po::value(), "Filter for this test set (not specified = no filtering)") ("top_e_given_f,n", po::value()->default_value(30), "Keep top N rules, according to p(e|f). 0 for all") - ("hiero_features", "Use 'Hiero' features") + ("backoff_features", "Extract backoff X-features, assumes E, F, EF counts") // ("feature,f", po::value >()->composing(), "List of features to compute") ("aligned_corpus,c", po::value(), "Aligned corpus (single line format)") ("help,h", "Print this help message and exit"); @@ -247,36 +247,66 @@ struct FeatureExtractor { const string extractor_name; }; +static bool IsZero(float f) { return (f > 0.999 && f < 1.001); } + struct LogRuleCount : public FeatureExtractor { LogRuleCount() : FeatureExtractor("LogRuleCount"), - fid_(FD::Convert("LogRuleCount")), kCFE(FD::Convert("CFE")) {} + fid_(FD::Convert("LogRuleCount")), + sfid_(FD::Convert("SingletonRule")), + kCFE(FD::Convert("CFE")) {} virtual void ExtractFeatures(const vector& lhs_src, const vector& trg, const RuleStatistics& info, SparseVector* result) const { (void) lhs_src; (void) trg; result->set_value(fid_, log(info.counts.value(kCFE))); + if (IsZero(info.counts.value(kCFE))) + result->set_value(sfid_, 1); } const int fid_; + const int sfid_; const int kCFE; }; -struct SingletonRule : public FeatureExtractor { - SingletonRule() : - FeatureExtractor("SingletonRule"), - fid_(FD::Convert("SingletonRule")), kCFE(FD::Convert("CFE")) {} +struct LogECount : public FeatureExtractor { + LogECount() : + FeatureExtractor("LogECount"), + sfid_(FD::Convert("SingletonE")), + fid_(FD::Convert("LogECount")), kCE(FD::Convert("CE")) {} virtual void ExtractFeatures(const vector& lhs_src, const vector& trg, const RuleStatistics& info, SparseVector* result) const { (void) lhs_src; (void) trg; - if (info.counts.value(kCFE) > 0.999 && info.counts.value(kCFE) < 1.001) { - result->set_value(fid_, 1.0); - } + assert(info.counts.value(kCE) > 0); + result->set_value(fid_, log(info.counts.value(kCE))); + if (IsZero(info.counts.value(kCE))) + result->set_value(sfid_, 1); } + const int sfid_; const int fid_; - const int kCFE; + const int kCE; +}; + +struct LogFCount : public FeatureExtractor { + LogFCount() : + FeatureExtractor("LogFCount"), + sfid_(FD::Convert("SingletonF")), + fid_(FD::Convert("LogFCount")), kCF(FD::Convert("CF")) {} + virtual void ExtractFeatures(const vector& lhs_src, + const vector& trg, + const RuleStatistics& info, + SparseVector* result) const { + (void) lhs_src; (void) trg; + assert(info.counts.value(kCF) > 0); + result->set_value(fid_, log(info.counts.value(kCF))); + if (IsZero(info.counts.value(kCF))) + result->set_value(sfid_, 1); + } + const int sfid_; + const int fid_; + const int kCF; }; struct EGivenFExtractor : public FeatureExtractor { @@ -437,13 +467,16 @@ int main(int argc, char** argv){ // TODO make this list configurable vector > extractors; - if (conf.count("hiero_features")) { + if (conf.count("backoff_features")) { + extractors.push_back(boost::shared_ptr(new LogRuleCount)); + extractors.push_back(boost::shared_ptr(new LogECount)); + extractors.push_back(boost::shared_ptr(new LogFCount)); extractors.push_back(boost::shared_ptr(new EGivenFExtractor)); extractors.push_back(boost::shared_ptr(new FGivenEExtractor)); extractors.push_back(boost::shared_ptr(new LexProbExtractor(conf["aligned_corpus"].as()))); } else { extractors.push_back(boost::shared_ptr(new LogRuleCount)); - extractors.push_back(boost::shared_ptr(new SingletonRule)); + extractors.push_back(boost::shared_ptr(new LogFCount)); extractors.push_back(boost::shared_ptr(new LexProbExtractor(conf["aligned_corpus"].as()))); } -- cgit v1.2.3