summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-03-30 23:50:17 -0400
committerChris Dyer <redpony@gmail.com>2014-03-30 23:50:17 -0400
commit8372086f2fc4bd765fdd05e8cf95faeb147a6587 (patch)
treefa4ac0342bc1259ce96c61fa9fffb5f8252d0333 /decoder
parentca29417acd47dbbd2aa68cd31fcd3129e6482bf7 (diff)
almost complete tree to string translator
Diffstat (limited to 'decoder')
-rw-r--r--decoder/Makefile.am3
-rw-r--r--decoder/decoder.cc6
-rw-r--r--decoder/t2s_test.cc110
-rw-r--r--decoder/tree2string_translator.cc120
-rw-r--r--decoder/tree_fragment.cc14
-rw-r--r--decoder/tree_fragment.h109
6 files changed, 311 insertions, 51 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 7481192b..5c91fe65 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -4,9 +4,12 @@ noinst_PROGRAMS = \
trule_test \
hg_test \
parser_test \
+ t2s_test \
grammar_test
TESTS = trule_test parser_test grammar_test hg_test
+t2s_test_SOURCES = t2s_test.cc
+t2s_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a
parser_test_SOURCES = parser_test.cc
parser_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a
grammar_test_SOURCES = grammar_test.cc
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 31049216..43e2640d 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -490,8 +490,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
}
formalism = LowercaseString(str("formalism",conf));
- if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") {
- cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n";
+ if (formalism != "t2s" && formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") {
+ cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 't2s', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n";
cerr << dcmdline_options << endl;
exit(1);
}
@@ -626,6 +626,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
// set up translation back end
if (formalism == "scfg")
translator.reset(new SCFGTranslator(conf));
+ else if (formalism == "t2s")
+ translator.reset(new Tree2StringTranslator(conf));
else if (formalism == "fst")
translator.reset(new FSTTranslator(conf));
else if (formalism == "pb")
diff --git a/decoder/t2s_test.cc b/decoder/t2s_test.cc
new file mode 100644
index 00000000..3c46ea89
--- /dev/null
+++ b/decoder/t2s_test.cc
@@ -0,0 +1,110 @@
+#include "tree_fragment.h"
+
+#define BOOST_TEST_MODULE T2STest
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+#include <iostream>
+#include "tdict.h"
+
+using namespace std;
+
+BOOST_AUTO_TEST_CASE(TestTreeFragments) {
+ cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))");
+ cdec::TreeFragment tree2("(S (NP (DT a) (NN cat)) (VP (V ate) (NP (DT the) (NN cake pie))))");
+ vector<unsigned> a, b;
+ vector<WordID> aw, bw;
+ cerr << "TREE1: " << tree << endl;
+ cerr << "TREE2: " << tree2 << endl;
+ for (auto& sym : tree)
+ if (cdec::IsTerminal(sym)) aw.push_back(sym); else a.push_back(sym);
+ for (auto& sym : tree2)
+ if (cdec::IsTerminal(sym)) bw.push_back(sym); else b.push_back(sym);
+ BOOST_CHECK_EQUAL(a.size(), b.size());
+ BOOST_CHECK_EQUAL(aw.size() + 1, bw.size());
+ BOOST_CHECK_EQUAL(aw.size(), 5);
+ BOOST_CHECK_EQUAL(TD::GetString(aw), "the boy saw a cat");
+ BOOST_CHECK_EQUAL(TD::GetString(bw), "a cat ate the cake pie");
+ if (a != b) {
+ BOOST_CHECK_EQUAL(1,2);
+ }
+
+ string nts;
+ for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) {
+ if (cdec::IsNT(*it)) {
+ if (cdec::IsRHS(*it)) it.truncate();
+ if (nts.size()) nts += " ";
+ if (cdec::IsLHS(*it)) nts += "(";
+ nts += TD::Convert(*it & cdec::ALL_MASK);
+ if (cdec::IsFrontier(*it)) nts += "*";
+ }
+ }
+ BOOST_CHECK_EQUAL(nts, "(S NP* VP*");
+
+ nts.clear();
+ int ntc = 0;
+ for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) {
+ if (cdec::IsNT(*it)) {
+ if (cdec::IsRHS(*it)) {
+ ++ntc;
+ if (ntc > 1) it.truncate();
+ }
+ if (nts.size()) nts += " ";
+ if (cdec::IsLHS(*it)) nts += "(";
+ nts += TD::Convert(*it & cdec::ALL_MASK);
+ if (cdec::IsFrontier(*it)) nts += "*";
+ }
+ }
+ BOOST_CHECK_EQUAL(nts, "(S NP VP* (NP DT* NN*");
+}
+
+BOOST_AUTO_TEST_CASE(TestSharing) {
+ cdec::TreeFragment rule1("(S [NP] [VP])", true);
+ cdec::TreeFragment rule2("(S [NP] (VP [V] [NP]))", true);
+ string r1,r2;
+ for (auto sym : rule1) {
+ if (r1.size()) r1 += " ";
+ if (cdec::IsLHS(sym)) r1 += "(";
+ r1 += TD::Convert(sym & cdec::ALL_MASK);
+ if (cdec::IsFrontier(sym)) r1 += "*";
+ }
+ for (auto sym : rule2) {
+ if (r2.size()) r2 += " ";
+ if (cdec::IsLHS(sym)) r2 += "(";
+ r2 += TD::Convert(sym & cdec::ALL_MASK);
+ if (cdec::IsFrontier(sym)) r2 += "*";
+ }
+ cerr << rule1 << endl;
+ cerr << r1 << endl;
+ cerr << rule2 << endl;
+ cerr << r2 << endl;
+ BOOST_CHECK_EQUAL(r1, "(S NP* VP*");
+ BOOST_CHECK_EQUAL(r2, "(S NP* VP (VP V* NP*");
+}
+
+BOOST_AUTO_TEST_CASE(TestEndInvariants) {
+ cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))");
+ BOOST_CHECK(tree.end().at_end());
+ BOOST_CHECK(!tree.begin().at_end());
+}
+
+BOOST_AUTO_TEST_CASE(TestBegins) {
+ cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))");
+ for (auto it = tree.begin(1); it != tree.end(); ++it) {
+ cerr << TD::Convert(*it & cdec::ALL_MASK) << endl;
+ }
+}
+
+BOOST_AUTO_TEST_CASE(TestRemainder) {
+ cdec::TreeFragment tree("(S (A a) (B b))");
+ auto it = tree.begin();
+ ++it;
+ BOOST_CHECK(cdec::IsRHS(*it));
+ cerr << tree << endl;
+ auto itr = it.remainder();
+ while(itr != tree.end()) {
+ cerr << TD::Convert(*itr & cdec::ALL_MASK) << endl;
+ ++itr;
+ }
+}
+
+
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index 1c249836..cd6ee550 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -1,5 +1,6 @@
#include <algorithm>
#include <vector>
+#include <queue>
#include <boost/functional/hash.hpp>
#include <unordered_map>
#include "tree_fragment.h"
@@ -15,11 +16,10 @@ using namespace std;
struct Tree2StringGrammarNode {
map<unsigned, Tree2StringGrammarNode> next;
- string rules;
+ vector<TRulePtr> rules;
};
-void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGrammarNode>* proots) {
- unordered_map<unsigned, Tree2StringGrammarNode>& roots = *proots;
+void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) {
string line;
while(getline(*in, line)) {
size_t pos = line.find("|||");
@@ -28,32 +28,124 @@ void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGram
unsigned xc = 0;
while (line[pos - 1] == ' ') { --pos; xc++; }
cdec::TreeFragment rule_src(line.substr(0, pos), true);
- Tree2StringGrammarNode* cur = &roots[rule_src.root];
- for (auto sym : rule_src)
+ Tree2StringGrammarNode* cur = root;
+ ostringstream os;
+ int lhs = -(rule_src.root & cdec::ALL_MASK);
+ // build source RHS for SCFG projection
+ // TODO - this is buggy - it will generate a well-formed SCFG rule
+ // but it will not generate source strings correctly
+ vector<int> frhs;
+ for (auto sym : rule_src) {
cur = &cur->next[sym];
+ if (sym) {
+ if (cdec::IsFrontier(sym)) { // frontier symbols -> variables
+ int nt = (sym & cdec::ALL_MASK);
+ frhs.push_back(-nt);
+ } else if (cdec::IsTerminal(sym)) {
+ frhs.push_back(sym);
+ }
+ }
+ }
+ os << '[' << TD::Convert(-lhs) << "] |||";
+ for (auto x : frhs) {
+ os << ' ';
+ if (x < 0)
+ os << '[' << TD::Convert(-x) << ']';
+ else
+ os << TD::Convert(x);
+ }
pos += 3 + xc;
while(line[pos] == ' ') { ++pos; }
- size_t pos2 = line.find("|||", pos);
- assert(pos2 != string::npos);
- while (line[pos2 - 1] == ' ') { --pos2; }
- cur->rules = line.substr(pos, pos2 - pos);
- cerr << "OUTPUT = '" << cur->rules << "'\n";
+ os << " ||| " << line.substr(pos);
+ TRulePtr rule(new TRule(os.str()));
+ cur->rules.push_back(rule);
}
}
+struct ParserState {
+ ParserState() : in_iter(), node() {}
+ cdec::TreeFragment::iterator in_iter;
+ ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const int rt) :
+ in_iter(it),
+ root_type(rt),
+ node(n) {}
+ ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) :
+ in_iter(it),
+ future_work(p.future_work),
+ root_type(p.root_type),
+ node(n) {}
+ vector<ParserState> future_work;
+ int root_type; // lhs of top level NT
+ Tree2StringGrammarNode* node;
+};
+
struct Tree2StringTranslatorImpl {
- unordered_map<unsigned, Tree2StringGrammarNode> roots; // root['S'] gives rule network for S rules
+ Tree2StringGrammarNode root;
Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) {
ReadFile rf(conf["grammar"].as<vector<string>>()[0]);
- ReadTree2StringGrammar(rf.stream(), &roots);
+ ReadTree2StringGrammar(rf.stream(), &root);
}
bool Translate(const string& input,
SentenceMetadata* smeta,
const vector<double>& weights,
Hypergraph* minus_lm_forest) {
cdec::TreeFragment input_tree(input, false);
- cerr << "Tree2StringTranslatorImpl: please implement this!\n";
- return false;
+ const int kS = -TD::Convert("S");
+ Hypergraph hg;
+ queue<ParserState> q;
+ q.push(ParserState(input_tree.begin(), &root, kS));
+ while(!q.empty()) {
+ ParserState& s = q.front();
+
+ if (s.in_iter.at_end()) { // completed a traversal of a subtree
+ cerr << "I traversed a subtree of the input...\n";
+ if (s.node->rules.size()) {
+ // TODO: build hypergraph
+ for (auto& r : s.node->rules)
+ cerr << "I can build: " << r->AsString() << endl;
+ for (auto& w : s.future_work)
+ q.push(w);
+ } else {
+ cerr << "I can't build anything :(\n";
+ }
+ } else { // more input tree to match
+ unsigned sym = *s.in_iter;
+ if (cdec::IsLHS(sym)) {
+ auto nit = s.node->next.find(sym);
+ if (nit != s.node->next.end()) {
+ //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
+ q.push(ParserState(++s.in_iter, &nit->second, s));
+ }
+ } else if (cdec::IsRHS(sym)) {
+ //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
+ cdec::TreeFragment::iterator var = s.in_iter;
+ var.truncate();
+ auto nit1 = s.node->next.find(sym);
+ auto nit2 = s.node->next.find(*var);
+ if (nit2 != s.node->next.end()) {
+ //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
+ ParserState new_s(++var, &nit2->second, s);
+ ParserState new_work(s.in_iter.remainder(), &root, -(sym & cdec::ALL_MASK));
+ new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q
+ q.push(new_s);
+ }
+ if (nit1 != s.node->next.end()) {
+ //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl;
+ q.push(ParserState(++s.in_iter, &nit1->second, s));
+ }
+ } else if (cdec::IsTerminal(sym)) {
+ auto nit = s.node->next.find(sym);
+ if (nit != s.node->next.end()) {
+ //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl;
+ q.push(ParserState(++s.in_iter, &nit->second, s));
+ }
+ } else {
+ cerr << "This can never happen!\n"; abort();
+ }
+ }
+ q.pop();
+ }
+ minus_lm_forest->swap(hg);
}
};
diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc
index 93aad64e..78a993b8 100644
--- a/decoder/tree_fragment.cc
+++ b/decoder/tree_fragment.cc
@@ -36,7 +36,7 @@ void TreeFragment::DebugRec(unsigned cur, ostream* out) const {
*out << ' ';
if (IsFrontier(x)) {
*out << '[' << TD::Convert(x & ALL_MASK) << ']';
- } else if (IsInternalNT(x)) {
+ } else if (IsRHS(x)) {
DebugRec(x & ALL_MASK, out);
} else { // must be terminal
*out << TD::Convert(x);
@@ -66,7 +66,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned
// recursively call parser to deal with constituent
ParseRec(tree, afs, cp, symp, np, &cp, &symp, &np);
unsigned ind = np - 1;
- rhs.push_back(ind | NT_BIT);
+ rhs.push_back(ind | RHS_BIT);
} else { // deal with terminal / nonterminal substitution
++symp;
assert(tree[cp] != ' ');
@@ -95,7 +95,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned
} // continuent has completed, cp is at ), build node
const unsigned j = symp; // span from (i,j)
// add an internal non-terminal symbol
- const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | NT_BIT;
+ const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | RHS_BIT;
nodes[np] = TreeFragmentProduction(nt, rhs);
//cerr << np << " production(" << i << "," << j << ")= " << TD::Convert(nt & ALL_MASK) << " -->";
//for (auto& x : rhs) {
@@ -113,11 +113,15 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned
}
BreadthFirstIterator TreeFragment::begin() const {
- return BreadthFirstIterator(this);
+ return BreadthFirstIterator(this, nodes.size() - 1);
+}
+
+BreadthFirstIterator TreeFragment::begin(unsigned node_idx) const {
+ return BreadthFirstIterator(this, node_idx);
}
BreadthFirstIterator TreeFragment::end() const {
- return BreadthFirstIterator(this, 0);
+ return BreadthFirstIterator(this);
}
}
diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h
index a38dbdfa..b83afc27 100644
--- a/decoder/tree_fragment.h
+++ b/decoder/tree_fragment.h
@@ -1,7 +1,7 @@
#ifndef TREE_FRAGMENT
#define TREE_FRAGMENT
-#include <queue>
+#include <deque>
#include <iostream>
#include <vector>
#include <string>
@@ -12,18 +12,32 @@ namespace cdec {
class BreadthFirstIterator;
-static const unsigned NT_BIT = 0x40000000u;
-static const unsigned FRONTIER_BIT = 0x80000000u;
-static const unsigned ALL_MASK = 0x0FFFFFFFu;
+static const unsigned LHS_BIT = 0x10000000u;
+static const unsigned RHS_BIT = 0x20000000u;
+static const unsigned FRONTIER_BIT = 0x40000000u;
+static const unsigned RESERVED_BIT = 0x80000000u;
+static const unsigned ALL_MASK = 0x0FFFFFFFu;
-inline bool IsInternalNT(unsigned x) {
- return (x & NT_BIT);
+inline bool IsNT(unsigned x) {
+ return (x & (LHS_BIT | RHS_BIT | FRONTIER_BIT));
+}
+
+inline bool IsLHS(unsigned x) {
+ return (x & LHS_BIT);
+}
+
+inline bool IsRHS(unsigned x) {
+ return (x & RHS_BIT);
}
inline bool IsFrontier(unsigned x) {
return (x & FRONTIER_BIT);
}
+inline bool IsTerminal(unsigned x) {
+ return (x & ALL_MASK) == x;
+}
+
struct TreeFragmentProduction {
TreeFragmentProduction() {}
TreeFragmentProduction(int nttype, const std::vector<unsigned>& r) : lhs(nttype), rhs(r) {}
@@ -46,6 +60,7 @@ class TreeFragment {
typedef const unsigned & reference;
iterator begin() const;
+ iterator begin(unsigned node_idx) const;
iterator end() const;
private:
@@ -62,24 +77,28 @@ class TreeFragment {
};
struct TFIState {
- TFIState() : node(), rhspos() {}
- TFIState(unsigned n, unsigned p) : node(n), rhspos(p) {}
- bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos; }
- bool operator!=(const TFIState& o) const { return node != o.node && rhspos != o.rhspos; }
+ TFIState() : node(), rhspos(), state() {}
+ TFIState(unsigned n, unsigned p, unsigned s) : node(n), rhspos(p), state(s) {}
+ bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos && state == o.state; }
+ bool operator!=(const TFIState& o) const { return node != o.node || rhspos != o.rhspos || state != o.state; }
unsigned short node;
unsigned short rhspos;
+ unsigned char state;
};
class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, unsigned> {
const TreeFragment* tf_;
- std::queue<TFIState> q_;
+ std::deque<TFIState> q_;
unsigned sym;
public:
- explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) {
- q_.push(TFIState(tf->nodes.size() - 1, 0));
+ BreadthFirstIterator() : tf_(), sym() {}
+ // used for begin
+ explicit BreadthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) {
+ q_.push_back(TFIState(node_idx, 0, 0));
Stage();
}
- BreadthFirstIterator(const TreeFragment* tf, int) : tf_(tf) {}
+ // used for end
+ explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) {}
const unsigned& operator*() const { return sym; }
const unsigned* operator->() const { return &sym; }
bool operator==(const BreadthFirstIterator& other) const {
@@ -88,26 +107,20 @@ class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, uns
bool operator!=(const BreadthFirstIterator& other) const {
return (tf_ != other.tf_) || (q_ != other.q_);
}
- void Stage() {
- if (q_.empty()) return;
- const TFIState& s = q_.front();
- sym = tf_->nodes[s.node].rhs[s.rhspos];
- if (IsInternalNT(sym)) {
- q_.push(TFIState(sym & ALL_MASK, 0));
- sym = tf_->nodes[sym & ALL_MASK].lhs;
- }
- }
const BreadthFirstIterator& operator++() {
TFIState& s = q_.front();
- const unsigned len = tf_->nodes[s.node].rhs.size();
- s.rhspos++;
- if (s.rhspos > len) {
- q_.pop();
+ if (s.state == 0) {
+ s.state++;
Stage();
- } else if (s.rhspos == len) {
- sym = 0;
} else {
- Stage();
+ const unsigned len = tf_->nodes[s.node].rhs.size();
+ s.rhspos++;
+ if (s.rhspos >= len) {
+ q_.pop_front();
+ Stage();
+ } else {
+ Stage();
+ }
}
return *this;
}
@@ -116,6 +129,42 @@ class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, uns
++(*this);
return res;
}
+ // tell iterator not to explore the subtree rooted at sym
+ // should only be called once per NT symbol encountered
+ const BreadthFirstIterator& truncate() {
+ assert(IsRHS(sym));
+ sym &= ALL_MASK;
+ sym |= FRONTIER_BIT;
+ q_.pop_back();
+ return *this;
+ }
+ BreadthFirstIterator remainder() const {
+ assert(IsRHS(sym));
+ return BreadthFirstIterator(tf_, q_.back());
+ }
+ bool at_end() const {
+ return q_.empty();
+ }
+ private:
+ void Stage() {
+ if (q_.empty()) return;
+ const TFIState& s = q_.front();
+ if (s.state == 0) {
+ sym = (tf_->nodes[s.node].lhs & ALL_MASK) | LHS_BIT;
+ } else {
+ sym = tf_->nodes[s.node].rhs[s.rhspos];
+ if (IsRHS(sym)) {
+ q_.push_back(TFIState(sym & ALL_MASK, 0, 0));
+ sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT;
+ }
+ }
+ }
+
+ // used by remainder
+ BreadthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) {
+ q_.push_back(s);
+ Stage();
+ }
};
inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) {