diff options
| -rw-r--r-- | extools/Makefile.am | 13 | ||||
| -rw-r--r-- | extools/extract.cc | 50 | ||||
| -rw-r--r-- | extools/extract.h | 17 | ||||
| -rw-r--r-- | extools/extractor.cc | 1 | ||||
| -rw-r--r-- | extools/featurize_grammar.cc | 4 | ||||
| -rw-r--r-- | extools/filter_grammar.cc | 167 | ||||
| -rw-r--r-- | extools/filter_score_grammar.cc | 1 | ||||
| -rw-r--r-- | extools/mr_stripe_rule_reduce.cc | 1 | ||||
| -rw-r--r-- | extools/sg_lexer.l | 242 | ||||
| -rw-r--r-- | extools/striped_grammar.cc | 67 | ||||
| -rw-r--r-- | extools/striped_grammar.h | 54 | 
11 files changed, 412 insertions, 205 deletions
| diff --git a/extools/Makefile.am b/extools/Makefile.am index fc02f831..1c0da21b 100644 --- a/extools/Makefile.am +++ b/extools/Makefile.am @@ -8,15 +8,18 @@ bin_PROGRAMS = \  noinst_PROGRAMS = -filter_score_grammar_SOURCES = filter_score_grammar.cc extract.cc sentence_pair.cc +sg_lexer.cc: sg_lexer.l +	$(LEX) -s -CF -8 -o$@ $< + +filter_score_grammar_SOURCES = filter_score_grammar.cc extract.cc sentence_pair.cc striped_grammar.cc  filter_score_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a -lz  filter_score_grammar_LDFLAGS = -all-static -filter_grammar_SOURCES = filter_grammar.cc extract.cc sentence_pair.cc +filter_grammar_SOURCES = filter_grammar.cc extract.cc sentence_pair.cc striped_grammar.cc sg_lexer.cc  filter_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a -lz  filter_grammar_LDFLAGS = -all-static -featurize_grammar_SOURCES = featurize_grammar.cc extract.cc sentence_pair.cc +featurize_grammar_SOURCES = featurize_grammar.cc extract.cc sentence_pair.cc sg_lexer.cc striped_grammar.cc  featurize_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a -lz  featurize_grammar_LDFLAGS = -all-static @@ -24,11 +27,11 @@ build_lexical_translation_SOURCES = build_lexical_translation.cc extract.cc sent  build_lexical_translation_LDADD = $(top_srcdir)/decoder/libcdec.a -lz  build_lexical_translation_LDFLAGS = -all-static -mr_stripe_rule_reduce_SOURCES = mr_stripe_rule_reduce.cc extract.cc sentence_pair.cc +mr_stripe_rule_reduce_SOURCES = mr_stripe_rule_reduce.cc extract.cc sentence_pair.cc striped_grammar.cc  mr_stripe_rule_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a -lz  mr_stripe_rule_reduce_LDFLAGS = -all-static -extractor_SOURCES = sentence_pair.cc extract.cc extractor.cc +extractor_SOURCES = sentence_pair.cc extract.cc extractor.cc striped_grammar.cc  extractor_LDADD = $(top_srcdir)/decoder/libcdec.a -lz  extractor_LDFLAGS = -all-static diff --git a/extools/extract.cc b/extools/extract.cc index 14497089..567348f4 100644 --- a/extools/extract.cc +++ b/extools/extract.cc @@ -320,53 +320,3 @@ void Extract::ExtractConsistentRules(const AnnotatedParallelSentence& sentence,    }  } -void RuleStatistics::ParseRuleStatistics(const char* buf, int start, int end) { -  int ptr = start; -  counts.clear(); -  aligns.clear(); -  while (ptr < end) { -    SkipWhitespace(buf, &ptr); -    int vstart = ptr; -    while(ptr < end && buf[ptr] != '=') ++ptr; -    assert(buf[ptr] == '='); -    assert(ptr > vstart); -    if (buf[vstart] == 'A' && buf[vstart+1] == '=') { -      ++ptr; -      while (ptr < end && !IsWhitespace(buf[ptr])) { -        while(ptr < end && buf[ptr] == ',') { ++ptr; } -        assert(ptr < end); -        vstart = ptr; -        while(ptr < end && buf[ptr] != ',' && !IsWhitespace(buf[ptr])) { ++ptr; } -        if (ptr > vstart) { -          short a, b; -          AnnotatedParallelSentence::ReadAlignmentPoint(buf, vstart, ptr, false, &a, &b); -          aligns.push_back(make_pair(a,b)); -        } -      } -    } else { -      int name = FD::Convert(string(buf,vstart,ptr-vstart)); -      ++ptr; -      vstart = ptr; -      while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } -      assert(ptr > vstart); -      counts.set_value(name, strtod(buf + vstart, NULL)); -    } -  } -} - -ostream& operator<<(ostream& os, const RuleStatistics& s) { -  bool needspace = false; -  for (SparseVector<float>::const_iterator it = s.counts.begin(); it != s.counts.end(); ++it) { -    if (needspace) os << ' '; else needspace = true; -    os << FD::Convert(it->first) << '=' << it->second; -  } -  if (s.aligns.size() > 0) { -    os << " A="; -    needspace = false; -    for (int i = 0; i < s.aligns.size(); ++i) { -      if (needspace) os << ','; else needspace = true; -      os << s.aligns[i].first << '-' << s.aligns[i].second; -    } -  } -  return os; -}
\ No newline at end of file diff --git a/extools/extract.h b/extools/extract.h index 72017034..76292bed 100644 --- a/extools/extract.h +++ b/extools/extract.h @@ -91,21 +91,4 @@ struct Extract {                            std::vector<WordID>* all_cats);  }; -// represents statistics / information about a rule pair -struct RuleStatistics { -  SparseVector<float> counts; -  std::vector<std::pair<short,short> > aligns; -  RuleStatistics() {} -  RuleStatistics(int name, float val, const std::vector<std::pair<short,short> >& al) : -      aligns(al) { -    counts.set_value(name, val); -  } -  void ParseRuleStatistics(const char* buf, int start, int end); -  RuleStatistics& operator+=(const RuleStatistics& rhs) { -    counts += rhs.counts; -    return *this; -  } -}; -std::ostream& operator<<(std::ostream& os, const RuleStatistics& s); -  #endif diff --git a/extools/extractor.cc b/extools/extractor.cc index 7149bfd7..bc27e408 100644 --- a/extools/extractor.cc +++ b/extools/extractor.cc @@ -16,6 +16,7 @@  #include "wordid.h"  #include "array2d.h"  #include "filelib.h" +#include "striped_grammar.h"  using namespace std;  using namespace std::tr1; diff --git a/extools/featurize_grammar.cc b/extools/featurize_grammar.cc index 27c0dadf..547f390a 100644 --- a/extools/featurize_grammar.cc +++ b/extools/featurize_grammar.cc @@ -21,6 +21,7 @@  #include "tdict.h"  #include "lex_trans_tbl.h"  #include "filelib.h" +#include "striped_grammar.h"  #include <boost/tuple/tuple.hpp>  #include <boost/shared_ptr.hpp> @@ -137,7 +138,6 @@ bool validate_non_terminal(const std::string& s)  namespace {    inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } -  inline bool IsBracket(char c){return c == '[' || c == ']';}    inline void SkipWhitespace(const char* buf, int* ptr) {      while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); }    } @@ -407,7 +407,7 @@ struct BackoffRule : public FeatureExtractor {                                 const RuleStatistics& /*info*/,                                 SparseVector<float>* result) const {      (void) lhs; (void) src; (void) trg; -    string lhstr = TD::Convert(lhs); +    const string& lhstr = TD::Convert(lhs);      if(lhstr.find('_')!=string::npos)          result->set_value(fid_, -1);    } diff --git a/extools/filter_grammar.cc b/extools/filter_grammar.cc index 6f0dcdfc..ca329de1 100644 --- a/extools/filter_grammar.cc +++ b/extools/filter_grammar.cc @@ -6,8 +6,6 @@  #include <map>  #include <vector>  #include <utility> -#include <cstdlib> -#include <fstream>  #include <tr1/unordered_map>  #include "suffix_tree.h" @@ -16,8 +14,8 @@  #include "extract.h"  #include "fdict.h"  #include "tdict.h" -#include "lex_trans_tbl.h"  #include "filelib.h" +#include "striped_grammar.h"  #include <boost/shared_ptr.hpp>  #include <boost/functional/hash.hpp> @@ -30,8 +28,6 @@ namespace po = boost::program_options;  static const size_t MAX_LINE_LENGTH = 64000000; -typedef unordered_map<vector<WordID>, RuleStatistics, boost::hash<vector<WordID> > > ID2RuleStatistics; -  void InitCommandLine(int argc, char** argv, po::variables_map* conf) {    po::options_description opts("Configuration options");    opts.add_options() @@ -51,85 +47,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {      exit(1);    }  }    -namespace { -  inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } -  inline void SkipWhitespace(const char* buf, int* ptr) { -    while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } -  } -} - -int ReadPhraseUntilDividerOrEnd(const char* buf, const int sstart, const int end, vector<WordID>* p) { -  static const WordID kDIV = TD::Convert("|||"); -  int ptr = sstart; -  while(ptr < end) { -    while(ptr < end && IsWhitespace(buf[ptr])) { ++ptr; } -    int start = ptr; -    while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } -    if (ptr == start) {cerr << "Warning! empty token.\n"; return ptr; } -    const WordID w = TD::Convert(string(buf, start, ptr - start)); -    if (w == kDIV) return ptr; -    p->push_back(w); -  } -  assert(p->size() > 0); -  return ptr; -} - - -void ParseLine(const char* buf, vector<WordID>* cur_key, ID2RuleStatistics* counts) { -  static const WordID kDIV = TD::Convert("|||"); -  counts->clear(); -  int ptr = 0; -  while(buf[ptr] != 0 && buf[ptr] != '\t') { ++ptr; } -  if (buf[ptr] != '\t') { -    cerr << "Missing tab separator between key and value!\n INPUT=" << buf << endl; -    exit(1); -  } -  cur_key->clear(); -  // key is: "[X] ||| word word word" -  int tmpp = ReadPhraseUntilDividerOrEnd(buf, 0, ptr, cur_key); -  cur_key->push_back(kDIV); -  ReadPhraseUntilDividerOrEnd(buf, tmpp, ptr, cur_key); -  ++ptr; -  int start = ptr; -  int end = ptr; -  int state = 0; // 0=reading label, 1=reading count -  vector<WordID> name; -  while(buf[ptr] != 0) { -    while(buf[ptr] != 0 && buf[ptr] != '|') { ++ptr; } -    if (buf[ptr] == '|') { -      ++ptr; -      if (buf[ptr] == '|') { -        ++ptr; -        if (buf[ptr] == '|') { -          ++ptr; -          end = ptr - 3; -          while (end > start && IsWhitespace(buf[end-1])) { --end; } -          if (start == end) { -            cerr << "Got empty token!\n  LINE=" << buf << endl; -            exit(1); -          } -          switch (state) { -            case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; -            case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; -            default: cerr << "Can't happen\n"; abort(); -          } -          SkipWhitespace(buf, &ptr); -          start = ptr; -        } -      } -    } -  } -  end=ptr; -  while (end > start && IsWhitespace(buf[end-1])) { --end; } -  if (end > start) { -    switch (state) { -      case 0: ++state; name.clear(); ReadPhraseUntilDividerOrEnd(buf, start, end, &name); break; -      case 1: --state; (*counts)[name].ParseRuleStatistics(buf, start, end); break; -      default: cerr << "Can't happen\n"; abort(); -    } -  } -} -  struct SourceFilter {    // return true to keep the rule, otherwise false @@ -138,8 +55,7 @@ struct SourceFilter {  };  struct DumbSuffixTreeFilter : SourceFilter { -  DumbSuffixTreeFilter(const string& corpus) : -      kDIV(TD::Convert("|||")) { +  DumbSuffixTreeFilter(const string& corpus) {      cerr << "Build suffix tree from test set in " << corpus << endl;      assert(FileExists(corpus));      ReadFile rfts(corpus); @@ -163,68 +79,57 @@ struct DumbSuffixTreeFilter : SourceFilter {      }      delete[] buf;    } -  virtual bool Matches(const vector<WordID>& key) const { +  virtual bool Matches(const vector<WordID>& src_rhs) const {      const Node<int>* curnode = &root; -    const int ks = key.size() - 1; -    for(int i=0; i < ks; i++) { -      const string& word = TD::Convert(key[i]); -      if (key[i] == kDIV || (word[0] == '[' && word[word.size() - 1] == ']')) { // non-terminal +    for(int i=0; i < src_rhs.size(); i++) { +      if (src_rhs[i] <= 0) {          curnode = &root;        } else if (curnode) { -        curnode = curnode->Extend(key[i]); +        curnode = curnode->Extend(src_rhs[i]);          if (!curnode) return false;        }      }      return true;    } -  const WordID kDIV;    Node<int> root;  }; +boost::shared_ptr<SourceFilter> filter; +multimap<float, ID2RuleStatistics::const_iterator> options;  +int kCOUNT; +int max_options; + +void cb(WordID lhs, const vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void*) { +  options.clear(); +  if (!filter || filter->Matches(src_rhs)) { +    for (ID2RuleStatistics::const_iterator it = rules.begin(); it != rules.end(); ++it) { +      options.insert(make_pair(-it->second.counts.value(kCOUNT), it)); +    } +    int ocount = 0; +    cout << '[' << TD::Convert(-lhs) << ']' << " ||| "; +    WriteNamed(src_rhs, &cout); +    cout << '\t'; +    bool first = true; +    for (multimap<float,ID2RuleStatistics::const_iterator>::iterator it = options.begin(); it != options.end(); ++it) { +      if (first) { first = false; } else { cout << " ||| "; } +      WriteAnonymous(it->second->first, &cout); +      cout << " ||| " << it->second->second; +      ++ocount; +      if (ocount == max_options) break; +    } +    cout << endl; +  } +} +  int main(int argc, char** argv){    po::variables_map conf;    InitCommandLine(argc, argv, &conf); -  const int max_options = conf["top_e_given_f"].as<size_t>();; +  max_options = conf["top_e_given_f"].as<size_t>();; +  kCOUNT = FD::Convert("CFE");    istream& unscored_grammar = cin; -    cerr << "Loading test set " << conf["test_set"].as<string>() << "...\n"; -  boost::shared_ptr<SourceFilter> filter;    filter.reset(new DumbSuffixTreeFilter(conf["test_set"].as<string>())); -    cerr << "Filtering...\n"; -  //score unscored grammar -  char* buf = new char[MAX_LINE_LENGTH]; - -  ID2RuleStatistics acc, cur_counts; -  vector<WordID> key, cur_key,temp_key; -  int line = 0; - -  multimap<float, ID2RuleStatistics::const_iterator> options;  -  const int kCOUNT = FD::Convert("CFE"); -  while(!unscored_grammar.eof()) -    { -      ++line; -      options.clear(); -      unscored_grammar.getline(buf, MAX_LINE_LENGTH); -      if (buf[0] == 0) continue; -      ParseLine(buf, &cur_key, &cur_counts); -      if (!filter || filter->Matches(cur_key)) { -        // sort by counts -        for (ID2RuleStatistics::const_iterator it = cur_counts.begin(); it != cur_counts.end(); ++it) { -          options.insert(make_pair(-it->second.counts.value(kCOUNT), it)); -        } -        int ocount = 0; -        cout << TD::GetString(cur_key) << '\t'; - -        bool first = true; -        for (multimap<float,ID2RuleStatistics::const_iterator>::iterator it = options.begin(); it != options.end(); ++it) { -          if (first) { first = false; } else { cout << " ||| "; } -          cout << TD::GetString(it->second->first) << " ||| " << it->second->second; -          ++ocount; -          if (ocount == max_options) break; -        } -        cout << endl; -      } -    } +  StripedGrammarLexer::ReadStripedGrammar(&unscored_grammar, cb, NULL);  } diff --git a/extools/filter_score_grammar.cc b/extools/filter_score_grammar.cc index fe9a2a07..24f5fd1c 100644 --- a/extools/filter_score_grammar.cc +++ b/extools/filter_score_grammar.cc @@ -18,6 +18,7 @@  #include "tdict.h"  #include "lex_trans_tbl.h"  #include "filelib.h" +#include "striped_grammar.h"  #include <boost/shared_ptr.hpp>  #include <boost/functional/hash.hpp> diff --git a/extools/mr_stripe_rule_reduce.cc b/extools/mr_stripe_rule_reduce.cc index 902b6a07..3298a801 100644 --- a/extools/mr_stripe_rule_reduce.cc +++ b/extools/mr_stripe_rule_reduce.cc @@ -12,6 +12,7 @@  #include "sentence_pair.h"  #include "fdict.h"  #include "extract.h" +#include "striped_grammar.h"  using namespace std;  using namespace std::tr1; diff --git a/extools/sg_lexer.l b/extools/sg_lexer.l new file mode 100644 index 00000000..f115e5bd --- /dev/null +++ b/extools/sg_lexer.l @@ -0,0 +1,242 @@ +%{ +#include "rule_lexer.h" + +#include <string> +#include <iostream> +#include <sstream> +#include <cstring> +#include <cassert> +#include "tdict.h" +#include "fdict.h" +#include "trule.h" +#include "striped_grammar.h" + +int lex_line = 0; +std::istream* sglex_stream = NULL; +StripedGrammarLexer::GrammarCallback grammar_callback = NULL; +void* grammar_callback_extra = NULL; + +#undef YY_INPUT +#define YY_INPUT(buf, result, max_size) (result = sglex_stream->read(buf, max_size).gcount()) + +#define YY_SKIP_YYWRAP 1 +int num_rules = 0; +int yywrap() { return 1; } +bool fl = true; +#define MAX_TOKEN_SIZE 255 +std::string sglex_tmp_token(MAX_TOKEN_SIZE, '\0'); + +#define MAX_RULE_SIZE 48 +WordID sglex_src_rhs[MAX_RULE_SIZE]; +WordID sglex_trg_rhs[MAX_RULE_SIZE]; +int sglex_src_rhs_size; +int sglex_trg_rhs_size; +WordID sglex_lhs; +int sglex_src_arity; +int sglex_trg_arity; + +#define MAX_FEATS 100 +int sglex_feat_ids[MAX_FEATS]; +double sglex_feat_vals[MAX_FEATS]; +int sglex_num_feats; + +#define MAX_ARITY 20 +int sglex_nt_sanity[MAX_ARITY]; +int sglex_src_nts[MAX_ARITY]; +float sglex_nt_size_means[MAX_ARITY]; +float sglex_nt_size_vars[MAX_ARITY]; + +std::vector<WordID> cur_src_rhs; +std::vector<WordID> cur_trg_rhs; +ID2RuleStatistics cur_options; +RuleStatistics* cur_stats = NULL; +int sglex_cur_fid = 0; + +static void sanity_check_trg_index(int index) { +  if (index > sglex_src_arity) { +    std::cerr << "Target index " << index << " exceeds source arity " << sglex_src_arity << std::endl; +    abort(); +  } +  int& flag = sglex_nt_sanity[index - 1]; +  if (flag) { +    std::cerr << "Target index " << index << " used multiple times!" << std::endl; +    abort(); +  } +  flag = 1; +} + +static void sglex_reset() { +  sglex_src_arity = 0; +  sglex_trg_arity = 0; +  sglex_num_feats = 0; +  sglex_src_rhs_size = 0; +  sglex_trg_rhs_size = 0; +} + +%} + +REAL [\-+]?[0-9]+(\.[0-9]*([eE][-+]*[0-9]+)?)?|inf|[\-+]inf +NT [\-#$A-Z_:=.",\\][\-#$".A-Z+/=_0-9!:@\\]* +ALIGN [0-9]+-[0-9]+ + +%x LHS_END SRC TRG FEATS FEATVAL ALIGNS +%% + +<INITIAL>[ ]	; + +<INITIAL>\[{NT}\]   { +		sglex_tmp_token.assign(yytext + 1, yyleng - 2); +		sglex_lhs = -TD::Convert(sglex_tmp_token); +		// std::cerr << sglex_tmp_token << "\n"; +  		BEGIN(LHS_END); +		} + +<SRC>\[{NT}\]   { +		sglex_tmp_token.assign(yytext + 1, yyleng - 2); +		sglex_src_nts[sglex_src_arity] = sglex_src_rhs[sglex_src_rhs_size] = -TD::Convert(sglex_tmp_token); +		++sglex_src_arity; +		++sglex_src_rhs_size; +		} + +<LHS_END>[ ] { ; } +<LHS_END>\|\|\|	{ +		sglex_reset(); +		BEGIN(SRC); +		} +<INITIAL,LHS_END>.	{ +		std::cerr << "Line " << lex_line << ": unexpected input in LHS: " << yytext << std::endl; +		exit(1); +		} + + +<SRC>\[{NT},[1-9][0-9]?\]   { +		int index = yytext[yyleng - 2] - '0'; +		if (yytext[yyleng - 3] == ',') { +		  sglex_tmp_token.assign(yytext + 1, yyleng - 4); +		} else { +		  sglex_tmp_token.assign(yytext + 1, yyleng - 5); +		  index += 10 * (yytext[yyleng - 3] - '0'); +		} +		if ((sglex_src_arity+1) != index) { +			std::cerr << "Src indices must go in order: expected " << sglex_src_arity << " but got " << index << std::endl; +			abort(); +		} +		sglex_src_nts[sglex_src_arity] = sglex_src_rhs[sglex_src_rhs_size] = -TD::Convert(sglex_tmp_token); +		++sglex_src_rhs_size; +		++sglex_src_arity; +		} + +<SRC>[^ \t]+	{  +		sglex_tmp_token.assign(yytext, yyleng); +		sglex_src_rhs[sglex_src_rhs_size] = TD::Convert(sglex_tmp_token); +		++sglex_src_rhs_size; +		} +<SRC>[ ]	{ ; } +<SRC>\t		{ +		//std::cerr << "LHS=" << TD::Convert(-sglex_lhs) << " "; +		//std::cerr << "  src_size: " << sglex_src_rhs_size << std::endl; +		//std::cerr << "  src_arity: " << sglex_src_arity << std::endl; +		memset(sglex_nt_sanity, 0, sglex_src_arity * sizeof(int)); +		cur_options.clear(); +		sglex_trg_rhs_size = 0; +		BEGIN(TRG); +		} + +<TRG>\[[1-9][0-9]?\]   { +		int index = yytext[yyleng - 2] - '0'; +		if (yyleng == 4) { +		  index += 10 * (yytext[yyleng - 3] - '0'); +		} +		++sglex_trg_arity; +		sanity_check_trg_index(index); +		sglex_trg_rhs[sglex_trg_rhs_size] = 1 - index; +		++sglex_trg_rhs_size; +} + +<TRG>\|\|\|	{ +		assert(sglex_trg_rhs_size > 0); +		cur_trg_rhs.resize(sglex_trg_rhs_size); +		for (int i = 0; i < sglex_trg_rhs_size; ++i) +			cur_trg_rhs[i] = sglex_trg_rhs[i]; +		cur_stats = &cur_options[cur_trg_rhs]; +		BEGIN(FEATS); +		} + +<TRG>[^ ]+	{ +		sglex_tmp_token.assign(yytext, yyleng); +		sglex_trg_rhs[sglex_trg_rhs_size] = TD::Convert(sglex_tmp_token); +		 +		++sglex_trg_rhs_size; +		} +<TRG>[ ]+	{ ; } + +<FEATS>\n	{ +		assert(sglex_lhs < 0); +		assert(sglex_src_rhs_size > 0); +		cur_src_rhs.resize(sglex_src_rhs_size); +		for (int i = 0; i < sglex_src_rhs_size; ++i) +			cur_src_rhs[i] = sglex_src_rhs[i]; +		grammar_callback(sglex_lhs, cur_src_rhs, cur_options, grammar_callback_extra); +		cur_options.clear(); +		BEGIN(INITIAL); +		} +<FEATS>[ ]+	{ ; } +<FEATS>\|\|\|	{ +		memset(sglex_nt_sanity, 0, sglex_src_arity * sizeof(int)); +		sglex_trg_rhs_size = 0; +		BEGIN(TRG); +		} +<FEATS>[A-Z][A-Z_0-9]*=	{ +		// std::cerr << "FV: " << yytext << std::endl; +		sglex_tmp_token.assign(yytext, yyleng - 1); +		sglex_cur_fid = FD::Convert(sglex_tmp_token); +		static const int Afid = FD::Convert("A"); +		if (sglex_cur_fid == Afid) { +			BEGIN(ALIGNS); +		} else { +			BEGIN(FEATVAL); +		} +		} +<FEATVAL>{REAL}	{ +		// std::cerr << "Feature val input: " << yytext << std::endl; +		cur_stats->counts.set_value(sglex_cur_fid, strtod(yytext, NULL)); +		BEGIN(FEATS); +		} +<FEATVAL>.	{ +		std::cerr << "Feature val unexpected input: " << yytext << std::endl; +		exit(1); +		} +<FEATS>.	{ +		std::cerr << "Features unexpected input: " << yytext << std::endl; +		exit(1); +		} +<ALIGNS>{ALIGN}(,{ALIGN})*	{ +		assert(cur_stats->aligns.empty()); +		int i = 0; +		while(i < yyleng) { +			short a = 0; +			short b = 0; +			while (yytext[i] != '-') { a *= 10; a += yytext[i] - '0'; ++i; } +			++i; +			while (yytext[i] != ',' && i < yyleng) { b *= 10; b += yytext[i] - '0'; ++i; } +			++i; +			cur_stats->aligns.push_back(std::make_pair(a,b)); +		} +		BEGIN(FEATS); +		} +<ALIGNS>.	{ +		std::cerr << "Aligns unexpected input: " << yytext << std::endl; +		exit(1); +		} +%% + +#include "filelib.h" + +void StripedGrammarLexer::ReadStripedGrammar(std::istream* in, GrammarCallback func, void* extra) { +  lex_line = 1; +  sglex_stream = in; +  grammar_callback_extra = extra; +  grammar_callback = func; +  yylex(); +} + diff --git a/extools/striped_grammar.cc b/extools/striped_grammar.cc new file mode 100644 index 00000000..accf44eb --- /dev/null +++ b/extools/striped_grammar.cc @@ -0,0 +1,67 @@ +#include "striped_grammar.h" + +#include <iostream> + +#include "sentence_pair.h" + +using namespace std; + +namespace { +  inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } + +  inline void SkipWhitespace(const char* buf, int* ptr) { +    while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } +  } +} + +void RuleStatistics::ParseRuleStatistics(const char* buf, int start, int end) { +  int ptr = start; +  counts.clear(); +  aligns.clear(); +  while (ptr < end) { +    SkipWhitespace(buf, &ptr); +    int vstart = ptr; +    while(ptr < end && buf[ptr] != '=') ++ptr; +    assert(buf[ptr] == '='); +    assert(ptr > vstart); +    if (buf[vstart] == 'A' && buf[vstart+1] == '=') { +      ++ptr; +      while (ptr < end && !IsWhitespace(buf[ptr])) { +        while(ptr < end && buf[ptr] == ',') { ++ptr; } +        assert(ptr < end); +        vstart = ptr; +        while(ptr < end && buf[ptr] != ',' && !IsWhitespace(buf[ptr])) { ++ptr; } +        if (ptr > vstart) { +          short a, b; +          AnnotatedParallelSentence::ReadAlignmentPoint(buf, vstart, ptr, false, &a, &b); +          aligns.push_back(make_pair(a,b)); +        } +      } +    } else { +      int name = FD::Convert(string(buf,vstart,ptr-vstart)); +      ++ptr; +      vstart = ptr; +      while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } +      assert(ptr > vstart); +      counts.set_value(name, strtod(buf + vstart, NULL)); +    } +  } +} + +ostream& operator<<(ostream& os, const RuleStatistics& s) { +  bool needspace = false; +  for (SparseVector<float>::const_iterator it = s.counts.begin(); it != s.counts.end(); ++it) { +    if (needspace) os << ' '; else needspace = true; +    os << FD::Convert(it->first) << '=' << it->second; +  } +  if (s.aligns.size() > 0) { +    os << " A="; +    needspace = false; +    for (int i = 0; i < s.aligns.size(); ++i) { +      if (needspace) os << ','; else needspace = true; +      os << s.aligns[i].first << '-' << s.aligns[i].second; +    } +  } +  return os; +} + diff --git a/extools/striped_grammar.h b/extools/striped_grammar.h new file mode 100644 index 00000000..cdf529d6 --- /dev/null +++ b/extools/striped_grammar.h @@ -0,0 +1,54 @@ +#ifndef _STRIPED_GRAMMAR_H_ +#define _STRIPED_GRAMMAR_H_ + +#include <iostream> +#include <boost/functional/hash.hpp> +#include <vector> +#include <tr1/unordered_map> +#include "sparse_vector.h" +#include "wordid.h" +#include "tdict.h" + +// represents statistics / information about a rule pair +struct RuleStatistics { +  SparseVector<float> counts; +  std::vector<std::pair<short,short> > aligns; +  RuleStatistics() {} +  RuleStatistics(int name, float val, const std::vector<std::pair<short,short> >& al) : +      aligns(al) { +    counts.set_value(name, val); +  } +  void ParseRuleStatistics(const char* buf, int start, int end); +  RuleStatistics& operator+=(const RuleStatistics& rhs) { +    counts += rhs.counts; +    return *this; +  } +}; +std::ostream& operator<<(std::ostream& os, const RuleStatistics& s); + +inline void WriteNamed(const std::vector<WordID>& v, std::ostream* os) { +  bool first = true; +  for (int i = 0; i < v.size(); ++i) { +    if (first) { first = false; } else { (*os) << ' '; } +    if (v[i] < 0) { (*os) << '[' << TD::Convert(-v[i]) << ']'; } +    else (*os) << TD::Convert(v[i]); +  } +} + +inline void WriteAnonymous(const std::vector<WordID>& v, std::ostream* os) { +  bool first = true; +  for (int i = 0; i < v.size(); ++i) { +    if (first) { first = false; } else { (*os) << ' '; } +    if (v[i] <= 0) { (*os) << '[' << (1-v[i]) << ']'; } +    else (*os) << TD::Convert(v[i]); +  } +} + +typedef std::tr1::unordered_map<std::vector<WordID>, RuleStatistics, boost::hash<std::vector<WordID> > > ID2RuleStatistics; + +struct StripedGrammarLexer { +  typedef void (*GrammarCallback)(WordID lhs, const std::vector<WordID>& src_rhs, const ID2RuleStatistics& rules, void *extra); +  static void ReadStripedGrammar(std::istream* in, GrammarCallback func, void* extra); +}; + +#endif | 
