summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-13 16:15:33 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-13 16:15:33 +0000
commit2eeaa2eb91334bea11d70db1011f1a28ce3bb7d2 (patch)
tree42a7480c4d601064fa9a5b25a33d59eb12cdf166
parentbddde964a3686d79e77898def1ca7eb228c7caf1 (diff)
major speed up using DFA parser
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@235 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r--extools/featurize_grammar.cc236
-rw-r--r--extools/lex_trans_tbl.h1
2 files changed, 58 insertions, 179 deletions
diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc
index 547f390a..9a4af4d8 100644
--- a/extools/featurize_grammar.cc
+++ b/extools/featurize_grammar.cc
@@ -5,21 +5,17 @@
#include <sstream>
#include <string>
#include <map>
-#include <set>
#include <vector>
#include <utility>
#include <cstdlib>
-#include <fstream>
#include <tr1/unordered_map>
-#include <boost/regex.hpp>
-#include "suffix_tree.h"
+#include "lex_trans_tbl.h"
#include "sparse_vector.h"
#include "sentence_pair.h"
#include "extract.h"
#include "fdict.h"
#include "tdict.h"
-#include "lex_trans_tbl.h"
#include "filelib.h"
#include "striped_grammar.h"
@@ -29,7 +25,6 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
-
using namespace std;
using namespace std::tr1;
using boost::shared_ptr;
@@ -38,8 +33,6 @@ namespace po = boost::program_options;
static string aligned_corpus;
static const size_t MAX_LINE_LENGTH = 64000000;
-typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics;
-
// Data structures for indexing and counting rules
//typedef boost::tuple< WordID, vector<WordID>, vector<WordID> > RuleTuple;
struct RuleTuple {
@@ -130,20 +123,6 @@ struct FreqCount {
};
typedef FreqCount<RuleTuple> RuleFreqCount;
-bool validate_non_terminal(const std::string& s)
-{
- static const boost::regex r("\\[X\\d+,\\d+\\]|\\[\\d+\\]");
- return regex_match(s, r);
-}
-
-namespace {
- inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; }
- inline void SkipWhitespace(const char* buf, int* ptr) {
- while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); }
- }
-}
-
-
class FeatureExtractor;
class FERegistry;
struct FEFactoryBase {
@@ -220,78 +199,7 @@ void InitCommandLine(const FERegistry& r, int argc, char** argv, po::variables_m
}
}
-int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) {
- static const WordID kDIV = TD::Convert("|||");
- int ptr = sstart;
- while(ptr < end) {
- while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; }
- int start = ptr;
- while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; }
- if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; }
- const WordID w = TD::Convert(string(buf, start, ptr - start));
- if (w == kDIV) return ptr;
- p->push_back(w);
- }
- assert(p->size() > 0);
- return ptr;
-}
-
-void ParseLine(const char* buf, vector<WordID>* cur_key, ID2RuleStatistics* counts) {
- static const WordID kDIV = TD::Convert("|||");
- counts->clear();
- int ptr = 0;
- while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; }
- if (buf[ptr] != '\t') {
- cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl;
- exit(1);
- }
- cur_key->clear();
- // key is: "[X] ||| word word word"
- int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key);
- if (buf[tmpp] != '\t') {
- cur_key->push_back(kDIV);
- ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key);
- }
- ++ptr;
- int start = ptr;
- int end = ptr;
- int state = 0; // 0=reading label, 1=reading count
- vector<WordID> name;
- while(buf[ptr] != 0) {
- while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; }
- if (buf[ptr] == '|') {
- ++ptr;
- if (buf[ptr] == '|') {
- ++ptr;
- if (buf[ptr] == '|') {
- ++ptr;
- end = ptr - 3;
- while (end > start && IsWhitespace(buf[end-1])) { --end; }
- if (start == end) {
- cerr << "Got empty token!\n LINE=" << buf << endl;
- exit(1);
- }
- switch (state) {
- case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break;
- case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break;
- default: cerr << "Can't happen\n"; abort();
- }
- SkipWhitespace(buf, &ptr);
- start = ptr;
- }
- }
- }
- }
- end=ptr;
- while (end > start && IsWhitespace(buf[end-1])) { --end; }
- if (end > start) {
- switch (state) {
- case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break;
- case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break;
- default: cerr << "Can't happen\n"; abort();
- }
- }
-}
+static const bool DEBUG = false;
void LexTranslationTable::createTTable(const char* buf){
AnnotatedParallelSentence sent;
@@ -431,11 +339,7 @@ struct XFeatures: public FeatureExtractor {
RuleTuple r(-1, src, trg);
map_rule(r);
rule_counts.inc(r, 0);
-
- normalise_string(r.source());
source_counts.inc(r.source(), 0);
-
- normalise_string(r.target());
target_counts.inc(r.target(), 0);
}
@@ -446,11 +350,7 @@ struct XFeatures: public FeatureExtractor {
RuleTuple r(-1, src, trg);
map_rule(r);
rule_counts.inc_if_exists(r, info.counts.value(kCFE));
-
- normalise_string(r.source());
source_counts.inc_if_exists(r.source(), info.counts.value(kCFE));
-
- normalise_string(r.target());
target_counts.inc_if_exists(r.target(), info.counts.value(kCFE));
}
@@ -463,11 +363,9 @@ struct XFeatures: public FeatureExtractor {
map_rule(r);
double l_r_freq = log(rule_counts(r));
- normalise_string(r.target());
result->set_value(fid_xfe, log(target_counts(r.target())) - l_r_freq);
result->set_value(fid_labelledfe, log(target_counts(r.target())) - log(info.counts.value(kCFE)));
- normalise_string(r.source());
result->set_value(fid_xef, log(source_counts(r.source())) - l_r_freq);
result->set_value(fid_labelledef, log(source_counts(r.source())) - log(info.counts.value(kCFE)));
}
@@ -475,21 +373,15 @@ struct XFeatures: public FeatureExtractor {
void map_rule(RuleTuple& r) const {
vector<WordID> indexes; int i=0;
for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it) {
- if (validate_non_terminal(TD::Convert(*it)))
+ if (*it <= 0)
indexes.push_back(*it);
}
for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it) {
- if (validate_non_terminal(TD::Convert(*it)))
+ if (*it <= 0)
*it = indexes.at(i++);
}
}
- void normalise_string(vector<WordID>& r) const {
- vector<WordID> indexes;
- for (vector<WordID>::iterator it = r.begin(); it != r.end(); ++it)
- if (validate_non_terminal(TD::Convert(*it))) *it = -1;
- }
-
const int fid_xfe, fid_xef;
const int fid_labelledfe, fid_labelledef;
const int kCFE;
@@ -508,10 +400,8 @@ struct LabelledRuleConditionals: public FeatureExtractor {
const vector<WordID>& trg) {
RuleTuple r(lhs, src, trg);
rule_counts.inc(r, 0);
- normalise_string(r.source());
source_counts.inc(r.source(), 0);
- normalise_string(r.target());
target_counts.inc(r.target(), 0);
}
@@ -521,10 +411,8 @@ struct LabelledRuleConditionals: public FeatureExtractor {
const RuleStatistics& info) {
RuleTuple r(lhs, src, trg);
rule_counts.inc_if_exists(r, info.counts.value(kCFE));
- normalise_string(r.source());
source_counts.inc_if_exists(r.source(), info.counts.value(kCFE));
- normalise_string(r.target());
target_counts.inc_if_exists(r.target(), info.counts.value(kCFE));
}
@@ -535,18 +423,10 @@ struct LabelledRuleConditionals: public FeatureExtractor {
SparseVector<float>* result) const {
RuleTuple r(lhs, src, trg);
double l_r_freq = log(rule_counts(r));
- normalise_string(r.target());
result->set_value(fid_fe, log(target_counts(r.target())) - l_r_freq);
- normalise_string(r.source());
result->set_value(fid_ef, log(source_counts(r.source())) - l_r_freq);
}
- void normalise_string(vector<WordID>& r) const {
- vector<WordID> indexes;
- for (vector<WordID>::iterator it = r.begin(); it != r.end(); ++it)
- if (validate_non_terminal(TD::Convert(*it))) *it = -1;
- }
-
const int fid_fe, fid_ef;
const int kCFE;
RuleFreqCount rule_counts;
@@ -643,9 +523,9 @@ struct LabellingShape: public FeatureExtractor {
// Replace all terminals with generic -1
void map_rule(RuleTuple& r) const {
for (vector<WordID>::iterator it = r.target().begin(); it != r.target().end(); ++it)
- if (!validate_non_terminal(TD::Convert(*it))) *it = -1;
+ if (*it <= 0) *it = -1;
for (vector<WordID>::iterator it = r.source().begin(); it != r.source().end(); ++it)
- if (!validate_non_terminal(TD::Convert(*it))) *it = -1;
+ if (*it <= 0) *it = -1;
}
const int fid_, kCFE;
@@ -758,6 +638,51 @@ struct LexProbExtractor : public FeatureExtractor {
mutable LexTranslationTable table;
};
+struct Featurizer {
+ Featurizer(const vector<boost::shared_ptr<FeatureExtractor> >& ex) : extractors(ex) {
+ }
+ void Callback1(WordID lhs, const vector<WordID>& src, const ID2RuleStatistics& trgs) {
+ for (ID2RuleStatistics::const_iterator it = trgs.begin(); it != trgs.end(); ++it) {
+ for (int i = 0; i < extractors.size(); ++i)
+ extractors[i]->ObserveFilteredRule(lhs, src, it->first);
+ }
+ }
+ void Callback2(WordID lhs, const vector<WordID>& src, const ID2RuleStatistics& trgs) {
+ for (ID2RuleStatistics::const_iterator it = trgs.begin(); it != trgs.end(); ++it) {
+ for (int i = 0; i < extractors.size(); ++i)
+ extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second);
+ }
+ }
+ void Callback3(WordID lhs, const vector<WordID>& src, const ID2RuleStatistics& trgs) {
+ for (ID2RuleStatistics::const_iterator it = trgs.begin(); it != trgs.end(); ++it) {
+ SparseVector<float> feats;
+ for (int i = 0; i < extractors.size(); ++i)
+ extractors[i]->ExtractFeatures(lhs, src, it->first, it->second, &feats);
+ cout << '[' << TD::Convert(-lhs) << "] ||| ";
+ WriteNamed(src, &cout);
+ cout << " ||| ";
+ WriteAnonymous(it->first, &cout);
+ cout << " ||| ";
+ feats.Write(false, &cout);
+ cout << endl;
+ }
+ }
+ private:
+ vector<boost::shared_ptr<FeatureExtractor> > extractors;
+};
+
+void cb1(WordID lhs, const vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void* extra) {
+ static_cast<Featurizer*>(extra)->Callback1(lhs, src_rhs, rules);
+}
+
+void cb2(WordID lhs, const vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void* extra) {
+ static_cast<Featurizer*>(extra)->Callback2(lhs, src_rhs, rules);
+}
+
+void cb3(WordID lhs, const vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void* extra) {
+ static_cast<Featurizer*>(extra)->Callback3(lhs, src_rhs, rules);
+}
+
int main(int argc, char** argv){
FERegistry reg;
reg.Register("LogRuleCount", new FEFactory<LogRuleCount>);
@@ -778,65 +703,18 @@ int main(int argc, char** argv){
vector<boost::shared_ptr<FeatureExtractor> > extractors(feats.size());
for (int i = 0; i < feats.size(); ++i)
extractors[i] = reg.Create(feats[i]);
+ Featurizer fizer(extractors);
- //score unscored grammar
cerr << "Reading filtered grammar to detect keys..." << endl;
- char* buf = new char[MAX_LINE_LENGTH];
-
- ID2RuleStatistics acc, cur_counts;
- vector<WordID> key, cur_key,temp_key;
- WordID lhs = 0;
- vector<WordID> src;
-
- istream& fs1 = *fg1.stream();
- while(fs1) {
- fs1.getline(buf, MAX_LINE_LENGTH);
- if (buf[0] == 0) continue;
- ParseLine(buf, &cur_key, &cur_counts);
- src.resize(cur_key.size() - 2);
- for (int i = 0; i < src.size(); ++i) src.at(i) = cur_key.at(i+2);
-
- lhs = cur_key[0];
- for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) {
- for (int i = 0; i < extractors.size(); ++i)
- extractors[i]->ObserveFilteredRule(lhs, src, it->first);
- }
- }
+ StripedGrammarLexer::ReadStripedGrammar(fg1.stream(), cb1, &fizer);
cerr << "Reading unfiltered grammar..." << endl;
- while(cin) {
- cin.getline(buf, MAX_LINE_LENGTH);
- if (buf[0] == 0) continue;
- ParseLine(buf, &cur_key, &cur_counts);
- src.resize(cur_key.size() - 2);
- for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2];
- lhs = cur_key[0];
- for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) {
- for (int i = 0; i < extractors.size(); ++i)
- extractors[i]->ObserveUnfilteredRule(lhs, src, it->first, it->second);
- }
- }
+ StripedGrammarLexer::ReadStripedGrammar(&cin, cb2, &fizer);
ReadFile fg2(conf["filtered_grammar"].as<string>());
- istream& fs2 = *fg2.stream();
cerr << "Reading filtered grammar and adding features..." << endl;
- while(fs2) {
- fs2.getline(buf, MAX_LINE_LENGTH);
- if (buf[0] == 0) continue;
- ParseLine(buf, &cur_key, &cur_counts);
- src.resize(cur_key.size() - 2);
- for (int i = 0; i < src.size(); ++i) src[i] = cur_key[i+2];
- lhs = cur_key[0];
-
- //loop over all the Target side phrases that this source aligns to
- for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) {
- SparseVector<float> feats;
- for (int i = 0; i < extractors.size(); ++i)
- extractors[i]->ExtractFeatures(lhs, src, it->first, it->second, &feats);
- cout << TD::Convert(lhs) << " ||| " << TD::GetString(src) << " ||| " << TD::GetString(it->first) << " ||| ";
- feats.Write(false, &cout);
- cout << endl;
- }
- }
+ StripedGrammarLexer::ReadStripedGrammar(fg2.stream(), cb3, &fizer);
+
+ return 0;
}
diff --git a/extools/lex_trans_tbl.h b/extools/lex_trans_tbl.h
index 81d6ccc9..161b4a0d 100644
--- a/extools/lex_trans_tbl.h
+++ b/extools/lex_trans_tbl.h
@@ -8,6 +8,7 @@
#ifndef LEX_TRANS_TBL_H_
#define LEX_TRANS_TBL_H_
+#include "wordid.h"
#include <map>
class LexTranslationTable