From a15b6133bdb8eee1cdbc67712f3a5e51c4ec5377 Mon Sep 17 00:00:00 2001 From: redpony Date: Tue, 13 Jul 2010 06:29:00 +0000 Subject: start moving toward striped grammar lexer git-svn-id: https://ws10smt.googlecode.com/svn/trunk@233 ec762483-ff6d-05da-a07a-a48fb63a330f --- extools/Makefile.am | 13 ++- extools/extract.cc | 50 -------- extools/extract.h | 17 --- extools/extractor.cc | 1 + extools/featurize_grammar.cc | 4 +- extools/filter_grammar.cc | 167 ++++++--------------------- extools/filter_score_grammar.cc | 1 + extools/mr_stripe_rule_reduce.cc | 1 + extools/sg_lexer.l | 242 +++++++++++++++++++++++++++++++++++++++ extools/striped_grammar.cc | 67 +++++++++++ extools/striped_grammar.h | 54 +++++++++ 11 files changed, 412 insertions(+), 205 deletions(-) create mode 100644 extools/sg_lexer.l create mode 100644 extools/striped_grammar.cc create mode 100644 extools/striped_grammar.h (limited to 'extools') diff --git a/extools/Makefile.am b/extools/Makefile.am index fc02f831..1c0da21b 100644 --- a/extools/Makefile.am +++ b/extools/Makefile.am @@ -8,15 +8,18 @@ bin_PROGRAMS = \ noinst_PROGRAMS = -filter_score_grammar_SOURCES = filter_score_grammar.cc extract.cc sentence_pair.cc +sg_lexer.cc: sg_lexer.l + $(LEX) -s -CF -8 -o$@ $< + +filter_score_grammar_SOURCES = filter_score_grammar.cc extract.cc sentence_pair.cc striped_grammar.cc filter_score_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a -lz filter_score_grammar_LDFLAGS = -all-static -filter_grammar_SOURCES = filter_grammar.cc extract.cc sentence_pair.cc +filter_grammar_SOURCES = filter_grammar.cc extract.cc sentence_pair.cc striped_grammar.cc sg_lexer.cc filter_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a -lz filter_grammar_LDFLAGS = -all-static -featurize_grammar_SOURCES = featurize_grammar.cc extract.cc sentence_pair.cc +featurize_grammar_SOURCES = featurize_grammar.cc extract.cc sentence_pair.cc sg_lexer.cc striped_grammar.cc featurize_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a -lz featurize_grammar_LDFLAGS = -all-static @@ -24,11 +27,11 @@ build_lexical_translation_SOURCES = build_lexical_translation.cc extract.cc sent build_lexical_translation_LDADD = $(top_srcdir)/decoder/libcdec.a -lz build_lexical_translation_LDFLAGS = -all-static -mr_stripe_rule_reduce_SOURCES = mr_stripe_rule_reduce.cc extract.cc sentence_pair.cc +mr_stripe_rule_reduce_SOURCES = mr_stripe_rule_reduce.cc extract.cc sentence_pair.cc striped_grammar.cc mr_stripe_rule_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a -lz mr_stripe_rule_reduce_LDFLAGS = -all-static -extractor_SOURCES = sentence_pair.cc extract.cc extractor.cc +extractor_SOURCES = sentence_pair.cc extract.cc extractor.cc striped_grammar.cc extractor_LDADD = $(top_srcdir)/decoder/libcdec.a -lz extractor_LDFLAGS = -all-static diff --git a/extools/extract.cc b/extools/extract.cc index 14497089..567348f4 100644 --- a/extools/extract.cc +++ b/extools/extract.cc @@ -320,53 +320,3 @@ void Extract::ExtractConsistentRules(const AnnotatedParallelSentence& sentence, } } -void RuleStatistics::ParseRuleStatistics(const char* buf, int start, int end) { - int ptr = start; - counts.clear(); - aligns.clear(); - while (ptr < end) { - SkipWhitespace(buf, &ptr); - int vstart = ptr; - while(ptr < end && buf[ptr] != '=') ++ptr; - assert(buf[ptr] == '='); - assert(ptr > vstart); - if (buf[vstart] == 'A' && buf[vstart+1] == '=') { - ++ptr; - while (ptr < end && !IsWhitespace(buf[ptr])) { - while(ptr < end && buf[ptr] == ',') { ++ptr; } - assert(ptr < end); - vstart = ptr; - while(ptr < end && buf[ptr] != ',' && !IsWhitespace(buf[ptr])) { ++ptr; } - if (ptr > vstart) { - short a, b; - AnnotatedParallelSentence::ReadAlignmentPoint(buf, vstart, ptr, false, &a, &b); - aligns.push_back(make_pair(a,b)); - } - } - } else { - int name = FD::Convert(string(buf,vstart,ptr-vstart)); - ++ptr; - vstart = ptr; - while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } - assert(ptr > vstart); - counts.set_value(name, strtod(buf + vstart, NULL)); - } - } -} - -ostream& operator<<(ostream& os, const RuleStatistics& s) { - bool needspace = false; - for (SparseVector::const_iterator it = s.counts.begin(); it != s.counts.end(); ++it) { - if (needspace) os << ' '; else needspace = true; - os << FD::Convert(it->first) << '=' << it->second; - } - if (s.aligns.size() > 0) { - os << " A="; - needspace = false; - for (int i = 0; i < s.aligns.size(); ++i) { - if (needspace) os << ','; else needspace = true; - os << s.aligns[i].first << '-' << s.aligns[i].second; - } - } - return os; -} \ No newline at end of file diff --git a/extools/extract.h b/extools/extract.h index 72017034..76292bed 100644 --- a/extools/extract.h +++ b/extools/extract.h @@ -91,21 +91,4 @@ struct Extract { std::vector* all_cats); }; -// represents statistics / information about a rule pair -struct RuleStatistics { - SparseVector counts; - std::vector > aligns; - RuleStatistics() {} - RuleStatistics(int name, float val, const std::vector >& al) : - aligns(al) { - counts.set_value(name, val); - } - void ParseRuleStatistics(const char* buf, int start, int end); - RuleStatistics& operator+=(const RuleStatistics& rhs) { - counts += rhs.counts; - return *this; - } -}; -std::ostream& operator<<(std::ostream& os, const RuleStatistics& s); - #endif diff --git a/extools/extractor.cc b/extools/extractor.cc index 7149bfd7..bc27e408 100644 --- a/extools/extractor.cc +++ b/extools/extractor.cc @@ -16,6 +16,7 @@ #include "wordid.h" #include "array2d.h" #include "filelib.h" +#include "striped_grammar.h" using namespace std; using namespace std::tr1; diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc index 27c0dadf..547f390a 100644 --- a/extools/featurize_grammar.cc +++ b/extools/featurize_grammar.cc @@ -21,6 +21,7 @@ #include "tdict.h" #include "lex_trans_tbl.h" #include "filelib.h" +#include "striped_grammar.h" #include #include @@ -137,7 +138,6 @@ bool validate_non_terminal(const std::string& s) namespace { inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } - inline bool IsBracket(char c){return c == '[' || c == ']';} inline void SkipWhitespace(const char* buf, int* ptr) { while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } } @@ -407,7 +407,7 @@ struct BackoffRule : public FeatureExtractor { const RuleStatistics& /*info*/, SparseVector* result) const { (void) lhs; (void) src; (void) trg; - string lhstr = TD::Convert(lhs); + const string& lhstr = TD::Convert(lhs); if(lhstr.find('_')!=string::npos) result->set_value(fid_, -1); } diff --git a/extools/filter_grammar.cc b/extools/filter_grammar.cc index 6f0dcdfc..ca329de1 100644 --- a/extools/filter_grammar.cc +++ b/extools/filter_grammar.cc @@ -6,8 +6,6 @@ #include #include #include -#include -#include #include #include "suffix_tree.h" @@ -16,8 +14,8 @@ #include "extract.h" #include "fdict.h" #include "tdict.h" -#include "lex_trans_tbl.h" #include "filelib.h" +#include "striped_grammar.h" #include #include @@ -30,8 +28,6 @@ namespace po = boost::program_options; static const size_t MAX_LINE_LENGTH = 64000000; -typedef unordered_map, RuleStatistics, boost::hash > > ID2RuleStatistics; - void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() @@ -51,85 +47,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { exit(1); } } -namespace { - inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } - inline void SkipWhitespace(const char* buf, int* ptr) { - while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } - } -} - -int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector* p) { - static const WordID kDIV = TD::Convert("|||"); - int ptr = sstart; - while(ptr < end) { - while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; } - int start = ptr; - while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } - if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; } - const WordID w = TD::Convert(string(buf, start, ptr - start)); - if (w == kDIV) return ptr; - p->push_back(w); - } - assert(p->size() > 0); - return ptr; -} - - -void ParseLine(const char* buf, vector* cur_key, ID2RuleStatistics* counts) { - static const WordID kDIV = TD::Convert("|||"); - counts->clear(); - int ptr = 0; - while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; } - if (buf[ptr] != '\t') { - cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl; - exit(1); - } - cur_key->clear(); - // key is: "[X] ||| word word word" - int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key); - cur_key->push_back(kDIV); - ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key); - ++ptr; - int start = ptr; - int end = ptr; - int state = 0; // 0=reading label, 1=reading count - vector name; - while(buf[ptr] != 0) { - while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; } - if (buf[ptr] == '|') { - ++ptr; - if (buf[ptr] == '|') { - ++ptr; - if (buf[ptr] == '|') { - ++ptr; - end = ptr - 3; - while (end > start && IsWhitespace(buf[end-1])) { --end; } - if (start == end) { - cerr << "Got empty token!\n LINE=" << buf << endl; - exit(1); - } - switch (state) { - case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; - case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; - default: cerr << "Can't happen\n"; abort(); - } - SkipWhitespace(buf, &ptr); - start = ptr; - } - } - } - } - end=ptr; - while (end > start && IsWhitespace(buf[end-1])) { --end; } - if (end > start) { - switch (state) { - case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; - case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; - default: cerr << "Can't happen\n"; abort(); - } - } -} - struct SourceFilter { // return true to keep the rule, otherwise false @@ -138,8 +55,7 @@ struct SourceFilter { }; struct DumbSuffixTreeFilter : SourceFilter { - DumbSuffixTreeFilter(const string& corpus) : - kDIV(TD::Convert("|||")) { + DumbSuffixTreeFilter(const string& corpus) { cerr << "Build suffix tree from test set in " << corpus << endl; assert(FileExists(corpus)); ReadFile rfts(corpus); @@ -163,68 +79,57 @@ struct DumbSuffixTreeFilter : SourceFilter { } delete[] buf; } - virtual bool Matches(const vector& key) const { + virtual bool Matches(const vector& src_rhs) const { const Node* curnode = &root; - const int ks = key.size() - 1; - for(int i=0; i < ks; i++) { - const string& word = TD::Convert(key[i]); - if (key[i] == kDIV || (word[0] == '[' && word[word.size() - 1] == ']')) { // non-terminal + for(int i=0; i < src_rhs.size(); i++) { + if (src_rhs[i] <= 0) { curnode = &root; } else if (curnode) { - curnode = curnode->Extend(key[i]); + curnode = curnode->Extend(src_rhs[i]); if (!curnode) return false; } } return true; } - const WordID kDIV; Node root; }; +boost::shared_ptr filter; +multimap options; +int kCOUNT; +int max_options; + +void cb(WordID lhs, const vector& src_rhs, const ID2RuleStatistics& rules, void*) { + options.clear(); + if (!filter || filter->Matches(src_rhs)) { + for (ID2RuleStatistics::const_iterator it = rules.begin(); it != rules.end(); ++it) { + options.insert(make_pair(-it->second.counts.value(kCOUNT), it)); + } + int ocount = 0; + cout << '[' << TD::Convert(-lhs) << ']' << " ||| "; + WriteNamed(src_rhs, &cout); + cout << '\t'; + bool first = true; + for (multimap::iterator it = options.begin(); it != options.end(); ++it) { + if (first) { first = false; } else { cout << " ||| "; } + WriteAnonymous(it->second->first, &cout); + cout << " ||| " << it->second->second; + ++ocount; + if (ocount == max_options) break; + } + cout << endl; + } +} + int main(int argc, char** argv){ po::variables_map conf; InitCommandLine(argc, argv, &conf); - const int max_options = conf["top_e_given_f"].as();; + max_options = conf["top_e_given_f"].as();; + kCOUNT = FD::Convert("CFE"); istream& unscored_grammar = cin; - cerr << "Loading test set " << conf["test_set"].as() << "...\n"; - boost::shared_ptr filter; filter.reset(new DumbSuffixTreeFilter(conf["test_set"].as())); - cerr << "Filtering...\n"; - //score unscored grammar - char* buf = new char[MAX_LINE_LENGTH]; - - ID2RuleStatistics acc, cur_counts; - vector key, cur_key,temp_key; - int line = 0; - - multimap options; - const int kCOUNT = FD::Convert("CFE"); - while(!unscored_grammar.eof()) - { - ++line; - options.clear(); - unscored_grammar.getline(buf, MAX_LINE_LENGTH); - if (buf[0] == 0) continue; - ParseLine(buf, &cur_key, &cur_counts); - if (!filter || filter->Matches(cur_key)) { - // sort by counts - for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { - options.insert(make_pair(-it->second.counts.value(kCOUNT), it)); - } - int ocount = 0; - cout << TD::GetString(cur_key) << '\t'; - - bool first = true; - for (multimap::iterator it = options.begin(); it != options.end(); ++it) { - if (first) { first = false; } else { cout << " ||| "; } - cout << TD::GetString(it->second->first) << " ||| " << it->second->second; - ++ocount; - if (ocount == max_options) break; - } - cout << endl; - } - } + StripedGrammarLexer::ReadStripedGrammar(&unscored_grammar, cb, NULL); } diff --git a/extools/filter_score_grammar.cc b/extools/filter_score_grammar.cc index fe9a2a07..24f5fd1c 100644 --- a/extools/filter_score_grammar.cc +++ b/extools/filter_score_grammar.cc @@ -18,6 +18,7 @@ #include "tdict.h" #include "lex_trans_tbl.h" #include "filelib.h" +#include "striped_grammar.h" #include #include diff --git a/extools/mr_stripe_rule_reduce.cc b/extools/mr_stripe_rule_reduce.cc index 902b6a07..3298a801 100644 --- a/extools/mr_stripe_rule_reduce.cc +++ b/extools/mr_stripe_rule_reduce.cc @@ -12,6 +12,7 @@ #include "sentence_pair.h" #include "fdict.h" #include "extract.h" +#include "striped_grammar.h" using namespace std; using namespace std::tr1; diff --git a/extools/sg_lexer.l b/extools/sg_lexer.l new file mode 100644 index 00000000..f115e5bd --- /dev/null +++ b/extools/sg_lexer.l @@ -0,0 +1,242 @@ +%{ +#include "rule_lexer.h" + +#include +#include +#include +#include +#include +#include "tdict.h" +#include "fdict.h" +#include "trule.h" +#include "striped_grammar.h" + +int lex_line = 0; +std::istream* sglex_stream = NULL; +StripedGrammarLexer::GrammarCallback grammar_callback = NULL; +void* grammar_callback_extra = NULL; + +#undef YY_INPUT +#define YY_INPUT(buf, result, max_size) (result = sglex_stream->read(buf, max_size).gcount()) + +#define YY_SKIP_YYWRAP 1 +int num_rules = 0; +int yywrap() { return 1; } +bool fl = true; +#define MAX_TOKEN_SIZE 255 +std::string sglex_tmp_token(MAX_TOKEN_SIZE, '\0'); + +#define MAX_RULE_SIZE 48 +WordID sglex_src_rhs[MAX_RULE_SIZE]; +WordID sglex_trg_rhs[MAX_RULE_SIZE]; +int sglex_src_rhs_size; +int sglex_trg_rhs_size; +WordID sglex_lhs; +int sglex_src_arity; +int sglex_trg_arity; + +#define MAX_FEATS 100 +int sglex_feat_ids[MAX_FEATS]; +double sglex_feat_vals[MAX_FEATS]; +int sglex_num_feats; + +#define MAX_ARITY 20 +int sglex_nt_sanity[MAX_ARITY]; +int sglex_src_nts[MAX_ARITY]; +float sglex_nt_size_means[MAX_ARITY]; +float sglex_nt_size_vars[MAX_ARITY]; + +std::vector cur_src_rhs; +std::vector cur_trg_rhs; +ID2RuleStatistics cur_options; +RuleStatistics* cur_stats = NULL; +int sglex_cur_fid = 0; + +static void sanity_check_trg_index(int index) { + if (index > sglex_src_arity) { + std::cerr << "Target index " << index << " exceeds source arity " << sglex_src_arity << std::endl; + abort(); + } + int& flag = sglex_nt_sanity[index - 1]; + if (flag) { + std::cerr << "Target index " << index << " used multiple times!" << std::endl; + abort(); + } + flag = 1; +} + +static void sglex_reset() { + sglex_src_arity = 0; + sglex_trg_arity = 0; + sglex_num_feats = 0; + sglex_src_rhs_size = 0; + sglex_trg_rhs_size = 0; +} + +%} + +REAL [\-+]?[0-9]+(\.[0-9]*([eE][-+]*[0-9]+)?)?|inf|[\-+]inf +NT [\-#$A-Z_:=.",\\][\-#$".A-Z+/=_0-9!:@\\]* +ALIGN [0-9]+-[0-9]+ + +%x LHS_END SRC TRG FEATS FEATVAL ALIGNS +%% + +[ ] ; + +\[{NT}\] { + sglex_tmp_token.assign(yytext + 1, yyleng - 2); + sglex_lhs = -TD::Convert(sglex_tmp_token); + // std::cerr << sglex_tmp_token << "\n"; + BEGIN(LHS_END); + } + +\[{NT}\] { + sglex_tmp_token.assign(yytext + 1, yyleng - 2); + sglex_src_nts[sglex_src_arity] = sglex_src_rhs[sglex_src_rhs_size] = -TD::Convert(sglex_tmp_token); + ++sglex_src_arity; + ++sglex_src_rhs_size; + } + +[ ] { ; } +\|\|\| { + sglex_reset(); + BEGIN(SRC); + } +. { + std::cerr << "Line " << lex_line << ": unexpected input in LHS: " << yytext << std::endl; + exit(1); + } + + +\[{NT},[1-9][0-9]?\] { + int index = yytext[yyleng - 2] - '0'; + if (yytext[yyleng - 3] == ',') { + sglex_tmp_token.assign(yytext + 1, yyleng - 4); + } else { + sglex_tmp_token.assign(yytext + 1, yyleng - 5); + index += 10 * (yytext[yyleng - 3] - '0'); + } + if ((sglex_src_arity+1) != index) { + std::cerr << "Src indices must go in order: expected " << sglex_src_arity << " but got " << index << std::endl; + abort(); + } + sglex_src_nts[sglex_src_arity] = sglex_src_rhs[sglex_src_rhs_size] = -TD::Convert(sglex_tmp_token); + ++sglex_src_rhs_size; + ++sglex_src_arity; + } + +[^ \t]+ { + sglex_tmp_token.assign(yytext, yyleng); + sglex_src_rhs[sglex_src_rhs_size] = TD::Convert(sglex_tmp_token); + ++sglex_src_rhs_size; + } +[ ] { ; } +\t { + //std::cerr << "LHS=" << TD::Convert(-sglex_lhs) << " "; + //std::cerr << " src_size: " << sglex_src_rhs_size << std::endl; + //std::cerr << " src_arity: " << sglex_src_arity << std::endl; + memset(sglex_nt_sanity, 0, sglex_src_arity * sizeof(int)); + cur_options.clear(); + sglex_trg_rhs_size = 0; + BEGIN(TRG); + } + +\[[1-9][0-9]?\] { + int index = yytext[yyleng - 2] - '0'; + if (yyleng == 4) { + index += 10 * (yytext[yyleng - 3] - '0'); + } + ++sglex_trg_arity; + sanity_check_trg_index(index); + sglex_trg_rhs[sglex_trg_rhs_size] = 1 - index; + ++sglex_trg_rhs_size; +} + +\|\|\| { + assert(sglex_trg_rhs_size > 0); + cur_trg_rhs.resize(sglex_trg_rhs_size); + for (int i = 0; i < sglex_trg_rhs_size; ++i) + cur_trg_rhs[i] = sglex_trg_rhs[i]; + cur_stats = &cur_options[cur_trg_rhs]; + BEGIN(FEATS); + } + +[^ ]+ { + sglex_tmp_token.assign(yytext, yyleng); + sglex_trg_rhs[sglex_trg_rhs_size] = TD::Convert(sglex_tmp_token); + + ++sglex_trg_rhs_size; + } +[ ]+ { ; } + +\n { + assert(sglex_lhs < 0); + assert(sglex_src_rhs_size > 0); + cur_src_rhs.resize(sglex_src_rhs_size); + for (int i = 0; i < sglex_src_rhs_size; ++i) + cur_src_rhs[i] = sglex_src_rhs[i]; + grammar_callback(sglex_lhs, cur_src_rhs, cur_options, grammar_callback_extra); + cur_options.clear(); + BEGIN(INITIAL); + } +[ ]+ { ; } +\|\|\| { + memset(sglex_nt_sanity, 0, sglex_src_arity * sizeof(int)); + sglex_trg_rhs_size = 0; + BEGIN(TRG); + } +[A-Z][A-Z_0-9]*= { + // std::cerr << "FV: " << yytext << std::endl; + sglex_tmp_token.assign(yytext, yyleng - 1); + sglex_cur_fid = FD::Convert(sglex_tmp_token); + static const int Afid = FD::Convert("A"); + if (sglex_cur_fid == Afid) { + BEGIN(ALIGNS); + } else { + BEGIN(FEATVAL); + } + } +{REAL} { + // std::cerr << "Feature val input: " << yytext << std::endl; + cur_stats->counts.set_value(sglex_cur_fid, strtod(yytext, NULL)); + BEGIN(FEATS); + } +. { + std::cerr << "Feature val unexpected input: " << yytext << std::endl; + exit(1); + } +. { + std::cerr << "Features unexpected input: " << yytext << std::endl; + exit(1); + } +{ALIGN}(,{ALIGN})* { + assert(cur_stats->aligns.empty()); + int i = 0; + while(i < yyleng) { + short a = 0; + short b = 0; + while (yytext[i] != '-') { a *= 10; a += yytext[i] - '0'; ++i; } + ++i; + while (yytext[i] != ',' && i < yyleng) { b *= 10; b += yytext[i] - '0'; ++i; } + ++i; + cur_stats->aligns.push_back(std::make_pair(a,b)); + } + BEGIN(FEATS); + } +. { + std::cerr << "Aligns unexpected input: " << yytext << std::endl; + exit(1); + } +%% + +#include "filelib.h" + +void StripedGrammarLexer::ReadStripedGrammar(std::istream* in, GrammarCallback func, void* extra) { + lex_line = 1; + sglex_stream = in; + grammar_callback_extra = extra; + grammar_callback = func; + yylex(); +} + diff --git a/extools/striped_grammar.cc b/extools/striped_grammar.cc new file mode 100644 index 00000000..accf44eb --- /dev/null +++ b/extools/striped_grammar.cc @@ -0,0 +1,67 @@ +#include "striped_grammar.h" + +#include + +#include "sentence_pair.h" + +using namespace std; + +namespace { + inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } + + inline void SkipWhitespace(const char* buf, int* ptr) { + while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } + } +} + +void RuleStatistics::ParseRuleStatistics(const char* buf, int start, int end) { + int ptr = start; + counts.clear(); + aligns.clear(); + while (ptr < end) { + SkipWhitespace(buf, &ptr); + int vstart = ptr; + while(ptr < end && buf[ptr] != '=') ++ptr; + assert(buf[ptr] == '='); + assert(ptr > vstart); + if (buf[vstart] == 'A' && buf[vstart+1] == '=') { + ++ptr; + while (ptr < end && !IsWhitespace(buf[ptr])) { + while(ptr < end && buf[ptr] == ',') { ++ptr; } + assert(ptr < end); + vstart = ptr; + while(ptr < end && buf[ptr] != ',' && !IsWhitespace(buf[ptr])) { ++ptr; } + if (ptr > vstart) { + short a, b; + AnnotatedParallelSentence::ReadAlignmentPoint(buf, vstart, ptr, false, &a, &b); + aligns.push_back(make_pair(a,b)); + } + } + } else { + int name = FD::Convert(string(buf,vstart,ptr-vstart)); + ++ptr; + vstart = ptr; + while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } + assert(ptr > vstart); + counts.set_value(name, strtod(buf + vstart, NULL)); + } + } +} + +ostream& operator<<(ostream& os, const RuleStatistics& s) { + bool needspace = false; + for (SparseVector::const_iterator it = s.counts.begin(); it != s.counts.end(); ++it) { + if (needspace) os << ' '; else needspace = true; + os << FD::Convert(it->first) << '=' << it->second; + } + if (s.aligns.size() > 0) { + os << " A="; + needspace = false; + for (int i = 0; i < s.aligns.size(); ++i) { + if (needspace) os << ','; else needspace = true; + os << s.aligns[i].first << '-' << s.aligns[i].second; + } + } + return os; +} + diff --git a/extools/striped_grammar.h b/extools/striped_grammar.h new file mode 100644 index 00000000..cdf529d6 --- /dev/null +++ b/extools/striped_grammar.h @@ -0,0 +1,54 @@ +#ifndef _STRIPED_GRAMMAR_H_ +#define _STRIPED_GRAMMAR_H_ + +#include +#include +#include +#include +#include "sparse_vector.h" +#include "wordid.h" +#include "tdict.h" + +// represents statistics / information about a rule pair +struct RuleStatistics { + SparseVector counts; + std::vector > aligns; + RuleStatistics() {} + RuleStatistics(int name, float val, const std::vector >& al) : + aligns(al) { + counts.set_value(name, val); + } + void ParseRuleStatistics(const char* buf, int start, int end); + RuleStatistics& operator+=(const RuleStatistics& rhs) { + counts += rhs.counts; + return *this; + } +}; +std::ostream& operator<<(std::ostream& os, const RuleStatistics& s); + +inline void WriteNamed(const std::vector& v, std::ostream* os) { + bool first = true; + for (int i = 0; i < v.size(); ++i) { + if (first) { first = false; } else { (*os) << ' '; } + if (v[i] < 0) { (*os) << '[' << TD::Convert(-v[i]) << ']'; } + else (*os) << TD::Convert(v[i]); + } +} + +inline void WriteAnonymous(const std::vector& v, std::ostream* os) { + bool first = true; + for (int i = 0; i < v.size(); ++i) { + if (first) { first = false; } else { (*os) << ' '; } + if (v[i] <= 0) { (*os) << '[' << (1-v[i]) << ']'; } + else (*os) << TD::Convert(v[i]); + } +} + +typedef std::tr1::unordered_map, RuleStatistics, boost::hash > > ID2RuleStatistics; + +struct StripedGrammarLexer { + typedef void (*GrammarCallback)(WordID lhs, const std::vector& src_rhs, const ID2RuleStatistics& rules, void *extra); + static void ReadStripedGrammar(std::istream* in, GrammarCallback func, void* extra); +}; + +#endif -- cgit v1.2.3