diff options
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/tree2string_translator.cc | 34 | ||||
-rw-r--r-- | decoder/tree_fragment.cc | 10 | ||||
-rw-r--r-- | decoder/tree_fragment.h | 6 | ||||
-rw-r--r-- | decoder/trule.h | 3 |
4 files changed, 33 insertions, 20 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index b5b47d5d..d61b9aba 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -26,14 +26,21 @@ struct Tree2StringGrammarNode { // use a lexer probably static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { string line; + int lc = 0; while(getline(*in, line)) { - size_t pos = line.find("|||"); - assert(pos != string::npos); - assert(pos > 3); - unsigned xc = 0; - while (line[pos - 1] == ' ') { --pos; xc++; } - cdec::TreeFragment rule_src(line.substr(0, pos), true); - // TODO transducer_state should (optionally?) be read from input + ++lc; + std::vector<StringPiece> fields = TokenizeMultisep(line, " ||| "); + if (has_multiple_states && fields.size() != 4) { + cerr << "Expected 4 fields in rule file but line " << lc << " is:\n" << line << endl; + abort(); + } + if (!has_multiple_states && fields.size() != 3) { + cerr << "Expected 3 fields in rule file but line " << lc << " is:\n" << line << endl; + abort(); + } + + cdec::TreeFragment rule_src(fields[has_multiple_states ? 1 : 0], true); + // TODO transducer_state should be read from input const unsigned transducer_state = 0; Tree2StringGrammarNode* cur = &root->next[transducer_state]; ostringstream os; @@ -59,12 +66,13 @@ static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bo else os << TD::Convert(x); } - pos += 3 + xc; - while(line[pos] == ' ') { ++pos; } - os << " ||| " << line.substr(pos); - TRulePtr rule(new TRule(os.str())); - // TODO the transducer_state you end up in after using this rule (for each NT) - // needs to be read and encoded somehow in the rule (for use XXX) + TRulePtr rule; + if (has_multiple_states) { + cerr << "Not implemented...\n"; abort(); // TODO read in states + } else { + os << " ||| " << fields[1] << " ||| " << fields[2]; + rule.reset(new TRule(os.str())); + } cur->rules.push_back(rule); //cerr << "RULE: " << rule->AsString() << "\n\n"; } diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 696c8601..aad0b2c4 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -8,7 +8,7 @@ using namespace std; namespace cdec { -TreeFragment::TreeFragment(const string& tree, bool allow_frontier_sites) { +TreeFragment::TreeFragment(const StringPiece& tree, bool allow_frontier_sites) { int bal = 0; const unsigned len = tree.size(); unsigned cur = 0; @@ -50,7 +50,7 @@ void TreeFragment::DebugRec(unsigned cur, ostream* out) const { // cp is the character index in the tree // np keeps track of the nodes (nonterminals) that have been built // symp keeps track of the terminal symbols that have been built -void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp) { +void TreeFragment::ParseRec(const StringPiece& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp) { if (tree[cp] != '(') { cerr << "Expected ( at " << cp << endl; abort(); @@ -79,12 +79,12 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned // TODO: add a terminal symbol to the current edge const bool is_terminal = tree[t_start] != '[' || (t_end - t_start < 3 || tree[t_end - 1] != ']'); if (is_terminal) { - const unsigned term = TD::Convert(tree.substr(t_start, t_end - t_start)); + const unsigned term = TD::Convert(tree.substr(t_start, t_end - t_start).as_string()); rhs.push_back(term); // cerr << "T='" << TD::Convert(term) << "'\n"; ++terminals; } else { // frontier site (NT but no recursion) - const unsigned nt = TD::Convert(tree.substr(t_start + 1, t_end - t_start - 2)) | FRONTIER_BIT; + const unsigned nt = TD::Convert(tree.substr(t_start + 1, t_end - t_start - 2).as_string()) | FRONTIER_BIT; rhs.push_back(nt); ++frontier_sites; // cerr << "FRONT-NT=[" << TD::Convert(nt & ALL_MASK) << "]\n"; @@ -97,7 +97,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned } // continuent has completed, cp is at ), build node const unsigned j = symp; // span from (i,j) // add an internal non-terminal symbol - const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | RHS_BIT; + const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start).as_string()) | RHS_BIT; nodes[np] = TreeFragmentProduction(nt, rhs); //cerr << np << " production(" << i << "," << j << ")= " << TD::Convert(nt & ALL_MASK) << " -->"; //for (auto& x : rhs) { diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index 79722b5a..8bb7251a 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -8,6 +8,8 @@ #include <cassert> #include <cstddef> +#include "string_piece.hh" + namespace cdec { class BreadthFirstIterator; @@ -52,7 +54,7 @@ class TreeFragment { public: TreeFragment() : frontier_sites(), terminals() {} // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar)))) - explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false); + explicit TreeFragment(const StringPiece& tree, bool allow_frontier_sites = false); void DebugRec(unsigned cur, std::ostream* out) const; typedef DepthFirstIterator iterator; typedef ptrdiff_t difference_type; @@ -73,7 +75,7 @@ class TreeFragment { // cp is the character index in the tree // np keeps track of the nodes (nonterminals) that have been built // symp keeps track of the terminal symbols that have been built - void ParseRec(const std::string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp); + void ParseRec(const StringPiece& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp); public: unsigned root; unsigned char frontier_sites; diff --git a/decoder/trule.h b/decoder/trule.h index cc370757..243b0da9 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -144,6 +144,9 @@ class TRule { SparseVector<double> scores_; char arity_; + std::vector<WordID> ext_states_; // in t2s or t2t translation, this is of length arity_ and + // indicates what state the transducer is in after having processed + // this transduction rule // these attributes are application-specific and should probably be refactored TRulePtr parent_rule_; // usually NULL, except when doing constrained decoding |