diff options
| author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-22 05:12:27 +0000 | 
|---|---|---|
| committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-22 05:12:27 +0000 | 
| commit | 0172721855098ca02b207231a654dffa5e4eb1c9 (patch) | |
| tree | 8069c3a62e2d72bd64a2cdeee9724b2679c8a56b /extools/extract.cc | |
| parent | 37728b8be4d0b3df9da81fdda2198ff55b4b2d91 (diff) | |
initial checkin
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@2 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'extools/extract.cc')
| -rw-r--r-- | extools/extract.cc | 341 | 
1 files changed, 341 insertions, 0 deletions
diff --git a/extools/extract.cc b/extools/extract.cc new file mode 100644 index 00000000..6ad124d2 --- /dev/null +++ b/extools/extract.cc @@ -0,0 +1,341 @@ +#include "extract.h" + +#include <queue> +#include <vector> +#include <utility> +#include <tr1/unordered_map> +#include <set> + +#include <boost/functional/hash.hpp> + +#include "sentence_pair.h" +#include "tdict.h" +#include "wordid.h" +#include "array2d.h" + +using namespace std; +using namespace tr1; + +namespace { +  inline bool IsWhitespace(char c) { return c == ' ' || c == '\t'; } + +  inline void SkipWhitespace(const char* buf, int* ptr) { +    while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } +  } +} + +Extract::RuleObserver::~RuleObserver() { +  cerr << "Rules extracted: " << count << endl; +} + +void Extract::ExtractBasePhrases(const int max_base_phrase_size, +                        const AnnotatedParallelSentence& sentence, +                        vector<ParallelSpan>* phrases) { +  phrases->clear(); + +  vector<pair<int,int> > f_spans(sentence.f_len, pair<int,int>(sentence.e_len, 0)); +  vector<pair<int,int> > e_spans(sentence.e_len, pair<int,int>(sentence.f_len, 0)); +  // for each alignment point in e, precompute the minimal consistent phrases in f +  // for each alignment point in f, precompute the minimal consistent phrases in e +  for (int i = 0; i < sentence.f_len; ++i) { +    for (int j = 0; j < sentence.e_len; ++j) { +      if (sentence.aligned(i,j)) { +        if (j < f_spans[i].first) f_spans[i].first = j; +        f_spans[i].second = j+1; +        if (i < e_spans[j].first) e_spans[j].first = i; +        e_spans[j].second = i+1; +      } +    } +  } + +  for (int i1 = 0; i1 < sentence.f_len; ++i1) { +    if (sentence.f_aligned[i1] == 0) continue; +    int j1 = sentence.e_len; +    int j2 = 0; +    const int i_limit = min(sentence.f_len, i1 + max_base_phrase_size); +    for (int i2 = i1 + 1; i2 <= i_limit; ++i2) { +      if (sentence.f_aligned[i2-1] == 0) continue; +      // cerr << "F has aligned span " << i1 << " to " << i2 << endl; +      j1 = min(j1, f_spans[i2-1].first); +      j2 = max(j2, f_spans[i2-1].second); +      if (j1 >= j2) continue; +      if (j2 - j1 > max_base_phrase_size) continue; +      int condition = 0; +      for (int j = j1; j < j2; ++j) { +        if (e_spans[j].first < i1) { condition = 1; break; } +        if (e_spans[j].second > i2) { condition = 2; break; } +      } +      if (condition == 1) break; +      if (condition == 2) continue; +      // category types added later! +      phrases->push_back(ParallelSpan(i1, i2, j1, j2)); +      // cerr << i1 << " " << i2 << " : " << j1 << " " << j2 << endl; +    } +  } +} + +void Extract::LoosenPhraseBounds(const AnnotatedParallelSentence& sentence, +                                 const int max_base_phrase_size, +                                 vector<ParallelSpan>* phrases) { +  const int num_phrases = phrases->size(); +  map<int, map<int, map<int, map<int, bool> > > > marker; +  for (int i = 0; i < num_phrases; ++i) { +    const ParallelSpan& cur = (*phrases)[i]; +    marker[cur.i1][cur.i2][cur.j1][cur.j2] = true; +  } +  for (int i = 0; i < num_phrases; ++i) { +    const ParallelSpan& cur = (*phrases)[i]; +    const int i1_max = cur.i1; +    const int i2_min = cur.i2; +    const int j1_max = cur.j1; +    const int j2_min = cur.j2; +    int i1_min = i1_max; +    while (i1_min > 0 && sentence.f_aligned[i1_min-1] == 0) { --i1_min; } +    int j1_min = j1_max; +    while (j1_min > 0 && sentence.e_aligned[j1_min-1] == 0) { --j1_min; } +    int i2_max = i2_min; +    while (i2_max < sentence.f_len && sentence.f_aligned[i2_max] == 0) { ++i2_max; } +    int j2_max = j2_min; +    while (j2_max < sentence.e_len && sentence.e_aligned[j2_max] == 0) { ++j2_max; } +    for (int i1 = i1_min; i1 <= i1_max; ++i1) { +      const int ilim = min(i2_max, i1 + max_base_phrase_size); +      for (int i2 = max(i1+1,i2_min); i2 <= ilim; ++i2) { +        for (int j1 = j1_min; j1 <= j1_max; ++j1) { +          const int jlim = min(j2_max, j1 + max_base_phrase_size); +          for (int j2 = max(j1+1, j2_min); j2 <= jlim; ++j2) { +            bool& seen = marker[i1][i2][j1][j2]; +            if (!seen) +              phrases->push_back(ParallelSpan(i1,i2,j1,j2)); +            seen = true; +          } +        } +      } +    } +  } +} + +// this uses the TARGET span (i,j) to annotate phrases, will copy +// phrases if there is more than one annotation. +// TODO: support source annotation +void Extract::AnnotatePhrasesWithCategoryTypes(const WordID default_cat, +                                      const Array2D<vector<WordID> >& types, +                                      vector<ParallelSpan>* phrases) { +  const int num_unannotated_phrases = phrases->size(); +  // have to use num_unannotated_phrases since we may grow the vector +  for (int i = 0; i < num_unannotated_phrases; ++i) { +    ParallelSpan& phrase = (*phrases)[i]; +    const vector<WordID>* pcats = &types(phrase.j1, phrase.j2); +    if (pcats->empty() && default_cat != 0) { +      static vector<WordID> s_default(1, default_cat); +      pcats = &s_default; +    } +    if (pcats->empty()) { +      cerr << "ERROR span " << phrase.i1 << "," << phrase.i2 << "-" +           << phrase.j1 << "," << phrase.j2 << " has no type. " +              "Did you forget --default_category?\n"; +    } +    const vector<WordID>& cats = *pcats; +    phrase.cat = cats[0]; +    for (int ci = 1; ci < cats.size(); ++ci) { +      ParallelSpan new_phrase = phrase; +      new_phrase.cat = cats[ci]; +      phrases->push_back(new_phrase); +    } +  } +} + +// a partially complete (f-side) of a rule +struct RuleItem { +  vector<ParallelSpan> f; +  int i,j,syms,vars; +  explicit RuleItem(int pi) : i(pi), j(pi), syms(), vars() {} +  void Extend(const WordID& fword) { +    f.push_back(ParallelSpan(fword)); +    ++j; +    ++syms; +  } +  void Extend(const ParallelSpan& subphrase) { +    f.push_back(subphrase); +    j += subphrase.i2 - subphrase.i1; +    ++vars; +    ++syms; +  } +  bool RuleFEndsInVariable() const { +    if (f.size() > 0) { +      return f.back().IsVariable(); +    } else { return false; } +  } +}; + +void Extract::ExtractConsistentRules(const AnnotatedParallelSentence& sentence, +                          const vector<ParallelSpan>& phrases, +                          const int max_vars, +                          const int max_syms, +                          const bool permit_adjacent_nonterminals, +                          const bool require_aligned_terminal, +                          RuleObserver* observer) { +  queue<RuleItem> q;  // agenda for BFS +  int max_len = -1; +  unordered_map<pair<short, short>, vector<ParallelSpan>, boost::hash<pair<short, short> > > fspans; +  vector<vector<ParallelSpan> > spans_by_start(sentence.f_len); +  set<int> starts; +  for (int i = 0; i < phrases.size(); ++i) { +    fspans[make_pair(phrases[i].i1,phrases[i].i2)].push_back(phrases[i]); +    max_len = max(max_len, phrases[i].i2 - phrases[i].i1); +    // have we already added a rule item starting at phrases[i].i1? +    if (starts.insert(phrases[i].i1).second) +      q.push(RuleItem(phrases[i].i1)); +    spans_by_start[phrases[i].i1].push_back(phrases[i]); +  } +  starts.clear(); +  vector<pair<int,int> > next_e(sentence.e_len); +  vector<WordID> cur_rhs_f, cur_rhs_e; +  vector<pair<short, short> > cur_terminal_align; +  vector<int> cur_es, cur_fs; +  while(!q.empty()) { +    const RuleItem& rule = q.front(); + +    // extend the partial rule +    if (rule.j < sentence.f_len && (rule.j - rule.i) < max_len && rule.syms < max_syms) { +      RuleItem ew = rule; + +      // extend with a word +      ew.Extend(sentence.f[ew.j]); +      q.push(ew); + +      // with variables +      if (rule.vars < max_vars && +          !spans_by_start[rule.j].empty() && +          ((!rule.RuleFEndsInVariable()) || permit_adjacent_nonterminals)) { +        const vector<ParallelSpan>& sub_phrases = spans_by_start[rule.j]; +        for (int it = 0; it < sub_phrases.size(); ++it) { +          if (sub_phrases[it].i2 - sub_phrases[it].i1 + rule.j - rule.i <= max_len) { +            RuleItem ev = rule; +            ev.Extend(sub_phrases[it]); +            q.push(ev); +            assert(ev.j <= sentence.f_len); +          } +        } +      } +    } +    // determine if rule is consistent +    if (rule.syms > 0 && +        fspans.count(make_pair(rule.i,rule.j)) && +        (!rule.RuleFEndsInVariable() || rule.syms > 1)) { +      const vector<ParallelSpan>& orig_spans = fspans[make_pair(rule.i,rule.j)]; +      for (int s = 0; s < orig_spans.size(); ++s) { +        const ParallelSpan& orig_span = orig_spans[s]; +        const WordID lhs = orig_span.cat; +        for (int j = orig_span.j1; j < orig_span.j2; ++j) next_e[j].first = -1; +        int nt_index_e = 0; +        for (int i = 0; i < rule.f.size(); ++i) { +          const ParallelSpan& cur = rule.f[i]; +          if (cur.IsVariable()) +            next_e[cur.j1] = pair<int,int>(cur.j2, ++nt_index_e); +        } +        cur_rhs_f.clear(); +        cur_rhs_e.clear(); +        cur_terminal_align.clear(); +        cur_fs.clear(); +        cur_es.clear(); + +        const int elen = orig_span.j2 - orig_span.j1; +        vector<int> isvar(elen, 0); +        int fbias = rule.i; +        bool bad_rule = false; +        bool has_aligned_terminal = false; +        for (int i = 0; i < rule.f.size(); ++i) { +          const ParallelSpan& cur = rule.f[i]; +          cur_rhs_f.push_back(cur.cat); +          if (cur.cat > 0) {   // terminal +            if (sentence.f_aligned[fbias + i]) has_aligned_terminal = true; +            cur_fs.push_back(fbias + i); +          } else {             // non-terminal +            int subj1 = cur.j1 - orig_span.j1; +            int subj2 = cur.j2 - orig_span.j1; +            if (subj1 < 0 || subj2 > elen) { bad_rule = true; break; } +            for (int j = subj1; j < subj2 && !bad_rule; ++j) { +              int& isvarj = isvar[j]; +              isvarj = true; +            } +            if (bad_rule) break; +            cur_fs.push_back(-1); +            fbias += cur.i2 - cur.i1 - 1; +          } +        } +        if (require_aligned_terminal && !has_aligned_terminal) bad_rule = true; +        if (!bad_rule) { +          for (int j = orig_span.j1; j < orig_span.j2; ++j) { +            if (next_e[j].first < 0) { +              cur_rhs_e.push_back(sentence.e[j]); +              cur_es.push_back(j); +            } else { +              cur_rhs_e.push_back(1 - next_e[j].second);  // next_e[j].second is NT gap index +              cur_es.push_back(-1); +              j = next_e[j].first - 1; +            } +          } +          for (short i = 0; i < cur_fs.size(); ++i) +            if (cur_fs[i] >= 0) +              for (short j = 0; j < cur_es.size(); ++j) +                if (cur_es[j] >= 0 && sentence.aligned(cur_fs[i],cur_es[j])) +                  cur_terminal_align.push_back(make_pair(i,j)); +          observer->CountRule(lhs, cur_rhs_f, cur_rhs_e, cur_terminal_align); +        } +      } +    } +    q.pop(); +  } +} + +void RuleStatistics::ParseRuleStatistics(const char* buf, int start, int end) { +  int ptr = start; +  counts.clear(); +  aligns.clear(); +  while (ptr < end) { +    SkipWhitespace(buf, &ptr); +    int vstart = ptr; +    while(ptr < end && buf[ptr] != '=') ++ptr; +    assert(buf[ptr] == '='); +    assert(ptr > vstart); +    if (buf[vstart] == 'A' && buf[vstart+1] == '=') { +      ++ptr; +      while (ptr < end && !IsWhitespace(buf[ptr])) { +        while(ptr < end && buf[ptr] == ',') { ++ptr; } +        assert(ptr < end); +        vstart = ptr; +        while(ptr < end && buf[ptr] != ',' && !IsWhitespace(buf[ptr])) { ++ptr; } +        if (ptr > vstart) { +          short a, b; +          AnnotatedParallelSentence::ReadAlignmentPoint(buf, vstart, ptr, false, &a, &b); +          aligns.push_back(make_pair(a,b)); +        } +      } +    } else { +      int name = FD::Convert(string(buf,vstart,ptr-vstart)); +      ++ptr; +      vstart = ptr; +      while(ptr < end && !IsWhitespace(buf[ptr])) { ++ptr; } +      assert(ptr > vstart); +      counts.set_value(name, strtod(buf + vstart, NULL)); +    } +  } +} + +ostream& operator<<(ostream& os, const RuleStatistics& s) { +  bool needspace = false; +  for (SparseVector<float>::const_iterator it = s.counts.begin(); it != s.counts.end(); ++it) { +    if (needspace) os << ' '; else needspace = true; +    os << FD::Convert(it->first) << '=' << it->second; +  } +  if (s.aligns.size() > 0) { +    os << " A="; +    needspace = false; +    for (int i = 0; i < s.aligns.size(); ++i) { +      if (needspace) os << ','; else needspace = true; +      os << s.aligns[i].first << '-' << s.aligns[i].second; +    } +  } +  return os; +} +  | 
