From 649b5ffc7c81182ba39d338b11bfe2e9a05544b5 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 16 Apr 2014 00:36:30 -0400 Subject: fix for bug due to using wrong tree traversal --- decoder/t2s_test.cc | 8 ++- decoder/tree2string_translator.cc | 18 +++--- decoder/tree_fragment.cc | 12 ---- decoder/tree_fragment.h | 112 +++++++++++++++++++++++++++++++++++++- 4 files changed, 125 insertions(+), 25 deletions(-) (limited to 'decoder') diff --git a/decoder/t2s_test.cc b/decoder/t2s_test.cc index 3c46ea89..5ebb2662 100644 --- a/decoder/t2s_test.cc +++ b/decoder/t2s_test.cc @@ -15,8 +15,11 @@ BOOST_AUTO_TEST_CASE(TestTreeFragments) { vector aw, bw; cerr << "TREE1: " << tree << endl; cerr << "TREE2: " << tree2 << endl; - for (auto& sym : tree) + for (auto& sym : tree) { + if (cdec::IsLHS(sym)) cerr << "("; + cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; if (cdec::IsTerminal(sym)) aw.push_back(sym); else a.push_back(sym); + } for (auto& sym : tree2) if (cdec::IsTerminal(sym)) bw.push_back(sym); else b.push_back(sym); BOOST_CHECK_EQUAL(a.size(), b.size()); @@ -38,11 +41,12 @@ BOOST_AUTO_TEST_CASE(TestTreeFragments) { if (cdec::IsFrontier(*it)) nts += "*"; } } + cerr << "Truncated: " << nts << endl; BOOST_CHECK_EQUAL(nts, "(S NP* VP*"); nts.clear(); int ntc = 0; - for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { + for (auto it = tree.bfs_begin(); it != tree.bfs_end(); ++it) { if (cdec::IsNT(*it)) { if (cdec::IsRHS(*it)) { ++ntc; diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 8d12d01d..3fbf1ee5 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -38,14 +38,13 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { // but it will not generate source strings correctly vector frhs; for (auto sym : rule_src) { + //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; cur = &cur->next[sym]; - if (sym) { - if (cdec::IsFrontier(sym)) { // frontier symbols -> variables - int nt = (sym & cdec::ALL_MASK); - frhs.push_back(-nt); - } else if (cdec::IsTerminal(sym)) { - frhs.push_back(sym); - } + if (cdec::IsFrontier(sym)) { // frontier symbols -> variables + int nt = (sym & cdec::ALL_MASK); + frhs.push_back(-nt); + } else if (cdec::IsTerminal(sym)) { + frhs.push_back(sym); } } os << '[' << TD::Convert(-lhs) << "] |||"; @@ -61,6 +60,7 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { os << " ||| " << line.substr(pos); TRulePtr rule(new TRule(os.str())); cur->rules.push_back(rule); + //cerr << "RULE: " << rule->AsString() << "\n\n"; } } @@ -82,7 +82,7 @@ struct ParserState { } vector future_work; int input_node_idx; // lhs of top level NT - Tree2StringGrammarNode* node; + Tree2StringGrammarNode* node; // pointer into grammar }; namespace std { @@ -239,11 +239,13 @@ struct Tree2StringTranslatorImpl { new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q if (unique.insert(new_s).second) q.push(new_s); } + //else { cerr << "did not match [" << TD::Convert(sym & cdec::ALL_MASK) << "]\n"; } if (nit1 != s.node->next.end()) { //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; const ParserState new_s(++s.in_iter, &nit1->second, s); if (unique.insert(new_s).second) q.push(new_s); } + //else { cerr << "did not match " << TD::Convert(sym & cdec::ALL_MASK) << "\n"; } } else if (cdec::IsTerminal(sym)) { auto nit = s.node->next.find(sym); if (nit != s.node->next.end()) { diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 78a993b8..4d429f42 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -112,16 +112,4 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned *psymp = symp; } -BreadthFirstIterator TreeFragment::begin() const { - return BreadthFirstIterator(this, nodes.size() - 1); -} - -BreadthFirstIterator TreeFragment::begin(unsigned node_idx) const { - return BreadthFirstIterator(this, node_idx); -} - -BreadthFirstIterator TreeFragment::end() const { - return BreadthFirstIterator(this); -} - } diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index f1c4c106..4a704bc4 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -11,6 +11,7 @@ namespace cdec { class BreadthFirstIterator; +class DepthFirstIterator; static const unsigned LHS_BIT = 0x10000000u; static const unsigned RHS_BIT = 0x20000000u; @@ -53,16 +54,21 @@ class TreeFragment { // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar)))) explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false); void DebugRec(unsigned cur, std::ostream* out) const; - typedef BreadthFirstIterator iterator; + typedef DepthFirstIterator iterator; typedef ptrdiff_t difference_type; typedef unsigned value_type; typedef const unsigned * pointer; typedef const unsigned & reference; + // default iterator is DFS iterator begin() const; iterator begin(unsigned node_idx) const; iterator end() const; + BreadthFirstIterator bfs_begin() const; + BreadthFirstIterator bfs_begin(unsigned node_idx) const; + BreadthFirstIterator bfs_end() const; + private: // cp is the character index in the tree // np keeps track of the nodes (nonterminals) that have been built @@ -78,14 +84,106 @@ class TreeFragment { struct TFIState { TFIState() : node(), rhspos(), state() {} - TFIState(unsigned n, unsigned p, unsigned s) : node(n), rhspos(p), state(s) {} + TFIState(unsigned n, int p, unsigned s) : node(n), rhspos(p), state(s) {} bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos && state == o.state; } bool operator!=(const TFIState& o) const { return node != o.node || rhspos != o.rhspos || state != o.state; } unsigned short node; - unsigned short rhspos; + short rhspos; unsigned char state; }; +class DepthFirstIterator : public std::iterator { + const TreeFragment* tf_; + std::deque q_; + unsigned sym; + public: + DepthFirstIterator() : tf_(), sym() {} + // used for begin + explicit DepthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) { + q_.push_back(TFIState(node_idx, -1, 0)); + Stage(); + q_.back().state++; + } + // used for end + explicit DepthFirstIterator(const TreeFragment* tf) : tf_(tf) {} + const unsigned& operator*() const { return sym; } + const unsigned* operator->() const { return &sym; } + bool operator==(const DepthFirstIterator& other) const { + return (tf_ == other.tf_) && (q_ == other.q_); + } + bool operator!=(const DepthFirstIterator& other) const { + return (tf_ != other.tf_) || (q_ != other.q_); + } + unsigned node_idx() const { return q_.front().node; } + const DepthFirstIterator& operator++() { + TFIState& s = q_.back(); + if (s.state == 0) { + Stage(); + s.state++; + } else if (s.state == 1) { + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos >= len) { + q_.pop_back(); + while (!q_.empty()) { + TFIState& s = q_.back(); + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos < len) break; + q_.pop_back(); + } + } + Stage(); + } + return *this; + } + DepthFirstIterator operator++(int) { + DepthFirstIterator res = *this; + ++(*this); + return res; + } + // tell iterator not to explore the subtree rooted at sym + // should only be called once per NT symbol encountered + const DepthFirstIterator& truncate() { + assert(IsRHS(sym)); + sym &= ALL_MASK; + sym |= FRONTIER_BIT; + q_.pop_back(); + return *this; + } + unsigned child_node() const { + assert(IsRHS(sym)); + return q_.back().node; + } + DepthFirstIterator remainder() const { + assert(IsRHS(sym)); + return DepthFirstIterator(tf_, q_.back()); + } + bool at_end() const { + return q_.empty(); + } + private: + void Stage() { + if (q_.empty()) return; + const TFIState& s = q_.back(); + if (s.state == 0) { + sym = (tf_->nodes[s.node].lhs & ALL_MASK) | LHS_BIT; + } else if (s.state == 1) { + sym = tf_->nodes[s.node].rhs[s.rhspos]; + if (IsRHS(sym)) { + q_.push_back(TFIState(sym & ALL_MASK, -1, 0)); + sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT; + } + } + } + + // used by remainder + DepthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) { + q_.push_back(s); + Stage(); + } +}; + class BreadthFirstIterator : public std::iterator { const TreeFragment* tf_; std::deque q_; @@ -172,6 +270,14 @@ class BreadthFirstIterator : public std::iterator Date: Thu, 17 Apr 2014 20:55:34 -0400 Subject: fix rescoring --- decoder/trule.cc | 6 ++++++ tests/system_tests/cfg_rescore/README | 4 ++++ tests/system_tests/cfg_rescore/cdec.ini | 2 ++ tests/system_tests/cfg_rescore/gold.statistics | 3 +++ tests/system_tests/cfg_rescore/gold.stdout | 4 ++++ tests/system_tests/cfg_rescore/input.cfg | 9 +++++++++ tests/system_tests/cfg_rescore/input.txt | 1 + tests/system_tests/cfg_rescore/weights | 3 +++ 8 files changed, 32 insertions(+) create mode 100644 tests/system_tests/cfg_rescore/README create mode 100644 tests/system_tests/cfg_rescore/cdec.ini create mode 100644 tests/system_tests/cfg_rescore/gold.statistics create mode 100644 tests/system_tests/cfg_rescore/gold.stdout create mode 100644 tests/system_tests/cfg_rescore/input.cfg create mode 100644 tests/system_tests/cfg_rescore/input.txt create mode 100644 tests/system_tests/cfg_rescore/weights (limited to 'decoder') diff --git a/decoder/trule.cc b/decoder/trule.cc index 1bd5425f..bee211d5 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -56,6 +56,12 @@ bool TRule::ReadFromString(const string& line, bool mono) { RuleLexer::ReadRule(line + '\n', assign_trule, mono, this); if (n_assigned > 1) cerr<<"\nWARNING: more than one rule parsed from multi-line string; kept last: "< input.txt diff --git a/tests/system_tests/cfg_rescore/cdec.ini b/tests/system_tests/cfg_rescore/cdec.ini new file mode 100644 index 00000000..1a913f2d --- /dev/null +++ b/tests/system_tests/cfg_rescore/cdec.ini @@ -0,0 +1,2 @@ +formalism=rescore +k_best=100 diff --git a/tests/system_tests/cfg_rescore/gold.statistics b/tests/system_tests/cfg_rescore/gold.statistics new file mode 100644 index 00000000..7b05e2d8 --- /dev/null +++ b/tests/system_tests/cfg_rescore/gold.statistics @@ -0,0 +1,3 @@ +-lm_nodes 8 +-lm_edges 10 +-lm_paths 4 diff --git a/tests/system_tests/cfg_rescore/gold.stdout b/tests/system_tests/cfg_rescore/gold.stdout new file mode 100644 index 00000000..ccf99263 --- /dev/null +++ b/tests/system_tests/cfg_rescore/gold.stdout @@ -0,0 +1,4 @@ +0 ||| the broccoli was eaten by John ||| Passive=1 Definite=1 ||| 2 +0 ||| John ate the broccoli ||| Active=1 Definite=1 ||| 1.1 +0 ||| broccoli was eaten by John ||| Passive=1 ||| 1 +0 ||| John ate broccoli ||| Active=1 ||| 0.1 diff --git a/tests/system_tests/cfg_rescore/input.cfg b/tests/system_tests/cfg_rescore/input.cfg new file mode 100644 index 00000000..0073cb7b --- /dev/null +++ b/tests/system_tests/cfg_rescore/input.cfg @@ -0,0 +1,9 @@ +[S] ||| [S1] +[S1] ||| [NP1] [VP] ||| Active=1 +[VP] ||| [V] [NP2] +[V] ||| ate +[VPSV] ||| was eaten +[S1] ||| [NP2] [VPSV] by [NP1] ||| Passive=1 +[NP1] ||| John +[NP2] ||| broccoli +[NP2] ||| the broccoli ||| Definite=1 diff --git a/tests/system_tests/cfg_rescore/input.txt b/tests/system_tests/cfg_rescore/input.txt new file mode 100644 index 00000000..71fc26bc --- /dev/null +++ b/tests/system_tests/cfg_rescore/input.txt @@ -0,0 +1 @@ +{"rules":[1,"[S] ||| [S1] ||| [1]",2,"[S1] ||| [NP1] [VP] ||| [1] [2] ||| Active=1",3,"[VP] ||| [V] [NP2] ||| [1] [2]",4,"[V] ||| ate ||| ate",5,"[VPSV] ||| was eaten ||| was eaten",6,"[S1] ||| [NP2] [VPSV] by [NP1] ||| [1] [2] by [3] ||| Passive=1",7,"[NP1] ||| John ||| John",8,"[NP2] ||| broccoli ||| broccoli",9,"[NP2] ||| the broccoli ||| the broccoli ||| Definite=1"],"features":["PhraseModel_0","PhraseModel_1","PhraseModel_2","PhraseModel_3","PhraseModel_4","PhraseModel_5","PhraseModel_6","PhraseModel_7","PhraseModel_8","PhraseModel_9","PhraseModel_10","PhraseModel_11","PhraseModel_12","PhraseModel_13","PhraseModel_14","PhraseModel_15","PhraseModel_16","PhraseModel_17","PhraseModel_18","PhraseModel_19","PhraseModel_20","PhraseModel_21","PhraseModel_22","PhraseModel_23","PhraseModel_24","PhraseModel_25","PhraseModel_26","PhraseModel_27","PhraseModel_28","PhraseModel_29","PhraseModel_30","PhraseModel_31","PhraseModel_32","PhraseModel_33","PhraseModel_34","PhraseModel_35","PhraseModel_36","PhraseModel_37","PhraseModel_38","PhraseModel_39","PhraseModel_40","PhraseModel_41","PhraseModel_42","PhraseModel_43","PhraseModel_44","PhraseModel_45","PhraseModel_46","PhraseModel_47","PhraseModel_48","PhraseModel_49","PhraseModel_50","PhraseModel_51","PhraseModel_52","PhraseModel_53","PhraseModel_54","PhraseModel_55","PhraseModel_56","PhraseModel_57","PhraseModel_58","PhraseModel_59","PhraseModel_60","PhraseModel_61","PhraseModel_62","PhraseModel_63","PhraseModel_64","PhraseModel_65","PhraseModel_66","PhraseModel_67","PhraseModel_68","PhraseModel_69","PhraseModel_70","PhraseModel_71","PhraseModel_72","PhraseModel_73","PhraseModel_74","PhraseModel_75","PhraseModel_76","PhraseModel_77","PhraseModel_78","PhraseModel_79","PhraseModel_80","PhraseModel_81","PhraseModel_82","PhraseModel_83","PhraseModel_84","PhraseModel_85","PhraseModel_86","PhraseModel_87","PhraseModel_88","PhraseModel_89","PhraseModel_90","PhraseModel_91","PhraseModel_92","PhraseModel_93","PhraseModel_94","PhraseModel_95","PhraseModel_96","PhraseModel_97","PhraseModel_98","PhraseModel_99","Active","Passive","Definite"],"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":7}],"node":{"in_edges":[0],"cat":"NP1","node_hash":"0000000000000007"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":4}],"node":{"in_edges":[1],"cat":"V","node_hash":"0000000000000004"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":8},{"tail":[],"spans":[-1,-1,-1,-1],"feats":[102,1],"rule":9}],"node":{"in_edges":[2,3],"cat":"NP2","node_hash":"0000000000000009"},"edges":[{"tail":[1,2],"spans":[-1,-1,-1,-1],"feats":[],"rule":3}],"node":{"in_edges":[4],"cat":"VP","node_hash":"0000000000000003"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":5}],"node":{"in_edges":[5],"cat":"VPSV","node_hash":"0000000000000005"},"edges":[{"tail":[0,3],"spans":[-1,-1,-1,-1],"feats":[100,1],"rule":2},{"tail":[2,4,0],"spans":[-1,-1,-1,-1],"feats":[101,1],"rule":6}],"node":{"in_edges":[6,7],"cat":"S1","node_hash":"0000000000000006"},"edges":[{"tail":[5],"spans":[-1,-1,-1,-1],"feats":[],"rule":1}],"node":{"in_edges":[8],"cat":"S","node_hash":"0000000000000001"}} diff --git a/tests/system_tests/cfg_rescore/weights b/tests/system_tests/cfg_rescore/weights new file mode 100644 index 00000000..bd3bb1af --- /dev/null +++ b/tests/system_tests/cfg_rescore/weights @@ -0,0 +1,3 @@ +Active 0.1 +Passive 1 +Definite 1 -- cgit v1.2.3