summaryrefslogtreecommitdiff
path: root/decoder/trule.h
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-22 05:12:27 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-22 05:12:27 +0000
commit0172721855098ca02b207231a654dffa5e4eb1c9 (patch)
tree8069c3a62e2d72bd64a2cdeee9724b2679c8a56b /decoder/trule.h
parent37728b8be4d0b3df9da81fdda2198ff55b4b2d91 (diff)
initial checkin
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@2 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder/trule.h')
-rw-r--r--decoder/trule.h145
1 files changed, 145 insertions, 0 deletions
diff --git a/decoder/trule.h b/decoder/trule.h
new file mode 100644
index 00000000..7fb92924
--- /dev/null
+++ b/decoder/trule.h
@@ -0,0 +1,145 @@
+#ifndef _RULE_H_
+#define _RULE_H_
+
+#include <algorithm>
+#include <vector>
+#include <cassert>
+#include <boost/shared_ptr.hpp>
+
+#include "sparse_vector.h"
+#include "wordid.h"
+
+class TRule;
+typedef boost::shared_ptr<TRule> TRulePtr;
+
+struct NTSizeSummaryStatistics {
+ NTSizeSummaryStatistics(int arity) : means(arity), vars(arity) {}
+ std::vector<float> means;
+ std::vector<float> vars;
+};
+
+// Translation rule
+class TRule {
+ public:
+ TRule() : lhs_(0), prev_i(-1), prev_j(-1) { }
+ TRule(WordID lhs, const WordID* src, int src_size, const WordID* trg, int trg_size, const int* feat_ids, const double* feat_vals, int feat_size, int arity) :
+ e_(trg, trg + trg_size), f_(src, src + src_size), lhs_(lhs), arity_(arity), prev_i(-1), prev_j(-1) {
+ for (int i = 0; i < feat_size; ++i)
+ scores_.set_value(feat_ids[i], feat_vals[i]);
+ }
+
+ explicit TRule(const std::vector<WordID>& e) : e_(e), lhs_(0), prev_i(-1), prev_j(-1) {}
+ TRule(const std::vector<WordID>& e, const std::vector<WordID>& f, const WordID& lhs) :
+ e_(e), f_(f), lhs_(lhs), prev_i(-1), prev_j(-1) {}
+
+ // deprecated - this will be private soon
+ explicit TRule(const std::string& text, bool strict = false, bool mono = false) : prev_i(-1), prev_j(-1) {
+ ReadFromString(text, strict, mono);
+ }
+
+ // deprecated, use lexer
+ // make a rule from a hiero-like rule table, e.g.
+ // [X] ||| [X,1] DE [X,2] ||| [X,2] of the [X,1]
+ // if misformatted, returns NULL
+ static TRule* CreateRuleSynchronous(const std::string& rule);
+
+ // deprecated, use lexer
+ // make a rule from a phrasetable entry (i.e., one that has no LHS type), e.g:
+ // el gato ||| the cat ||| Feature_2=0.34
+ static TRule* CreateRulePhrasetable(const std::string& rule);
+
+ // deprecated, use lexer
+ // make a rule from a non-synchrnous CFG representation, e.g.:
+ // [LHS] ||| term1 [NT] term2 [OTHER_NT] [YET_ANOTHER_NT]
+ static TRule* CreateRuleMonolingual(const std::string& rule);
+
+ static TRule* CreateLexicalRule(const WordID& src, const WordID& trg) {
+ return new TRule(src, trg);
+ }
+
+ void ESubstitute(const std::vector<const std::vector<WordID>* >& var_values,
+ std::vector<WordID>* result) const {
+ int vc = 0;
+ result->clear();
+ for (std::vector<WordID>::const_iterator i = e_.begin(); i != e_.end(); ++i) {
+ const WordID& c = *i;
+ if (c < 1) {
+ ++vc;
+ const std::vector<WordID>& var_value = *var_values[-c];
+ std::copy(var_value.begin(),
+ var_value.end(),
+ std::back_inserter(*result));
+ } else {
+ result->push_back(c);
+ }
+ }
+ assert(vc == var_values.size());
+ }
+
+ void FSubstitute(const std::vector<const std::vector<WordID>* >& var_values,
+ std::vector<WordID>* result) const {
+ int vc = 0;
+ result->clear();
+ for (std::vector<WordID>::const_iterator i = f_.begin(); i != f_.end(); ++i) {
+ const WordID& c = *i;
+ if (c < 1) {
+ const std::vector<WordID>& var_value = *var_values[vc++];
+ std::copy(var_value.begin(),
+ var_value.end(),
+ std::back_inserter(*result));
+ } else {
+ result->push_back(c);
+ }
+ }
+ assert(vc == var_values.size());
+ }
+
+ bool ReadFromString(const std::string& line, bool strict = false, bool monolingual = false);
+
+ bool Initialized() const { return e_.size(); }
+
+ std::string AsString(bool verbose = true) const;
+
+ static TRule DummyRule() {
+ TRule res;
+ res.e_.resize(1, 0);
+ return res;
+ }
+
+ const std::vector<WordID>& f() const { return f_; }
+ const std::vector<WordID>& e() const { return e_; }
+
+ int EWords() const { return ELength() - Arity(); }
+ int FWords() const { return FLength() - Arity(); }
+ int FLength() const { return f_.size(); }
+ int ELength() const { return e_.size(); }
+ int Arity() const { return arity_; }
+ bool IsUnary() const { return (Arity() == 1) && (f_.size() == 1); }
+ const SparseVector<double>& GetFeatureValues() const { return scores_; }
+ double Score(int i) const { return scores_[i]; }
+ WordID GetLHS() const { return lhs_; }
+ void ComputeArity();
+
+ // 0 = first variable, -1 = second variable, -2 = third ...
+ std::vector<WordID> e_;
+ // < 0: *-1 = encoding of category of variable
+ std::vector<WordID> f_;
+ WordID lhs_;
+ SparseVector<double> scores_;
+
+ char arity_;
+ TRulePtr parent_rule_; // usually NULL, except when doing constrained decoding
+
+ // this is only used when doing synchronous parsing
+ short int prev_i;
+ short int prev_j;
+
+ // may be null
+ boost::shared_ptr<NTSizeSummaryStatistics> nt_size_summary_;
+
+ private:
+ TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {}
+ bool SanityCheck() const;
+};
+
+#endif