diff options
author | philblunsom <philblunsom@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-07 20:16:18 +0000 |
---|---|---|
committer | philblunsom <philblunsom@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-07 20:16:18 +0000 |
commit | e1b840374b3f07185db38b6ada0384120ee166e9 (patch) | |
tree | 6385a49c15a1fdda1a77b96f84644233f893cc63 | |
parent | 1d1b21cf220f9f4d2612a7d399dee0316c945f2c (diff) |
In unfinished state. DO NOT USE
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@179 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r-- | extools/featurize_grammar.cc | 259 |
1 files changed, 252 insertions, 7 deletions
diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc index 17f59e6e..8be057b0 100644 --- a/extools/featurize_grammar.cc +++ b/extools/featurize_grammar.cc @@ -10,6 +10,7 @@ #include <cstdlib> #include <fstream> #include <tr1/unordered_map> +#include <boost/regex.hpp> #include "suffix_tree.h" #include "sparse_vector.h" @@ -20,6 +21,7 @@ #include "lex_trans_tbl.h" #include "filelib.h" +#include <boost/tuple/tuple.hpp> #include <boost/shared_ptr.hpp> #include <boost/functional/hash.hpp> #include <boost/program_options.hpp> @@ -36,6 +38,102 @@ static const size_t MAX_LINE_LENGTH = 64000000; typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics; +// Data structures for indexing and counting rules +//typedef boost::tuple< WordID, vector<WordID>, vector<WordID> > RuleTuple; +struct RuleTuple { + RuleTuple(const WordID& lhs, const vector<WordID>& s, const vector<WordID>& t) + : m_lhs(lhs), m_source(s), m_target(t) { + hash_value(); + m_dirty = false; + } + + size_t hash_value() const { +// if (m_dirty) { + size_t hash = 0; + boost::hash_combine(hash, m_lhs); + boost::hash_combine(hash, m_source); + boost::hash_combine(hash, m_target); +// } +// m_dirty = false; + return hash; + } + + bool operator==(RuleTuple const& b) const + { return m_lhs == b.m_lhs && m_source == b.m_source && m_target == b.m_target; } + + WordID& lhs() { m_dirty=true; return m_lhs; } + vector<WordID>& source() { m_dirty=true; return m_source; } + vector<WordID>& target() { m_dirty=true; return m_target; } + const WordID& lhs() const { return m_lhs; } + const vector<WordID>& source() const { return m_source; } + const vector<WordID>& target() const { return m_target; } + +// mutable size_t m_hash; +private: + WordID m_lhs; + vector<WordID> m_source, m_target; + mutable bool m_dirty; +}; +std::size_t hash_value(RuleTuple const& b) { return b.hash_value(); } +bool operator<(RuleTuple const& l, RuleTuple const& r) { + if (l.lhs() < r.lhs()) return true; + else if (l.lhs() == r.lhs()) { + if (l.source() < r.source()) return true; + else if (l.source() == r.source()) { + if (l.target() < r.target()) return true; + } + } + return false; +} + +ostream& operator<<(ostream& o, RuleTuple const& r) { + o << "(" << r.lhs() << "-->" << "<"; + for (vector<WordID>::const_iterator it=r.source().begin(); it!=r.source().end(); ++it) + o << TD::Convert(*it) << " "; + o << "|||"; + for (vector<WordID>::const_iterator it=r.target().begin(); it!=r.target().end(); ++it) + o << " " << TD::Convert(*it); + o << ">)"; + return o; +} + +template <typename Key> +struct FreqCount { + //typedef unordered_map<Key, int, boost::hash<Key> > Counts; + typedef map<Key, int> Counts; + Counts counts; + + int inc(const Key& r, int c=1) { + pair<typename Counts::iterator,bool> itb + = counts.insert(make_pair(r,c)); + if (!itb.second) + itb.first->second += c; + return itb.first->second; + } + + int inc_if_exists(const Key& r, int c=1) { + typename Counts::iterator it = counts.find(r); + if (it != counts.end()) + it->second += c; + return it->second; + } + + int count(const Key& r) const { + typename Counts::const_iterator it = counts.find(r); + if (it == counts.end()) return 0; + return it->second; + } + + int operator()(const Key& r) const { return count(r); } +}; +typedef FreqCount<RuleTuple> RuleFreqCount; + +bool validate_non_terminal(const std::string& s) +{ + static const boost::regex r("\\[X\\d+,\\d+\\]|\\[\\d+\\]"); + return regex_match(s, r); +} + namespace { inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } inline bool IsBracket(char c){return c == '[' || c == ']';} @@ -280,7 +378,8 @@ struct LogRuleCount : public FeatureExtractor { const RuleStatistics& info, SparseVector<float>* result) const { (void) lhs; (void) src; (void) trg; - result->set_value(fid_, log(info.counts.value(kCFE))); + //result->set_value(fid_, log(info.counts.value(kCFE))); + result->set_value(fid_, (info.counts.value(kCFE))); if (IsZero(info.counts.value(kCFE))) result->set_value(sfid_, 1); } @@ -289,6 +388,141 @@ struct LogRuleCount : public FeatureExtractor { const int kCFE; }; +// The negative log of the condition rule probs +// ignoring the identities of the non-terminals. +// i.e. the prob Hiero would assign. +struct XFeatures: public FeatureExtractor { + XFeatures() : + fid_fe(FD::Convert("XFE")), + fid_ef(FD::Convert("XEF")), + kCFE(FD::Convert("CFE")) {} + virtual void ObserveFilteredRule(const WordID /*lhs*/, + const vector<WordID>& src, + const vector<WordID>& trg) { + RuleTuple r(-1, src, trg); + map_rule(r); + rule_counts.inc(r, 0); + source_counts.inc(r.source(), 0); + target_counts.inc(r.target(), 0); + } + + // compute statistics over keys, the same lhs-src-trg tuple may be seen + // more than once + virtual void ObserveUnfilteredRule(const WordID /*lhs*/, + const vector<WordID>& src, + const vector<WordID>& trg, + const RuleStatistics& info) { + RuleTuple r(-1, src, trg); +// cerr << " ObserveUnfilteredRule() in:" << r << " " << hash_value(r) << endl; + map_rule(r); + rule_counts.inc_if_exists(r, info.counts.value(kCFE)); + source_counts.inc_if_exists(r.source(), info.counts.value(kCFE)); + target_counts.inc_if_exists(r.target(), info.counts.value(kCFE)); +// cerr << " ObserveUnfilteredRule() inc: " << r << " " << hash_value(r) << " " << info.counts.value(kCFE) << " to " << rule_counts(r) << endl; + } + + virtual void ExtractFeatures(const WordID /*lhs*/, + const vector<WordID>& src, + const vector<WordID>& trg, + const RuleStatistics& /*info*/, + SparseVector<float>* result) const { + RuleTuple r(-1, src, trg); + map_rule(r); + //result->set_value(fid_fe, log(target_counts(r.target())) - log(rule_counts(r))); + //result->set_value(fid_ef, log(source_counts(r.source())) - log(rule_counts(r))); + result->set_value(fid_ef, target_counts(r.target())); + result->set_value(fid_fe, rule_counts(r)); + //result->set_value(fid_fe, (source_counts(r.source()))); + } + + void map_rule(RuleTuple& r) const { + vector<WordID> indexes; int i=0; + for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it) { + if (validate_non_terminal(TD::Convert(*it))) + indexes.push_back(*it); + } + for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it) { + if (validate_non_terminal(TD::Convert(*it))) + *it = indexes.at(i++); + } + } + + const int fid_fe, fid_ef; + const int kCFE; + RuleFreqCount rule_counts; + FreqCount< vector<WordID> > source_counts, target_counts; +}; + +struct LabelledRuleConditionals: public FeatureExtractor { + LabelledRuleConditionals() : + fid_fe(FD::Convert("TLabelledFE")), + fid_ef(FD::Convert("TLabelledEF")), + kCFE(FD::Convert("CFE")) {} + virtual void ObserveFilteredRule(const WordID /*lhs*/, + const vector<WordID>& src, + const vector<WordID>& trg) { + RuleTuple r(-1, src, trg); + rule_counts.inc(r, 0); + cerr << " ObservefilteredRule() inc: " << r << " " << hash_value(r) << endl; +// map_rule(r); + source_counts.inc(r.source(), 0); + target_counts.inc(r.target(), 0); + } + + // compute statistics over keys, the same lhs-src-trg tuple may be seen + // more than once + virtual void ObserveUnfilteredRule(const WordID /*lhs*/, + const vector<WordID>& src, + const vector<WordID>& trg, + const RuleStatistics& info) { + RuleTuple r(-1, src, trg); + //cerr << " ObserveUnfilteredRule() in:" << r << " " << hash_value(r) << endl; + rule_counts.inc_if_exists(r, info.counts.value(kCFE)); + cerr << " ObserveUnfilteredRule() inc_if_exists: " << r << " " << hash_value(r) << " " << info.counts.value(kCFE) << " to " << rule_counts(r) << endl; +// map_rule(r); + source_counts.inc_if_exists(r.source(), info.counts.value(kCFE)); + target_counts.inc_if_exists(r.target(), info.counts.value(kCFE)); + } + + virtual void ExtractFeatures(const WordID /*lhs*/, + const vector<WordID>& src, + const vector<WordID>& trg, + const RuleStatistics& info, + SparseVector<float>* result) const { + RuleTuple r(-1, src, trg); + //cerr << " ExtractFeatures() in:" << " " << r.m_hash << endl; + int r_freq = rule_counts(r); + cerr << " ExtractFeatures() count: " << r << " " << hash_value(r) << " " << info.counts.value(kCFE) << " | " << rule_counts(r) << endl; + assert(r_freq == info.counts.value(kCFE)); + //cerr << " ExtractFeatures() after:" << " " << r.hash << endl; + //cerr << " ExtractFeatures() in:" << r << " " << r_freq << " " << hash_value(r) << endl; + //cerr << " ExtractFeatures() in:" << r << " " << r_freq << endl; +// map_rule(r); + //result->set_value(fid_fe, log(target_counts(r.target())) - log(r_freq)); + //result->set_value(fid_ef, log(source_counts(r.source())) - log(r_freq)); + result->set_value(fid_ef, target_counts(r.target())); + result->set_value(fid_fe, r_freq); + //result->set_value(fid_fe, (source_counts(r.source()))); + } + + void map_rule(RuleTuple& r) const { + vector<WordID> indexes; int i=0; + for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it) { + if (validate_non_terminal(TD::Convert(*it))) + indexes.push_back(*it); + } + for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it) { + if (validate_non_terminal(TD::Convert(*it))) + *it = indexes.at(i++); + } + } + + const int fid_fe, fid_ef; + const int kCFE; + RuleFreqCount rule_counts; + FreqCount< vector<WordID> > source_counts, target_counts; +}; + // this extracts the lexical translation prob features // in BOTH directions. struct LexProbExtractor : public FeatureExtractor { @@ -307,7 +541,7 @@ struct LexProbExtractor : public FeatureExtractor { delete[] buf; } - virtual void ExtractFeatures(const WordID lhs, + virtual void ExtractFeatures(const WordID /*lhs*/, const vector<WordID>& src, const vector<WordID>& trg, const RuleStatistics& info, @@ -397,6 +631,8 @@ int main(int argc, char** argv){ FERegistry reg; reg.Register("LogRuleCount", new FEFactory<LogRuleCount>); reg.Register("LexProb", new FEFactory<LexProbExtractor>); + reg.Register("XFeatures", new FEFactory<XFeatures>); + reg.Register("LabelledRuleConditionals", new FEFactory<LabelledRuleConditionals>); po::variables_map conf; InitCommandLine(reg, argc, argv, &conf); aligned_corpus = conf["aligned_corpus"].as<string>(); // GLOBAL VAR @@ -421,10 +657,20 @@ int main(int argc, char** argv){ fs1.getline(buf, MAX_LINE_LENGTH); if (buf[0] == 0) continue; ParseLine(buf, &cur_key, &cur_counts); - src.resize(cur_key.size() - 4); - for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; + //src.resize(cur_key.size() - 4); + src.resize(cur_key.size() - 3); + for (int i = 0; i < src.size(); ++i) src.at(i) = cur_key.at(i+2); + + cerr << "Key: "; for (vector<WordID>::const_iterator wit=cur_key.begin(); wit!=cur_key.end(); ++wit) cerr << TD::Convert(*wit) << " "; cerr << endl; + lhs = cur_key[0]; + cerr << buf << endl; for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { + + cerr << "READ: <"; for (vector<WordID>::const_iterator wit=src.begin(); wit!=src.end(); ++wit) cerr << TD::Convert(*wit) << " "; + cerr << "|||"; for (vector<WordID>::const_iterator wit=it->first.begin(); wit!=it->first.end(); ++wit) cerr << " " << TD::Convert(*wit); + cerr << ">\n"; + for (int i = 0; i < extractors.size(); ++i) extractors[i]->ObserveFilteredRule(lhs, src, it->first); } @@ -435,11 +681,10 @@ int main(int argc, char** argv){ cin.getline(buf, MAX_LINE_LENGTH); if (buf[0] == 0) continue; ParseLine(buf, &cur_key, &cur_counts); - src.resize(cur_key.size() - 4); + src.resize(cur_key.size() - 3); for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; lhs = cur_key[0]; for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { - // TODO set lhs, src, trg for (int i = 0; i < extractors.size(); ++i) extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second); } @@ -452,7 +697,7 @@ int main(int argc, char** argv){ fs2.getline(buf, MAX_LINE_LENGTH); if (buf[0] == 0) continue; ParseLine(buf, &cur_key, &cur_counts); - src.resize(cur_key.size() - 4); + src.resize(cur_key.size() - 3); for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2]; lhs = cur_key[0]; |