summaryrefslogtreecommitdiff
path: root/extools
diff options
context:
space:
mode:
Diffstat (limited to 'extools')
-rw-r--r--extools/featurize_grammar.cc259
1 files changed, 252 insertions, 7 deletions
diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc
index 17f59e6e..8be057b0 100644
--- a/extools/featurize_grammar.cc
+++ b/extools/featurize_grammar.cc
@@ -10,6 +10,7 @@
#include <cstdlib>
#include <fstream>
#include <tr1/unordered_map>
+#include <boost/regex.hpp>
#include "suffix_tree.h"
#include "sparse_vector.h"
@@ -20,6 +21,7 @@
#include "lex_trans_tbl.h"
#include "filelib.h"
+#include <boost/tuple/tuple.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/functional/hash.hpp>
#include <boost/program_options.hpp>
@@ -36,6 +38,102 @@ static const size_t MAX_LINE_LENGTH = 64000000;
typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics;
+// Data structures for indexing and counting rules
+//typedef boost::tuple< WordID, vector<WordID>, vector<WordID> > RuleTuple;
+struct RuleTuple {
+ RuleTuple(const WordID& lhs, const vector<WordID>& s, const vector<WordID>& t)
+ : m_lhs(lhs), m_source(s), m_target(t) {
+ hash_value();
+ m_dirty = false;
+ }
+
+ size_t hash_value() const {
+// if (m_dirty) {
+ size_t hash = 0;
+ boost::hash_combine(hash, m_lhs);
+ boost::hash_combine(hash, m_source);
+ boost::hash_combine(hash, m_target);
+// }
+// m_dirty = false;
+ return hash;
+ }
+
+ bool operator==(RuleTuple const& b) const
+ { return m_lhs == b.m_lhs && m_source == b.m_source && m_target == b.m_target; }
+
+ WordID& lhs() { m_dirty=true; return m_lhs; }
+ vector<WordID>& source() { m_dirty=true; return m_source; }
+ vector<WordID>& target() { m_dirty=true; return m_target; }
+ const WordID& lhs() const { return m_lhs; }
+ const vector<WordID>& source() const { return m_source; }
+ const vector<WordID>& target() const { return m_target; }
+
+// mutable size_t m_hash;
+private:
+ WordID m_lhs;
+ vector<WordID> m_source, m_target;
+ mutable bool m_dirty;
+};
+std::size_t hash_value(RuleTuple const& b) { return b.hash_value(); }
+bool operator<(RuleTuple const& l, RuleTuple const& r) {
+ if (l.lhs() < r.lhs()) return true;
+ else if (l.lhs() == r.lhs()) {
+ if (l.source() < r.source()) return true;
+ else if (l.source() == r.source()) {
+ if (l.target() < r.target()) return true;
+ }
+ }
+ return false;
+}
+
+ostream& operator<<(ostream& o, RuleTuple const& r) {
+ o << "(" << r.lhs() << "-->" << "<";
+ for (vector<WordID>::const_iterator it=r.source().begin(); it!=r.source().end(); ++it)
+ o << TD::Convert(*it) << " ";
+ o << "|||";
+ for (vector<WordID>::const_iterator it=r.target().begin(); it!=r.target().end(); ++it)
+ o << " " << TD::Convert(*it);
+ o << ">)";
+ return o;
+}
+
+template <typename Key>
+struct FreqCount {
+ //typedef unordered_map<Key, int, boost::hash<Key> > Counts;
+ typedef map<Key, int> Counts;
+ Counts counts;
+
+ int inc(const Key& r, int c=1) {
+ pair<typename Counts::iterator,bool> itb
+ = counts.insert(make_pair(r,c));
+ if (!itb.second)
+ itb.first->second += c;
+ return itb.first->second;
+ }
+
+ int inc_if_exists(const Key& r, int c=1) {
+ typename Counts::iterator it = counts.find(r);
+ if (it != counts.end())
+ it->second += c;
+ return it->second;
+ }
+
+ int count(const Key& r) const {
+ typename Counts::const_iterator it = counts.find(r);
+ if (it == counts.end()) return 0;
+ return it->second;
+ }
+
+ int operator()(const Key& r) const { return count(r); }
+};
+typedef FreqCount<RuleTuple> RuleFreqCount;
+
+bool validate_non_terminal(const std::string& s)
+{
+ static const boost::regex r("\\[X\\d+,\\d+\\]|\\[\\d+\\]");
+ return regex_match(s, r);
+}
+
namespace {
inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; }
inline bool IsBracket(char c){return c == '[' || c == ']';}
@@ -280,7 +378,8 @@ struct LogRuleCount : public FeatureExtractor {
const RuleStatistics& info,
SparseVector<float>* result) const {
(void) lhs; (void) src; (void) trg;
- result->set_value(fid_, log(info.counts.value(kCFE)));
+ //result->set_value(fid_, log(info.counts.value(kCFE)));
+ result->set_value(fid_, (info.counts.value(kCFE)));
if (IsZero(info.counts.value(kCFE)))
result->set_value(sfid_, 1);
}
@@ -289,6 +388,141 @@ struct LogRuleCount : public FeatureExtractor {
const int kCFE;
};
+// The negative log of the condition rule probs
+// ignoring the identities of the non-terminals.
+// i.e. the prob Hiero would assign.
+struct XFeatures: public FeatureExtractor {
+ XFeatures() :
+ fid_fe(FD::Convert("XFE")),
+ fid_ef(FD::Convert("XEF")),
+ kCFE(FD::Convert("CFE")) {}
+ virtual void ObserveFilteredRule(const WordID /*lhs*/,
+ const vector<WordID>& src,
+ const vector<WordID>& trg) {
+ RuleTuple r(-1, src, trg);
+ map_rule(r);
+ rule_counts.inc(r, 0);
+ source_counts.inc(r.source(), 0);
+ target_counts.inc(r.target(), 0);
+ }
+
+ // compute statistics over keys, the same lhs-src-trg tuple may be seen
+ // more than once
+ virtual void ObserveUnfilteredRule(const WordID /*lhs*/,
+ const vector<WordID>& src,
+ const vector<WordID>& trg,
+ const RuleStatistics& info) {
+ RuleTuple r(-1, src, trg);
+// cerr << " ObserveUnfilteredRule() in:" << r << " " << hash_value(r) << endl;
+ map_rule(r);
+ rule_counts.inc_if_exists(r, info.counts.value(kCFE));
+ source_counts.inc_if_exists(r.source(), info.counts.value(kCFE));
+ target_counts.inc_if_exists(r.target(), info.counts.value(kCFE));
+// cerr << " ObserveUnfilteredRule() inc: " << r << " " << hash_value(r) << " " << info.counts.value(kCFE) << " to " << rule_counts(r) << endl;
+ }
+
+ virtual void ExtractFeatures(const WordID /*lhs*/,
+ const vector<WordID>& src,
+ const vector<WordID>& trg,
+ const RuleStatistics& /*info*/,
+ SparseVector<float>* result) const {
+ RuleTuple r(-1, src, trg);
+ map_rule(r);
+ //result->set_value(fid_fe, log(target_counts(r.target())) - log(rule_counts(r)));
+ //result->set_value(fid_ef, log(source_counts(r.source())) - log(rule_counts(r)));
+ result->set_value(fid_ef, target_counts(r.target()));
+ result->set_value(fid_fe, rule_counts(r));
+ //result->set_value(fid_fe, (source_counts(r.source())));
+ }
+
+ void map_rule(RuleTuple& r) const {
+ vector<WordID> indexes; int i=0;
+ for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it) {
+ if (validate_non_terminal(TD::Convert(*it)))
+ indexes.push_back(*it);
+ }
+ for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it) {
+ if (validate_non_terminal(TD::Convert(*it)))
+ *it = indexes.at(i++);
+ }
+ }
+
+ const int fid_fe, fid_ef;
+ const int kCFE;
+ RuleFreqCount rule_counts;
+ FreqCount< vector<WordID> > source_counts, target_counts;
+};
+
+struct LabelledRuleConditionals: public FeatureExtractor {
+ LabelledRuleConditionals() :
+ fid_fe(FD::Convert("TLabelledFE")),
+ fid_ef(FD::Convert("TLabelledEF")),
+ kCFE(FD::Convert("CFE")) {}
+ virtual void ObserveFilteredRule(const WordID /*lhs*/,
+ const vector<WordID>& src,
+ const vector<WordID>& trg) {
+ RuleTuple r(-1, src, trg);
+ rule_counts.inc(r, 0);
+ cerr << " ObservefilteredRule() inc: " << r << " " << hash_value(r) << endl;
+// map_rule(r);
+ source_counts.inc(r.source(), 0);
+ target_counts.inc(r.target(), 0);
+ }
+
+ // compute statistics over keys, the same lhs-src-trg tuple may be seen
+ // more than once
+ virtual void ObserveUnfilteredRule(const WordID /*lhs*/,
+ const vector<WordID>& src,
+ const vector<WordID>& trg,
+ const RuleStatistics& info) {
+ RuleTuple r(-1, src, trg);
+ //cerr << " ObserveUnfilteredRule() in:" << r << " " << hash_value(r) << endl;
+ rule_counts.inc_if_exists(r, info.counts.value(kCFE));
+ cerr << " ObserveUnfilteredRule() inc_if_exists: " << r << " " << hash_value(r) << " " << info.counts.value(kCFE) << " to " << rule_counts(r) << endl;
+// map_rule(r);
+ source_counts.inc_if_exists(r.source(), info.counts.value(kCFE));
+ target_counts.inc_if_exists(r.target(), info.counts.value(kCFE));
+ }
+
+ virtual void ExtractFeatures(const WordID /*lhs*/,
+ const vector<WordID>& src,
+ const vector<WordID>& trg,
+ const RuleStatistics& info,
+ SparseVector<float>* result) const {
+ RuleTuple r(-1, src, trg);
+ //cerr << " ExtractFeatures() in:" << " " << r.m_hash << endl;
+ int r_freq = rule_counts(r);
+ cerr << " ExtractFeatures() count: " << r << " " << hash_value(r) << " " << info.counts.value(kCFE) << " | " << rule_counts(r) << endl;
+ assert(r_freq == info.counts.value(kCFE));
+ //cerr << " ExtractFeatures() after:" << " " << r.hash << endl;
+ //cerr << " ExtractFeatures() in:" << r << " " << r_freq << " " << hash_value(r) << endl;
+ //cerr << " ExtractFeatures() in:" << r << " " << r_freq << endl;
+// map_rule(r);
+ //result->set_value(fid_fe, log(target_counts(r.target())) - log(r_freq));
+ //result->set_value(fid_ef, log(source_counts(r.source())) - log(r_freq));
+ result->set_value(fid_ef, target_counts(r.target()));
+ result->set_value(fid_fe, r_freq);
+ //result->set_value(fid_fe, (source_counts(r.source())));
+ }
+
+ void map_rule(RuleTuple& r) const {
+ vector<WordID> indexes; int i=0;
+ for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it) {
+ if (validate_non_terminal(TD::Convert(*it)))
+ indexes.push_back(*it);
+ }
+ for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it) {
+ if (validate_non_terminal(TD::Convert(*it)))
+ *it = indexes.at(i++);
+ }
+ }
+
+ const int fid_fe, fid_ef;
+ const int kCFE;
+ RuleFreqCount rule_counts;
+ FreqCount< vector<WordID> > source_counts, target_counts;
+};
+
// this extracts the lexical translation prob features
// in BOTH directions.
struct LexProbExtractor : public FeatureExtractor {
@@ -307,7 +541,7 @@ struct LexProbExtractor : public FeatureExtractor {
delete[] buf;
}
- virtual void ExtractFeatures(const WordID lhs,
+ virtual void ExtractFeatures(const WordID /*lhs*/,
const vector<WordID>& src,
const vector<WordID>& trg,
const RuleStatistics& info,
@@ -397,6 +631,8 @@ int main(int argc, char** argv){
FERegistry reg;
reg.Register("LogRuleCount", new FEFactory<LogRuleCount>);
reg.Register("LexProb", new FEFactory<LexProbExtractor>);
+ reg.Register("XFeatures", new FEFactory<XFeatures>);
+ reg.Register("LabelledRuleConditionals", new FEFactory<LabelledRuleConditionals>);
po::variables_map conf;
InitCommandLine(reg, argc, argv, &conf);
aligned_corpus = conf["aligned_corpus"].as<string>(); // GLOBAL VAR
@@ -421,10 +657,20 @@ int main(int argc, char** argv){
fs1.getline(buf, MAX_LINE_LENGTH);
if (buf[0] == 0) continue;
ParseLine(buf, &cur_key, &cur_counts);
- src.resize(cur_key.size() - 4);
- for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2];
+ //src.resize(cur_key.size() - 4);
+ src.resize(cur_key.size() - 3);
+ for (int i = 0; i < src.size(); ++i) src.at(i) = cur_key.at(i+2);
+
+ cerr << "Key: "; for (vector<WordID>::const_iterator wit=cur_key.begin(); wit!=cur_key.end(); ++wit) cerr << TD::Convert(*wit) << " "; cerr << endl;
+
lhs = cur_key[0];
+ cerr << buf << endl;
for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) {
+
+ cerr << "READ: <"; for (vector<WordID>::const_iterator wit=src.begin(); wit!=src.end(); ++wit) cerr << TD::Convert(*wit) << " ";
+ cerr << "|||"; for (vector<WordID>::const_iterator wit=it->first.begin(); wit!=it->first.end(); ++wit) cerr << " " << TD::Convert(*wit);
+ cerr << ">\n";
+
for (int i = 0; i < extractors.size(); ++i)
extractors[i]->ObserveFilteredRule(lhs, src, it->first);
}
@@ -435,11 +681,10 @@ int main(int argc, char** argv){
cin.getline(buf, MAX_LINE_LENGTH);
if (buf[0] == 0) continue;
ParseLine(buf, &cur_key, &cur_counts);
- src.resize(cur_key.size() - 4);
+ src.resize(cur_key.size() - 3);
for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2];
lhs = cur_key[0];
for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) {
- // TODO set lhs, src, trg
for (int i = 0; i < extractors.size(); ++i)
extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second);
}
@@ -452,7 +697,7 @@ int main(int argc, char** argv){
fs2.getline(buf, MAX_LINE_LENGTH);
if (buf[0] == 0) continue;
ParseLine(buf, &cur_key, &cur_counts);
- src.resize(cur_key.size() - 4);
+ src.resize(cur_key.size() - 3);
for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2];
lhs = cur_key[0];