From 11f476baf7d855198413b091cae775bde4ea41ed Mon Sep 17 00:00:00 2001 From: "olivia.buzek" Date: Thu, 8 Jul 2010 21:59:50 +0000 Subject: Adding backoff grammar and BackoffRule feature. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@191 ec762483-ff6d-05da-a07a-a48fb63a330f --- extools/extract.cc | 44 +++++- extools/extract.h | 3 +- extools/extractor.cc | 14 +- extools/featurize_grammar.cc | 17 +++ extools/score_grammar.cc | 352 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 424 insertions(+), 6 deletions(-) create mode 100644 extools/score_grammar.cc (limited to 'extools') diff --git a/extools/extract.cc b/extools/extract.cc index 6ad124d2..c2c413e2 100644 --- a/extools/extract.cc +++ b/extools/extract.cc @@ -173,12 +173,15 @@ void Extract::ExtractConsistentRules(const AnnotatedParallelSentence& sentence, const int max_syms, const bool permit_adjacent_nonterminals, const bool require_aligned_terminal, - RuleObserver* observer) { + RuleObserver* observer, + vector* all_cats) { + const char bkoff_mrkr = '_'; queue q; // agenda for BFS int max_len = -1; unordered_map, vector, boost::hash > > fspans; vector > spans_by_start(sentence.f_len); set starts; + WordID bkoff; for (int i = 0; i < phrases.size(); ++i) { fspans[make_pair(phrases[i].i1,phrases[i].i2)].push_back(phrases[i]); max_len = max(max_len, phrases[i].i2 - phrases[i].i1); @@ -281,6 +284,42 @@ void Extract::ExtractConsistentRules(const AnnotatedParallelSentence& sentence, if (cur_es[j] >= 0 && sentence.aligned(cur_fs[i],cur_es[j])) cur_terminal_align.push_back(make_pair(i,j)); observer->CountRule(lhs, cur_rhs_f, cur_rhs_e, cur_terminal_align); + + if(!all_cats->empty()) { + //produce the backoff grammar if the category wordIDs are available + for (int i = 0; i < cur_rhs_f.size(); ++i) { + if(cur_rhs_f[i] < 0) { + //cerr << cur_rhs_f[i] << ": (cats,f) |" << TD::Convert(-cur_rhs_f[i]) << endl; + string nonterm = TD::Convert(-cur_rhs_f[i]); + nonterm+=bkoff_mrkr; + bkoff = -TD::Convert(nonterm); + cur_rhs_f[i]=bkoff; + vector rhs_f_bkoff; + vector rhs_e_bkoff; + vector > bkoff_align; + bkoff_align.clear(); + bkoff_align.push_back(make_pair(0,0)); + + for (int cat = 0; cat < all_cats->size(); ++cat) { + rhs_f_bkoff.clear(); + rhs_e_bkoff.clear(); + rhs_f_bkoff.push_back(-(*all_cats)[cat]); + rhs_e_bkoff.push_back(0); + observer->CountRule(bkoff,rhs_f_bkoff,rhs_e_bkoff,bkoff_align); + + } + }//else + //cerr << cur_rhs_f[i] << ": (words,f) |" << TD::Convert(cur_rhs_f[i]) << endl; + } + /*for (int i=0; i < cur_rhs_e.size(); ++i) + if(cur_rhs_e[i] <= 0) + cerr << cur_rhs_e[i] << ": (cats,e) |" << TD::Convert(1-cur_rhs_e[i]) << endl; + else + cerr << cur_rhs_e[i] << ": (words,e) |" << TD::Convert(cur_rhs_e[i]) << endl; + */ + + observer->CountRule(lhs, cur_rhs_f, cur_rhs_e, cur_terminal_align); + } } } } @@ -337,5 +376,4 @@ ostream& operator<<(ostream& os, const RuleStatistics& s) { } } return os; -} - +} \ No newline at end of file diff --git a/extools/extract.h b/extools/extract.h index f87aa6cb..72017034 100644 --- a/extools/extract.h +++ b/extools/extract.h @@ -87,7 +87,8 @@ struct Extract { const int max_syms, const bool permit_adjacent_nonterminals, const bool require_aligned_terminal, - RuleObserver* observer); + RuleObserver* observer, + std::vector* all_cats); }; // represents statistics / information about a rule pair diff --git a/extools/extractor.cc b/extools/extractor.cc index 4f9b4dc6..7149bfd7 100644 --- a/extools/extractor.cc +++ b/extools/extractor.cc @@ -43,6 +43,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("max_vars,v", po::value()->default_value(2), "Maximum number of nonterminal variables in final phrase size") ("permit_adjacent_nonterminals,A", "Permit adjacent nonterminals in source side of rules") ("no_required_aligned_terminal,n", "Do not require an aligned terminal") + ("topics,t", po::value()->default_value(50), "Number of categories assigned during clustering") + ("backoff,g","Produce a backoff grammar") ("help,h", "Print this help message and exit"); po::options_description clo("Command line options"); po::options_description dcmdline_options; @@ -299,6 +301,7 @@ int main(int argc, char** argv) { WordID default_cat = 0; // 0 means no default- extraction will // fail if a phrase is extracted without a // category + const bool backoff = (conf.count("backoff") ? true : false); if (conf.count("default_category")) { string sdefault_cat = conf["default_category"].as(); default_cat = -TD::Convert(sdefault_cat); @@ -310,6 +313,7 @@ int main(int argc, char** argv) { char buf[MAX_LINE_LENGTH]; AnnotatedParallelSentence sentence; vector phrases; + vector all_cats; const int max_base_phrase_size = conf["max_base_phrase_size"].as(); const bool write_phrase_contexts = conf.count("phrase_context") > 0; const bool write_base_phrases = conf.count("base_phrase") > 0; @@ -319,12 +323,19 @@ int main(int argc, char** argv) { const int max_syms = conf["max_syms"].as(); const int max_vars = conf["max_vars"].as(); const int ctx_size = conf["phrase_context_size"].as(); + const int num_categories = conf["topics"].as(); const bool permit_adjacent_nonterminals = conf.count("permit_adjacent_nonterminals") > 0; const bool require_aligned_terminal = conf.count("no_required_aligned_terminal") == 0; int line = 0; CountCombiner cc(conf["combiner_size"].as()); HadoopStreamingRuleObserver o(&cc, conf.count("bidir") > 0); + + if(backoff) { + for (int i=0;i < num_categories;++i) + all_cats.push_back(TD::Convert("X"+boost::lexical_cast(i))); + } + //SimpleRuleWriter o; while(in) { ++line; @@ -356,9 +367,8 @@ int main(int argc, char** argv) { continue; } Extract::AnnotatePhrasesWithCategoryTypes(default_cat, sentence.span_types, &phrases); - Extract::ExtractConsistentRules(sentence, phrases, max_vars, max_syms, permit_adjacent_nonterminals, require_aligned_terminal, &o); + Extract::ExtractConsistentRules(sentence, phrases, max_vars, max_syms, permit_adjacent_nonterminals, require_aligned_terminal, &o, &all_cats); } if (!silent) cerr << endl; return 0; } - diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc index 0d054626..b387fe04 100644 --- a/extools/featurize_grammar.cc +++ b/extools/featurize_grammar.cc @@ -385,6 +385,22 @@ struct LogRuleCount : public FeatureExtractor { const int kCFE; }; +struct BackoffRule : public FeatureExtractor { + BackoffRule() : + fid_(FD::Convert("BackoffRule")) {} + virtual void ExtractFeatures(const WordID lhs, + const vector& src, + const vector& trg, + const RuleStatistics& info, + SparseVector* result) const { + (void) lhs; (void) src; (void) trg; + string lhstr = TD::Convert(lhs); + if(lhstr.find('_')!=string::npos) + result->set_value(fid_, -1); + } + const int fid_; +}; + // The negative log of the condition rule probs // ignoring the identities of the non-terminals. // i.e. the prob Hiero would assign. @@ -656,6 +672,7 @@ int main(int argc, char** argv){ reg.Register("LexProb", new FEFactory); reg.Register("XFeatures", new FEFactory); reg.Register("LabelledRuleConditionals", new FEFactory); + reg.Register("BackoffRule", new FEFactory); po::variables_map conf; InitCommandLine(reg, argc, argv, &conf); aligned_corpus = conf["aligned_corpus"].as(); // GLOBAL VAR diff --git a/extools/score_grammar.cc b/extools/score_grammar.cc new file mode 100644 index 00000000..0945e018 --- /dev/null +++ b/extools/score_grammar.cc @@ -0,0 +1,352 @@ +/* + * Score a grammar in striped format + * ./score_grammar < filtered.grammar > scored.grammar + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sentence_pair.h" +#include "extract.h" +#include "fdict.h" +#include "tdict.h" +#include "lex_trans_tbl.h" +#include "filelib.h" + +#include +#include +#include + +using namespace std; +using namespace std::tr1; + + +static const size_t MAX_LINE_LENGTH = 64000000; + +typedef unordered_map, RuleStatistics, boost::hash > > ID2RuleStatistics; + + +namespace { + inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } + inline bool IsBracket(char c){return c == '[' || c == ']';} + inline void SkipWhitespace(const char* buf, int* ptr) { + while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } + } +} + +int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector* p) { + static const WordID kDIV = TD::Convert("|||"); + int ptr = sstart; + while(ptr < end) { + while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; } + int start = ptr; + while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } + if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; } + const WordID w = TD::Convert(string(buf, start, ptr - start)); + + if((IsBracket(buf[start]) and IsBracket(buf[ptr-1])) or( w == kDIV)) + p->push_back(1 * w); + else { + if (w == kDIV) return ptr; + p->push_back(w); + } + } + return ptr; +} + + +void ParseLine(const char* buf, vector* cur_key, ID2RuleStatistics* counts) { + static const WordID kDIV = TD::Convert("|||"); + counts->clear(); + int ptr = 0; + while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; } + if (buf[ptr] != '\t') { + cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl; + exit(1); + } + cur_key->clear(); + // key is: "[X] ||| word word word" + int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key); + cur_key->push_back(kDIV); + ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key); + ++ptr; + int start = ptr; + int end = ptr; + int state = 0; // 0=reading label, 1=reading count + vector name; + while(buf[ptr] != 0) { + while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; } + if (buf[ptr] == '|') { + ++ptr; + if (buf[ptr] == '|') { + ++ptr; + if (buf[ptr] == '|') { + ++ptr; + end = ptr - 3; + while (end > start && IsWhitespace(buf[end-1])) { --end; } + if (start == end) { + cerr << "Got empty token!\n LINE=" << buf << endl; + exit(1); + } + switch (state) { + case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; + case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; + default: cerr << "Can't happen\n"; abort(); + } + SkipWhitespace(buf, &ptr); + start = ptr; + } + } + } + } + end=ptr; + while (end > start && IsWhitespace(buf[end-1])) { --end; } + if (end > start) { + switch (state) { + case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; + case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; + default: cerr << "Can't happen\n"; abort(); + } + } +} + + + +void LexTranslationTable::createTTable(const char* buf){ + + bool DEBUG = false; + + AnnotatedParallelSentence sent; + + sent.ParseInputLine(buf); + + //iterate over the alignment to compute aligned words + + for(int i =0;i (sent.f[i], sent.e[j])]; + ++total_foreign[sent.f[i]]; + ++total_english[sent.e[j]]; + } + } + if (DEBUG) cerr << endl; + } + if (DEBUG) cerr << endl; + + static const WordID NULL_ = TD::Convert("NULL"); + //handle unaligned words - align them to null + for (int j =0; j < sent.e_len; j++) + { + if (sent.e_aligned[j]) continue; + ++word_translation[pair (NULL_, sent.e[j])]; + ++total_foreign[NULL_]; + ++total_english[sent.e[j]]; + } + + for (int i =0; i < sent.f_len; i++) + { + if (sent.f_aligned[i]) continue; + ++word_translation[pair (sent.f[i], NULL_)]; + ++total_english[NULL_]; + ++total_foreign[sent.f[i]]; + } + +} + + +inline float safenlog(float v) { + if (v == 1.0f) return 0.0f; + float res = -log(v); + if (res > 100.0f) res = 100.0f; + return res; +} + +int main(int argc, char** argv){ + bool DEBUG= false; + if (argc != 2) { + cerr << "Usage: " << argv[0] << " corpus.al < filtered.grammar\n"; + return 1; + } + ifstream alignment (argv[1]); + istream& unscored_grammar = cin; + ostream& scored_grammar = cout; + + //create lexical translation table + cerr << "Creating table..." << endl; + char* buf = new char[MAX_LINE_LENGTH]; + + LexTranslationTable table; + + while(!alignment.eof()) + { + alignment.getline(buf, MAX_LINE_LENGTH); + if (buf[0] == 0) continue; + + table.createTTable(buf); + } + + bool PRINT_TABLE=false; + if (PRINT_TABLE) + { + ofstream trans_table; + trans_table.open("lex_trans_table.out"); + for(map < pair,int >::iterator it = table.word_translation.begin(); it != table.word_translation.end(); ++it) + { + trans_table << TD::Convert(it->first.first) << "|||" << TD::Convert(it->first.second) << "==" << it->second << "//" << table.total_foreign[it->first.first] << "//" << table.total_english[it->first.second] << endl; + } + + trans_table.close(); + } + + + //score unscored grammar + cerr <<"Scoring grammar..." << endl; + + ID2RuleStatistics acc, cur_counts; + vector key, cur_key,temp_key; + vector< pair > al; + vector< pair >::iterator ita; + int line = 0; + + static const int kCF = FD::Convert("CF"); + static const int kCE = FD::Convert("CE"); + static const int kCFE = FD::Convert("CFE"); + + while(!unscored_grammar.eof()) + { + ++line; + unscored_grammar.getline(buf, MAX_LINE_LENGTH); + if (buf[0] == 0) continue; + ParseLine(buf, &cur_key, &cur_counts); + + //loop over all the Target side phrases that this source aligns to + for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) + { + + /*Compute phrase translation prob. + Print out scores in this format: + Phrase trnaslation prob P(F|E) + Phrase translation prob P(E|F) + Lexical weighting prob lex(F|E) + Lexical weighting prob lex(E|F) + */ + + float pEF_ = it->second.counts.value(kCFE) / it->second.counts.value(kCF); + float pFE_ = it->second.counts.value(kCFE) / it->second.counts.value(kCE); + + map > foreign_aligned; + map > english_aligned; + + //Loop over all the alignment points to compute lexical translation probability + al = it->second.aligns; + for(ita = al.begin(); ita != al.end(); ++ita) + { + + if (DEBUG) + { + cerr << "\nA:" << ita->first << "," << ita->second << "::"; + cerr << TD::Convert(cur_key[ita->first + 2]) << "-" << TD::Convert(it->first[ita->second]); + } + + + //Lookup this alignment probability in the table + int temp = table.word_translation[pair (cur_key[ita->first+2],it->first[ita->second])]; + float f2e=0, e2f=0; + if ( table.total_foreign[cur_key[ita->first+2]] != 0) + f2e = (float) temp / table.total_foreign[cur_key[ita->first+2]]; + if ( table.total_english[it->first[ita->second]] !=0 ) + e2f = (float) temp / table.total_english[it->first[ita->second]]; + if (DEBUG) printf (" %d %E %E\n", temp, f2e, e2f); + + + //local counts to keep track of which things haven't been aligned, to later compute their null alignment + if (foreign_aligned.count(cur_key[ita->first+2])) + { + foreign_aligned[ cur_key[ita->first+2] ].first++; + foreign_aligned[ cur_key[ita->first+2] ].second += e2f; + } + else + foreign_aligned [ cur_key[ita->first+2] ] = pair (1,e2f); + + + + if (english_aligned.count( it->first[ ita->second] )) + { + english_aligned[ it->first[ ita->second ]].first++; + english_aligned[ it->first[ ita->second] ].second += f2e; + } + else + english_aligned [ it->first[ ita->second] ] = pair (1,f2e); + + + + + } + + float final_lex_f2e=1, final_lex_e2f=1; + static const WordID NULL_ = TD::Convert("NULL"); + + //compute lexical weight P(F|E) and include unaligned foreign words + for(int i=0;i temp_lex_prob = foreign_aligned[cur_key[i]]; + final_lex_e2f *= temp_lex_prob.second / temp_lex_prob.first; + } + else //dealing with null alignment + { + int temp_count = table.word_translation[pair (cur_key[i],NULL_)]; + float temp_e2f = (float) temp_count / table.total_english[NULL_]; + final_lex_e2f *= temp_e2f; + } + + } + + //compute P(E|F) unaligned english words + for(int j=0; j< it->first.size(); j++) + { + if (!table.total_english.count(it->first[j])) continue; + + if (english_aligned.count(it->first[j])) + { + pair temp_lex_prob = english_aligned[it->first[j]]; + final_lex_f2e *= temp_lex_prob.second / temp_lex_prob.first; + } + else //dealing with null + { + int temp_count = table.word_translation[pair (NULL_,it->first[j])]; + float temp_f2e = (float) temp_count / table.total_foreign[NULL_]; + final_lex_f2e *= temp_f2e; + } + } + + + scored_grammar << TD::GetString(cur_key); + string lhs = TD::Convert(cur_key[0]); + scored_grammar << " " << TD::GetString(it->first) << " |||"; + if(lhs.find('_')!=string::npos) { + scored_grammar << " Bkoff=" << safenlog(3.0f); + } else { + scored_grammar << " FGivenE=" << safenlog(pFE_) << " EGivenF=" << safenlog(pEF_); + scored_grammar << " LexE2F=" << safenlog(final_lex_e2f) << " LexF2E=" << safenlog(final_lex_f2e); + } + scored_grammar << endl; + } + } +} + -- cgit v1.2.3