summaryrefslogtreecommitdiff
path: root/decoder/trule.h
blob: d2b1babe4efaa67f3ade23de1e69ebc3f4bc0295 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#ifndef _RULE_H_
#define _RULE_H_

#include <algorithm>
#include <vector>
#include <cassert>
#include <boost/shared_ptr.hpp>

#include "sparse_vector.h"
#include "wordid.h"

class TRule;
typedef boost::shared_ptr<TRule> TRulePtr;
struct SpanInfo;

// Translation rule
class TRule {
 public:
  TRule() : lhs_(0), prev_i(-1), prev_j(-1) { }
  explicit TRule(const std::vector<WordID>& e) : e_(e), lhs_(0), prev_i(-1), prev_j(-1) {}
  TRule(const std::vector<WordID>& e, const std::vector<WordID>& f, const WordID& lhs) :
    e_(e), f_(f), lhs_(lhs), prev_i(-1), prev_j(-1) {}

  // deprecated - this will be private soon
  explicit TRule(const std::string& text, bool strict = false, bool mono = false) {
    ReadFromString(text, strict, mono);
  }

  // make a rule from a hiero-like rule table, e.g.
  //    [X] ||| [X,1] DE [X,2] ||| [X,2] of the [X,1]
  // if misformatted, returns NULL
  static TRule* CreateRuleSynchronous(const std::string& rule);

  // make a rule from a phrasetable entry (i.e., one that has no LHS type), e.g:
  //    el gato ||| the cat ||| Feature_2=0.34
  static TRule* CreateRulePhrasetable(const std::string& rule);

  // make a rule from a non-synchrnous CFG representation, e.g.:
  //    [LHS] ||| term1 [NT] term2 [OTHER_NT] [YET_ANOTHER_NT]
  static TRule* CreateRuleMonolingual(const std::string& rule);

  void ESubstitute(const std::vector<const std::vector<WordID>* >& var_values,
                   std::vector<WordID>* result) const {
    int vc = 0;
    result->clear();
    for (std::vector<WordID>::const_iterator i = e_.begin(); i != e_.end(); ++i) {
      const WordID& c = *i;
      if (c < 1) {
        ++vc;
        const std::vector<WordID>& var_value = *var_values[-c];
        std::copy(var_value.begin(),
                  var_value.end(),
                  std::back_inserter(*result));
      } else {
        result->push_back(c);
      }
    }
    assert(vc == var_values.size());
  }

  void FSubstitute(const std::vector<const std::vector<WordID>* >& var_values,
                   std::vector<WordID>* result) const {
    int vc = 0;
    result->clear();
    for (std::vector<WordID>::const_iterator i = f_.begin(); i != f_.end(); ++i) {
      const WordID& c = *i;
      if (c < 1) {
        const std::vector<WordID>& var_value = *var_values[vc++];
        std::copy(var_value.begin(),
                  var_value.end(),
                  std::back_inserter(*result));
      } else {
        result->push_back(c);
      }
    }
    assert(vc == var_values.size());
  }

  bool ReadFromString(const std::string& line, bool strict = false, bool monolingual = false);

  bool Initialized() const { return e_.size(); }

  std::string AsString(bool verbose = true) const;

  static TRule DummyRule() {
    TRule res;
    res.e_.resize(1, 0);
    return res;
  }

  const std::vector<WordID>& f() const { return f_; }
  const std::vector<WordID>& e() const { return e_; }

  int EWords() const { return ELength() - Arity(); }
  int FWords() const { return FLength() - Arity(); }
  int FLength() const { return f_.size(); }
  int ELength() const { return e_.size(); }
  int Arity() const { return arity_; }
  bool IsUnary() const { return (Arity() == 1) && (f_.size() == 1); }
  const SparseVector<double>& GetFeatureValues() const { return scores_; }
  double Score(int i) const { return scores_[i]; }
  WordID GetLHS() const { return lhs_; }
  void ComputeArity();

  // 0 = first variable, -1 = second variable, -2 = third ...
  std::vector<WordID> e_;
  // < 0: *-1 = encoding of category of variable
  std::vector<WordID> f_;
  WordID lhs_;
  SparseVector<double> scores_;
  char arity_;
  TRulePtr parent_rule_;  // usually NULL, except when doing constrained decoding

  // this is only used when doing synchronous parsing
  short int prev_i;
  short int prev_j;

 private:
  bool SanityCheck() const;
};

#endif