summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/tree2string_translator.cc34
-rw-r--r--decoder/tree_fragment.cc10
-rw-r--r--decoder/tree_fragment.h6
-rw-r--r--decoder/trule.h3
4 files changed, 33 insertions, 20 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index b5b47d5d..d61b9aba 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -26,14 +26,21 @@ struct Tree2StringGrammarNode {
// use a lexer probably
static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) {
string line;
+ int lc = 0;
while(getline(*in, line)) {
- size_t pos = line.find("|||");
- assert(pos != string::npos);
- assert(pos > 3);
- unsigned xc = 0;
- while (line[pos - 1] == ' ') { --pos; xc++; }
- cdec::TreeFragment rule_src(line.substr(0, pos), true);
- // TODO transducer_state should (optionally?) be read from input
+ ++lc;
+ std::vector<StringPiece> fields = TokenizeMultisep(line, " ||| ");
+ if (has_multiple_states && fields.size() != 4) {
+ cerr << "Expected 4 fields in rule file but line " << lc << " is:\n" << line << endl;
+ abort();
+ }
+ if (!has_multiple_states && fields.size() != 3) {
+ cerr << "Expected 3 fields in rule file but line " << lc << " is:\n" << line << endl;
+ abort();
+ }
+
+ cdec::TreeFragment rule_src(fields[has_multiple_states ? 1 : 0], true);
+ // TODO transducer_state should be read from input
const unsigned transducer_state = 0;
Tree2StringGrammarNode* cur = &root->next[transducer_state];
ostringstream os;
@@ -59,12 +66,13 @@ static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bo
else
os << TD::Convert(x);
}
- pos += 3 + xc;
- while(line[pos] == ' ') { ++pos; }
- os << " ||| " << line.substr(pos);
- TRulePtr rule(new TRule(os.str()));
- // TODO the transducer_state you end up in after using this rule (for each NT)
- // needs to be read and encoded somehow in the rule (for use XXX)
+ TRulePtr rule;
+ if (has_multiple_states) {
+ cerr << "Not implemented...\n"; abort(); // TODO read in states
+ } else {
+ os << " ||| " << fields[1] << " ||| " << fields[2];
+ rule.reset(new TRule(os.str()));
+ }
cur->rules.push_back(rule);
//cerr << "RULE: " << rule->AsString() << "\n\n";
}
diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc
index 696c8601..aad0b2c4 100644
--- a/decoder/tree_fragment.cc
+++ b/decoder/tree_fragment.cc
@@ -8,7 +8,7 @@ using namespace std;
namespace cdec {
-TreeFragment::TreeFragment(const string& tree, bool allow_frontier_sites) {
+TreeFragment::TreeFragment(const StringPiece& tree, bool allow_frontier_sites) {
int bal = 0;
const unsigned len = tree.size();
unsigned cur = 0;
@@ -50,7 +50,7 @@ void TreeFragment::DebugRec(unsigned cur, ostream* out) const {
// cp is the character index in the tree
// np keeps track of the nodes (nonterminals) that have been built
// symp keeps track of the terminal symbols that have been built
-void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp) {
+void TreeFragment::ParseRec(const StringPiece& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp) {
if (tree[cp] != '(') {
cerr << "Expected ( at " << cp << endl;
abort();
@@ -79,12 +79,12 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned
// TODO: add a terminal symbol to the current edge
const bool is_terminal = tree[t_start] != '[' || (t_end - t_start < 3 || tree[t_end - 1] != ']');
if (is_terminal) {
- const unsigned term = TD::Convert(tree.substr(t_start, t_end - t_start));
+ const unsigned term = TD::Convert(tree.substr(t_start, t_end - t_start).as_string());
rhs.push_back(term);
// cerr << "T='" << TD::Convert(term) << "'\n";
++terminals;
} else { // frontier site (NT but no recursion)
- const unsigned nt = TD::Convert(tree.substr(t_start + 1, t_end - t_start - 2)) | FRONTIER_BIT;
+ const unsigned nt = TD::Convert(tree.substr(t_start + 1, t_end - t_start - 2).as_string()) | FRONTIER_BIT;
rhs.push_back(nt);
++frontier_sites;
// cerr << "FRONT-NT=[" << TD::Convert(nt & ALL_MASK) << "]\n";
@@ -97,7 +97,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned
} // continuent has completed, cp is at ), build node
const unsigned j = symp; // span from (i,j)
// add an internal non-terminal symbol
- const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | RHS_BIT;
+ const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start).as_string()) | RHS_BIT;
nodes[np] = TreeFragmentProduction(nt, rhs);
//cerr << np << " production(" << i << "," << j << ")= " << TD::Convert(nt & ALL_MASK) << " -->";
//for (auto& x : rhs) {
diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h
index 79722b5a..8bb7251a 100644
--- a/decoder/tree_fragment.h
+++ b/decoder/tree_fragment.h
@@ -8,6 +8,8 @@
#include <cassert>
#include <cstddef>
+#include "string_piece.hh"
+
namespace cdec {
class BreadthFirstIterator;
@@ -52,7 +54,7 @@ class TreeFragment {
public:
TreeFragment() : frontier_sites(), terminals() {}
// (S (NP a (X b) c d) (VP (V foo) (NP (NN bar))))
- explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false);
+ explicit TreeFragment(const StringPiece& tree, bool allow_frontier_sites = false);
void DebugRec(unsigned cur, std::ostream* out) const;
typedef DepthFirstIterator iterator;
typedef ptrdiff_t difference_type;
@@ -73,7 +75,7 @@ class TreeFragment {
// cp is the character index in the tree
// np keeps track of the nodes (nonterminals) that have been built
// symp keeps track of the terminal symbols that have been built
- void ParseRec(const std::string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp);
+ void ParseRec(const StringPiece& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp);
public:
unsigned root;
unsigned char frontier_sites;
diff --git a/decoder/trule.h b/decoder/trule.h
index cc370757..243b0da9 100644
--- a/decoder/trule.h
+++ b/decoder/trule.h
@@ -144,6 +144,9 @@ class TRule {
SparseVector<double> scores_;
char arity_;
+ std::vector<WordID> ext_states_; // in t2s or t2t translation, this is of length arity_ and
+ // indicates what state the transducer is in after having processed
+ // this transduction rule
// these attributes are application-specific and should probably be refactored
TRulePtr parent_rule_; // usually NULL, except when doing constrained decoding