summaryrefslogtreecommitdiff
path: root/extools
diff options
context:
space:
mode:
authorolivia.buzek <olivia.buzek@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-08 21:59:50 +0000
committerolivia.buzek <olivia.buzek@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-08 21:59:50 +0000
commitc12e7241e8908def96943b1a4056e536ea91eded (patch)
treec24b9cf0d2a90239b01eb6432e683292c95bb06f /extools
parenta034f92b1fe0c6368ebb140bc691f0718dd23a23 (diff)
Adding backoff grammar and BackoffRule feature.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@191 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'extools')
-rw-r--r--extools/extract.cc44
-rw-r--r--extools/extract.h3
-rw-r--r--extools/extractor.cc14
-rw-r--r--extools/featurize_grammar.cc17
-rw-r--r--extools/score_grammar.cc352
5 files changed, 424 insertions, 6 deletions
diff --git a/extools/extract.cc b/extools/extract.cc
index 6ad124d2..c2c413e2 100644
--- a/extools/extract.cc
+++ b/extools/extract.cc
@@ -173,12 +173,15 @@ void Extract::ExtractConsistentRules(const AnnotatedParallelSentence& sentence,
const int max_syms,
const bool permit_adjacent_nonterminals,
const bool require_aligned_terminal,
- RuleObserver* observer) {
+ RuleObserver* observer,
+ vector<WordID>* all_cats) {
+ const char bkoff_mrkr = '_';
queue<RuleItem> q; // agenda for BFS
int max_len = -1;
unordered_map<pair<short, short>, vector<ParallelSpan>, boost::hash<pair<short, short> > > fspans;
vector<vector<ParallelSpan> > spans_by_start(sentence.f_len);
set<int> starts;
+ WordID bkoff;
for (int i = 0; i < phrases.size(); ++i) {
fspans[make_pair(phrases[i].i1,phrases[i].i2)].push_back(phrases[i]);
max_len = max(max_len, phrases[i].i2 - phrases[i].i1);
@@ -281,6 +284,42 @@ void Extract::ExtractConsistentRules(const AnnotatedParallelSentence& sentence,
if (cur_es[j] >= 0 && sentence.aligned(cur_fs[i],cur_es[j]))
cur_terminal_align.push_back(make_pair(i,j));
observer->CountRule(lhs, cur_rhs_f, cur_rhs_e, cur_terminal_align);
+
+ if(!all_cats->empty()) {
+ //produce the backoff grammar if the category wordIDs are available
+ for (int i = 0; i < cur_rhs_f.size(); ++i) {
+ if(cur_rhs_f[i] < 0) {
+ //cerr << cur_rhs_f[i] << ": (cats,f) |" << TD::Convert(-cur_rhs_f[i]) << endl;
+ string nonterm = TD::Convert(-cur_rhs_f[i]);
+ nonterm+=bkoff_mrkr;
+ bkoff = -TD::Convert(nonterm);
+ cur_rhs_f[i]=bkoff;
+ vector<WordID> rhs_f_bkoff;
+ vector<WordID> rhs_e_bkoff;
+ vector<pair<short,short> > bkoff_align;
+ bkoff_align.clear();
+ bkoff_align.push_back(make_pair(0,0));
+
+ for (int cat = 0; cat < all_cats->size(); ++cat) {
+ rhs_f_bkoff.clear();
+ rhs_e_bkoff.clear();
+ rhs_f_bkoff.push_back(-(*all_cats)[cat]);
+ rhs_e_bkoff.push_back(0);
+ observer->CountRule(bkoff,rhs_f_bkoff,rhs_e_bkoff,bkoff_align);
+
+ }
+ }//else
+ //cerr << cur_rhs_f[i] << ": (words,f) |" << TD::Convert(cur_rhs_f[i]) << endl;
+ }
+ /*for (int i=0; i < cur_rhs_e.size(); ++i)
+ if(cur_rhs_e[i] <= 0)
+ cerr << cur_rhs_e[i] << ": (cats,e) |" << TD::Convert(1-cur_rhs_e[i]) << endl;
+ else
+ cerr << cur_rhs_e[i] << ": (words,e) |" << TD::Convert(cur_rhs_e[i]) << endl;
+ */
+
+ observer->CountRule(lhs, cur_rhs_f, cur_rhs_e, cur_terminal_align);
+ }
}
}
}
@@ -337,5 +376,4 @@ ostream& operator<<(ostream& os, const RuleStatistics& s) {
}
}
return os;
-}
-
+} \ No newline at end of file
diff --git a/extools/extract.h b/extools/extract.h
index f87aa6cb..72017034 100644
--- a/extools/extract.h
+++ b/extools/extract.h
@@ -87,7 +87,8 @@ struct Extract {
const int max_syms,
const bool permit_adjacent_nonterminals,
const bool require_aligned_terminal,
- RuleObserver* observer);
+ RuleObserver* observer,
+ std::vector<WordID>* all_cats);
};
// represents statistics / information about a rule pair
diff --git a/extools/extractor.cc b/extools/extractor.cc
index 4f9b4dc6..7149bfd7 100644
--- a/extools/extractor.cc
+++ b/extools/extractor.cc
@@ -43,6 +43,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("max_vars,v", po::value<int>()->default_value(2), "Maximum number of nonterminal variables in final phrase size")
("permit_adjacent_nonterminals,A", "Permit adjacent nonterminals in source side of rules")
("no_required_aligned_terminal,n", "Do not require an aligned terminal")
+ ("topics,t", po::value<int>()->default_value(50), "Number of categories assigned during clustering")
+ ("backoff,g","Produce a backoff grammar")
("help,h", "Print this help message and exit");
po::options_description clo("Command line options");
po::options_description dcmdline_options;
@@ -299,6 +301,7 @@ int main(int argc, char** argv) {
WordID default_cat = 0; // 0 means no default- extraction will
// fail if a phrase is extracted without a
// category
+ const bool backoff = (conf.count("backoff") ? true : false);
if (conf.count("default_category")) {
string sdefault_cat = conf["default_category"].as<string>();
default_cat = -TD::Convert(sdefault_cat);
@@ -310,6 +313,7 @@ int main(int argc, char** argv) {
char buf[MAX_LINE_LENGTH];
AnnotatedParallelSentence sentence;
vector<ParallelSpan> phrases;
+ vector<WordID> all_cats;
const int max_base_phrase_size = conf["max_base_phrase_size"].as<int>();
const bool write_phrase_contexts = conf.count("phrase_context") > 0;
const bool write_base_phrases = conf.count("base_phrase") > 0;
@@ -319,12 +323,19 @@ int main(int argc, char** argv) {
const int max_syms = conf["max_syms"].as<int>();
const int max_vars = conf["max_vars"].as<int>();
const int ctx_size = conf["phrase_context_size"].as<int>();
+ const int num_categories = conf["topics"].as<int>();
const bool permit_adjacent_nonterminals = conf.count("permit_adjacent_nonterminals") > 0;
const bool require_aligned_terminal = conf.count("no_required_aligned_terminal") == 0;
int line = 0;
CountCombiner cc(conf["combiner_size"].as<size_t>());
HadoopStreamingRuleObserver o(&cc,
conf.count("bidir") > 0);
+
+ if(backoff) {
+ for (int i=0;i < num_categories;++i)
+ all_cats.push_back(TD::Convert("X"+boost::lexical_cast<string>(i)));
+ }
+
//SimpleRuleWriter o;
while(in) {
++line;
@@ -356,9 +367,8 @@ int main(int argc, char** argv) {
continue;
}
Extract::AnnotatePhrasesWithCategoryTypes(default_cat, sentence.span_types, &phrases);
- Extract::ExtractConsistentRules(sentence, phrases, max_vars, max_syms, permit_adjacent_nonterminals, require_aligned_terminal, &o);
+ Extract::ExtractConsistentRules(sentence, phrases, max_vars, max_syms, permit_adjacent_nonterminals, require_aligned_terminal, &o, &all_cats);
}
if (!silent) cerr << endl;
return 0;
}
-
diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc
index 0d054626..b387fe04 100644
--- a/extools/featurize_grammar.cc
+++ b/extools/featurize_grammar.cc
@@ -385,6 +385,22 @@ struct LogRuleCount : public FeatureExtractor {
const int kCFE;
};
+struct BackoffRule : public FeatureExtractor {
+ BackoffRule() :
+ fid_(FD::Convert("BackoffRule")) {}
+ virtual void ExtractFeatures(const WordID lhs,
+ const vector<WordID>& src,
+ const vector<WordID>& trg,
+ const RuleStatistics& info,
+ SparseVector<float>* result) const {
+ (void) lhs; (void) src; (void) trg;
+ string lhstr = TD::Convert(lhs);
+ if(lhstr.find('_')!=string::npos)
+ result->set_value(fid_, -1);
+ }
+ const int fid_;
+};
+
// The negative log of the condition rule probs
// ignoring the identities of the non-terminals.
// i.e. the prob Hiero would assign.
@@ -656,6 +672,7 @@ int main(int argc, char** argv){
reg.Register("LexProb", new FEFactory<LexProbExtractor>);
reg.Register("XFeatures", new FEFactory<XFeatures>);
reg.Register("LabelledRuleConditionals", new FEFactory<LabelledRuleConditionals>);
+ reg.Register("BackoffRule", new FEFactory<BackoffRule>);
po::variables_map conf;
InitCommandLine(reg, argc, argv, &conf);
aligned_corpus = conf["aligned_corpus"].as<string>(); // GLOBAL VAR
diff --git a/extools/score_grammar.cc b/extools/score_grammar.cc
new file mode 100644
index 00000000..0945e018
--- /dev/null
+++ b/extools/score_grammar.cc
@@ -0,0 +1,352 @@
+/*
+ * Score a grammar in striped format
+ * ./score_grammar <alignment> < filtered.grammar > scored.grammar
+ */
+#include <iostream>
+#include <string>
+#include <map>
+#include <vector>
+#include <utility>
+#include <cstdlib>
+#include <fstream>
+#include <tr1/unordered_map>
+
+#include "sentence_pair.h"
+#include "extract.h"
+#include "fdict.h"
+#include "tdict.h"
+#include "lex_trans_tbl.h"
+#include "filelib.h"
+
+#include <boost/functional/hash.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+using namespace std;
+using namespace std::tr1;
+
+
+static const size_t MAX_LINE_LENGTH = 64000000;
+
+typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics;
+
+
+namespace {
+ inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; }
+ inline bool IsBracket(char c){return c == '[' || c == ']';}
+ inline void SkipWhitespace(const char* buf, int* ptr) {
+ while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); }
+ }
+}
+
+int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) {
+ static const WordID kDIV = TD::Convert("|||");
+ int ptr = sstart;
+ while(ptr < end) {
+ while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; }
+ int start = ptr;
+ while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; }
+ if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; }
+ const WordID w = TD::Convert(string(buf, start, ptr - start));
+
+ if((IsBracket(buf[start]) and IsBracket(buf[ptr-1])) or( w == kDIV))
+ p->push_back(1 * w);
+ else {
+ if (w == kDIV) return ptr;
+ p->push_back(w);
+ }
+ }
+ return ptr;
+}
+
+
+void ParseLine(const char* buf, vector<WordID>* cur_key, ID2RuleStatistics* counts) {
+ static const WordID kDIV = TD::Convert("|||");
+ counts->clear();
+ int ptr = 0;
+ while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; }
+ if (buf[ptr] != '\t') {
+ cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl;
+ exit(1);
+ }
+ cur_key->clear();
+ // key is: "[X] ||| word word word"
+ int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key);
+ cur_key->push_back(kDIV);
+ ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key);
+ ++ptr;
+ int start = ptr;
+ int end = ptr;
+ int state = 0; // 0=reading label, 1=reading count
+ vector<WordID> name;
+ while(buf[ptr] != 0) {
+ while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; }
+ if (buf[ptr] == '|') {
+ ++ptr;
+ if (buf[ptr] == '|') {
+ ++ptr;
+ if (buf[ptr] == '|') {
+ ++ptr;
+ end = ptr - 3;
+ while (end > start && IsWhitespace(buf[end-1])) { --end; }
+ if (start == end) {
+ cerr << "Got empty token!\n LINE=" << buf << endl;
+ exit(1);
+ }
+ switch (state) {
+ case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break;
+ case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break;
+ default: cerr << "Can't happen\n"; abort();
+ }
+ SkipWhitespace(buf, &ptr);
+ start = ptr;
+ }
+ }
+ }
+ }
+ end=ptr;
+ while (end > start && IsWhitespace(buf[end-1])) { --end; }
+ if (end > start) {
+ switch (state) {
+ case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break;
+ case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break;
+ default: cerr << "Can't happen\n"; abort();
+ }
+ }
+}
+
+
+
+void LexTranslationTable::createTTable(const char* buf){
+
+ bool DEBUG = false;
+
+ AnnotatedParallelSentence sent;
+
+ sent.ParseInputLine(buf);
+
+ //iterate over the alignment to compute aligned words
+
+ for(int i =0;i<sent.aligned.width();i++)
+ {
+ for (int j=0;j<sent.aligned.height();j++)
+ {
+ if (DEBUG) cerr << sent.aligned(i,j) << " ";
+ if( sent.aligned(i,j))
+ {
+ if (DEBUG) cerr << TD::Convert(sent.f[i]) << " aligned to " << TD::Convert(sent.e[j]);
+ ++word_translation[pair<WordID,WordID> (sent.f[i], sent.e[j])];
+ ++total_foreign[sent.f[i]];
+ ++total_english[sent.e[j]];
+ }
+ }
+ if (DEBUG) cerr << endl;
+ }
+ if (DEBUG) cerr << endl;
+
+ static const WordID NULL_ = TD::Convert("NULL");
+ //handle unaligned words - align them to null
+ for (int j =0; j < sent.e_len; j++)
+ {
+ if (sent.e_aligned[j]) continue;
+ ++word_translation[pair<WordID,WordID> (NULL_, sent.e[j])];
+ ++total_foreign[NULL_];
+ ++total_english[sent.e[j]];
+ }
+
+ for (int i =0; i < sent.f_len; i++)
+ {
+ if (sent.f_aligned[i]) continue;
+ ++word_translation[pair<WordID,WordID> (sent.f[i], NULL_)];
+ ++total_english[NULL_];
+ ++total_foreign[sent.f[i]];
+ }
+
+}
+
+
+inline float safenlog(float v) {
+ if (v == 1.0f) return 0.0f;
+ float res = -log(v);
+ if (res > 100.0f) res = 100.0f;
+ return res;
+}
+
+int main(int argc, char** argv){
+ bool DEBUG= false;
+ if (argc != 2) {
+ cerr << "Usage: " << argv[0] << " corpus.al < filtered.grammar\n";
+ return 1;
+ }
+ ifstream alignment (argv[1]);
+ istream& unscored_grammar = cin;
+ ostream& scored_grammar = cout;
+
+ //create lexical translation table
+ cerr << "Creating table..." << endl;
+ char* buf = new char[MAX_LINE_LENGTH];
+
+ LexTranslationTable table;
+
+ while(!alignment.eof())
+ {
+ alignment.getline(buf, MAX_LINE_LENGTH);
+ if (buf[0] == 0) continue;
+
+ table.createTTable(buf);
+ }
+
+ bool PRINT_TABLE=false;
+ if (PRINT_TABLE)
+ {
+ ofstream trans_table;
+ trans_table.open("lex_trans_table.out");
+ for(map < pair<WordID,WordID>,int >::iterator it = table.word_translation.begin(); it != table.word_translation.end(); ++it)
+ {
+ trans_table << TD::Convert(it->first.first) << "|||" << TD::Convert(it->first.second) << "==" << it->second << "//" << table.total_foreign[it->first.first] << "//" << table.total_english[it->first.second] << endl;
+ }
+
+ trans_table.close();
+ }
+
+
+ //score unscored grammar
+ cerr <<"Scoring grammar..." << endl;
+
+ ID2RuleStatistics acc, cur_counts;
+ vector<WordID> key, cur_key,temp_key;
+ vector< pair<short,short> > al;
+ vector< pair<short,short> >::iterator ita;
+ int line = 0;
+
+ static const int kCF = FD::Convert("CF");
+ static const int kCE = FD::Convert("CE");
+ static const int kCFE = FD::Convert("CFE");
+
+ while(!unscored_grammar.eof())
+ {
+ ++line;
+ unscored_grammar.getline(buf, MAX_LINE_LENGTH);
+ if (buf[0] == 0) continue;
+ ParseLine(buf, &cur_key, &cur_counts);
+
+ //loop over all the Target side phrases that this source aligns to
+ for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it)
+ {
+
+ /*Compute phrase translation prob.
+ Print out scores in this format:
+ Phrase trnaslation prob P(F|E)
+ Phrase translation prob P(E|F)
+ Lexical weighting prob lex(F|E)
+ Lexical weighting prob lex(E|F)
+ */
+
+ float pEF_ = it->second.counts.value(kCFE) / it->second.counts.value(kCF);
+ float pFE_ = it->second.counts.value(kCFE) / it->second.counts.value(kCE);
+
+ map <WordID, pair<int, float> > foreign_aligned;
+ map <WordID, pair<int, float> > english_aligned;
+
+ //Loop over all the alignment points to compute lexical translation probability
+ al = it->second.aligns;
+ for(ita = al.begin(); ita != al.end(); ++ita)
+ {
+
+ if (DEBUG)
+ {
+ cerr << "\nA:" << ita->first << "," << ita->second << "::";
+ cerr << TD::Convert(cur_key[ita->first + 2]) << "-" << TD::Convert(it->first[ita->second]);
+ }
+
+
+ //Lookup this alignment probability in the table
+ int temp = table.word_translation[pair<WordID,WordID> (cur_key[ita->first+2],it->first[ita->second])];
+ float f2e=0, e2f=0;
+ if ( table.total_foreign[cur_key[ita->first+2]] != 0)
+ f2e = (float) temp / table.total_foreign[cur_key[ita->first+2]];
+ if ( table.total_english[it->first[ita->second]] !=0 )
+ e2f = (float) temp / table.total_english[it->first[ita->second]];
+ if (DEBUG) printf (" %d %E %E\n", temp, f2e, e2f);
+
+
+ //local counts to keep track of which things haven't been aligned, to later compute their null alignment
+ if (foreign_aligned.count(cur_key[ita->first+2]))
+ {
+ foreign_aligned[ cur_key[ita->first+2] ].first++;
+ foreign_aligned[ cur_key[ita->first+2] ].second += e2f;
+ }
+ else
+ foreign_aligned [ cur_key[ita->first+2] ] = pair<int,float> (1,e2f);
+
+
+
+ if (english_aligned.count( it->first[ ita->second] ))
+ {
+ english_aligned[ it->first[ ita->second ]].first++;
+ english_aligned[ it->first[ ita->second] ].second += f2e;
+ }
+ else
+ english_aligned [ it->first[ ita->second] ] = pair<int,float> (1,f2e);
+
+
+
+
+ }
+
+ float final_lex_f2e=1, final_lex_e2f=1;
+ static const WordID NULL_ = TD::Convert("NULL");
+
+ //compute lexical weight P(F|E) and include unaligned foreign words
+ for(int i=0;i<cur_key.size(); i++)
+ {
+
+ if (!table.total_foreign.count(cur_key[i])) continue; //if we dont have it in the translation table, we won't know its lexical weight
+
+ if (foreign_aligned.count(cur_key[i]))
+ {
+ pair<int, float> temp_lex_prob = foreign_aligned[cur_key[i]];
+ final_lex_e2f *= temp_lex_prob.second / temp_lex_prob.first;
+ }
+ else //dealing with null alignment
+ {
+ int temp_count = table.word_translation[pair<WordID,WordID> (cur_key[i],NULL_)];
+ float temp_e2f = (float) temp_count / table.total_english[NULL_];
+ final_lex_e2f *= temp_e2f;
+ }
+
+ }
+
+ //compute P(E|F) unaligned english words
+ for(int j=0; j< it->first.size(); j++)
+ {
+ if (!table.total_english.count(it->first[j])) continue;
+
+ if (english_aligned.count(it->first[j]))
+ {
+ pair<int, float> temp_lex_prob = english_aligned[it->first[j]];
+ final_lex_f2e *= temp_lex_prob.second / temp_lex_prob.first;
+ }
+ else //dealing with null
+ {
+ int temp_count = table.word_translation[pair<WordID,WordID> (NULL_,it->first[j])];
+ float temp_f2e = (float) temp_count / table.total_foreign[NULL_];
+ final_lex_f2e *= temp_f2e;
+ }
+ }
+
+
+ scored_grammar << TD::GetString(cur_key);
+ string lhs = TD::Convert(cur_key[0]);
+ scored_grammar << " " << TD::GetString(it->first) << " |||";
+ if(lhs.find('_')!=string::npos) {
+ scored_grammar << " Bkoff=" << safenlog(3.0f);
+ } else {
+ scored_grammar << " FGivenE=" << safenlog(pFE_) << " EGivenF=" << safenlog(pEF_);
+ scored_grammar << " LexE2F=" << safenlog(final_lex_e2f) << " LexF2E=" << safenlog(final_lex_f2e);
+ }
+ scored_grammar << endl;
+ }
+ }
+}
+