From 8e60188039211c9a585625dd8309f2ba57b0cbb2 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 18 Mar 2014 01:41:17 -0400 Subject: star function --- decoder/Makefile.am | 1 - decoder/exp_semiring.h | 71 -------------------------------------------------- decoder/hg.h | 17 +++++++++++- 3 files changed, 16 insertions(+), 73 deletions(-) delete mode 100644 decoder/exp_semiring.h (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index c41cd7f9..7481192b 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -37,7 +37,6 @@ libcdec_a_SOURCES = \ csplit.h \ decoder.h \ earley_composer.h \ - exp_semiring.h \ factored_lexicon_helper.h \ ff.h \ ff_basic.h \ diff --git a/decoder/exp_semiring.h b/decoder/exp_semiring.h deleted file mode 100644 index 2a9034bb..00000000 --- a/decoder/exp_semiring.h +++ /dev/null @@ -1,71 +0,0 @@ -#ifndef _EXP_SEMIRING_H_ -#define _EXP_SEMIRING_H_ - -#include - -// this file implements the first-order expectation semiring described -// in Li & Eisner (EMNLP 2009) - -// requirements: -// RType * RType ==> RType -// PType * PType ==> PType -// RType * PType ==> RType -// good examples: -// PType scalar, RType vector -// BAD examples: -// PType vector, RType scalar -template -struct PRPair { - PRPair() : p(), r() {} - // Inside algorithm requires that T(0) and T(1) - // return the 0 and 1 values of the semiring - explicit PRPair(double x) : p(x), r() {} - PRPair(const PType& p, const RType& r) : p(p), r(r) {} - PRPair& operator+=(const PRPair& o) { - p += o.p; - r += o.r; - return *this; - } - PRPair& operator*=(const PRPair& o) { - r = (o.r * p) + (o.p * r); - p *= o.p; - return *this; - } - PType p; - RType r; -}; - -template -std::ostream& operator<<(std::ostream& o, const PRPair& x) { - return o << '<' << x.p << ", " << x.r << '>'; -} - -template -const PRPair operator+(const PRPair& a, const PRPair& b) { - PRPair result = a; - result += b; - return result; -} - -template -const PRPair operator*(const PRPair& a, const PRPair& b) { - PRPair result = a; - result *= b; - return result; -} - -template -struct PRWeightFunction { - explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(), - const RWeightFunction& rwf = RWeightFunction()) : - pweight(pwf), rweight(rwf) {} - PRPair operator()(const HG::Edge& e) const { - const P p = pweight(e); - const R r = rweight(e); - return PRPair(p, r * p); - } - const PWeightFunction pweight; - const RWeightFunction rweight; -}; - -#endif diff --git a/decoder/hg.h b/decoder/hg.h index 3d8cd9bc..343b99cf 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -25,6 +25,7 @@ #include "tdict.h" #include "trule.h" #include "prob.h" +#include "exp_semiring.h" #include "indices_after.h" #include "nt_span.h" @@ -527,7 +528,21 @@ struct EdgeFeaturesAndProbWeightFunction { struct TransitionCountWeightFunction { typedef double Weight; - inline double operator()(const HG::Edge& e) const { (void)e; return 1.0; } + inline double operator()(const HG::Edge&) const { return 1.0; } +}; + +template +struct PRWeightFunction { + explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(), + const RWeightFunction& rwf = RWeightFunction()) : + pweight(pwf), rweight(rwf) {} + PRPair operator()(const HG::Edge& e) const { + const P p = pweight(e); + const R r = rweight(e); + return PRPair(p, r * p); + } + const PWeightFunction pweight; + const RWeightFunction rweight; }; #endif -- cgit v1.2.3 From 26ad3172c1c5712e877e05c01db3d03f50a5d98b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 27 Mar 2014 00:07:41 -0400 Subject: breadth first iterator for tree fragment --- decoder/tree2string_translator.cc | 32 +++++++---------- decoder/tree_fragment.cc | 8 +++++ decoder/tree_fragment.h | 73 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 20 deletions(-) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index ac9c0d74..1c249836 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -13,25 +13,6 @@ using namespace std; -// root: S -// A implication: (S [A] *INCOMPLETE* -// B implication: (S [A] [B] *INCOMPLETE* -// *0* implication: (S _[A] [B]) -// a implication: (S (A a *INCOMPLETE* [B]) -// a implication: (S (A a a *INCOMPLETE* [B]) -// *0* implication: (S (A a a) _[B]) -// D implication: (S (A a a) (B [D] *INCOMPLETE*) -// *0* implication: (S (A a a) (B _[D])) -// d implication: (S (A a a) (B (D d *INCOMPLETE*)) -// *0* implication: (S (A a a) (B (D d))) -// --there are no further outgoing links possible-- - -// root: S -// A implication: (S [A] *INCOMPLETE* -// B implication: (S [A] [B] *INCOMPLETE* -// *0* implication: (S _[A] [B]) -// *0* implication: (S [A] _[B]) -// b implication: (S [A] (B b *INCOMPLETE*)) struct Tree2StringGrammarNode { map next; string rules; @@ -44,8 +25,19 @@ void ReadTree2StringGrammar(istream* in, unordered_map 3); - if (line[pos - 1] == ' ') --pos; + unsigned xc = 0; + while (line[pos - 1] == ' ') { --pos; xc++; } cdec::TreeFragment rule_src(line.substr(0, pos), true); + Tree2StringGrammarNode* cur = &roots[rule_src.root]; + for (auto sym : rule_src) + cur = &cur->next[sym]; + pos += 3 + xc; + while(line[pos] == ' ') { ++pos; } + size_t pos2 = line.find("|||", pos); + assert(pos2 != string::npos); + while (line[pos2 - 1] == ' ') { --pos2; } + cur->rules = line.substr(pos, pos2 - pos); + cerr << "OUTPUT = '" << cur->rules << "'\n"; } } diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index d5c30f58..93aad64e 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -112,4 +112,12 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned *psymp = symp; } +BreadthFirstIterator TreeFragment::begin() const { + return BreadthFirstIterator(this); +} + +BreadthFirstIterator TreeFragment::end() const { + return BreadthFirstIterator(this, 0); +} + } diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index 83cd1c1e..b1dbbae0 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -1,6 +1,7 @@ #ifndef TREE_FRAGMENT #define TREE_FRAGMENT +#include #include #include #include @@ -9,6 +10,8 @@ namespace cdec { +class BreadthFirstIterator; + static const unsigned NT_BIT = 0x40000000u; static const unsigned FRONTIER_BIT = 0x80000000u; static const unsigned ALL_MASK = 0x0FFFFFFFu; @@ -36,6 +39,15 @@ class TreeFragment { // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar)))) explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false); void DebugRec(unsigned cur, std::ostream* out) const; + typedef BreadthFirstIterator iterator; + typedef ptrdiff_t difference_type; + typedef unsigned value_type; + typedef const unsigned * pointer; + typedef const unsigned & reference; + + iterator begin() const; + iterator end() const; + private: // cp is the character index in the tree // np keeps track of the nodes (nonterminals) that have been built @@ -49,6 +61,67 @@ class TreeFragment { std::vector nodes; }; +struct TFIState { + TFIState() : node(), rhspos() {} + TFIState(unsigned n, unsigned p) : node(n), rhspos(p) {} + bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos; } + bool operator!=(const TFIState& o) const { return node != o.node && rhspos != o.rhspos; } + unsigned short node; + unsigned short rhspos; +}; + +class BreadthFirstIterator : public std::iterator { + const TreeFragment* tf_; + std::queue q_; + unsigned sym; + public: + explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) { + q_.push(TFIState(tf->nodes.size() - 1, 0)); + Stage(); + } + BreadthFirstIterator(const TreeFragment* tf, int) : tf_(tf) {} + const unsigned& operator*() const { return sym; } + const unsigned* operator->() const { return &sym; } + bool operator==(const BreadthFirstIterator& other) const { + return (tf_ == other.tf_) && (q_ == other.q_); + } + bool operator!=(const BreadthFirstIterator& other) const { + return (tf_ != other.tf_) || (q_ != other.q_); + } + void Stage() { + if (q_.empty()) return; + const TFIState& s = q_.front(); + if (s.rhspos < 0) { + sym = tf_->nodes[s.node].lhs; + } else { + sym = tf_->nodes[s.node].rhs[s.rhspos]; + if (IsInternalNT(sym)) { + q_.push(TFIState(sym & ALL_MASK, 0)); + sym = tf_->nodes[sym & ALL_MASK].lhs; + } + } + } + const BreadthFirstIterator& operator++() { + TFIState& s = q_.front(); + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos > len) { + q_.pop(); + Stage(); + } else if (s.rhspos == len) { + sym = 0; + } else { + Stage(); + } + return *this; + } + BreadthFirstIterator operator++(int) { + BreadthFirstIterator res = *this; + ++(*this); + return res; + } +}; + inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) { x.DebugRec(x.nodes.size() - 1, &os); return os; -- cgit v1.2.3 From b27675f6d29a5b5da2f4211f3aa216aa321b2b97 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 27 Mar 2014 00:11:01 -0400 Subject: remove warnings --- decoder/tree_fragment.h | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) (limited to 'decoder') diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index b1dbbae0..a38dbdfa 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -91,14 +91,10 @@ class BreadthFirstIterator : public std::iteratornodes[s.node].lhs; - } else { - sym = tf_->nodes[s.node].rhs[s.rhspos]; - if (IsInternalNT(sym)) { - q_.push(TFIState(sym & ALL_MASK, 0)); - sym = tf_->nodes[sym & ALL_MASK].lhs; - } + sym = tf_->nodes[s.node].rhs[s.rhspos]; + if (IsInternalNT(sym)) { + q_.push(TFIState(sym & ALL_MASK, 0)); + sym = tf_->nodes[sym & ALL_MASK].lhs; } } const BreadthFirstIterator& operator++() { -- cgit v1.2.3 From 58c50cc49721445751499298d9e935601baaf7ca Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 30 Mar 2014 23:50:17 -0400 Subject: almost complete tree to string translator --- decoder/Makefile.am | 3 + decoder/decoder.cc | 6 +- decoder/t2s_test.cc | 110 ++++++++++++++++++++++++++++++++++ decoder/tree2string_translator.cc | 120 +++++++++++++++++++++++++++++++++----- decoder/tree_fragment.cc | 14 +++-- decoder/tree_fragment.h | 109 ++++++++++++++++++++++++---------- 6 files changed, 311 insertions(+), 51 deletions(-) create mode 100644 decoder/t2s_test.cc (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 7481192b..5c91fe65 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -4,9 +4,12 @@ noinst_PROGRAMS = \ trule_test \ hg_test \ parser_test \ + t2s_test \ grammar_test TESTS = trule_test parser_test grammar_test hg_test +t2s_test_SOURCES = t2s_test.cc +t2s_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a parser_test_SOURCES = parser_test.cc parser_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a grammar_test_SOURCES = grammar_test.cc diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 31049216..43e2640d 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -490,8 +490,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } formalism = LowercaseString(str("formalism",conf)); - if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; + if (formalism != "t2s" && formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 't2s', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } @@ -626,6 +626,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream // set up translation back end if (formalism == "scfg") translator.reset(new SCFGTranslator(conf)); + else if (formalism == "t2s") + translator.reset(new Tree2StringTranslator(conf)); else if (formalism == "fst") translator.reset(new FSTTranslator(conf)); else if (formalism == "pb") diff --git a/decoder/t2s_test.cc b/decoder/t2s_test.cc new file mode 100644 index 00000000..3c46ea89 --- /dev/null +++ b/decoder/t2s_test.cc @@ -0,0 +1,110 @@ +#include "tree_fragment.h" + +#define BOOST_TEST_MODULE T2STest +#include +#include +#include +#include "tdict.h" + +using namespace std; + +BOOST_AUTO_TEST_CASE(TestTreeFragments) { + cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); + cdec::TreeFragment tree2("(S (NP (DT a) (NN cat)) (VP (V ate) (NP (DT the) (NN cake pie))))"); + vector a, b; + vector aw, bw; + cerr << "TREE1: " << tree << endl; + cerr << "TREE2: " << tree2 << endl; + for (auto& sym : tree) + if (cdec::IsTerminal(sym)) aw.push_back(sym); else a.push_back(sym); + for (auto& sym : tree2) + if (cdec::IsTerminal(sym)) bw.push_back(sym); else b.push_back(sym); + BOOST_CHECK_EQUAL(a.size(), b.size()); + BOOST_CHECK_EQUAL(aw.size() + 1, bw.size()); + BOOST_CHECK_EQUAL(aw.size(), 5); + BOOST_CHECK_EQUAL(TD::GetString(aw), "the boy saw a cat"); + BOOST_CHECK_EQUAL(TD::GetString(bw), "a cat ate the cake pie"); + if (a != b) { + BOOST_CHECK_EQUAL(1,2); + } + + string nts; + for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { + if (cdec::IsNT(*it)) { + if (cdec::IsRHS(*it)) it.truncate(); + if (nts.size()) nts += " "; + if (cdec::IsLHS(*it)) nts += "("; + nts += TD::Convert(*it & cdec::ALL_MASK); + if (cdec::IsFrontier(*it)) nts += "*"; + } + } + BOOST_CHECK_EQUAL(nts, "(S NP* VP*"); + + nts.clear(); + int ntc = 0; + for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { + if (cdec::IsNT(*it)) { + if (cdec::IsRHS(*it)) { + ++ntc; + if (ntc > 1) it.truncate(); + } + if (nts.size()) nts += " "; + if (cdec::IsLHS(*it)) nts += "("; + nts += TD::Convert(*it & cdec::ALL_MASK); + if (cdec::IsFrontier(*it)) nts += "*"; + } + } + BOOST_CHECK_EQUAL(nts, "(S NP VP* (NP DT* NN*"); +} + +BOOST_AUTO_TEST_CASE(TestSharing) { + cdec::TreeFragment rule1("(S [NP] [VP])", true); + cdec::TreeFragment rule2("(S [NP] (VP [V] [NP]))", true); + string r1,r2; + for (auto sym : rule1) { + if (r1.size()) r1 += " "; + if (cdec::IsLHS(sym)) r1 += "("; + r1 += TD::Convert(sym & cdec::ALL_MASK); + if (cdec::IsFrontier(sym)) r1 += "*"; + } + for (auto sym : rule2) { + if (r2.size()) r2 += " "; + if (cdec::IsLHS(sym)) r2 += "("; + r2 += TD::Convert(sym & cdec::ALL_MASK); + if (cdec::IsFrontier(sym)) r2 += "*"; + } + cerr << rule1 << endl; + cerr << r1 << endl; + cerr << rule2 << endl; + cerr << r2 << endl; + BOOST_CHECK_EQUAL(r1, "(S NP* VP*"); + BOOST_CHECK_EQUAL(r2, "(S NP* VP (VP V* NP*"); +} + +BOOST_AUTO_TEST_CASE(TestEndInvariants) { + cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); + BOOST_CHECK(tree.end().at_end()); + BOOST_CHECK(!tree.begin().at_end()); +} + +BOOST_AUTO_TEST_CASE(TestBegins) { + cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); + for (auto it = tree.begin(1); it != tree.end(); ++it) { + cerr << TD::Convert(*it & cdec::ALL_MASK) << endl; + } +} + +BOOST_AUTO_TEST_CASE(TestRemainder) { + cdec::TreeFragment tree("(S (A a) (B b))"); + auto it = tree.begin(); + ++it; + BOOST_CHECK(cdec::IsRHS(*it)); + cerr << tree << endl; + auto itr = it.remainder(); + while(itr != tree.end()) { + cerr << TD::Convert(*itr & cdec::ALL_MASK) << endl; + ++itr; + } +} + + diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 1c249836..cd6ee550 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -1,5 +1,6 @@ #include #include +#include #include #include #include "tree_fragment.h" @@ -15,11 +16,10 @@ using namespace std; struct Tree2StringGrammarNode { map next; - string rules; + vector rules; }; -void ReadTree2StringGrammar(istream* in, unordered_map* proots) { - unordered_map& roots = *proots; +void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { string line; while(getline(*in, line)) { size_t pos = line.find("|||"); @@ -28,32 +28,124 @@ void ReadTree2StringGrammar(istream* in, unordered_map frhs; + for (auto sym : rule_src) { cur = &cur->next[sym]; + if (sym) { + if (cdec::IsFrontier(sym)) { // frontier symbols -> variables + int nt = (sym & cdec::ALL_MASK); + frhs.push_back(-nt); + } else if (cdec::IsTerminal(sym)) { + frhs.push_back(sym); + } + } + } + os << '[' << TD::Convert(-lhs) << "] |||"; + for (auto x : frhs) { + os << ' '; + if (x < 0) + os << '[' << TD::Convert(-x) << ']'; + else + os << TD::Convert(x); + } pos += 3 + xc; while(line[pos] == ' ') { ++pos; } - size_t pos2 = line.find("|||", pos); - assert(pos2 != string::npos); - while (line[pos2 - 1] == ' ') { --pos2; } - cur->rules = line.substr(pos, pos2 - pos); - cerr << "OUTPUT = '" << cur->rules << "'\n"; + os << " ||| " << line.substr(pos); + TRulePtr rule(new TRule(os.str())); + cur->rules.push_back(rule); } } +struct ParserState { + ParserState() : in_iter(), node() {} + cdec::TreeFragment::iterator in_iter; + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const int rt) : + in_iter(it), + root_type(rt), + node(n) {} + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : + in_iter(it), + future_work(p.future_work), + root_type(p.root_type), + node(n) {} + vector future_work; + int root_type; // lhs of top level NT + Tree2StringGrammarNode* node; +}; + struct Tree2StringTranslatorImpl { - unordered_map roots; // root['S'] gives rule network for S rules + Tree2StringGrammarNode root; Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) { ReadFile rf(conf["grammar"].as>()[0]); - ReadTree2StringGrammar(rf.stream(), &roots); + ReadTree2StringGrammar(rf.stream(), &root); } bool Translate(const string& input, SentenceMetadata* smeta, const vector& weights, Hypergraph* minus_lm_forest) { cdec::TreeFragment input_tree(input, false); - cerr << "Tree2StringTranslatorImpl: please implement this!\n"; - return false; + const int kS = -TD::Convert("S"); + Hypergraph hg; + queue q; + q.push(ParserState(input_tree.begin(), &root, kS)); + while(!q.empty()) { + ParserState& s = q.front(); + + if (s.in_iter.at_end()) { // completed a traversal of a subtree + cerr << "I traversed a subtree of the input...\n"; + if (s.node->rules.size()) { + // TODO: build hypergraph + for (auto& r : s.node->rules) + cerr << "I can build: " << r->AsString() << endl; + for (auto& w : s.future_work) + q.push(w); + } else { + cerr << "I can't build anything :(\n"; + } + } else { // more input tree to match + unsigned sym = *s.in_iter; + if (cdec::IsLHS(sym)) { + auto nit = s.node->next.find(sym); + if (nit != s.node->next.end()) { + //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + q.push(ParserState(++s.in_iter, &nit->second, s)); + } + } else if (cdec::IsRHS(sym)) { + //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + cdec::TreeFragment::iterator var = s.in_iter; + var.truncate(); + auto nit1 = s.node->next.find(sym); + auto nit2 = s.node->next.find(*var); + if (nit2 != s.node->next.end()) { + //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + ParserState new_s(++var, &nit2->second, s); + ParserState new_work(s.in_iter.remainder(), &root, -(sym & cdec::ALL_MASK)); + new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q + q.push(new_s); + } + if (nit1 != s.node->next.end()) { + //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + q.push(ParserState(++s.in_iter, &nit1->second, s)); + } + } else if (cdec::IsTerminal(sym)) { + auto nit = s.node->next.find(sym); + if (nit != s.node->next.end()) { + //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl; + q.push(ParserState(++s.in_iter, &nit->second, s)); + } + } else { + cerr << "This can never happen!\n"; abort(); + } + } + q.pop(); + } + minus_lm_forest->swap(hg); } }; diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 93aad64e..78a993b8 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -36,7 +36,7 @@ void TreeFragment::DebugRec(unsigned cur, ostream* out) const { *out << ' '; if (IsFrontier(x)) { *out << '[' << TD::Convert(x & ALL_MASK) << ']'; - } else if (IsInternalNT(x)) { + } else if (IsRHS(x)) { DebugRec(x & ALL_MASK, out); } else { // must be terminal *out << TD::Convert(x); @@ -66,7 +66,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned // recursively call parser to deal with constituent ParseRec(tree, afs, cp, symp, np, &cp, &symp, &np); unsigned ind = np - 1; - rhs.push_back(ind | NT_BIT); + rhs.push_back(ind | RHS_BIT); } else { // deal with terminal / nonterminal substitution ++symp; assert(tree[cp] != ' '); @@ -95,7 +95,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned } // continuent has completed, cp is at ), build node const unsigned j = symp; // span from (i,j) // add an internal non-terminal symbol - const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | NT_BIT; + const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | RHS_BIT; nodes[np] = TreeFragmentProduction(nt, rhs); //cerr << np << " production(" << i << "," << j << ")= " << TD::Convert(nt & ALL_MASK) << " -->"; //for (auto& x : rhs) { @@ -113,11 +113,15 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned } BreadthFirstIterator TreeFragment::begin() const { - return BreadthFirstIterator(this); + return BreadthFirstIterator(this, nodes.size() - 1); +} + +BreadthFirstIterator TreeFragment::begin(unsigned node_idx) const { + return BreadthFirstIterator(this, node_idx); } BreadthFirstIterator TreeFragment::end() const { - return BreadthFirstIterator(this, 0); + return BreadthFirstIterator(this); } } diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index a38dbdfa..b83afc27 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -1,7 +1,7 @@ #ifndef TREE_FRAGMENT #define TREE_FRAGMENT -#include +#include #include #include #include @@ -12,18 +12,32 @@ namespace cdec { class BreadthFirstIterator; -static const unsigned NT_BIT = 0x40000000u; -static const unsigned FRONTIER_BIT = 0x80000000u; -static const unsigned ALL_MASK = 0x0FFFFFFFu; +static const unsigned LHS_BIT = 0x10000000u; +static const unsigned RHS_BIT = 0x20000000u; +static const unsigned FRONTIER_BIT = 0x40000000u; +static const unsigned RESERVED_BIT = 0x80000000u; +static const unsigned ALL_MASK = 0x0FFFFFFFu; -inline bool IsInternalNT(unsigned x) { - return (x & NT_BIT); +inline bool IsNT(unsigned x) { + return (x & (LHS_BIT | RHS_BIT | FRONTIER_BIT)); +} + +inline bool IsLHS(unsigned x) { + return (x & LHS_BIT); +} + +inline bool IsRHS(unsigned x) { + return (x & RHS_BIT); } inline bool IsFrontier(unsigned x) { return (x & FRONTIER_BIT); } +inline bool IsTerminal(unsigned x) { + return (x & ALL_MASK) == x; +} + struct TreeFragmentProduction { TreeFragmentProduction() {} TreeFragmentProduction(int nttype, const std::vector& r) : lhs(nttype), rhs(r) {} @@ -46,6 +60,7 @@ class TreeFragment { typedef const unsigned & reference; iterator begin() const; + iterator begin(unsigned node_idx) const; iterator end() const; private: @@ -62,24 +77,28 @@ class TreeFragment { }; struct TFIState { - TFIState() : node(), rhspos() {} - TFIState(unsigned n, unsigned p) : node(n), rhspos(p) {} - bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos; } - bool operator!=(const TFIState& o) const { return node != o.node && rhspos != o.rhspos; } + TFIState() : node(), rhspos(), state() {} + TFIState(unsigned n, unsigned p, unsigned s) : node(n), rhspos(p), state(s) {} + bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos && state == o.state; } + bool operator!=(const TFIState& o) const { return node != o.node || rhspos != o.rhspos || state != o.state; } unsigned short node; unsigned short rhspos; + unsigned char state; }; class BreadthFirstIterator : public std::iterator { const TreeFragment* tf_; - std::queue q_; + std::deque q_; unsigned sym; public: - explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) { - q_.push(TFIState(tf->nodes.size() - 1, 0)); + BreadthFirstIterator() : tf_(), sym() {} + // used for begin + explicit BreadthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) { + q_.push_back(TFIState(node_idx, 0, 0)); Stage(); } - BreadthFirstIterator(const TreeFragment* tf, int) : tf_(tf) {} + // used for end + explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) {} const unsigned& operator*() const { return sym; } const unsigned* operator->() const { return &sym; } bool operator==(const BreadthFirstIterator& other) const { @@ -88,26 +107,20 @@ class BreadthFirstIterator : public std::iteratornodes[s.node].rhs[s.rhspos]; - if (IsInternalNT(sym)) { - q_.push(TFIState(sym & ALL_MASK, 0)); - sym = tf_->nodes[sym & ALL_MASK].lhs; - } - } const BreadthFirstIterator& operator++() { TFIState& s = q_.front(); - const unsigned len = tf_->nodes[s.node].rhs.size(); - s.rhspos++; - if (s.rhspos > len) { - q_.pop(); + if (s.state == 0) { + s.state++; Stage(); - } else if (s.rhspos == len) { - sym = 0; } else { - Stage(); + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos >= len) { + q_.pop_front(); + Stage(); + } else { + Stage(); + } } return *this; } @@ -116,6 +129,42 @@ class BreadthFirstIterator : public std::iteratornodes[s.node].lhs & ALL_MASK) | LHS_BIT; + } else { + sym = tf_->nodes[s.node].rhs[s.rhspos]; + if (IsRHS(sym)) { + q_.push_back(TFIState(sym & ALL_MASK, 0, 0)); + sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT; + } + } + } + + // used by remainder + BreadthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) { + q_.push_back(s); + Stage(); + } }; inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) { -- cgit v1.2.3 From c80f10eed03430423c00a94b19ce7d3e68c8c65a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 1 Apr 2014 00:19:19 -0400 Subject: minimally tested t2s translator --- decoder/tree2string_translator.cc | 48 +++++++++++++++++++++++++++++---------- decoder/tree_fragment.h | 1 + 2 files changed, 37 insertions(+), 12 deletions(-) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index cd6ee550..09eca147 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -65,17 +65,17 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { struct ParserState { ParserState() : in_iter(), node() {} cdec::TreeFragment::iterator in_iter; - ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const int rt) : + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n) : in_iter(it), - root_type(rt), + input_node_idx(it.node_idx()), node(n) {} ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : in_iter(it), future_work(p.future_work), - root_type(p.root_type), + input_node_idx(p.input_node_idx), node(n) {} vector future_work; - int root_type; // lhs of top level NT + int input_node_idx; // lhs of top level NT Tree2StringGrammarNode* node; }; @@ -90,23 +90,40 @@ struct Tree2StringTranslatorImpl { const vector& weights, Hypergraph* minus_lm_forest) { cdec::TreeFragment input_tree(input, false); - const int kS = -TD::Convert("S"); Hypergraph hg; + hg.ReserveNodes(input_tree.nodes.size()); + vector tree2hg(input_tree.nodes.size() + 1, -1); queue q; - q.push(ParserState(input_tree.begin(), &root, kS)); + q.push(ParserState(input_tree.begin(), &root)); + unsigned tree_top = q.front().input_node_idx; while(!q.empty()) { ParserState& s = q.front(); if (s.in_iter.at_end()) { // completed a traversal of a subtree - cerr << "I traversed a subtree of the input...\n"; + //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" << + // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; if (s.node->rules.size()) { - // TODO: build hypergraph - for (auto& r : s.node->rules) - cerr << "I can build: " << r->AsString() << endl; + TailNodeVector tail; + int& node_id = tree2hg[s.input_node_idx]; + if (node_id < 0) + node_id = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK))->id_; + for (auto& n : s.future_work) { + int& nix = tree2hg[n.input_node_idx]; + if (nix < 0) + nix = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK))->id_; + tail.push_back(nix); + } + for (auto& r : s.node->rules) { + assert(tail.size() == r->Arity()); + HG::Edge* new_edge = hg.AddEdge(r, tail); + new_edge->feature_values_ = r->GetFeatureValues(); + // TODO: set i and j + hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]); + } for (auto& w : s.future_work) q.push(w); } else { - cerr << "I can't build anything :(\n"; + //cerr << "I can't build anything :(\n"; } } else { // more input tree to match unsigned sym = *s.in_iter; @@ -125,7 +142,7 @@ struct Tree2StringTranslatorImpl { if (nit2 != s.node->next.end()) { //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; ParserState new_s(++var, &nit2->second, s); - ParserState new_work(s.in_iter.remainder(), &root, -(sym & cdec::ALL_MASK)); + ParserState new_work(s.in_iter.remainder(), &root); new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q q.push(new_s); } @@ -145,7 +162,14 @@ struct Tree2StringTranslatorImpl { } q.pop(); } + int goal = tree2hg[tree_top]; + if (goal < 0) return false; + //cerr << "Goal node: " << goal << endl; + hg.TopologicallySortNodesAndEdges(goal); + hg.Reweight(weights); + //hg.PrintGraphviz(); minus_lm_forest->swap(hg); + return true; } }; diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index b83afc27..ceb7fa60 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -107,6 +107,7 @@ class BreadthFirstIterator : public std::iterator Date: Tue, 1 Apr 2014 00:53:00 -0400 Subject: tree2string test, fix for edge case --- decoder/tree2string_translator.cc | 6 ++++++ tests/system_tests/t2s/cdec.ini | 2 ++ tests/system_tests/t2s/gold.statistics | 3 +++ tests/system_tests/t2s/gold.stdout | 1 + tests/system_tests/t2s/grammar.t2s | 8 ++++++++ tests/system_tests/t2s/input.txt | 1 + tests/system_tests/t2s/weights | 6 ++++++ 7 files changed, 27 insertions(+) create mode 100644 tests/system_tests/t2s/cdec.ini create mode 100644 tests/system_tests/t2s/gold.statistics create mode 100644 tests/system_tests/t2s/gold.stdout create mode 100644 tests/system_tests/t2s/grammar.t2s create mode 100644 tests/system_tests/t2s/input.txt create mode 100644 tests/system_tests/t2s/weights (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 09eca147..7bc49132 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -167,6 +167,12 @@ struct Tree2StringTranslatorImpl { //cerr << "Goal node: " << goal << endl; hg.TopologicallySortNodesAndEdges(goal); hg.Reweight(weights); + + // there might be nodes that cannot be derived + // the following takes care of them + vector prune(hg.edges_.size(), false); + hg.PruneEdges(prune, true); + //hg.PrintGraphviz(); minus_lm_forest->swap(hg); return true; diff --git a/tests/system_tests/t2s/cdec.ini b/tests/system_tests/t2s/cdec.ini new file mode 100644 index 00000000..ad83438f --- /dev/null +++ b/tests/system_tests/t2s/cdec.ini @@ -0,0 +1,2 @@ +formalism=t2s +grammar=grammar.t2s diff --git a/tests/system_tests/t2s/gold.statistics b/tests/system_tests/t2s/gold.statistics new file mode 100644 index 00000000..452cc93e --- /dev/null +++ b/tests/system_tests/t2s/gold.statistics @@ -0,0 +1,3 @@ +-lm_nodes 6 +-lm_edges 8 +-lm_paths 4 diff --git a/tests/system_tests/t2s/gold.stdout b/tests/system_tests/t2s/gold.stdout new file mode 100644 index 00000000..afb11818 --- /dev/null +++ b/tests/system_tests/t2s/gold.stdout @@ -0,0 +1 @@ +qiangshou bei jingfang jibi . diff --git a/tests/system_tests/t2s/grammar.t2s b/tests/system_tests/t2s/grammar.t2s new file mode 100644 index 00000000..2e6cf68c --- /dev/null +++ b/tests/system_tests/t2s/grammar.t2s @@ -0,0 +1,8 @@ +(S [NP-C] [VP] (PUNC .)) ||| [1] [2] . ||| R1=1 +(NP-C (DT the) (NN gunman)) ||| qiangshou ||| R2=1 +(NP-C (DT the) [NN]) ||| [1] ||| R2a=1 +(NN gunman) ||| qiangshou ||| R2b=1 +(VP (VBD was) (VP-C [VBN] (PP (IN by) [NP-C]))) ||| bei [2] [1] ||| R3=1 +(NP-C (DT the) (NN police)) ||| jingfang ||| R4=1 +(VBN killed) ||| jibi ||| R5=1 +(VBN killed) ||| killed' ||| R6=1 diff --git a/tests/system_tests/t2s/input.txt b/tests/system_tests/t2s/input.txt new file mode 100644 index 00000000..b8fe314e --- /dev/null +++ b/tests/system_tests/t2s/input.txt @@ -0,0 +1 @@ +(S (NP-C (DT the) (NN gunman)) (VP (VBD was) (VP-C (VBN killed) (PP (IN by) (NP-C (DT the) (NN police))))) (PUNC .)) diff --git a/tests/system_tests/t2s/weights b/tests/system_tests/t2s/weights new file mode 100644 index 00000000..4980db45 --- /dev/null +++ b/tests/system_tests/t2s/weights @@ -0,0 +1,6 @@ +R1 1 +R2a 1 +R2b 1 +R3 1 +R5 1 +R4 1 -- cgit v1.2.3 From 1fb851b365765b01dbbedaf6699e7ff4c2b41b4a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 1 Apr 2014 18:47:20 -0400 Subject: deal with multiple grammars in t2s --- decoder/tree2string_translator.cc | 80 +++++++++++++++++++++++++++++---------- 1 file changed, 61 insertions(+), 19 deletions(-) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 7bc49132..6966ccf8 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -1,8 +1,10 @@ #include #include #include +#include +#include +#include #include -#include #include "tree_fragment.h" #include "translator.h" #include "hg.h" @@ -74,16 +76,43 @@ struct ParserState { future_work(p.future_work), input_node_idx(p.input_node_idx), node(n) {} - vector future_work; + bool operator==(const ParserState& o) const { + return node == o.node && input_node_idx == o.input_node_idx && + future_work == o.future_work && in_iter == o.in_iter; + } + vector future_work; int input_node_idx; // lhs of top level NT Tree2StringGrammarNode* node; }; +namespace std { + template<> + struct hash { + size_t operator()(const ParserState& s) const { + size_t h = boost::hash_range(s.future_work.begin(), s.future_work.end()); + boost::hash_combine(h, boost::hash_value(s.node)); + boost::hash_combine(h, boost::hash_value(s.input_node_idx)); + //boost::hash_combine(h, ); + return h; + } + }; +}; + struct Tree2StringTranslatorImpl { - Tree2StringGrammarNode root; - Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) { - ReadFile rf(conf["grammar"].as>()[0]); - ReadTree2StringGrammar(rf.stream(), &root); + vector> root; + bool add_pass_through_rules; + Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) : + add_pass_through_rules(conf.count("add_pass_through_rules")) { + if (conf.count("grammar")) { + const vector gf = conf["grammar"].as>(); + root.resize(gf.size()); + unsigned gc = 0; + for (auto& f : gf) { + ReadFile rf(f); + root[gc].reset(new Tree2StringGrammarNode); + ReadTree2StringGrammar(rf.stream(), &*root[gc++]); + } + } } bool Translate(const string& input, SentenceMetadata* smeta, @@ -94,7 +123,11 @@ struct Tree2StringTranslatorImpl { hg.ReserveNodes(input_tree.nodes.size()); vector tree2hg(input_tree.nodes.size() + 1, -1); queue q; - q.push(ParserState(input_tree.begin(), &root)); + unordered_set unique; // only create items one time + for (auto& g : root) { + q.push(ParserState(input_tree.begin(), g.get())); + unique.insert(q.back()); + } unsigned tree_top = q.front().input_node_idx; while(!q.empty()) { ParserState& s = q.front(); @@ -103,14 +136,14 @@ struct Tree2StringTranslatorImpl { //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" << // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; if (s.node->rules.size()) { - TailNodeVector tail; int& node_id = tree2hg[s.input_node_idx]; if (node_id < 0) node_id = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK))->id_; - for (auto& n : s.future_work) { - int& nix = tree2hg[n.input_node_idx]; + TailNodeVector tail; + for (auto n : s.future_work) { + int& nix = tree2hg[n]; if (nix < 0) - nix = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK))->id_; + nix = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK))->id_; tail.push_back(nix); } for (auto& r : s.node->rules) { @@ -120,8 +153,13 @@ struct Tree2StringTranslatorImpl { // TODO: set i and j hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]); } - for (auto& w : s.future_work) - q.push(w); + for (auto n : s.future_work) { + const auto it = input_tree.begin(n); // start tree iterator at node n + for (auto& g : root) { + ParserState s(it, g.get()); + if (unique.insert(s).second) q.push(s); + } + } } else { //cerr << "I can't build anything :(\n"; } @@ -131,7 +169,8 @@ struct Tree2StringTranslatorImpl { auto nit = s.node->next.find(sym); if (nit != s.node->next.end()) { //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; - q.push(ParserState(++s.in_iter, &nit->second, s)); + ParserState news(++s.in_iter, &nit->second, s); + if (unique.insert(news).second) q.push(news); } } else if (cdec::IsRHS(sym)) { //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; @@ -141,20 +180,23 @@ struct Tree2StringTranslatorImpl { auto nit2 = s.node->next.find(*var); if (nit2 != s.node->next.end()) { //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; - ParserState new_s(++var, &nit2->second, s); - ParserState new_work(s.in_iter.remainder(), &root); + ++var; + const unsigned new_work = s.in_iter.child_node(); + ParserState new_s(var, &nit2->second, s); new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q - q.push(new_s); + if (unique.insert(new_s).second) q.push(new_s); } if (nit1 != s.node->next.end()) { //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; - q.push(ParserState(++s.in_iter, &nit1->second, s)); + const ParserState new_s(++s.in_iter, &nit1->second, s); + if (unique.insert(new_s).second) q.push(new_s); } } else if (cdec::IsTerminal(sym)) { auto nit = s.node->next.find(sym); if (nit != s.node->next.end()) { //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl; - q.push(ParserState(++s.in_iter, &nit->second, s)); + const ParserState new_s(++s.in_iter, &nit->second, s); + if (unique.insert(new_s).second) q.push(new_s); } } else { cerr << "This can never happen!\n"; abort(); -- cgit v1.2.3 From d75c78040cb49e69e4702a2c35e997315a57150b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 1 Apr 2014 18:49:22 -0400 Subject: deal with multiple grammars in t2s --- decoder/tree_fragment.h | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'decoder') diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index ceb7fa60..f1c4c106 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -139,6 +139,10 @@ class BreadthFirstIterator : public std::iterator Date: Tue, 1 Apr 2014 21:24:11 -0400 Subject: check for empty hg --- decoder/tree2string_translator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 6966ccf8..6f65658e 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -214,7 +214,7 @@ struct Tree2StringTranslatorImpl { // the following takes care of them vector prune(hg.edges_.size(), false); hg.PruneEdges(prune, true); - + if (hg.edges_.size() == 0) return false; //hg.PrintGraphviz(); minus_lm_forest->swap(hg); return true; -- cgit v1.2.3 From 10427eab298e019bffed15e71125d34b2e27f468 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 1 Apr 2014 21:52:15 -0400 Subject: deal with pass through rules --- decoder/tree2string_translator.cc | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 6f65658e..4cd584fb 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -101,6 +101,7 @@ namespace std { struct Tree2StringTranslatorImpl { vector> root; bool add_pass_through_rules; + unsigned remove_grammars; Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) : add_pass_through_rules(conf.count("add_pass_through_rules")) { if (conf.count("grammar")) { @@ -114,11 +115,56 @@ struct Tree2StringTranslatorImpl { } } } + + void CreatePassThroughRules(const cdec::TreeFragment& tree) { + static const int kFID = FD::Convert("PassThrough"); + root.resize(root.size() + 1); + root.back().reset(new Tree2StringGrammarNode); + ++remove_grammars; + for (auto& prod : tree.nodes) { + ostringstream os; + vector rhse, rhsf; + int ntc = 0; + int lhs = -(prod.lhs & cdec::ALL_MASK); + os << '(' << TD::Convert(-lhs); + for (auto& sym : prod.rhs) { + os << ' '; + if (cdec::IsTerminal(sym)) { + os << TD::Convert(sym); + rhse.push_back(sym); + rhsf.push_back(sym); + } else { + unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK; + os << '[' << TD::Convert(id) << ']'; + rhsf.push_back(-id); + rhse.push_back(-ntc); + ++ntc; + } + } + os << ')'; + cdec::TreeFragment rule_src(os.str(), true); + Tree2StringGrammarNode* cur = root.back().get(); + for (auto sym : rule_src) + cur = &cur->next[sym]; + TRulePtr rule(new TRule(rhse, rhsf, lhs)); + rule->ComputeArity(); + rule->scores_.set_value(kFID, 1.0); + cur->rules.push_back(rule); + } + } + + void RemoveGrammars() { + assert(remove_grammars < root.size()); + root.resize(root.size() - remove_grammars); + } + bool Translate(const string& input, SentenceMetadata* smeta, const vector& weights, Hypergraph* minus_lm_forest) { + remove_grammars = 0; cdec::TreeFragment input_tree(input, false); + if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; hg.ReserveNodes(input_tree.nodes.size()); vector tree2hg(input_tree.nodes.size() + 1, -1); @@ -235,6 +281,7 @@ void Tree2StringTranslator::ProcessMarkupHintsImpl(const map& kv } void Tree2StringTranslator::SentenceCompleteImpl() { + pimpl_->RemoveGrammars(); } std::string Tree2StringTranslator::GetDecoderType() const { -- cgit v1.2.3 From 54e527dbce04fdac6a5dc30c491b53d405d42931 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 2 Apr 2014 20:19:35 -0400 Subject: fix floating point issue --- decoder/rule_lexer.ll | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'decoder') diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll index 05963d05..cc73c079 100644 --- a/decoder/rule_lexer.ll +++ b/decoder/rule_lexer.ll @@ -119,7 +119,7 @@ void check_and_update_ctf_stack(const TRulePtr& rp) { %} -REAL [\-+]?[0-9]+(\.[0-9]*([eE][-+]*[0-9]+)?)?|inf|[\-+]inf +REAL [\-+]?[0-9]+(\.[0-9]*)?([eE][-+]*[0-9]+)? NT [^\t \[\],]+ %x LHS_END SRC TRG FEATS FEATVAL ALIGNS TREE -- cgit v1.2.3 From cb85bfbe908f6001a92b0af99ea03f369416059e Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 7 Apr 2014 00:54:52 -0400 Subject: clean up dead TRule code --- decoder/bottom_up_parser.cc | 2 +- decoder/grammar_test.cc | 6 +- decoder/hg_test.cc | 60 ++++++---- decoder/hg_test.h | 69 ++++------- decoder/rule_lexer.h | 1 + decoder/rule_lexer.ll | 47 ++++++-- decoder/scfg_translator.cc | 2 +- decoder/test_data/hg_test.hg | 1 + decoder/test_data/hg_test.hg_balanced | 1 + decoder/test_data/hg_test.hg_int | 1 + decoder/test_data/hg_test.lattice | 1 + decoder/test_data/hg_test.tiny | 1 + decoder/test_data/hg_test.tiny_lattice | 1 + decoder/test_data/small.json.gz | Bin 1561 -> 1733 bytes decoder/tree2string_translator.cc | 1 + decoder/trule.cc | 202 +++------------------------------ decoder/trule.h | 22 ++-- training/dpmert/lo_test.cc | 2 +- 18 files changed, 142 insertions(+), 278 deletions(-) create mode 100644 decoder/test_data/hg_test.hg create mode 100644 decoder/test_data/hg_test.hg_balanced create mode 100644 decoder/test_data/hg_test.hg_int create mode 100644 decoder/test_data/hg_test.lattice create mode 100644 decoder/test_data/hg_test.tiny create mode 100644 decoder/test_data/hg_test.tiny_lattice (limited to 'decoder') diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index 8738c8f1..ff4c7a90 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -159,7 +159,7 @@ PassiveChart::PassiveChart(const string& goal, chart_(input.size()+1, input.size()+1), nodemap_(input.size()+1, input.size()+1), goal_cat_(TD::Convert(goal) * -1), - goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")), + goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), goal_idx_(-1), lc_fid_(FD::Convert("LatticeCost")), unaries_() { diff --git a/decoder/grammar_test.cc b/decoder/grammar_test.cc index 6d2c6e67..69240139 100644 --- a/decoder/grammar_test.cc +++ b/decoder/grammar_test.cc @@ -33,9 +33,9 @@ BOOST_AUTO_TEST_CASE(TestTextGrammar) { ModelSet models(w, ms); TextGrammar g; - TRulePtr r1(new TRule("[X] ||| a b c ||| A B C ||| 0.1 0.2 0.3", true)); - TRulePtr r2(new TRule("[X] ||| a b c ||| 1 2 3 ||| 0.2 0.3 0.4", true)); - TRulePtr r3(new TRule("[X] ||| a b c d ||| A B C D ||| 0.1 0.2 0.3", true)); + TRulePtr r1(new TRule("[X] ||| a b c ||| A B C ||| 0.1 0.2 0.3")); + TRulePtr r2(new TRule("[X] ||| a b c ||| 1 2 3 ||| 0.2 0.3 0.4")); + TRulePtr r3(new TRule("[X] ||| a b c d ||| A B C D ||| 0.1 0.2 0.3")); cerr << r1->AsString() << endl; g.AddRule(r1); g.AddRule(r2); diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 8519e559..95cfae51 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -18,8 +18,10 @@ using namespace std; BOOST_FIXTURE_TEST_SUITE( s, HGSetup ); BOOST_AUTO_TEST_CASE(Controlled) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); + cerr << "PATH: " << path << "/hg.tiny\n"; Hypergraph hg; - CreateHG_tiny(&hg); + CreateHG_tiny(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 0.8); @@ -37,10 +39,11 @@ BOOST_AUTO_TEST_CASE(Controlled) { } BOOST_AUTO_TEST_CASE(Union) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg1; Hypergraph hg2; - CreateHG_tiny(&hg1); - CreateHG(&hg2); + CreateHG_tiny(path, &hg1); + CreateHG(path, &hg2); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 1.0); @@ -84,8 +87,9 @@ BOOST_AUTO_TEST_CASE(Union) { } BOOST_AUTO_TEST_CASE(ControlledKBest) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); + CreateHG(path, &hg); vector w(2); w[0]=0.4; w[1]=0.8; hg.Reweight(w); vector trans; @@ -107,10 +111,11 @@ BOOST_AUTO_TEST_CASE(ControlledKBest) { BOOST_AUTO_TEST_CASE(InsideScore) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); SparseVector wts; wts.set_value(FD::Convert("f1"), 1.0); Hypergraph hg; - CreateTinyLatticeHG(&hg); + CreateTinyLatticeHG(path, &hg); hg.Reweight(wts); vector trans; prob_t cost = ViterbiESentence(hg, &trans); @@ -130,10 +135,11 @@ BOOST_AUTO_TEST_CASE(InsideScore) { BOOST_AUTO_TEST_CASE(PruneInsideOutside) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); SparseVector wts; wts.set_value(FD::Convert("Feature_1"), 1.0); Hypergraph hg; - CreateLatticeHG(&hg); + CreateLatticeHG(path, &hg); hg.Reweight(wts); vector trans; prob_t cost = ViterbiESentence(hg, &trans); @@ -152,8 +158,9 @@ BOOST_AUTO_TEST_CASE(PruneInsideOutside) { } BOOST_AUTO_TEST_CASE(TestPruneEdges) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateLatticeHG(&hg); + CreateLatticeHG(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 1.0); hg.Reweight(wts); @@ -166,8 +173,9 @@ BOOST_AUTO_TEST_CASE(TestPruneEdges) { } BOOST_AUTO_TEST_CASE(TestIntersect) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG_int(&hg); + CreateHG_int(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 1.0); hg.Reweight(wts); @@ -192,8 +200,9 @@ BOOST_AUTO_TEST_CASE(TestIntersect) { } BOOST_AUTO_TEST_CASE(TestPrune2) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG_int(&hg); + CreateHG_int(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 1.0); hg.Reweight(wts); @@ -207,8 +216,9 @@ BOOST_AUTO_TEST_CASE(TestPrune2) { } BOOST_AUTO_TEST_CASE(Sample) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateLatticeHG(&hg); + CreateLatticeHG(path, &hg); SparseVector wts; wts.set_value(FD::Convert("Feature_1"), 0.0); hg.Reweight(wts); @@ -220,6 +230,7 @@ BOOST_AUTO_TEST_CASE(Sample) { } BOOST_AUTO_TEST_CASE(PLF) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)"; HypergraphIO::ReadFromPLF(inplf, &hg); @@ -234,8 +245,9 @@ BOOST_AUTO_TEST_CASE(PLF) { } BOOST_AUTO_TEST_CASE(PushWeightsToGoal) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); + CreateHG(path, &hg); vector w(2); w[0]=0.4; w[1]=0.8; hg.Reweight(w); vector trans; @@ -248,8 +260,9 @@ BOOST_AUTO_TEST_CASE(PushWeightsToGoal) { } BOOST_AUTO_TEST_CASE(TestSpecialKBest) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHGBalanced(&hg); + CreateHGBalanced(path, &hg); vector w(1); w[0]=0; hg.Reweight(w); vector, prob_t> > list; @@ -264,8 +277,9 @@ BOOST_AUTO_TEST_CASE(TestSpecialKBest) { } BOOST_AUTO_TEST_CASE(TestGenericViterbi) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG_tiny(&hg); + CreateHG_tiny(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 0.8); @@ -279,8 +293,9 @@ BOOST_AUTO_TEST_CASE(TestGenericViterbi) { } BOOST_AUTO_TEST_CASE(TestGenericInside) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateTinyLatticeHG(&hg); + CreateTinyLatticeHG(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 1.0); hg.Reweight(wts); @@ -296,8 +311,9 @@ BOOST_AUTO_TEST_CASE(TestGenericInside) { } BOOST_AUTO_TEST_CASE(TestGenericInside2) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); + CreateHG(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 0.8); @@ -322,8 +338,9 @@ BOOST_AUTO_TEST_CASE(TestGenericInside2) { } BOOST_AUTO_TEST_CASE(TestAddExpectations) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); + CreateHG(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 0.8); @@ -338,8 +355,8 @@ BOOST_AUTO_TEST_CASE(TestAddExpectations) { } BOOST_AUTO_TEST_CASE(Small) { - Hypergraph hg; std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); + Hypergraph hg; CreateSmallHG(&hg, path); SparseVector wts; wts.set_value(FD::Convert("Model_0"), -2.0); @@ -361,6 +378,7 @@ BOOST_AUTO_TEST_CASE(Small) { } BOOST_AUTO_TEST_CASE(JSONTest) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); ostringstream os; JSONParser::WriteEscapedString("\"I don't know\", she said.", &os); BOOST_CHECK_EQUAL("\"\\\"I don't know\\\", she said.\"", os.str()); @@ -370,9 +388,10 @@ BOOST_AUTO_TEST_CASE(JSONTest) { } BOOST_AUTO_TEST_CASE(TestGenericKBest) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); - //CreateHGBalanced(&hg); + CreateHG(path, &hg); + //CreateHGBalanced(path, &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 1.0); @@ -392,8 +411,9 @@ BOOST_AUTO_TEST_CASE(TestGenericKBest) { } BOOST_AUTO_TEST_CASE(TestReadWriteHG) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg,hg2; - CreateHG(&hg); + CreateHG(path, &hg); hg.edges_.front().j_ = 23; hg.edges_.back().prev_i_ = 99; ostringstream os; diff --git a/decoder/hg_test.h b/decoder/hg_test.h index e96cb0b1..b7bab3c2 100644 --- a/decoder/hg_test.h +++ b/decoder/hg_test.h @@ -23,25 +23,13 @@ Name perro_wts="SameFirstLetter 1 LongerThanPrev 1 ShorterThanPrev 1 GlueTop 0.0 // you can inherit from this or just use the static methods struct HGSetup { - enum { - HG, - HG_int, - HG_tiny, - HGBalanced, - LatticeHG, - TinyLatticeHG, - }; - static void CreateHG(Hypergraph* hg); - static void CreateHG_int(Hypergraph* hg); - static void CreateHG_tiny(Hypergraph* hg); - static void CreateHGBalanced(Hypergraph* hg); - static void CreateLatticeHG(Hypergraph* hg); - static void CreateTinyLatticeHG(Hypergraph* hg); - - static void Json(Hypergraph *hg,std::string const& json) { - std::istringstream i(json); - HypergraphIO::ReadFromJSON(&i, hg); - } + static void CreateHG(const std::string& path,Hypergraph* hg); + static void CreateHG_int(const std::string& path,Hypergraph* hg); + static void CreateHG_tiny(const std::string& path, Hypergraph* hg); + static void CreateHGBalanced(const std::string& path,Hypergraph* hg); + static void CreateLatticeHG(const std::string& path,Hypergraph* hg); + static void CreateTinyLatticeHG(const std::string& path,Hypergraph* hg); + static void JsonFile(Hypergraph *hg,std::string f) { ReadFile rf(f); HypergraphIO::ReadFromJSON(rf.stream(), hg); @@ -52,18 +40,6 @@ struct HGSetup { static void CreateSmallHG(Hypergraph *hg, std::string path) { JsonTestFile(hg,path,small_json); } }; -namespace { -Name HGjsons[]= { - "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}", -"{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| b\",3,\"[X] ||| a [1]\",4,\"[X] ||| [1] b\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,0.1],\"rule\":1},{\"tail\":[],\"feats\":[0,0.1],\"rule\":2}],\"node\":{\"in_edges\":[0,1],\"cat\":\"X\"},\"edges\":[{\"tail\":[0],\"feats\":[0,0.3],\"rule\":3},{\"tail\":[0],\"feats\":[0,0.2],\"rule\":4}],\"node\":{\"in_edges\":[2,3],\"cat\":\"Goal\"}}", - "{\"rules\":[1,\"[X] ||| \",2,\"[X] ||| X [1]\",3,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,-2,1,-99],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.5,1,-0.8],\"rule\":2},{\"tail\":[0],\"feats\":[0,-0.7,1,-0.9],\"rule\":3}],\"node\":{\"in_edges\":[1,2]}}", - "{\"rules\":[1,\"[X] ||| i\",2,\"[X] ||| a\",3,\"[X] ||| b\",4,\"[X] ||| [1] [2]\",5,\"[X] ||| [1] [2]\",6,\"[X] ||| c\",7,\"[X] ||| d\",8,\"[X] ||| [1] [2]\",9,\"[X] ||| [1] [2]\",10,\"[X] ||| [1] [2]\",11,\"[X] ||| [1] [2]\",12,\"[X] ||| [1] [2]\",13,\"[X] ||| [1] [2]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[1,2],\"feats\":[],\"rule\":4},{\"tail\":[2,1],\"feats\":[],\"rule\":5}],\"node\":{\"in_edges\":[3,4]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":6}],\"node\":{\"in_edges\":[5]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":7}],\"node\":{\"in_edges\":[6]},\"edges\":[{\"tail\":[4,5],\"feats\":[],\"rule\":8},{\"tail\":[5,4],\"feats\":[],\"rule\":9}],\"node\":{\"in_edges\":[7,8]},\"edges\":[{\"tail\":[3,6],\"feats\":[],\"rule\":10},{\"tail\":[6,3],\"feats\":[],\"rule\":11}],\"node\":{\"in_edges\":[9,10]},\"edges\":[{\"tail\":[7,0],\"feats\":[],\"rule\":12},{\"tail\":[0,7],\"feats\":[],\"rule\":13}],\"node\":{\"in_edges\":[11,12]}}", - "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] A A\",4,\"[X] ||| [1] b\",5,\"[X] ||| [1] c\",6,\"[X] ||| [1] B C\",7,\"[X] ||| [1] A B C\",8,\"[X] ||| [1] CC\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[2,-0.3],\"rule\":1},{\"tail\":[0],\"feats\":[2,-0.6],\"rule\":2},{\"tail\":[0],\"feats\":[2,-1.7],\"rule\":3}],\"node\":{\"in_edges\":[0,1,2]},\"edges\":[{\"tail\":[1],\"feats\":[2,-0.5],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[2],\"feats\":[2,-0.6],\"rule\":5},{\"tail\":[1],\"feats\":[2,-0.8],\"rule\":6},{\"tail\":[0],\"feats\":[2,-0.01],\"rule\":7},{\"tail\":[2],\"feats\":[2,-0.8],\"rule\":8}],\"node\":{\"in_edges\":[4,5,6,7]}}", - "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] b\",4,\"[X] ||| [1] B'\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.2],\"rule\":1},{\"tail\":[0],\"feats\":[0,-0.6],\"rule\":2}],\"node\":{\"in_edges\":[0,1]},\"edges\":[{\"tail\":[1],\"feats\":[0,-0.1],\"rule\":3},{\"tail\":[1],\"feats\":[0,-0.9],\"rule\":4}],\"node\":{\"in_edges\":[2,3]}}", -}; - -} - void AddNullEdge(Hypergraph* hg) { TRule x; x.arity_ = 0; @@ -71,31 +47,36 @@ void AddNullEdge(Hypergraph* hg) { hg->edges_.back().head_node_ = 0; } -void HGSetup::CreateTinyLatticeHG(Hypergraph* hg) { - Json(hg,HGjsons[TinyLatticeHG]); +void HGSetup::CreateTinyLatticeHG(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.tiny_lattice"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); AddNullEdge(hg); } -void HGSetup::CreateLatticeHG(Hypergraph* hg) { - Json(hg,HGjsons[LatticeHG]); +void HGSetup::CreateLatticeHG(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.lattice"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); AddNullEdge(hg); } -void HGSetup::CreateHG_tiny(Hypergraph* hg) { - Json(hg,HGjsons[HG_tiny]); +void HGSetup::CreateHG_tiny(const std::string& path, Hypergraph* hg) { + ReadFile rf(path + "/hg_test.tiny"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); } -void HGSetup::CreateHG_int(Hypergraph* hg) { - Json(hg,HGjsons[HG_int]); +void HGSetup::CreateHG_int(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.hg_int"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); } -void HGSetup::CreateHG(Hypergraph* hg) { - Json(hg,HGjsons[HG]); +void HGSetup::CreateHG(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.hg"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); } -void HGSetup::CreateHGBalanced(Hypergraph* hg) { - Json(hg,HGjsons[HGBalanced]); +void HGSetup::CreateHGBalanced(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.hg_balanced"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); } - #endif diff --git a/decoder/rule_lexer.h b/decoder/rule_lexer.h index f844e5b2..e15c056d 100644 --- a/decoder/rule_lexer.h +++ b/decoder/rule_lexer.h @@ -9,6 +9,7 @@ struct RuleLexer { typedef void (*RuleCallback)(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra); static void ReadRules(std::istream* in, RuleCallback func, const std::string& fname, void* extra); + static void ReadRule(const std::string&, RuleCallback func, bool mono_rule, void* extra); }; #endif diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll index cc73c079..d4a8d86b 100644 --- a/decoder/rule_lexer.ll +++ b/decoder/rule_lexer.ll @@ -14,6 +14,7 @@ #include "verbose.h" #include "tree_fragment.h" +bool lex_mono_rules = false; int lex_line = 0; std::istream* scfglex_stream = NULL; RuleLexer::RuleCallback rule_callback = NULL; @@ -120,7 +121,7 @@ void check_and_update_ctf_stack(const TRulePtr& rp) { %} REAL [\-+]?[0-9]+(\.[0-9]*)?([eE][-+]*[0-9]+)? -NT [^\t \[\],]+ +NT ([^\t \n\r\[\],]+|Goal) %x LHS_END SRC TRG FEATS FEATVAL ALIGNS TREE %% @@ -132,7 +133,7 @@ NT [^\t \[\],]+ \[{NT}\] { scfglex_tmp_token.assign(yytext + 1, yyleng - 2); scfglex_lhs = -TD::Convert(scfglex_tmp_token); - // std::cerr << scfglex_tmp_token << "\n"; + //std::cerr << "LHS: " << scfglex_tmp_token << "\n"; BEGIN(LHS_END); } @@ -199,9 +200,9 @@ NT [^\t \[\],]+ \|\|\| { memset(scfglex_nt_sanity, 0, scfglex_src_arity * sizeof(int)); - BEGIN(TRG); + if (lex_mono_rules) { BEGIN(FEATS); } else { BEGIN(TRG); } } -[^ \t]+ { +[^ \t\n\r]+ { scfglex_tmp_token.assign(yytext, yyleng); scfglex_src_rhs[scfglex_src_rhs_size] = TD::Convert(scfglex_tmp_token); ++scfglex_src_rhs_size; @@ -217,14 +218,28 @@ NT [^\t \[\],]+ \|\|\| { BEGIN(FEATS); } -[^ \t]+ { +[^ \t\n\r]+ { scfglex_tmp_token.assign(yytext, yyleng); scfglex_trg_rhs[scfglex_trg_rhs_size] = TD::Convert(scfglex_tmp_token); ++scfglex_trg_rhs_size; } [ \t]+ { ; } -\n { +\n { + if (lex_mono_rules) { + if (scfglex_trg_rhs_size != 0) { + std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": expected monolingual rule\n"; + abort(); + } + scfglex_trg_arity = scfglex_src_arity; + scfglex_trg_rhs_size = scfglex_src_rhs_size; + int ntc = 0; + for (int i = 0; i < scfglex_src_rhs_size; ++i) + if (scfglex_trg_rhs[i] <= 0) + scfglex_trg_rhs[i] = ntc--; + else + scfglex_trg_rhs[i] = scfglex_src_rhs[i]; + } if (scfglex_src_arity != scfglex_trg_arity) { std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": LHS and RHS arity mismatch!\n"; abort(); @@ -243,7 +258,7 @@ NT [^\t \[\],]+ TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top()); rule_callback(rp, ctf_level, coarse_rp, rule_callback_extra); ctf_rule_stack.push(rp); - // std::cerr << rp->AsString() << std::endl; + //std::cerr << "RULE: " << rp->AsString() << std::endl; num_rules++; lex_line++; if (!SILENT) { @@ -317,7 +332,7 @@ NT [^\t \[\],]+ #include "filelib.h" -void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const std::string& fname, void* extra) { +static void init_default_feature_names() { if (scfglex_phrase_fnames.empty()) { scfglex_phrase_fnames.resize(100); for (int i = 0; i < scfglex_phrase_fnames.size(); ++i) { @@ -326,6 +341,11 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const scfglex_phrase_fnames[i] = FD::Convert(os.str()); } } +} + +void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const std::string& fname, void* extra) { + init_default_feature_names(); + lex_mono_rules = false; lex_line = 1; scfglex_fname = fname; scfglex_stream = in; @@ -334,3 +354,14 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const yylex(); } +void RuleLexer::ReadRule(const std::string& srule, RuleCallback func, bool mono, void* extra) { + init_default_feature_names(); + lex_mono_rules = mono; + lex_line = 1; + rule_callback_extra = extra; + rule_callback = func; + yy_scan_string(srule.c_str()); + yylex(); + yylex_destroy(); +} + diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 159a1d60..88f62769 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -47,7 +47,7 @@ GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt, const TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [1]")); AddRule(stop_glue); RefineRule(stop_glue, ctf_level); - TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["+ default_nt + ",2] ||| [1] [2] ||| Glue=1")); + TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + "] ["+ default_nt + "] ||| [1] [2] ||| Glue=1")); AddRule(glue); RefineRule(glue, ctf_level); } diff --git a/decoder/test_data/hg_test.hg b/decoder/test_data/hg_test.hg new file mode 100644 index 00000000..ef98e9d4 --- /dev/null +++ b/decoder/test_data/hg_test.hg @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| a ||| a",2,"[X] ||| A [X] ||| A [1]",3,"[X] ||| c ||| c",4,"[X] ||| C [X] ||| C [1]",5,"[X] ||| [X] B [X] ||| [1] B [2]",6,"[X] ||| [X] b [X] ||| [1] b [2]",7,"[X] ||| X [X] ||| X [1]",8,"[X] ||| Z [X] ||| Z [1]"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7","LatticeCost"],"edges":[{"tail":[],"spans":[24568,32767,24568,32767],"feats":[],"rule":1}],"node":{"in_edges":[0],"cat":"X"},"edges":[{"tail":[0],"spans":[24568,32767,24568,32767],"feats":[0,-0.8,1,-0.1],"rule":2}],"node":{"in_edges":[1],"cat":"X"},"edges":[{"tail":[],"spans":[24568,32767,24568,32767],"feats":[1,-1],"rule":3}],"node":{"in_edges":[2],"cat":"X"},"edges":[{"tail":[2],"spans":[24568,32767,24568,32767],"feats":[0,-0.2,1,-0.1],"rule":4}],"node":{"in_edges":[3],"cat":"X"},"edges":[{"tail":[1,3],"spans":[24568,32767,24568,32767],"feats":[0,-1.2,1,-0.2],"rule":5},{"tail":[1,3],"spans":[24568,32767,24568,32767],"feats":[0,-0.5,1,-1.3],"rule":6}],"node":{"in_edges":[4,5],"cat":"X"},"edges":[{"tail":[4],"spans":[24568,32767,24568,32767],"feats":[0,-0.5,1,-0.8],"rule":7},{"tail":[4],"spans":[24568,32767,24568,32767],"feats":[0,-0.7,1,-0.9],"rule":8}],"node":{"in_edges":[6,7],"cat":"X"}} diff --git a/decoder/test_data/hg_test.hg_balanced b/decoder/test_data/hg_test.hg_balanced new file mode 100644 index 00000000..0f0f499f --- /dev/null +++ b/decoder/test_data/hg_test.hg_balanced @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| i ||| i",2,"[X] ||| a ||| a",3,"[X] ||| b ||| b",4,"[X] ||| [X] [X] ||| [1] [2]",5,"[X] ||| [X] [X] ||| [1] [2]",6,"[X] ||| c ||| c",7,"[X] ||| d ||| d",8,"[X] ||| [X] [X] ||| [1] [2]",9,"[X] ||| [X] [X] ||| [1] [2]",10,"[X] ||| [X] [X] ||| [1] [2]",11,"[X] ||| [X] [X] ||| [1] [2]",12,"[X] ||| [X] [X] ||| [1] [2]",13,"[X] ||| [X] [X] ||| [1] [2]"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7","LatticeCost"],"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":1}],"node":{"in_edges":[0],"cat":"X"},"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":2}],"node":{"in_edges":[1],"cat":"X"},"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":3}],"node":{"in_edges":[2],"cat":"X"},"edges":[{"tail":[1,2],"spans":[32760,32767,32760,32767],"feats":[],"rule":4},{"tail":[2,1],"spans":[32760,32767,32760,32767],"feats":[],"rule":5}],"node":{"in_edges":[3,4],"cat":"X"},"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":6}],"node":{"in_edges":[5],"cat":"X"},"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":7}],"node":{"in_edges":[6],"cat":"X"},"edges":[{"tail":[4,5],"spans":[32760,32767,32760,32767],"feats":[],"rule":8},{"tail":[5,4],"spans":[32760,32767,32760,32767],"feats":[],"rule":9}],"node":{"in_edges":[7,8],"cat":"X"},"edges":[{"tail":[3,6],"spans":[32760,32767,32760,32767],"feats":[],"rule":10},{"tail":[6,3],"spans":[32760,32767,32760,32767],"feats":[],"rule":11}],"node":{"in_edges":[9,10],"cat":"X"},"edges":[{"tail":[7,0],"spans":[32760,32767,32760,32767],"feats":[],"rule":12},{"tail":[0,7],"spans":[32760,32767,32760,32767],"feats":[],"rule":13}],"node":{"in_edges":[11,12],"cat":"X"}} diff --git a/decoder/test_data/hg_test.hg_int b/decoder/test_data/hg_test.hg_int new file mode 100644 index 00000000..9c4603bc --- /dev/null +++ b/decoder/test_data/hg_test.hg_int @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| a ||| a",2,"[X] ||| b ||| b",3,"[X] ||| a [X] ||| a [1]",4,"[X] ||| [X] b ||| [1] b"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7","LatticeCost"],"edges":[{"tail":[],"spans":[-8200,32767,-8200,32767],"feats":[0,0.1],"rule":1},{"tail":[],"spans":[-8200,32767,-8200,32767],"feats":[0,0.1],"rule":2}],"node":{"in_edges":[0,1],"cat":"X"},"edges":[{"tail":[0],"spans":[-8200,32767,-8200,32767],"feats":[0,0.3],"rule":3},{"tail":[0],"spans":[-8200,32767,-8200,32767],"feats":[0,0.2],"rule":4}],"node":{"in_edges":[2,3],"cat":"Goal"}} diff --git a/decoder/test_data/hg_test.lattice b/decoder/test_data/hg_test.lattice new file mode 100644 index 00000000..29e021c5 --- /dev/null +++ b/decoder/test_data/hg_test.lattice @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| [X] a ||| [1] a",2,"[X] ||| [X] A ||| [1] A",3,"[X] ||| [X] A A ||| [1] A A",4,"[X] ||| [X] b ||| [1] b",5,"[X] ||| [X] c ||| [1] c",6,"[X] ||| [X] B C ||| [1] B C",7,"[X] ||| [X] A B C ||| [1] A B C",8,"[X] ||| [X] CC ||| [1] CC"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7"],"edges":[],"node":{"in_edges":[]},"edges":[{"tail":[0],"feats":[2,-0.3],"rule":1},{"tail":[0],"feats":[2,-0.6],"rule":2},{"tail":[0],"feats":[2,-1.7],"rule":3}],"node":{"in_edges":[0,1,2]},"edges":[{"tail":[1],"feats":[2,-0.5],"rule":4}],"node":{"in_edges":[3]},"edges":[{"tail":[2],"feats":[2,-0.6],"rule":5},{"tail":[1],"feats":[2,-0.8],"rule":6},{"tail":[0],"feats":[2,-0.01],"rule":7},{"tail":[2],"feats":[2,-0.8],"rule":8}],"node":{"in_edges":[4,5,6,7]}}" diff --git a/decoder/test_data/hg_test.tiny b/decoder/test_data/hg_test.tiny new file mode 100644 index 00000000..101b96e9 --- /dev/null +++ b/decoder/test_data/hg_test.tiny @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| ||| ",2,"[X] ||| X [X] ||| X [1]",3,"[X] ||| Z [X] ||| Z [1]"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7","LatticeCost"],"edges":[{"tail":[],"spans":[25080,32767,25080,32767],"feats":[0,-2,1,-99],"rule":1}],"node":{"in_edges":[0],"cat":"X"},"edges":[{"tail":[0],"spans":[25080,32767,25080,32767],"feats":[0,-0.5,1,-0.8],"rule":2},{"tail":[0],"spans":[25080,32767,25080,32767],"feats":[0,-0.7,1,-0.9],"rule":3}],"node":{"in_edges":[1,2],"cat":"X"}} diff --git a/decoder/test_data/hg_test.tiny_lattice b/decoder/test_data/hg_test.tiny_lattice new file mode 100644 index 00000000..b9adf3cd --- /dev/null +++ b/decoder/test_data/hg_test.tiny_lattice @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| [X] a ||| [1] a",2,"[X] ||| [X] A ||| [1] A",3,"[X] ||| [X] b ||| [1] b",4,"[X] ||| [X] B' ||| [1] B'"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7"],"edges":[],"node":{"in_edges":[]},"edges":[{"tail":[0],"feats":[0,-0.2],"rule":1},{"tail":[0],"feats":[0,-0.6],"rule":2}],"node":{"in_edges":[0,1]},"edges":[{"tail":[1],"feats":[0,-0.1],"rule":3},{"tail":[1],"feats":[0,-0.9],"rule":4}],"node":{"in_edges":[2,3]}} diff --git a/decoder/test_data/small.json.gz b/decoder/test_data/small.json.gz index 892ba360..f6f37293 100644 Binary files a/decoder/test_data/small.json.gz and b/decoder/test_data/small.json.gz differ diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 4cd584fb..f288ab4e 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -174,6 +174,7 @@ struct Tree2StringTranslatorImpl { q.push(ParserState(input_tree.begin(), g.get())); unique.insert(q.back()); } + if (q.size() == 0) return false; unsigned tree_top = q.front().input_node_idx; while(!q.empty()) { ParserState& s = q.front(); diff --git a/decoder/trule.cc b/decoder/trule.cc index c22baae3..1bd5425f 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -17,73 +17,16 @@ bool TRule::IsGoal() const { return GetLHS() == kGOAL; } -static WordID ConvertTrgString(const string& w) { - const unsigned len = w.size(); - WordID id = 0; - // [X,0] or [0] - // for target rules, we ignore the category, just keep the index - if (len > 2 && w[0]=='[' && w[len-1]==']' && w[len-2] > '0' && w[len-2] <= '9' && - (len == 3 || (len > 4 && w[len-3] == ','))) { - id = w[len-2] - '0'; - id = 1 - id; - } else { - id = TD::Convert(w); - } - return id; -} - -static WordID ConvertSrcString(const string& w, bool mono = false) { - const unsigned len = w.size(); - // [X,0] - // for source rules, we keep the category and ignore the index (source rules are - // always numbered 1, 2, 3... - if (mono) { - if (len > 2 && w[0]=='[' && w[len-1]==']') { - if (len > 4 && w[len-3] == ',') { - cerr << "[ERROR] Monolingual rules mut not have non-terminal indices:\n " - << w << endl; - exit(1); - } - // TODO check that source indices go 1,2,3,etc. - return TD::Convert(w.substr(1, len-2)) * -1; - } else { - return TD::Convert(w); - } - } else { - if (len > 4 && w[0]=='[' && w[len-1]==']' && w[len-3] == ',' && w[len-2] > '0' && w[len-2] <= '9') { - return TD::Convert(w.substr(1, len-4)) * -1; - } else { - return TD::Convert(w); - } - } -} - -static WordID ConvertLHS(const string& w) { - if (w[0] == '[') { - const unsigned len = w.size(); - if (len < 3) { cerr << "Format error: " << w << endl; exit(1); } - return TD::Convert(w.substr(1, len-2)) * -1; - } else { - return TD::Convert(w) * -1; - } -} - TRule* TRule::CreateRuleSynchronous(const string& rule) { TRule* res = new TRule; - if (res->ReadFromString(rule, true, false)) return res; + if (res->ReadFromString(rule)) return res; cerr << "[ERROR] Failed to creating rule from: " << rule << endl; delete res; return NULL; } TRule* TRule::CreateRulePhrasetable(const string& rule) { - // TODO make this faster - // TODO add configuration for default NT type - if (rule[0] == '[') { - cerr << "Phrasetable rules shouldn't have a LHS / non-terminals:\n " << rule << endl; - return NULL; - } - TRule* res = new TRule("[X] ||| " + rule, true, false); + TRule* res = new TRule("[X] ||| " + rule); if (res->Arity() != 0) { cerr << "Phrasetable rules should have arity 0:\n " << rule << endl; delete res; @@ -93,138 +36,27 @@ TRule* TRule::CreateRulePhrasetable(const string& rule) { } TRule* TRule::CreateRuleMonolingual(const string& rule) { - return new TRule(rule, false, true); + return new TRule(rule, true); } namespace { -// callback for lexer +// callback for single rule lexer int n_assigned=0; -void assign_trule(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) { - (void) ctf_level; - (void) coarse_rule; - TRule *assignto=(TRule *)extra; - *assignto=*new_rule; - ++n_assigned; -} - -} - -bool TRule::ReadFromString(const string& line, bool strict, bool mono) { - if (!is_single_line_stripped(line)) - cerr<<"\nWARNING: building rule from multi-line string "<1) - cerr<<"\nWARNING: more than one rule parsed from multi-line string; kept last: "<(extra) = *new_rule; + ++n_assigned; } - if (format >= 2 || (mono && format == 1)) { - while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); } - while(is>>w && w!="|||") { f_.push_back(ConvertSrcString(w, mono)); } - if (!mono) { - while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); } - } - int fv = 0; - if (is) { - string ss; - getline(is, ss); - //cerr << "L: " << ss << endl; - unsigned start = 0; - unsigned len = ss.size(); - const size_t ppos = ss.find(" |||"); - if (ppos != string::npos) { len = ppos; } - while (start < len) { - while(start < len && (ss[start] == ' ' || ss[start] == ';')) - ++start; - if (start == len) break; - unsigned end = start + 1; - while(end < len && (ss[end] != '=' && ss[end] != ' ' && ss[end] != ';')) - ++end; - if (end == len || ss[end] == ' ' || ss[end] == ';') { - //cerr << "PROC: '" << ss.substr(start, end - start) << "'\n"; - // non-named features - if (end != len) { ss[end] = 0; } - string fname = "PhraseModel_X"; - if (fv > 9) { cerr << "Too many phrasetable scores - used named format\n"; abort(); } - fname[12]='0' + fv; - ++fv; - // if the feature set is frozen, this may return zero, indicating an - // undefined feature - const int fid = FD::Convert(fname); - if (fid) - scores_.set_value(fid, atof(&ss[start])); - //cerr << "F: " << fname << " VAL=" << scores_.value(FD::Convert(fname)) << endl; - } else { - const int fid = FD::Convert(ss.substr(start, end - start)); - start = end + 1; - end = start + 1; - while(end < len && (ss[end] != ' ' && ss[end] != ';')) - ++end; - if (end < len) { ss[end] = 0; } - assert(start < len); - if (fid) - scores_.set_value(fid, atof(&ss[start])); - //cerr << "F: " << FD::Convert(fid) << " VAL=" << scores_.value(fid) << endl; - } - start = end + 1; - } - } - } else if (format == 1) { - while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); } - while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); } - f_ = e_; - int x = ConvertLHS("[X]"); - for (unsigned i = 0; i < f_.size(); ++i) - if (f_[i] <= 0) { f_[i] = x; } - } else { - cerr << "F: " << format << endl; - cerr << "[ERROR] Don't know how to read:\n" << line << endl; - } - if (mono) { - e_ = f_; - int ci = 0; - for (unsigned i = 0; i < e_.size(); ++i) - if (e_[i] < 0) - e_[i] = ci--; - } - ComputeArity(); - return SanityCheck(); } -bool TRule::SanityCheck() const { - vector used(f_.size(), 0); - int ac = 0; - for (unsigned i = 0; i < e_.size(); ++i) { - int ind = e_[i]; - if (ind > 0) continue; - ind = -ind; - if ((++used[ind]) != 1) { - cerr << "[ERROR] e-side variable index " << (ind+1) << " used more than once!\n"; - return false; - } - ac++; - } - if (ac != Arity()) { - cerr << "[ERROR] e-side arity mismatches f-side\n"; - return false; - } - return true; +bool TRule::ReadFromString(const string& line, bool mono) { + n_assigned = 0; + //cerr << "LINE: " << line << " -- mono=" << mono << endl; + RuleLexer::ReadRule(line + '\n', assign_trule, mono, this); + if (n_assigned > 1) + cerr<<"\nWARNING: more than one rule parsed from multi-line string; kept last: "< 9 variables won't work - explicit TRule(const std::string& text, bool strict = false, bool mono = false) : prev_i(-1), prev_j(-1) { - ReadFromString(text, strict, mono); + explicit TRule(const std::string& text, bool mono = false) : prev_i(-1), prev_j(-1) { + ReadFromString(text, mono); } - // deprecated, use lexer // make a rule from a hiero-like rule table, e.g. // [X] ||| [X,1] DE [X,2] ||| [X,2] of the [X,1] - // if misformatted, returns NULL static TRule* CreateRuleSynchronous(const std::string& rule); - // deprecated, use lexer // make a rule from a phrasetable entry (i.e., one that has no LHS type), e.g: // el gato ||| the cat ||| Feature_2=0.34 static TRule* CreateRulePhrasetable(const std::string& rule); - // deprecated, use lexer // make a rule from a non-synchrnous CFG representation, e.g.: // [LHS] ||| term1 [NT] term2 [OTHER_NT] [YET_ANOTHER_NT] static TRule* CreateRuleMonolingual(const std::string& rule); @@ -80,11 +75,10 @@ class TRule { std::vector* result) const { unsigned vc = 0; result->clear(); - for (std::vector::const_iterator i = e_.begin(); i != e_.end(); ++i) { - const WordID& c = *i; + for (const auto& c : e_) { if (c < 1) { ++vc; - const std::vector& var_value = *var_values[-c]; + const auto& var_value = *var_values[-c]; std::copy(var_value.begin(), var_value.end(), std::back_inserter(*result)); @@ -99,10 +93,9 @@ class TRule { std::vector* result) const { unsigned vc = 0; result->clear(); - for (std::vector::const_iterator i = f_.begin(); i != f_.end(); ++i) { - const WordID& c = *i; + for (const auto& c : f_) { if (c < 1) { - const std::vector& var_value = *var_values[vc++]; + const auto& var_value = *var_values[vc++]; std::copy(var_value.begin(), var_value.end(), std::back_inserter(*result)); @@ -113,7 +106,7 @@ class TRule { assert(vc == var_values.size()); } - bool ReadFromString(const std::string& line, bool strict = false, bool monolingual = false); + bool ReadFromString(const std::string& line, bool monolingual = false); bool Initialized() const { return e_.size(); } @@ -166,7 +159,6 @@ class TRule { private: TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {} - bool SanityCheck() const; }; inline size_t hash_value(const TRule& r) { diff --git a/training/dpmert/lo_test.cc b/training/dpmert/lo_test.cc index d89bcd99..b8776169 100644 --- a/training/dpmert/lo_test.cc +++ b/training/dpmert/lo_test.cc @@ -56,7 +56,7 @@ BOOST_AUTO_TEST_CASE(TestConvexHull) { } BOOST_AUTO_TEST_CASE(TestConvexHullInside) { - const string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}"; + const string json = "{\"rules\":[1,\"[X] ||| a ||| a\",2,\"[X] ||| A [X] ||| A [1]\",3,\"[X] ||| c ||| c\",4,\"[X] ||| C [X] ||| C [1]\",5,\"[X] ||| [X] B [X] ||| [1] B [2]\",6,\"[X] ||| [X] b [X] ||| [1] b [2]\",7,\"[X] ||| X [X] ||| X [1]\",8,\"[X] ||| Z [X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}"; Hypergraph hg; istringstream instr(json); HypergraphIO::ReadFromJSON(&instr, &hg); -- cgit v1.2.3 From 5a8b0aa3245f14a3c049f6f8d34c3a6f37cf070c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 7 Apr 2014 22:56:34 -0400 Subject: track node state for smarter union --- Makefile.am | 2 +- decoder/Makefile.am | 1 + decoder/apply_models.cc | 306 +++++++++++++++++----------------- decoder/bottom_up_parser.cc | 10 ++ decoder/decoder.cc | 13 ++ decoder/fst_translator.cc | 6 + decoder/hg.cc | 47 ++---- decoder/hg.h | 10 +- decoder/lexalign.cc | 5 + decoder/lextrans.cc | 5 + decoder/node_state_hash.h | 36 ++++ decoder/nt_span.h | 2 +- decoder/tagger.cc | 5 + decoder/tree2string_translator.cc | 14 +- mteval/Makefile.am | 8 +- tests/tools/filter-stderr.pl | 1 + utils/Makefile.am | 3 +- utils/hash.h | 21 +-- utils/murmur_hash.h | 186 --------------------- utils/murmur_hash3.cc | 340 ++++++++++++++++++++++++++++++++++++++ utils/murmur_hash3.h | 67 ++++++++ 21 files changed, 699 insertions(+), 389 deletions(-) create mode 100644 decoder/node_state_hash.h delete mode 100644 utils/murmur_hash.h create mode 100644 utils/murmur_hash3.cc create mode 100644 utils/murmur_hash3.h (limited to 'decoder') diff --git a/Makefile.am b/Makefile.am index 598293d1..88327477 100644 --- a/Makefile.am +++ b/Makefile.am @@ -3,13 +3,13 @@ # cyclic dependencies between these directories! SUBDIRS = \ utils \ - mteval \ klm/util/double-conversion \ klm/util \ klm/util/stream \ klm/lm \ klm/lm/builder \ klm/search \ + mteval \ decoder \ training \ word-aligner \ diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 5c91fe65..c85f17ed 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -144,6 +144,7 @@ libcdec_a_SOURCES = \ lattice.cc \ lexalign.cc \ lextrans.cc \ + node_state_hash.h \ tree_fragment.cc \ tree_fragment.h \ maxtrans_blunsom.cc \ diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9a8f60be..9f8bbead 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -19,6 +19,7 @@ namespace std { using std::tr1::unordered_map; using std::tr1::unordered_set; } #include +#include "node_state_hash.h" #include "verbose.h" #include "hg.h" #include "ff.h" @@ -229,7 +230,7 @@ public: D.clear(); } - void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) { + void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) { Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_); new_edge->edge_prob_ = item->out_edge_.edge_prob_; Candidate*& o_item = (*s2n)[item->state_]; @@ -238,6 +239,7 @@ public: int& node_id = o_item->node_index_; if (node_id < 0) { Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_); + new_node->node_hash = cdec::HashNode(head_node_hash, item->state_); // ID is combination of existing state + residual state node_states_.push_back(item->state_); node_id = new_node->id_; } @@ -287,7 +289,7 @@ public: cand.pop_back(); // cerr << "POPPED: " << *item << endl; PushSucc(*item, is_goal, &cand, &unique_cands); - IncorporateIntoPlusLMForest(item, &state2node, &freelist); + IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); ++pops; } D_v.resize(state2node.size()); @@ -306,112 +308,112 @@ public: } void KBestFast(const int vert_index, const bool is_goal) { - // cerr << "KBest(" << vert_index << ")\n"; - CandidateList& D_v = D[vert_index]; - assert(D_v.empty()); - const Hypergraph::Node& v = in.nodes_[vert_index]; - // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; - const vector& in_edges = v.in_edges_; - CandidateHeap cand; - CandidateList freelist; - cand.reserve(in_edges.size()); - //init with j<0,0> for all rules-edges that lead to node-(NT-span) - for (int i = 0; i < in_edges.size(); ++i) { - const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; - const JVector j(edge.tail_nodes_.size(), 0); - cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); - } - // cerr << " making heap of " << cand.size() << " candidates\n"; - make_heap(cand.begin(), cand.end(), HeapCandCompare()); - State2Node state2node; // "buf" in Figure 2 - int pops = 0; - while(!cand.empty() && pops < pop_limit_) { - pop_heap(cand.begin(), cand.end(), HeapCandCompare()); - Candidate* item = cand.back(); - cand.pop_back(); - // cerr << "POPPED: " << *item << endl; - - PushSuccFast(*item, is_goal, &cand); - IncorporateIntoPlusLMForest(item, &state2node, &freelist); - ++pops; - } - D_v.resize(state2node.size()); - int c = 0; - for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ - D_v[c++] = i->second; - // cerr << "MERGED: " << *i->second << endl; - } - //cerr <<"Node id: "<< vert_index<< endl; - //#ifdef MEASURE_CA - // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + + PushSuccFast(*item, is_goal, &cand); + IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (auto& i : state2node) { + D_v[c++] = i.second; + // cerr << "MERGED: " << *i.second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<& in_edges = v.in_edges_; - CandidateHeap cand; - CandidateList freelist; - cand.reserve(in_edges.size()); - UniqueCandidateSet unique_accepted; - //init with j<0,0> for all rules-edges that lead to node-(NT-span) - for (int i = 0; i < in_edges.size(); ++i) { - const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; - const JVector j(edge.tail_nodes_.size(), 0); - cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); - } - // cerr << " making heap of " << cand.size() << " candidates\n"; - make_heap(cand.begin(), cand.end(), HeapCandCompare()); - State2Node state2node; // "buf" in Figure 2 - int pops = 0; - while(!cand.empty() && pops < pop_limit_) { - pop_heap(cand.begin(), cand.end(), HeapCandCompare()); - Candidate* item = cand.back(); - cand.pop_back(); + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_accepted; + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); bool is_new = unique_accepted.insert(item).second; - assert(is_new); // these should all be unique! - // cerr << "POPPED: " << *item << endl; - - PushSuccFast2(*item, is_goal, &cand, &unique_accepted); - IncorporateIntoPlusLMForest(item, &state2node, &freelist); - ++pops; - } - D_v.resize(state2node.size()); - int c = 0; - for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ - D_v[c++] = i->second; - // cerr << "MERGED: " << *i->second << endl; - } - //cerr <<"Node id: "<< vert_index<< endl; - //#ifdef MEASURE_CA - // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<tail_nodes_[i]].size()) { - Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); - cand.push_back(new_cand); - push_heap(cand.begin(), cand.end(), HeapCandCompare()); - } - if(item.j_[i]!=0){ - return; - } - } + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + if(item.j_[i]!=0){ + return; + } + } } //PushSucc only if all ancest Cand are added void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){ - CandidateHeap& cand = *pcand; - for (int i = 0; i < item.j_.size(); ++i) { - JVector j = item.j_; - ++j[i]; - if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { - Candidate query_unique(*item.in_edge_, j); - if (HasAllAncestors(&query_unique,ps)) { - Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); - cand.push_back(new_cand); - push_heap(cand.begin(), cand.end(), HeapCandCompare()); - } - } - } + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (HasAllAncestors(&query_unique,ps)) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + } + } } bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){ - for (int i = 0; i < item->j_.size(); ++i) { - JVector j = item->j_; - --j[i]; - if (j[i] >=0) { - Candidate query_unique(*item->in_edge_, j); - if (cs->count(&query_unique) == 0) { - return false; - } - } - } - return true; + for (int i = 0; i < item->j_.size(); ++i) { + JVector j = item->j_; + --j[i]; + if (j[i] >=0) { + Candidate query_unique(*item->in_edge_, j); + if (cs->count(&query_unique) == 0) { + return false; + } + } + } + return true; } const ModelSet& models; @@ -491,7 +493,7 @@ public: FFStates node_states_; // for each node in the out-HG what is // its q function value? const int pop_limit_; - const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) + const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) }; struct NoPruningRescorer { @@ -507,7 +509,7 @@ struct NoPruningRescorer { typedef unordered_map > State2NodeIndex; - void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, State2NodeIndex* state2node) { + void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, size_t head_node_hash, State2NodeIndex* state2node) { const int arity = in_edge.Arity(); Hypergraph::TailNodeVector ends(arity); for (int i = 0; i < arity; ++i) @@ -531,7 +533,9 @@ struct NoPruningRescorer { } int& head_plus1 = (*state2node)[head_state]; if (!head_plus1) { - head_plus1 = out.AddNode(in_edge.rule_->GetLHS())->id_ + 1; + HG::Node* new_node = out.AddNode(in_edge.rule_->GetLHS()); + new_node->node_hash = cdec::HashNode(head_node_hash, head_state); // ID is combination of existing state + residual state + head_plus1 = new_node->id_ + 1; node_states_.push_back(head_state); nodemap[in_edge.head_node_].push_back(head_plus1 - 1); } @@ -553,7 +557,7 @@ struct NoPruningRescorer { const Hypergraph::Node& node = in.nodes_[node_num]; for (int i = 0; i < node.in_edges_.size(); ++i) { const Hypergraph::Edge& edge = in.edges_[node.in_edges_[i]]; - ExpandEdge(edge, is_goal, &state2node); + ExpandEdge(edge, is_goal, node.node_hash, &state2node); } } @@ -605,16 +609,16 @@ void ApplyModelSet(const Hypergraph& in, cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; } if (config.algorithm == IntersectionConfiguration::CUBE) { - CubePruningRescorer ma(models, smeta, in, pl, out); - ma.Apply(); + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); } else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){ - CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); - ma.Apply(); + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); + ma.Apply(); } else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){ - CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); - ma.Apply(); + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); + ma.Apply(); } } else { diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index ff4c7a90..b30f1ec6 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -7,6 +7,8 @@ #include #include +#include "node_state_hash.h" +#include "nt_span.h" #include "hg.h" #include "array2d.h" #include "tdict.h" @@ -356,5 +358,13 @@ bool ExhaustiveBottomUpParser::Parse(const Lattice& input, kEPS = TD::Convert("*EPS*"); PassiveChart chart(goal_sym_, grammars_, input, forest); const bool result = chart.Parse(); + + if (result) { + for (auto& node : forest->nodes_) { + Span prev; + const Span s = forest->NodeSpan(node.id_, &prev); + node.node_hash = cdec::HashNode(node.cat_, s.l, s.r, prev.l, prev.r); + } + } return result; } diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 43e2640d..354ea2d9 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -750,6 +750,11 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { return false; } + // this is mainly used for debugging, eventually this will be an assertion + if (!forest.AreNodesUniquelyIdentified()) { + if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; + } + const bool show_tree_structure=conf.count("show_tree_structure"); if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,oracle.show_derivation); if (conf.count("show_expected_length")) { @@ -813,6 +818,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { forest.swap(rescored_forest); forest.Reweight(cur_weights); if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation, conf.count("extract_rules"), extract_file); + // this is mainly used for debugging, eventually this will be an assertion + if (!forest.AreNodesUniquelyIdentified()) { + if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; + } } if (conf.count("show_partition")) { @@ -984,6 +993,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { forest.edges_[i].rule_ = forest.edges_[i].rule_->parent_rule_; } forest.Reweight(last_weights); + // this is mainly used for debugging, eventually this will be an assertion + if (!forest.AreNodesUniquelyIdentified()) { + if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; + } if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,oracle.show_derivation); if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; if (conf.count("show_partition")) { diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc index 074de4c9..4253b652 100644 --- a/decoder/fst_translator.cc +++ b/decoder/fst_translator.cc @@ -67,6 +67,12 @@ struct FSTTranslatorImpl { Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); forest->ConnectEdgeToHeadNode(hg_edge, goal); forest->Reweight(weights); + + // since we don't do any pruning, the node_hash will be the same for + // every run of the composer + int nc = 0; + for (auto& node : forest->nodes_) + node.node_hash = ++nc; } if (add_pass_through_rules) fst->ClearPassThroughTranslations(); diff --git a/decoder/hg.cc b/decoder/hg.cc index 7240a8ab..405169c6 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -1,14 +1,17 @@ -//TODO: lazily generate feature vectors for hyperarcs (because some of them will be pruned). this means 1) storing ref to rule for those features 2) providing ff interface for regenerating its feature vector from hyperedge+states and probably 3) still caching feat. vect on hyperedge once it's been generated. ff would normally just contribute its weighted score and result state, not component features. however, the hypergraph drops the state used by ffs after rescoring is done, so recomputation would have to start at the leaves and work bottom up. question: which takes more space, feature id+value, or state? - #include "hg.h" #include #include #include -#include #include #include #include +#ifndef HAVE_OLD_CPP +# include +#else +# include +namespace std { using std::tr1::unordered_set; } +#endif #include "viterbi.h" #include "inside_outside.h" @@ -17,28 +20,21 @@ using namespace std; -#if 0 -Hypergraph::Edge const* Hypergraph::ViterbiGoalEdge() const -{ - Edge const* r=0; - for (unsigned i=0,e=edges_.size();iIsGoal() && (!r || e.edge_prob_ > r->edge_prob_)) - r=&e; - } - return r; +bool Hypergraph::AreNodesUniquelyIdentified() const { + unordered_set s(nodes_.size() * 3 + 7); + for (const auto& n : nodes_) + if (!s.insert(n.node_hash).second) + return false; + return true; } -#endif -Hypergraph::Edge const* Hypergraph::ViterbiSortInEdges() -{ +Hypergraph::Edge const* Hypergraph::ViterbiSortInEdges() { NodeProbs nv; ComputeNodeViterbi(&nv); return SortInEdgesByNodeViterbi(nv); } -Hypergraph::Edge const* Hypergraph::SortInEdgesByNodeViterbi(NodeProbs const& nv) -{ +Hypergraph::Edge const* Hypergraph::SortInEdgesByNodeViterbi(NodeProbs const& nv) { EdgeProbs ev; ComputeEdgeViterbi(nv,&ev); return ViterbiSortInEdges(ev); @@ -375,9 +371,7 @@ bool Hypergraph::PruneInsideOutside(double alpha,double density,const EdgeMask* void Hypergraph::PrintGraphviz() const { int ei = 0; cerr << "digraph G {\n rankdir=LR;\n nodesep=.05;\n"; - for (vector::const_iterator i = edges_.begin(); - i != edges_.end(); ++i) { - const Edge& edge=*i; + for (const auto& edge : edges_) { ++ei; static const string none = ""; string rule = (edge.rule_ ? edge.rule_->AsString(false) : none); @@ -399,14 +393,9 @@ void Hypergraph::PrintGraphviz() const { } cerr << " A_" << ei << " -> " << edge.head_node_ << ";\n"; } - for (vector::const_iterator ni = nodes_.begin(); - ni != nodes_.end(); ++ni) { - cerr << " " << ni->id_ << "[label=\"" << (ni->cat_ < 0 ? TD::Convert(ni->cat_ * -1) : "") - //cerr << " " << ni->id_ << "[label=\"" << ni->cat_ - << " n=" << ni->id_ -// << ",x=" << &*ni -// << ",in=" << ni->in_edges_.size() -// << ",out=" << ni->out_edges_.size() + for (const auto& node : nodes_) { + cerr << " " << node.id_ << "[label=\"" << (node.cat_ < 0 ? TD::Convert(node.cat_ * -1) : "") + << " n=" << node.id_ << "\"];\n"; } cerr << "}\n"; diff --git a/decoder/hg.h b/decoder/hg.h index 343b99cf..43fb275b 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -142,13 +142,15 @@ namespace HG { // TODO get rid of cat_? // TODO keep cat_ and add span and/or state? :) struct Node { - Node() : id_(), cat_() {} + Node() : node_hash(), id_(), cat_() {} + size_t node_hash; // hash of all the information that makes this node unique int id_; // equal to this object's position in the nodes_ vector WordID cat_; // non-terminal category if <0, 0 if not set WordID NT() const { return -cat_; } EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_ EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_ void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting + node_hash = o.node_hash; cat_=o.cat_; } void copy_reindex(Node const& o,indices_after const& n2,indices_after const& e2) { @@ -192,13 +194,14 @@ public: SetNodeOrigin(nodeid,r); return r; } - Span NodeSpan(int nodeid) const { + Span NodeSpan(int nodeid, Span* prev = nullptr) const { Span s; Node const &n=nodes_[nodeid]; if (!n.in_edges_.empty()) { Edge const& e=edges_[n.in_edges_.front()]; s.l=e.i_; s.r=e.j_; + if (prev) { prev->l = e.prev_i_; prev->r = e.prev_j_; } } return s; } @@ -262,6 +265,9 @@ public: for (int i = 0; i < size; ++i) nodes_[i].id_ = i; } + // if all node states are unique, return true + bool AreNodesUniquelyIdentified() const; + // reserves space in the nodes vector to prevent memory locations // from changing void ReserveNodes(size_t n, size_t e = 0) { diff --git a/decoder/lexalign.cc b/decoder/lexalign.cc index 6adb1892..11f20de7 100644 --- a/decoder/lexalign.cc +++ b/decoder/lexalign.cc @@ -124,6 +124,11 @@ bool LexicalAlign::TranslateImpl(const string& input, pimpl_->BuildTrellis(lattice, *smeta, forest); forest->is_linear_chain_ = true; forest->Reweight(weights); + // since we don't do any pruning, the node_hash will be the same for + // every run of the composer + int nc = 0; + for (auto& node : forest->nodes_) + node.node_hash = ++nc; return true; } diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index 8c3269bf..74a18c3f 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -280,6 +280,11 @@ bool LexicalTrans::TranslateImpl(const string& input, smeta->SetSourceLength(lattice.size()); if (!pimpl_->BuildTrellis(lattice, *smeta, forest)) return false; forest->Reweight(weights); + // since we don't do any pruning, the node_hash will be the same for + // every run of the composer + int nc = 0; + for (auto& node : forest->nodes_) + node.node_hash = ++nc; return true; } diff --git a/decoder/node_state_hash.h b/decoder/node_state_hash.h new file mode 100644 index 00000000..cdc05877 --- /dev/null +++ b/decoder/node_state_hash.h @@ -0,0 +1,36 @@ +#ifndef _NODE_STATE_HASH_ +#define _NODE_STATE_HASH_ + +#include +#include +#include "murmur_hash3.h" +#include "ffset.h" + +namespace cdec { + + struct FirstPassNode { + FirstPassNode(int cat, int i, int j, int pi, int pj) : lhs(cat), s(i), t(j), u(pi), v(pj) {} + int32_t lhs; + short s; + short t; + short u; + short v; + }; + + inline uint64_t HashNode(int cat, int i, int j, int pi, int pj) { + FirstPassNode fpn(cat, i, j, pi, pj); + return MurmurHash3_64(&fpn, sizeof(FirstPassNode), 2654435769U); + } + + inline uint64_t HashNode(uint64_t old_hash, const FFState& state) { + uint8_t buf[1024]; + std::memcpy(buf, &old_hash, sizeof(uint64_t)); + assert(state.size() < (1024u - sizeof(uint64_t))); + std::memcpy(&buf[sizeof(uint64_t)], state.begin(), state.size()); + return MurmurHash3_64(buf, sizeof(uint64_t) + state.size(), 2654435769U); + } + +} + +#endif + diff --git a/decoder/nt_span.h b/decoder/nt_span.h index a918f301..6ff9391f 100644 --- a/decoder/nt_span.h +++ b/decoder/nt_span.h @@ -7,7 +7,7 @@ struct Span { int l,r; - Span() : l(-1) { } + Span() : l(-1), r(-1) { } bool is_null() const { return l<0; } void print(std::ostream &o,char const* for_null="") const { if (is_null()) diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 63e855c8..30fb055f 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -108,6 +108,11 @@ bool Tagger::TranslateImpl(const string& input, pimpl_->BuildTrellis(sequence, forest); forest->Reweight(weights); forest->is_linear_chain_ = true; + // since we don't do any pruning, the node_hash will be the same for + // every run of the composer + int nc = 0; + for (auto& node : forest->nodes_) + node.node_hash = ++nc; return true; } diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index f288ab4e..8d12d01d 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -184,13 +184,19 @@ struct Tree2StringTranslatorImpl { // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; if (s.node->rules.size()) { int& node_id = tree2hg[s.input_node_idx]; - if (node_id < 0) - node_id = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK))->id_; + if (node_id < 0) { + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK)); + new_node->node_hash = s.input_node_idx + 1; + node_id = new_node->id_; + } TailNodeVector tail; for (auto n : s.future_work) { int& nix = tree2hg[n]; - if (nix < 0) - nix = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK))->id_; + if (nix < 0) { + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK)); + new_node->node_hash = n + 1; + nix = new_node->id_; + } tail.push_back(nix); } for (auto& r : s.node->rules) { diff --git a/mteval/Makefile.am b/mteval/Makefile.am index 681e798e..08591c9a 100644 --- a/mteval/Makefile.am +++ b/mteval/Makefile.am @@ -1,6 +1,7 @@ bin_PROGRAMS = \ fast_score \ - mbr_kbest + mbr_kbest\ + marginalize noinst_PROGRAMS = \ scorer_test @@ -46,4 +47,7 @@ mbr_kbest_LDADD = libmteval.a ../utils/libutils.a scorer_test_SOURCES = scorer_test.cc scorer_test_LDADD = libmteval.a ../utils/libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -AM_CPPFLAGS = -DTEST_DATA=\"$(top_srcdir)/mteval/test_data\" -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare -I$(top_srcdir) -I$(top_srcdir)/utils +marginalize_SOURCES = marginalize.cc +marginalize_LDADD = libmteval.a ../klm/search/libksearch.a ../klm/lm/libklm.a ../klm/util/libklm_util.a ../klm/util/double-conversion/libklm_util_double.a ../utils/libutils.a + +AM_CPPFLAGS = -DTEST_DATA=\"$(top_srcdir)/mteval/test_data\" -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare -I$(top_srcdir) -I$(top_srcdir)/utils -I$(top_srcdir)/klm diff --git a/tests/tools/filter-stderr.pl b/tests/tools/filter-stderr.pl index 4a762324..54fe9210 100755 --- a/tests/tools/filter-stderr.pl +++ b/tests/tools/filter-stderr.pl @@ -13,6 +13,7 @@ if (/Init.*\s+Viterbi:\s+($REAL)/) { # -LM Viterbi: australia is have diplomatic relations with north korea one of the few countries . print "-lm_trans $1\n"; } +if (/NODES NOT UNIQUELY IDENTIFIED/) { print "NODES_NOT_UNIQUE 1\n"; } #Constr. forest (nodes/edges): 111/305 #Constr. forest (paths): 9899 if (/Constr\. forest\s+\(nodes\/edges\): (\d+)\/(\d+)/) { print "constr_nodes $1\nconstr_edges $2\n"; } diff --git a/utils/Makefile.am b/utils/Makefile.am index c0ce3509..341fd80b 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -39,7 +39,8 @@ libutils_a_SOURCES = \ kernel_string_subseq.h \ logval.h \ m.h \ - murmur_hash.h \ + murmur_hash3.h \ + murmur_hash3.cc \ named_enum.h \ null_deleter.h \ null_traits.h \ diff --git a/utils/hash.h b/utils/hash.h index e1426ffb..24d2b6ad 100644 --- a/utils/hash.h +++ b/utils/hash.h @@ -3,7 +3,7 @@ #include -#include "murmur_hash.h" +#include "murmur_hash3.h" #ifdef HAVE_CONFIG_H #include "config.h" @@ -44,23 +44,21 @@ const unsigned GOLDEN_MEAN_FRACTION=2654435769U; // assumes C is POD template -struct murmur_hash -{ - typedef MurmurInt result_type; +struct murmur_hash { + typedef size_t result_type; typedef C /*const&*/ argument_type; result_type operator()(argument_type const& c) const { - return MurmurHash((void*)&c,sizeof(c)); + return cdec::MurmurHash3_64((void*)&c, sizeof(c), GOLDEN_MEAN_FRACTION); } }; // murmur_hash_array isn't std guaranteed safe (you need to use string::data()) template <> -struct murmur_hash -{ - typedef MurmurInt result_type; +struct murmur_hash { + typedef size_t result_type; typedef std::string /*const&*/ argument_type; result_type operator()(argument_type const& c) const { - return MurmurHash(c.data(),c.size()); + return cdec::MurmurHash3_64(c.data(), c.size(), GOLDEN_MEAN_FRACTION); } }; @@ -68,10 +66,10 @@ struct murmur_hash template struct murmur_hash_array { - typedef MurmurInt result_type; + typedef size_t result_type; typedef C /*const&*/ argument_type; result_type operator()(argument_type const& c) const { - return MurmurHash(&*c.begin(),c.size()*sizeof(*c.begin())); + return cdec::MurmurHash3_64(&*c.begin(), c.size()*sizeof(*c.begin()), GOLDEN_MEAN_FRACTION); } }; @@ -95,7 +93,6 @@ typename H::mapped_type & get_or_construct(H &ht,K const& k,C0 const& c0) { } } - // get_or_call (0 arg) template typename H::mapped_type & get_or_call(H &ht,K const& k,F const& f) { diff --git a/utils/murmur_hash.h b/utils/murmur_hash.h deleted file mode 100644 index 6063d524..00000000 --- a/utils/murmur_hash.h +++ /dev/null @@ -1,186 +0,0 @@ -#ifndef _MURMUR_HASH_H_ -#define _MURMUR_HASH_H_ - -//NOTE: quite fast, nice collision properties, but endian dependent hash values - -#include "have_64_bits.h" -typedef uintptr_t MurmurInt; - -// MurmurHash2, by Austin Appleby - -static const uint32_t DEFAULT_SEED=2654435769U; - -#if HAVE_64_BITS -//MurmurInt MurmurHash(void const *key, int len, uint32_t seed=DEFAULT_SEED); - -inline uint64_t MurmurHash64( const void * key, int len, unsigned int seed=DEFAULT_SEED ) -{ - const uint64_t m = 0xc6a4a7935bd1e995ULL; - const int r = 47; - - uint64_t h = seed ^ (len * m); - - const uint64_t * data = (const uint64_t *)key; - const uint64_t * end = data + (len/8); - - while(data != end) - { - uint64_t k = *data++; - - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - - const unsigned char * data2 = (const unsigned char*)data; - - switch(len & 7) - { - case 7: h ^= uint64_t(data2[6]) << 48; - case 6: h ^= uint64_t(data2[5]) << 40; - case 5: h ^= uint64_t(data2[4]) << 32; - case 4: h ^= uint64_t(data2[3]) << 24; - case 3: h ^= uint64_t(data2[2]) << 16; - case 2: h ^= uint64_t(data2[1]) << 8; - case 1: h ^= uint64_t(data2[0]); - h *= m; - }; - - h ^= h >> r; - h *= m; - h ^= h >> r; - - return h; -} - -inline uint32_t MurmurHash32(void const *key, int len, uint32_t seed=DEFAULT_SEED) -{ - return (uint32_t) MurmurHash64(key,len,seed); -} - -inline MurmurInt MurmurHash(void const *key, int len, uint32_t seed=DEFAULT_SEED) -{ - return MurmurHash64(key,len,seed); -} - -#else -// 32-bit - -// Note - This code makes a few assumptions about how your machine behaves - -// 1. We can read a 4-byte value from any address without crashing -// 2. sizeof(int) == 4 -inline uint32_t MurmurHash32 ( const void * key, int len, uint32_t seed=DEFAULT_SEED) -{ - // 'm' and 'r' are mixing constants generated offline. - // They're not really 'magic', they just happen to work well. - - const uint32_t m = 0x5bd1e995; - const int r = 24; - - // Initialize the hash to a 'random' value - - uint32_t h = seed ^ len; - - // Mix 4 bytes at a time into the hash - - const unsigned char * data = (const unsigned char *)key; - - while(len >= 4) - { - uint32_t k = *(uint32_t *)data; - - k *= m; - k ^= k >> r; - k *= m; - - h *= m; - h ^= k; - - data += 4; - len -= 4; - } - - // Handle the last few bytes of the input array - - switch(len) - { - case 3: h ^= data[2] << 16; - case 2: h ^= data[1] << 8; - case 1: h ^= data[0]; - h *= m; - }; - - // Do a few final mixes of the hash to ensure the last few - // bytes are well-incorporated. - - h ^= h >> 13; - h *= m; - h ^= h >> 15; - - return h; -} - -inline MurmurInt MurmurHash ( const void * key, int len, uint32_t seed=DEFAULT_SEED) { - return MurmurHash32(key,len,seed); -} - -// 64-bit hash for 32-bit platforms - -inline uint64_t MurmurHash64 ( const void * key, int len, uint32_t seed=DEFAULT_SEED) -{ - const uint32_t m = 0x5bd1e995; - const int r = 24; - - uint32_t h1 = seed ^ len; - uint32_t h2 = 0; - - const uint32_t * data = (const uint32_t *)key; - - while(len >= 8) - { - uint32_t k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - - uint32_t k2 = *data++; - k2 *= m; k2 ^= k2 >> r; k2 *= m; - h2 *= m; h2 ^= k2; - len -= 4; - } - - if(len >= 4) - { - uint32_t k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - } - - switch(len) - { - case 3: h2 ^= ((unsigned char*)data)[2] << 16; - case 2: h2 ^= ((unsigned char*)data)[1] << 8; - case 1: h2 ^= ((unsigned char*)data)[0]; - h2 *= m; - }; - - h1 ^= h2 >> 18; h1 *= m; - h2 ^= h1 >> 22; h2 *= m; - h1 ^= h2 >> 17; h1 *= m; - h2 ^= h1 >> 19; h2 *= m; - - uint64_t h = h1; - - h = (h << 32) | h2; - - return h; -} - -#endif -//32bit - -#endif diff --git a/utils/murmur_hash3.cc b/utils/murmur_hash3.cc new file mode 100644 index 00000000..68a71d02 --- /dev/null +++ b/utils/murmur_hash3.cc @@ -0,0 +1,340 @@ +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. + +// Note - The x86 and x64 versions do _not_ produce the same results, as the +// algorithms are optimized for their respective platforms. You can still +// compile and run any of them on any platform, but your performance with the +// non-native version will be less than optimal. + +#include "murmur_hash3.h" + +//----------------------------------------------------------------------------- +// Platform-specific functions and macros + +// Microsoft Visual Studio + +#if defined(_MSC_VER) + +#define FORCE_INLINE __forceinline + +#include + +#define ROTL32(x,y) _rotl(x,y) +#define ROTL64(x,y) _rotl64(x,y) + +#define BIG_CONSTANT(x) (x) + +// Other compilers + +#else // defined(_MSC_VER) + +#define FORCE_INLINE inline __attribute__((always_inline)) + +namespace cdec { + +inline uint32_t rotl32 ( uint32_t x, int8_t r ) +{ + return (x << r) | (x >> (32 - r)); +} + +inline uint64_t rotl64 ( uint64_t x, int8_t r ) +{ + return (x << r) | (x >> (64 - r)); +} + +#define ROTL32(x,y) rotl32(x,y) +#define ROTL64(x,y) rotl64(x,y) + +#define BIG_CONSTANT(x) (x##LLU) + +#endif // !defined(_MSC_VER) + +//----------------------------------------------------------------------------- +// Block read - if your platform needs to do endian-swapping or can only +// handle aligned reads, do the conversion here + +FORCE_INLINE uint32_t getblock32 ( const uint32_t * p, int i ) +{ + return p[i]; +} + +FORCE_INLINE uint64_t getblock64 ( const uint64_t * p, int i ) +{ + return p[i]; +} + +//----------------------------------------------------------------------------- +// Finalization mix - force all bits of a hash block to avalanche + +FORCE_INLINE uint32_t fmix32 ( uint32_t h ) +{ + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + + return h; +} + +//---------- + +FORCE_INLINE uint64_t fmix64 ( uint64_t k ) +{ + k ^= k >> 33; + k *= BIG_CONSTANT(0xff51afd7ed558ccd); + k ^= k >> 33; + k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); + k ^= k >> 33; + + return k; +} + +//----------------------------------------------------------------------------- + +void MurmurHash3_x86_32 ( const void * key, int len, + uint32_t seed, void * out ) +{ + const uint8_t * data = (const uint8_t*)key; + const int nblocks = len / 4; + + uint32_t h1 = seed; + + const uint32_t c1 = 0xcc9e2d51; + const uint32_t c2 = 0x1b873593; + + //---------- + // body + + const uint32_t * blocks = (const uint32_t *)(data + nblocks*4); + + for(int i = -nblocks; i; i++) + { + uint32_t k1 = getblock32(blocks,i); + + k1 *= c1; + k1 = ROTL32(k1,15); + k1 *= c2; + + h1 ^= k1; + h1 = ROTL32(h1,13); + h1 = h1*5+0xe6546b64; + } + + //---------- + // tail + + const uint8_t * tail = (const uint8_t*)(data + nblocks*4); + + uint32_t k1 = 0; + + switch(len & 3) + { + case 3: k1 ^= tail[2] << 16; + case 2: k1 ^= tail[1] << 8; + case 1: k1 ^= tail[0]; + k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; + + h1 = fmix32(h1); + + *(uint32_t*)out = h1; +} + +//----------------------------------------------------------------------------- + +void MurmurHash3_x86_128 ( const void * key, const int len, + uint32_t seed, void * out ) +{ + const uint8_t * data = (const uint8_t*)key; + const int nblocks = len / 16; + + uint32_t h1 = seed; + uint32_t h2 = seed; + uint32_t h3 = seed; + uint32_t h4 = seed; + + const uint32_t c1 = 0x239b961b; + const uint32_t c2 = 0xab0e9789; + const uint32_t c3 = 0x38b34ae5; + const uint32_t c4 = 0xa1e38b93; + + //---------- + // body + + const uint32_t * blocks = (const uint32_t *)(data + nblocks*16); + + for(int i = -nblocks; i; i++) + { + uint32_t k1 = getblock32(blocks,i*4+0); + uint32_t k2 = getblock32(blocks,i*4+1); + uint32_t k3 = getblock32(blocks,i*4+2); + uint32_t k4 = getblock32(blocks,i*4+3); + + k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; + + h1 = ROTL32(h1,19); h1 += h2; h1 = h1*5+0x561ccd1b; + + k2 *= c2; k2 = ROTL32(k2,16); k2 *= c3; h2 ^= k2; + + h2 = ROTL32(h2,17); h2 += h3; h2 = h2*5+0x0bcaa747; + + k3 *= c3; k3 = ROTL32(k3,17); k3 *= c4; h3 ^= k3; + + h3 = ROTL32(h3,15); h3 += h4; h3 = h3*5+0x96cd1c35; + + k4 *= c4; k4 = ROTL32(k4,18); k4 *= c1; h4 ^= k4; + + h4 = ROTL32(h4,13); h4 += h1; h4 = h4*5+0x32ac3b17; + } + + //---------- + // tail + + const uint8_t * tail = (const uint8_t*)(data + nblocks*16); + + uint32_t k1 = 0; + uint32_t k2 = 0; + uint32_t k3 = 0; + uint32_t k4 = 0; + + switch(len & 15) + { + case 15: k4 ^= tail[14] << 16; + case 14: k4 ^= tail[13] << 8; + case 13: k4 ^= tail[12] << 0; + k4 *= c4; k4 = ROTL32(k4,18); k4 *= c1; h4 ^= k4; + + case 12: k3 ^= tail[11] << 24; + case 11: k3 ^= tail[10] << 16; + case 10: k3 ^= tail[ 9] << 8; + case 9: k3 ^= tail[ 8] << 0; + k3 *= c3; k3 = ROTL32(k3,17); k3 *= c4; h3 ^= k3; + + case 8: k2 ^= tail[ 7] << 24; + case 7: k2 ^= tail[ 6] << 16; + case 6: k2 ^= tail[ 5] << 8; + case 5: k2 ^= tail[ 4] << 0; + k2 *= c2; k2 = ROTL32(k2,16); k2 *= c3; h2 ^= k2; + + case 4: k1 ^= tail[ 3] << 24; + case 3: k1 ^= tail[ 2] << 16; + case 2: k1 ^= tail[ 1] << 8; + case 1: k1 ^= tail[ 0] << 0; + k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; h2 ^= len; h3 ^= len; h4 ^= len; + + h1 += h2; h1 += h3; h1 += h4; + h2 += h1; h3 += h1; h4 += h1; + + h1 = fmix32(h1); + h2 = fmix32(h2); + h3 = fmix32(h3); + h4 = fmix32(h4); + + h1 += h2; h1 += h3; h1 += h4; + h2 += h1; h3 += h1; h4 += h1; + + ((uint32_t*)out)[0] = h1; + ((uint32_t*)out)[1] = h2; + ((uint32_t*)out)[2] = h3; + ((uint32_t*)out)[3] = h4; +} + +//----------------------------------------------------------------------------- + +void MurmurHash3_x64_128 ( const void * key, const int len, + const uint32_t seed, void * out ) +{ + const uint8_t * data = (const uint8_t*)key; + const int nblocks = len / 16; + + uint64_t h1 = seed; + uint64_t h2 = seed; + + const uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5); + const uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f); + + //---------- + // body + + const uint64_t * blocks = (const uint64_t *)(data); + + for(int i = 0; i < nblocks; i++) + { + uint64_t k1 = getblock64(blocks,i*2+0); + uint64_t k2 = getblock64(blocks,i*2+1); + + k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; + + h1 = ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729; + + k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + h2 = ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5; + } + + //---------- + // tail + + const uint8_t * tail = (const uint8_t*)(data + nblocks*16); + + uint64_t k1 = 0; + uint64_t k2 = 0; + + switch(len & 15) + { + case 15: k2 ^= ((uint64_t)tail[14]) << 48; + case 14: k2 ^= ((uint64_t)tail[13]) << 40; + case 13: k2 ^= ((uint64_t)tail[12]) << 32; + case 12: k2 ^= ((uint64_t)tail[11]) << 24; + case 11: k2 ^= ((uint64_t)tail[10]) << 16; + case 10: k2 ^= ((uint64_t)tail[ 9]) << 8; + case 9: k2 ^= ((uint64_t)tail[ 8]) << 0; + k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; + + case 8: k1 ^= ((uint64_t)tail[ 7]) << 56; + case 7: k1 ^= ((uint64_t)tail[ 6]) << 48; + case 6: k1 ^= ((uint64_t)tail[ 5]) << 40; + case 5: k1 ^= ((uint64_t)tail[ 4]) << 32; + case 4: k1 ^= ((uint64_t)tail[ 3]) << 24; + case 3: k1 ^= ((uint64_t)tail[ 2]) << 16; + case 2: k1 ^= ((uint64_t)tail[ 1]) << 8; + case 1: k1 ^= ((uint64_t)tail[ 0]) << 0; + k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; + }; + + //---------- + // finalization + + h1 ^= len; h2 ^= len; + + h1 += h2; + h2 += h1; + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 += h2; + h2 += h1; + + ((uint64_t*)out)[0] = h1; + ((uint64_t*)out)[1] = h2; +} + +//----------------------------------------------------------------------------- + +} // namespace cdec + + diff --git a/utils/murmur_hash3.h b/utils/murmur_hash3.h new file mode 100644 index 00000000..a125d775 --- /dev/null +++ b/utils/murmur_hash3.h @@ -0,0 +1,67 @@ +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. + +#ifndef _MURMURHASH3_H_ +#define _MURMURHASH3_H_ + +//----------------------------------------------------------------------------- +// Platform-specific functions and macros + +// Microsoft Visual Studio + +#if defined(_MSC_VER) && (_MSC_VER < 1600) + +typedef unsigned char uint8_t; +typedef unsigned int uint32_t; +typedef unsigned __int64 uint64_t; + +// Other compilers + +#else // defined(_MSC_VER) + +#include + +#endif // !defined(_MSC_VER) + +//----------------------------------------------------------------------------- + +namespace cdec { + +void MurmurHash3_x86_32 ( const void * key, int len, uint32_t seed, void * out ); + +void MurmurHash3_x86_128 ( const void * key, int len, uint32_t seed, void * out ); + +void MurmurHash3_x64_128 ( const void * key, int len, uint32_t seed, void * out ); + +namespace { + #ifdef __clang__ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wunused-function" + #endif + template inline void cdecMurmurHashNativeBackend(const void * key, int len, uint32_t seed, void * out) { + MurmurHash3_x86_128(key, len, seed, out); + } + template <> inline void cdecMurmurHashNativeBackend<4>(const void * key, int len, uint32_t seed, void * out) { + MurmurHash3_x64_128(key, len, seed, out); + } + #ifdef __clang__ + #pragma clang diagnostic pop + #endif +} // namespace + +inline uint64_t MurmurHash3_64(const void * key, int len, uint32_t seed) { + uint64_t out[2]; + cdecMurmurHashNativeBackend(key, len, seed, &out); + return out[0]; +} + +inline void MurmurHash3_128(const void * key, int len, uint32_t seed, void * out) { + cdecMurmurHashNativeBackend(key, len, seed, out); +} + +} + +//----------------------------------------------------------------------------- + +#endif // _MURMURHASH3_H_ -- cgit v1.2.3 From 4b1ae0b11bb9fc4ee4c48b1443f36a824dc92fbe Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 8 Apr 2014 02:02:07 -0400 Subject: track node_hash field --- decoder/hg_io.cc | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) (limited to 'decoder') diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index 64c6663e..eb0be3d4 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -1,5 +1,7 @@ #include "hg_io.h" +#include +#include #include #include #include @@ -15,10 +17,15 @@ using namespace std; struct HGReader : public JSONParser { HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; } - void CreateNode(const string& cat, const vector& in_edges) { + void CreateNode(const string& cat, const string& shash, const vector& in_edges) { WordID c = TD::Convert("X") * -1; if (!cat.empty()) c = TD::Convert(cat) * -1; Hypergraph::Node* node = hg.AddNode(c); + char* dend; + if (shash.size()) + node->node_hash = strtoull(shash.c_str(), &dend, 16); + else + node->node_hash = 0; for (int i = 0; i < in_edges.size(); ++i) { if (in_edges[i] >= hg.edges_.size()) { cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i] @@ -102,17 +109,19 @@ struct HGReader : public JSONParser { ++nodes; in_edges.clear(); cat.clear(); + shash.clear(); state = 9; break; case 9: if (type == JSON_T_OBJECT_END) { //cerr << "Creating NODE\n"; - CreateNode(cat, in_edges); + CreateNode(cat, shash, in_edges); state = 0; break; } assert(type == JSON_T_KEY); cur_key = value->vu.str.value; if (cur_key == "cat") { assert(cat.empty()); state = 10; break; } if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; } + if (cur_key == "node_hash") { assert(shash.empty()); state = 24; break; } cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n"; return false; case 10: @@ -224,6 +233,12 @@ struct HGReader : public JSONParser { assert(spanc < 4); spans[spanc] = value->vu.integer_value; ++spanc; + break; + case 24: // read node hash + assert(type == JSON_T_STRING); + shash = value->vu.str.value; + state = 9; + break; } return true; } @@ -231,6 +246,7 @@ struct HGReader : public JSONParser { string cat; SmallVectorUnsigned tail; vector in_edges; + string shash; TRulePtr cur_rule; map rules; vector fdict; @@ -340,6 +356,9 @@ bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* o << ",\"cat\":"; JSONParser::WriteEscapedString(TD::Convert(node.cat_ * -1), &o); } + char buf[48]; + sprintf(buf, "%016lX", node.node_hash); + o << ",\"node_hash\":\"" << buf << "\""; o << "}"; } o << "}\n"; -- cgit v1.2.3 From 5daa3f144c980624eded2ff498d5e22f6716e969 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 8 Apr 2014 23:16:24 -0400 Subject: smarter union --- decoder/hg.cc | 1 + decoder/hg_test.cc | 17 +++++++++ decoder/hg_union.cc | 105 +++++++++++++++++++++++++++++++++++----------------- 3 files changed, 89 insertions(+), 34 deletions(-) (limited to 'decoder') diff --git a/decoder/hg.cc b/decoder/hg.cc index 405169c6..e456fa7c 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -396,6 +396,7 @@ void Hypergraph::PrintGraphviz() const { for (const auto& node : nodes_) { cerr << " " << node.id_ << "[label=\"" << (node.cat_ < 0 ? TD::Convert(node.cat_ * -1) : "") << " n=" << node.id_ + << " h=" << node.node_hash << "\"];\n"; } cerr << "}\n"; diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 95cfae51..5cb8626a 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -44,6 +44,13 @@ BOOST_AUTO_TEST_CASE(Union) { Hypergraph hg2; CreateHG_tiny(path, &hg1); CreateHG(path, &hg2); + int nc = 0; + for (auto& node: hg1.nodes_) + node.node_hash = ++nc; + for (auto& node: hg2.nodes_) + node.node_hash = ++nc; + hg1.nodes_.back().node_hash = nc; + SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 1.0); @@ -56,8 +63,11 @@ BOOST_AUTO_TEST_CASE(Union) { int l2 = ViterbiPathLength(hg2); cerr << c1 << "\t" << TD::GetString(t1) << endl; cerr << c2 << "\t" << TD::GetString(t2) << endl; + hg1.PrintGraphviz(); + hg2.PrintGraphviz(); HG::Union(hg2, &hg1); hg1.Reweight(wts); + hg1.PrintGraphviz(); c3 = ViterbiESentence(hg1, &t3); int l3 = ViterbiPathLength(hg1); cerr << c3 << "\t" << TD::GetString(t3) << endl; @@ -84,6 +94,13 @@ BOOST_AUTO_TEST_CASE(Union) { BOOST_CHECK_CLOSE(log(list[0].second), log(c4), 1e-4); BOOST_CHECK_EQUAL(list.size(), 6); BOOST_CHECK_CLOSE(log(list.back().second / list.front().second), -97.7, 1e-4); + hg1 = hg2; + BOOST_CHECK_EQUAL(hg1.nodes_.size(), hg2.nodes_.size()); + BOOST_CHECK_EQUAL(hg1.edges_.size(), hg2.edges_.size()); + HG::Union(hg1, &hg2); // this should be a no-op + BOOST_CHECK_EQUAL(hg1.nodes_.size(), hg2.nodes_.size()); + BOOST_CHECK_EQUAL(hg1.edges_.size(), hg2.edges_.size()); + cerr << "DONE UNION\n"; } BOOST_AUTO_TEST_CASE(ControlledKBest) { diff --git a/decoder/hg_union.cc b/decoder/hg_union.cc index 37082976..4899e716 100644 --- a/decoder/hg_union.cc +++ b/decoder/hg_union.cc @@ -1,56 +1,93 @@ #include "hg_union.h" +#ifndef HAVE_OLD_CPP +# include +#else +# include +namespace std { using std::tr1::unordered_set; } +#endif + #include "hg.h" +#include "sparse_vector.h" using namespace std; namespace HG { +static bool EdgesMatch(const HG::Edge& a, const Hypergraph& ahg, const HG::Edge& b, const Hypergraph& bhg) { + const unsigned arity = a.tail_nodes_.size(); + if (arity != b.tail_nodes_.size()) return false; + if (a.rule_->e() != b.rule_->e()) return false; + if (a.rule_->f() != b.rule_->f()) return false; + + for (unsigned i = 0; i < arity; ++i) + if (ahg.nodes_[a.tail_nodes_[i]].node_hash != bhg.nodes_[b.tail_nodes_[i]].node_hash) return false; + const SparseVector diff = a.feature_values_ - b.feature_values_; + for (auto& kv : diff) + if (fabs(kv.second) > 1e-6) return false; + return true; +} + void Union(const Hypergraph& in, Hypergraph* out) { if (&in == out) return; if (out->nodes_.empty()) { out->nodes_ = in.nodes_; out->edges_ = in.edges_; return; } - unsigned noff = out->nodes_.size(); - unsigned eoff = out->edges_.size(); - int ogoal = in.nodes_.size() - 1; - int cgoal = noff - 1; - // keep a single goal node, so add nodes.size - 1 - out->nodes_.resize(out->nodes_.size() + ogoal); - // add all edges - out->edges_.resize(out->edges_.size() + in.edges_.size()); - - for (int i = 0; i < ogoal; ++i) { - const Hypergraph::Node& on = in.nodes_[i]; - Hypergraph::Node& cn = out->nodes_[i + noff]; - cn.id_ = i + noff; - cn.in_edges_.resize(on.in_edges_.size()); - for (unsigned j = 0; j < on.in_edges_.size(); ++j) - cn.in_edges_[j] = on.in_edges_[j] + eoff; - - cn.out_edges_.resize(on.out_edges_.size()); - for (unsigned j = 0; j < on.out_edges_.size(); ++j) - cn.out_edges_[j] = on.out_edges_[j] + eoff; + if (!in.AreNodesUniquelyIdentified()) { + cerr << "Union: Nodes are not uniquely identified in input!\n"; + abort(); + } + if (!out->AreNodesUniquelyIdentified()) { + cerr << "Union: Nodes are not uniquely identified in output!\n"; + abort(); } + if (out->nodes_.back().node_hash != in.nodes_.back().node_hash) { + cerr << "Union: Goal nodes are mismatched!\n"; + abort(); + } + const int cgoal = out->nodes_.back().id_; - for (unsigned i = 0; i < in.edges_.size(); ++i) { - const Hypergraph::Edge& oe = in.edges_[i]; - Hypergraph::Edge& ce = out->edges_[i + eoff]; - ce.id_ = i + eoff; - ce.rule_ = oe.rule_; - ce.feature_values_ = oe.feature_values_; - if (oe.head_node_ == ogoal) { - ce.head_node_ = cgoal; - out->nodes_[cgoal].in_edges_.push_back(ce.id_); - } else { - ce.head_node_ = oe.head_node_ + noff; + unordered_map h2n; + for (const auto& node : out->nodes_) + h2n[node.node_hash] = node.id_; + for (const auto& node : in.nodes_) { + if (h2n.count(node.node_hash) == 0) { + HG::Node* new_node = out->AddNode(node.cat_); + new_node->node_hash = node.node_hash; + h2n[node.node_hash] = new_node->id_; } - ce.tail_nodes_.resize(oe.tail_nodes_.size()); - for (unsigned j = 0; j < oe.tail_nodes_.size(); ++j) - ce.tail_nodes_[j] = oe.tail_nodes_[j] + noff; } + for (const auto& in_node : in.nodes_) { + HG::Node& out_node = out->nodes_[h2n[in_node.node_hash]]; + for (const auto oeid : out_node.in_edges_) { + // TODO hash currently existing edges for quick check for duplication + } + for (const auto ieid : in_node.in_edges_) { + const HG::Edge& in_edge = in.edges_[ieid]; + // TODO: replace slow N^2 check with hashing + bool edge_exists = false; + for (const auto oeid : out_node.in_edges_) { + if (EdgesMatch(in_edge, in, out->edges_[oeid], *out)) { + edge_exists = true; + break; + } + } + if (!edge_exists) { + const unsigned arity = in_edge.tail_nodes_.size(); + TailNodeVector t(arity); + HG::Node& head = out->nodes_[h2n[in_node.node_hash]]; + for (unsigned i = 0; i < arity; ++i) + t[i] = h2n[in.nodes_[in_edge.tail_nodes_[i]].node_hash]; + HG::Edge* new_edge = out->AddEdge(in_edge, t); + out->ConnectEdgeToHeadNode(new_edge, &head); + //cerr << "Created: " << new_edge->rule_->AsString() << " [head=" << new_edge->head_node_ << "]\n"; + } //else { + // cerr << "Not created: " << in.edges_[ieid].rule_->AsString() << "\n"; + //} + } + } out->TopologicallySortNodesAndEdges(cgoal); } -- cgit v1.2.3 From a154f22c8ef2d170fa4a8179b1211898eb6831d3 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 9 Apr 2014 01:06:33 -0400 Subject: don't hash on an internal id --- decoder/node_state_hash.h | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'decoder') diff --git a/decoder/node_state_hash.h b/decoder/node_state_hash.h index cdc05877..9fc01a09 100644 --- a/decoder/node_state_hash.h +++ b/decoder/node_state_hash.h @@ -3,14 +3,19 @@ #include #include +#include "tdict.h" #include "murmur_hash3.h" #include "ffset.h" namespace cdec { struct FirstPassNode { - FirstPassNode(int cat, int i, int j, int pi, int pj) : lhs(cat), s(i), t(j), u(pi), v(pj) {} - int32_t lhs; + FirstPassNode(int cat, int i, int j, int pi, int pj) : s(i), t(j), u(pi), v(pj) { + memset(lhs, 0, 120); + unsigned it = 0; + for (auto& c : TD::Convert(-cat)) { lhs[it++] = c; if (it == 120) break; } + } + char lhs[120]; short s; short t; short u; @@ -23,6 +28,7 @@ namespace cdec { } inline uint64_t HashNode(uint64_t old_hash, const FFState& state) { + if (state.size() == 0) return old_hash; uint8_t buf[1024]; std::memcpy(buf, &old_hash, sizeof(uint64_t)); assert(state.size() < (1024u - sizeof(uint64_t))); -- cgit v1.2.3 From d92457adef112e34e14c1ce9a43f95cf3c0da4d4 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 9 Apr 2014 01:09:07 -0400 Subject: better logging of union stats --- decoder/hg_union.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) (limited to 'decoder') diff --git a/decoder/hg_union.cc b/decoder/hg_union.cc index 4899e716..a659b6bc 100644 --- a/decoder/hg_union.cc +++ b/decoder/hg_union.cc @@ -7,6 +7,7 @@ namespace std { using std::tr1::unordered_set; } #endif +#include "verbose.h" #include "hg.h" #include "sparse_vector.h" @@ -43,7 +44,7 @@ void Union(const Hypergraph& in, Hypergraph* out) { abort(); } if (out->nodes_.back().node_hash != in.nodes_.back().node_hash) { - cerr << "Union: Goal nodes are mismatched!\n"; + cerr << "Union: Goal nodes are mismatched!\n a=" << in.nodes_.back().node_hash << " b=" << out->nodes_.back().node_hash << "\n"; abort(); } const int cgoal = out->nodes_.back().id_; @@ -59,6 +60,8 @@ void Union(const Hypergraph& in, Hypergraph* out) { } } + double n_exists = 0; + double n_created = 0; for (const auto& in_node : in.nodes_) { HG::Node& out_node = out->nodes_[h2n[in_node.node_hash]]; for (const auto oeid : out_node.in_edges_) { @@ -82,12 +85,20 @@ void Union(const Hypergraph& in, Hypergraph* out) { t[i] = h2n[in.nodes_[in_edge.tail_nodes_[i]].node_hash]; HG::Edge* new_edge = out->AddEdge(in_edge, t); out->ConnectEdgeToHeadNode(new_edge, &head); + ++n_created; //cerr << "Created: " << new_edge->rule_->AsString() << " [head=" << new_edge->head_node_ << "]\n"; - } //else { + } else { + ++n_exists; + } // cerr << "Not created: " << in.edges_[ieid].rule_->AsString() << "\n"; //} } } + if (!SILENT) + cerr << " Union: edges_created=" << n_created + << " edges_already_existing=" + << n_exists << " ratio_new=" << (n_created / (n_exists + n_created)) + << endl; out->TopologicallySortNodesAndEdges(cgoal); } -- cgit v1.2.3 From 34a148b9882cc983c8292978503849b114fc2983 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 16 Apr 2014 00:36:30 -0400 Subject: fix for bug due to using wrong tree traversal --- decoder/t2s_test.cc | 8 ++- decoder/tree2string_translator.cc | 18 +++--- decoder/tree_fragment.cc | 12 ---- decoder/tree_fragment.h | 112 +++++++++++++++++++++++++++++++++++++- 4 files changed, 125 insertions(+), 25 deletions(-) (limited to 'decoder') diff --git a/decoder/t2s_test.cc b/decoder/t2s_test.cc index 3c46ea89..5ebb2662 100644 --- a/decoder/t2s_test.cc +++ b/decoder/t2s_test.cc @@ -15,8 +15,11 @@ BOOST_AUTO_TEST_CASE(TestTreeFragments) { vector aw, bw; cerr << "TREE1: " << tree << endl; cerr << "TREE2: " << tree2 << endl; - for (auto& sym : tree) + for (auto& sym : tree) { + if (cdec::IsLHS(sym)) cerr << "("; + cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; if (cdec::IsTerminal(sym)) aw.push_back(sym); else a.push_back(sym); + } for (auto& sym : tree2) if (cdec::IsTerminal(sym)) bw.push_back(sym); else b.push_back(sym); BOOST_CHECK_EQUAL(a.size(), b.size()); @@ -38,11 +41,12 @@ BOOST_AUTO_TEST_CASE(TestTreeFragments) { if (cdec::IsFrontier(*it)) nts += "*"; } } + cerr << "Truncated: " << nts << endl; BOOST_CHECK_EQUAL(nts, "(S NP* VP*"); nts.clear(); int ntc = 0; - for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { + for (auto it = tree.bfs_begin(); it != tree.bfs_end(); ++it) { if (cdec::IsNT(*it)) { if (cdec::IsRHS(*it)) { ++ntc; diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 8d12d01d..3fbf1ee5 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -38,14 +38,13 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { // but it will not generate source strings correctly vector frhs; for (auto sym : rule_src) { + //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; cur = &cur->next[sym]; - if (sym) { - if (cdec::IsFrontier(sym)) { // frontier symbols -> variables - int nt = (sym & cdec::ALL_MASK); - frhs.push_back(-nt); - } else if (cdec::IsTerminal(sym)) { - frhs.push_back(sym); - } + if (cdec::IsFrontier(sym)) { // frontier symbols -> variables + int nt = (sym & cdec::ALL_MASK); + frhs.push_back(-nt); + } else if (cdec::IsTerminal(sym)) { + frhs.push_back(sym); } } os << '[' << TD::Convert(-lhs) << "] |||"; @@ -61,6 +60,7 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { os << " ||| " << line.substr(pos); TRulePtr rule(new TRule(os.str())); cur->rules.push_back(rule); + //cerr << "RULE: " << rule->AsString() << "\n\n"; } } @@ -82,7 +82,7 @@ struct ParserState { } vector future_work; int input_node_idx; // lhs of top level NT - Tree2StringGrammarNode* node; + Tree2StringGrammarNode* node; // pointer into grammar }; namespace std { @@ -239,11 +239,13 @@ struct Tree2StringTranslatorImpl { new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q if (unique.insert(new_s).second) q.push(new_s); } + //else { cerr << "did not match [" << TD::Convert(sym & cdec::ALL_MASK) << "]\n"; } if (nit1 != s.node->next.end()) { //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; const ParserState new_s(++s.in_iter, &nit1->second, s); if (unique.insert(new_s).second) q.push(new_s); } + //else { cerr << "did not match " << TD::Convert(sym & cdec::ALL_MASK) << "\n"; } } else if (cdec::IsTerminal(sym)) { auto nit = s.node->next.find(sym); if (nit != s.node->next.end()) { diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 78a993b8..4d429f42 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -112,16 +112,4 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned *psymp = symp; } -BreadthFirstIterator TreeFragment::begin() const { - return BreadthFirstIterator(this, nodes.size() - 1); -} - -BreadthFirstIterator TreeFragment::begin(unsigned node_idx) const { - return BreadthFirstIterator(this, node_idx); -} - -BreadthFirstIterator TreeFragment::end() const { - return BreadthFirstIterator(this); -} - } diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index f1c4c106..4a704bc4 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -11,6 +11,7 @@ namespace cdec { class BreadthFirstIterator; +class DepthFirstIterator; static const unsigned LHS_BIT = 0x10000000u; static const unsigned RHS_BIT = 0x20000000u; @@ -53,16 +54,21 @@ class TreeFragment { // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar)))) explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false); void DebugRec(unsigned cur, std::ostream* out) const; - typedef BreadthFirstIterator iterator; + typedef DepthFirstIterator iterator; typedef ptrdiff_t difference_type; typedef unsigned value_type; typedef const unsigned * pointer; typedef const unsigned & reference; + // default iterator is DFS iterator begin() const; iterator begin(unsigned node_idx) const; iterator end() const; + BreadthFirstIterator bfs_begin() const; + BreadthFirstIterator bfs_begin(unsigned node_idx) const; + BreadthFirstIterator bfs_end() const; + private: // cp is the character index in the tree // np keeps track of the nodes (nonterminals) that have been built @@ -78,14 +84,106 @@ class TreeFragment { struct TFIState { TFIState() : node(), rhspos(), state() {} - TFIState(unsigned n, unsigned p, unsigned s) : node(n), rhspos(p), state(s) {} + TFIState(unsigned n, int p, unsigned s) : node(n), rhspos(p), state(s) {} bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos && state == o.state; } bool operator!=(const TFIState& o) const { return node != o.node || rhspos != o.rhspos || state != o.state; } unsigned short node; - unsigned short rhspos; + short rhspos; unsigned char state; }; +class DepthFirstIterator : public std::iterator { + const TreeFragment* tf_; + std::deque q_; + unsigned sym; + public: + DepthFirstIterator() : tf_(), sym() {} + // used for begin + explicit DepthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) { + q_.push_back(TFIState(node_idx, -1, 0)); + Stage(); + q_.back().state++; + } + // used for end + explicit DepthFirstIterator(const TreeFragment* tf) : tf_(tf) {} + const unsigned& operator*() const { return sym; } + const unsigned* operator->() const { return &sym; } + bool operator==(const DepthFirstIterator& other) const { + return (tf_ == other.tf_) && (q_ == other.q_); + } + bool operator!=(const DepthFirstIterator& other) const { + return (tf_ != other.tf_) || (q_ != other.q_); + } + unsigned node_idx() const { return q_.front().node; } + const DepthFirstIterator& operator++() { + TFIState& s = q_.back(); + if (s.state == 0) { + Stage(); + s.state++; + } else if (s.state == 1) { + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos >= len) { + q_.pop_back(); + while (!q_.empty()) { + TFIState& s = q_.back(); + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos < len) break; + q_.pop_back(); + } + } + Stage(); + } + return *this; + } + DepthFirstIterator operator++(int) { + DepthFirstIterator res = *this; + ++(*this); + return res; + } + // tell iterator not to explore the subtree rooted at sym + // should only be called once per NT symbol encountered + const DepthFirstIterator& truncate() { + assert(IsRHS(sym)); + sym &= ALL_MASK; + sym |= FRONTIER_BIT; + q_.pop_back(); + return *this; + } + unsigned child_node() const { + assert(IsRHS(sym)); + return q_.back().node; + } + DepthFirstIterator remainder() const { + assert(IsRHS(sym)); + return DepthFirstIterator(tf_, q_.back()); + } + bool at_end() const { + return q_.empty(); + } + private: + void Stage() { + if (q_.empty()) return; + const TFIState& s = q_.back(); + if (s.state == 0) { + sym = (tf_->nodes[s.node].lhs & ALL_MASK) | LHS_BIT; + } else if (s.state == 1) { + sym = tf_->nodes[s.node].rhs[s.rhspos]; + if (IsRHS(sym)) { + q_.push_back(TFIState(sym & ALL_MASK, -1, 0)); + sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT; + } + } + } + + // used by remainder + DepthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) { + q_.push_back(s); + Stage(); + } +}; + class BreadthFirstIterator : public std::iterator { const TreeFragment* tf_; std::deque q_; @@ -172,6 +270,14 @@ class BreadthFirstIterator : public std::iterator Date: Thu, 17 Apr 2014 20:55:34 -0400 Subject: fix rescoring --- decoder/trule.cc | 6 ++++++ tests/system_tests/cfg_rescore/README | 4 ++++ tests/system_tests/cfg_rescore/cdec.ini | 2 ++ tests/system_tests/cfg_rescore/gold.statistics | 3 +++ tests/system_tests/cfg_rescore/gold.stdout | 4 ++++ tests/system_tests/cfg_rescore/input.cfg | 9 +++++++++ tests/system_tests/cfg_rescore/input.txt | 1 + tests/system_tests/cfg_rescore/weights | 3 +++ 8 files changed, 32 insertions(+) create mode 100644 tests/system_tests/cfg_rescore/README create mode 100644 tests/system_tests/cfg_rescore/cdec.ini create mode 100644 tests/system_tests/cfg_rescore/gold.statistics create mode 100644 tests/system_tests/cfg_rescore/gold.stdout create mode 100644 tests/system_tests/cfg_rescore/input.cfg create mode 100644 tests/system_tests/cfg_rescore/input.txt create mode 100644 tests/system_tests/cfg_rescore/weights (limited to 'decoder') diff --git a/decoder/trule.cc b/decoder/trule.cc index 1bd5425f..bee211d5 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -56,6 +56,12 @@ bool TRule::ReadFromString(const string& line, bool mono) { RuleLexer::ReadRule(line + '\n', assign_trule, mono, this); if (n_assigned > 1) cerr<<"\nWARNING: more than one rule parsed from multi-line string; kept last: "< input.txt diff --git a/tests/system_tests/cfg_rescore/cdec.ini b/tests/system_tests/cfg_rescore/cdec.ini new file mode 100644 index 00000000..1a913f2d --- /dev/null +++ b/tests/system_tests/cfg_rescore/cdec.ini @@ -0,0 +1,2 @@ +formalism=rescore +k_best=100 diff --git a/tests/system_tests/cfg_rescore/gold.statistics b/tests/system_tests/cfg_rescore/gold.statistics new file mode 100644 index 00000000..7b05e2d8 --- /dev/null +++ b/tests/system_tests/cfg_rescore/gold.statistics @@ -0,0 +1,3 @@ +-lm_nodes 8 +-lm_edges 10 +-lm_paths 4 diff --git a/tests/system_tests/cfg_rescore/gold.stdout b/tests/system_tests/cfg_rescore/gold.stdout new file mode 100644 index 00000000..ccf99263 --- /dev/null +++ b/tests/system_tests/cfg_rescore/gold.stdout @@ -0,0 +1,4 @@ +0 ||| the broccoli was eaten by John ||| Passive=1 Definite=1 ||| 2 +0 ||| John ate the broccoli ||| Active=1 Definite=1 ||| 1.1 +0 ||| broccoli was eaten by John ||| Passive=1 ||| 1 +0 ||| John ate broccoli ||| Active=1 ||| 0.1 diff --git a/tests/system_tests/cfg_rescore/input.cfg b/tests/system_tests/cfg_rescore/input.cfg new file mode 100644 index 00000000..0073cb7b --- /dev/null +++ b/tests/system_tests/cfg_rescore/input.cfg @@ -0,0 +1,9 @@ +[S] ||| [S1] +[S1] ||| [NP1] [VP] ||| Active=1 +[VP] ||| [V] [NP2] +[V] ||| ate +[VPSV] ||| was eaten +[S1] ||| [NP2] [VPSV] by [NP1] ||| Passive=1 +[NP1] ||| John +[NP2] ||| broccoli +[NP2] ||| the broccoli ||| Definite=1 diff --git a/tests/system_tests/cfg_rescore/input.txt b/tests/system_tests/cfg_rescore/input.txt new file mode 100644 index 00000000..71fc26bc --- /dev/null +++ b/tests/system_tests/cfg_rescore/input.txt @@ -0,0 +1 @@ +{"rules":[1,"[S] ||| [S1] ||| [1]",2,"[S1] ||| [NP1] [VP] ||| [1] [2] ||| Active=1",3,"[VP] ||| [V] [NP2] ||| [1] [2]",4,"[V] ||| ate ||| ate",5,"[VPSV] ||| was eaten ||| was eaten",6,"[S1] ||| [NP2] [VPSV] by [NP1] ||| [1] [2] by [3] ||| Passive=1",7,"[NP1] ||| John ||| John",8,"[NP2] ||| broccoli ||| broccoli",9,"[NP2] ||| the broccoli ||| the broccoli ||| Definite=1"],"features":["PhraseModel_0","PhraseModel_1","PhraseModel_2","PhraseModel_3","PhraseModel_4","PhraseModel_5","PhraseModel_6","PhraseModel_7","PhraseModel_8","PhraseModel_9","PhraseModel_10","PhraseModel_11","PhraseModel_12","PhraseModel_13","PhraseModel_14","PhraseModel_15","PhraseModel_16","PhraseModel_17","PhraseModel_18","PhraseModel_19","PhraseModel_20","PhraseModel_21","PhraseModel_22","PhraseModel_23","PhraseModel_24","PhraseModel_25","PhraseModel_26","PhraseModel_27","PhraseModel_28","PhraseModel_29","PhraseModel_30","PhraseModel_31","PhraseModel_32","PhraseModel_33","PhraseModel_34","PhraseModel_35","PhraseModel_36","PhraseModel_37","PhraseModel_38","PhraseModel_39","PhraseModel_40","PhraseModel_41","PhraseModel_42","PhraseModel_43","PhraseModel_44","PhraseModel_45","PhraseModel_46","PhraseModel_47","PhraseModel_48","PhraseModel_49","PhraseModel_50","PhraseModel_51","PhraseModel_52","PhraseModel_53","PhraseModel_54","PhraseModel_55","PhraseModel_56","PhraseModel_57","PhraseModel_58","PhraseModel_59","PhraseModel_60","PhraseModel_61","PhraseModel_62","PhraseModel_63","PhraseModel_64","PhraseModel_65","PhraseModel_66","PhraseModel_67","PhraseModel_68","PhraseModel_69","PhraseModel_70","PhraseModel_71","PhraseModel_72","PhraseModel_73","PhraseModel_74","PhraseModel_75","PhraseModel_76","PhraseModel_77","PhraseModel_78","PhraseModel_79","PhraseModel_80","PhraseModel_81","PhraseModel_82","PhraseModel_83","PhraseModel_84","PhraseModel_85","PhraseModel_86","PhraseModel_87","PhraseModel_88","PhraseModel_89","PhraseModel_90","PhraseModel_91","PhraseModel_92","PhraseModel_93","PhraseModel_94","PhraseModel_95","PhraseModel_96","PhraseModel_97","PhraseModel_98","PhraseModel_99","Active","Passive","Definite"],"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":7}],"node":{"in_edges":[0],"cat":"NP1","node_hash":"0000000000000007"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":4}],"node":{"in_edges":[1],"cat":"V","node_hash":"0000000000000004"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":8},{"tail":[],"spans":[-1,-1,-1,-1],"feats":[102,1],"rule":9}],"node":{"in_edges":[2,3],"cat":"NP2","node_hash":"0000000000000009"},"edges":[{"tail":[1,2],"spans":[-1,-1,-1,-1],"feats":[],"rule":3}],"node":{"in_edges":[4],"cat":"VP","node_hash":"0000000000000003"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":5}],"node":{"in_edges":[5],"cat":"VPSV","node_hash":"0000000000000005"},"edges":[{"tail":[0,3],"spans":[-1,-1,-1,-1],"feats":[100,1],"rule":2},{"tail":[2,4,0],"spans":[-1,-1,-1,-1],"feats":[101,1],"rule":6}],"node":{"in_edges":[6,7],"cat":"S1","node_hash":"0000000000000006"},"edges":[{"tail":[5],"spans":[-1,-1,-1,-1],"feats":[],"rule":1}],"node":{"in_edges":[8],"cat":"S","node_hash":"0000000000000001"}} diff --git a/tests/system_tests/cfg_rescore/weights b/tests/system_tests/cfg_rescore/weights new file mode 100644 index 00000000..bd3bb1af --- /dev/null +++ b/tests/system_tests/cfg_rescore/weights @@ -0,0 +1,3 @@ +Active 0.1 +Passive 1 +Definite 1 -- cgit v1.2.3 From c5531cdbcecb6e902ab08c9297f16f03b829497f Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 25 Apr 2014 02:01:59 -0400 Subject: support for multiple xRs states in parser (not yet in rules) --- decoder/tree2string_translator.cc | 121 ++++++++++++++++++++++++++------------ decoder/trule.h | 3 + training/utils/grammar_convert.cc | 5 +- 3 files changed, 89 insertions(+), 40 deletions(-) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 3fbf1ee5..29caaf8f 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -30,12 +30,15 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { unsigned xc = 0; while (line[pos - 1] == ' ') { --pos; xc++; } cdec::TreeFragment rule_src(line.substr(0, pos), true); - Tree2StringGrammarNode* cur = root; + // TODO transducer_state should (optionally?) be read from input + const unsigned transducer_state = 0; + Tree2StringGrammarNode* cur = &root->next[transducer_state]; ostringstream os; int lhs = -(rule_src.root & cdec::ALL_MASK); // build source RHS for SCFG projection // TODO - this is buggy - it will generate a well-formed SCFG rule - // but it will not generate source strings correctly + // so it will not generate source strings correctly + // it will, however, generate target translations appropriately vector frhs; for (auto sym : rule_src) { //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; @@ -59,40 +62,65 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { while(line[pos] == ' ') { ++pos; } os << " ||| " << line.substr(pos); TRulePtr rule(new TRule(os.str())); + // TODO the transducer_state you end up in after using this rule (for each NT) + // needs to be read and encoded somehow in the rule (for use XXX) cur->rules.push_back(rule); //cerr << "RULE: " << rule->AsString() << "\n\n"; } } +// represents where in an input parse tree the transducer must continue +// and what state it is in +struct TransducerState { + TransducerState() : input_node_idx(), transducer_state() {} + TransducerState(unsigned n, unsigned q) : input_node_idx(n), transducer_state(q) {} + bool operator==(const TransducerState& o) const { + return input_node_idx == o.input_node_idx && + transducer_state == o.transducer_state; + } + unsigned input_node_idx; + unsigned transducer_state; +}; + +// represents the state of the composition algorithm struct ParserState { ParserState() : in_iter(), node() {} cdec::TreeFragment::iterator in_iter; - ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n) : + ParserState(const cdec::TreeFragment::iterator& it, unsigned q, Tree2StringGrammarNode* n) : in_iter(it), - input_node_idx(it.node_idx()), + task(it.node_idx(), q), node(n) {} ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : in_iter(it), future_work(p.future_work), - input_node_idx(p.input_node_idx), + task(p.task), node(n) {} bool operator==(const ParserState& o) const { - return node == o.node && input_node_idx == o.input_node_idx && + return node == o.node && task == o.task && future_work == o.future_work && in_iter == o.in_iter; } - vector future_work; - int input_node_idx; // lhs of top level NT - Tree2StringGrammarNode* node; // pointer into grammar + vector future_work; + TransducerState task; // subtree root where and in what state did the transducer start? + Tree2StringGrammarNode* node; // pointer into grammar trie }; namespace std { + template<> + struct hash { + size_t operator()(const TransducerState& q) const { + size_t h = boost::hash_value(q.transducer_state); + boost::hash_combine(h, boost::hash_value(q.input_node_idx)); + return h; + } + }; template<> struct hash { size_t operator()(const ParserState& s) const { - size_t h = boost::hash_range(s.future_work.begin(), s.future_work.end()); - boost::hash_combine(h, boost::hash_value(s.node)); - boost::hash_combine(h, boost::hash_value(s.input_node_idx)); - //boost::hash_combine(h, ); + size_t h = boost::hash_value(s.node); + for (auto& w : s.future_work) + boost::hash_combine(h, hash()(w)); + boost::hash_combine(h, hash()(s.task)); + // TODO hash with iterator return h; } }; @@ -144,6 +172,9 @@ struct Tree2StringTranslatorImpl { os << ')'; cdec::TreeFragment rule_src(os.str(), true); Tree2StringGrammarNode* cur = root.back().get(); + // do we need all transducer states here??? a list??? no pass through rules??? + unsigned transducer_state = 0; + cur = &cur->next[transducer_state]; for (auto sym : rule_src) cur = &cur->next[sym]; TRulePtr rule(new TRule(rhse, rhsf, lhs)); @@ -167,15 +198,19 @@ struct Tree2StringTranslatorImpl { if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; hg.ReserveNodes(input_tree.nodes.size()); - vector tree2hg(input_tree.nodes.size() + 1, -1); + unordered_map x2hg(input_tree.nodes.size() * 5); queue q; unordered_set unique; // only create items one time for (auto& g : root) { - q.push(ParserState(input_tree.begin(), g.get())); - unique.insert(q.back()); + unsigned q_0 = 0; // TODO initialize q_0 properly once multi-state transducers are supported + auto rit = g->next.find(q_0); + if (rit != g->next.end()) { // does this g have this transducer state? + q.push(ParserState(input_tree.begin(), q_0, &rit->second)); + unique.insert(q.back()); + } } if (q.size() == 0) return false; - unsigned tree_top = q.front().input_node_idx; + const TransducerState tree_top = q.front().task; while(!q.empty()) { ParserState& s = q.front(); @@ -183,21 +218,24 @@ struct Tree2StringTranslatorImpl { //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" << // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; if (s.node->rules.size()) { - int& node_id = tree2hg[s.input_node_idx]; - if (node_id < 0) { - HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK)); - new_node->node_hash = s.input_node_idx + 1; - node_id = new_node->id_; + auto it = x2hg.find(s.task); + if (it == x2hg.end()) { + // TODO create composite state symbol that encodes transducer state type? + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.task.input_node_idx].lhs & cdec::ALL_MASK)); + new_node->node_hash = std::hash()(s.task); + it = x2hg.insert(make_pair(s.task, new_node->id_)).first; } + const unsigned node_id = it->second; TailNodeVector tail; - for (auto n : s.future_work) { - int& nix = tree2hg[n]; - if (nix < 0) { - HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n].lhs & cdec::ALL_MASK)); - new_node->node_hash = n + 1; - nix = new_node->id_; + for (const auto& n : s.future_work) { + auto it = x2hg.find(n); + if (it == x2hg.end()) { + // TODO create composite state symbol that encodes transducer state type? + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK)); + new_node->node_hash = std::hash()(n); + it = x2hg.insert(make_pair(n, new_node->id_)).first; } - tail.push_back(nix); + tail.push_back(it->second); } for (auto& r : s.node->rules) { assert(tail.size() == r->Arity()); @@ -206,11 +244,14 @@ struct Tree2StringTranslatorImpl { // TODO: set i and j hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]); } - for (auto n : s.future_work) { - const auto it = input_tree.begin(n); // start tree iterator at node n + for (const auto& n : s.future_work) { + const auto it = input_tree.begin(n.input_node_idx); // start tree iterator at node n for (auto& g : root) { - ParserState s(it, g.get()); - if (unique.insert(s).second) q.push(s); + auto rit = g->next.find(n.transducer_state); + if (rit != g->next.end()) { // does this g have this transducer state? + const ParserState s(it, n.transducer_state, &rit->second); + if (unique.insert(s).second) q.push(s); + } } } } else { @@ -234,9 +275,13 @@ struct Tree2StringTranslatorImpl { if (nit2 != s.node->next.end()) { //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; ++var; - const unsigned new_work = s.in_iter.child_node(); + // TODO: find out from rule what the new target state is (the 0 in the next line) + // if it is associated with the rule, we won't know until we match the whole input + // so the 0 may be okay (if this is the case, which is probably the easiest thing, + // then the state must be dealt with when the future work becomes real work) + const TransducerState new_task(s.in_iter.child_node(), 0); ParserState new_s(var, &nit2->second, s); - new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q + new_s.future_work.push_back(new_task); // if this traversal of the input succeeds, future_work goes on the q if (unique.insert(new_s).second) q.push(new_s); } //else { cerr << "did not match [" << TD::Convert(sym & cdec::ALL_MASK) << "]\n"; } @@ -259,10 +304,10 @@ struct Tree2StringTranslatorImpl { } q.pop(); } - int goal = tree2hg[tree_top]; - if (goal < 0) return false; + const auto goal_it = x2hg.find(tree_top); + if (goal_it == x2hg.end()) return false; //cerr << "Goal node: " << goal << endl; - hg.TopologicallySortNodesAndEdges(goal); + hg.TopologicallySortNodesAndEdges(goal_it->second); hg.Reweight(weights); // there might be nodes that cannot be derived diff --git a/decoder/trule.h b/decoder/trule.h index 7dced5a1..cc370757 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -42,6 +42,9 @@ class TRule { scores_.set_value(feat_ids[i], feat_vals[i]); } + TRule(WordID lhs, const WordID* src, int src_size, const WordID* trg, int trg_size, int arity, int pi, int pj) : + e_(trg, trg + trg_size), f_(src, src + src_size), lhs_(lhs), arity_(arity), prev_i(pi), prev_j(pj) {} + bool IsGoal() const; explicit TRule(const std::vector& e) : e_(e), lhs_(0), prev_i(-1), prev_j(-1) {} diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc index 607a7cb9..58d1957c 100644 --- a/training/utils/grammar_convert.cc +++ b/training/utils/grammar_convert.cc @@ -292,10 +292,10 @@ int main(int argc, char **argv) { int lc = 0; Hypergraph hg; map lhs2node; + string line; while(*in) { - string line; + getline(*in,line); ++lc; - getline(*in, line); if (is_json_input) { if (line.empty() || line[0] == '#') continue; string ref; @@ -342,6 +342,7 @@ int main(int argc, char **argv) { edge->feature_values_ = tr->scores_; Hypergraph::Node* node = &hg.nodes_[head]; hg.ConnectEdgeToHeadNode(edge, node); + node->node_hash = lc; } } } -- cgit v1.2.3 From 94d515b00d66d9ea18758b951c1ce4789f3875b0 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 25 Apr 2014 13:20:06 -0400 Subject: fix tree-to-string forest so it works with cube pruning assumptions --- decoder/tree2string_translator.cc | 17 ++++++++++++++++- tests/system_tests/t2s/gold.statistics | 4 ++-- 2 files changed, 18 insertions(+), 3 deletions(-) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 29caaf8f..c353f7ca 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -126,6 +126,17 @@ namespace std { }; }; +void AddDummyGoalNode(Hypergraph* hg) { + static const int kGOAL = -TD::Convert("Goal"); + static TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X] ||| [1]")); + unsigned old_goal_node_idx = hg->nodes_.size() - 1; + HG::Node* goal_node = hg->AddNode(kGOAL); + goal_node->node_hash = 1; + TailNodeVector tail(1, old_goal_node_idx); + HG::Edge* new_edge = hg->AddEdge(kGOAL_RULE, tail); + hg->ConnectEdgeToHeadNode(new_edge, goal_node); +} + struct Tree2StringTranslatorImpl { vector> root; bool add_pass_through_rules; @@ -308,14 +319,18 @@ struct Tree2StringTranslatorImpl { if (goal_it == x2hg.end()) return false; //cerr << "Goal node: " << goal << endl; hg.TopologicallySortNodesAndEdges(goal_it->second); - hg.Reweight(weights); // there might be nodes that cannot be derived // the following takes care of them vector prune(hg.edges_.size(), false); hg.PruneEdges(prune, true); if (hg.edges_.size() == 0) return false; + // rescoring assumes the goal edge is arity 1 (code laziness), add that here + AddDummyGoalNode(&hg); + + hg.Reweight(weights); //hg.PrintGraphviz(); + minus_lm_forest->swap(hg); return true; } diff --git a/tests/system_tests/t2s/gold.statistics b/tests/system_tests/t2s/gold.statistics index 452cc93e..5778a24c 100644 --- a/tests/system_tests/t2s/gold.statistics +++ b/tests/system_tests/t2s/gold.statistics @@ -1,3 +1,3 @@ --lm_nodes 6 --lm_edges 8 +-lm_nodes 7 +-lm_edges 9 -lm_paths 4 -- cgit v1.2.3 From 1400adb9f5e8ebf8ea1658d658893f0731b6fc22 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 25 Apr 2014 23:45:32 -0400 Subject: check for non-rescorable hypergraphs --- decoder/decoder.cc | 5 +++++ decoder/hg.cc | 7 +++++++ decoder/hg.h | 4 ++++ tests/system_tests/cfg_rescore/input.cfg | 5 ++--- tests/system_tests/cfg_rescore/input.txt | 2 +- training/utils/grammar_convert.cc | 16 ++++++++++++---- 6 files changed, 31 insertions(+), 8 deletions(-) (limited to 'decoder') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 354ea2d9..41f36822 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -755,6 +755,11 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; } + if (!forest.ArePreGoalEdgesArity1()) { + cerr << "Pre-goal edges are not arity-1. The decoder requires this.\n"; + abort(); + } + const bool show_tree_structure=conf.count("show_tree_structure"); if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,oracle.show_derivation); if (conf.count("show_expected_length")) { diff --git a/decoder/hg.cc b/decoder/hg.cc index e456fa7c..46543b01 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -28,6 +28,13 @@ bool Hypergraph::AreNodesUniquelyIdentified() const { return true; } +bool Hypergraph::ArePreGoalEdgesArity1() const { + auto& n = nodes_.back(); + for (auto eid : n.in_edges_) + if (edges_[eid].Arity() != 1) return false; + return true; +} + Hypergraph::Edge const* Hypergraph::ViterbiSortInEdges() { NodeProbs nv; ComputeNodeViterbi(&nv); diff --git a/decoder/hg.h b/decoder/hg.h index 43fb275b..4ed27d87 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -268,6 +268,10 @@ public: // if all node states are unique, return true bool AreNodesUniquelyIdentified() const; + // the feature function interface assumes that pre-goal edges are + // arity 1 (this simplifies the "final transition" feature computation) + bool ArePreGoalEdgesArity1() const; + // reserves space in the nodes vector to prevent memory locations // from changing void ReserveNodes(size_t n, size_t e = 0) { diff --git a/tests/system_tests/cfg_rescore/input.cfg b/tests/system_tests/cfg_rescore/input.cfg index 0073cb7b..75602c75 100644 --- a/tests/system_tests/cfg_rescore/input.cfg +++ b/tests/system_tests/cfg_rescore/input.cfg @@ -1,9 +1,8 @@ -[S] ||| [S1] -[S1] ||| [NP1] [VP] ||| Active=1 +[S] ||| [NP1] [VP] ||| Active=1 +[S] ||| [NP2] [VPSV] by [NP1] ||| Passive=1 [VP] ||| [V] [NP2] [V] ||| ate [VPSV] ||| was eaten -[S1] ||| [NP2] [VPSV] by [NP1] ||| Passive=1 [NP1] ||| John [NP2] ||| broccoli [NP2] ||| the broccoli ||| Definite=1 diff --git a/tests/system_tests/cfg_rescore/input.txt b/tests/system_tests/cfg_rescore/input.txt index 71fc26bc..2999a5fb 100644 --- a/tests/system_tests/cfg_rescore/input.txt +++ b/tests/system_tests/cfg_rescore/input.txt @@ -1 +1 @@ -{"rules":[1,"[S] ||| [S1] ||| [1]",2,"[S1] ||| [NP1] [VP] ||| [1] [2] ||| Active=1",3,"[VP] ||| [V] [NP2] ||| [1] [2]",4,"[V] ||| ate ||| ate",5,"[VPSV] ||| was eaten ||| was eaten",6,"[S1] ||| [NP2] [VPSV] by [NP1] ||| [1] [2] by [3] ||| Passive=1",7,"[NP1] ||| John ||| John",8,"[NP2] ||| broccoli ||| broccoli",9,"[NP2] ||| the broccoli ||| the broccoli ||| Definite=1"],"features":["PhraseModel_0","PhraseModel_1","PhraseModel_2","PhraseModel_3","PhraseModel_4","PhraseModel_5","PhraseModel_6","PhraseModel_7","PhraseModel_8","PhraseModel_9","PhraseModel_10","PhraseModel_11","PhraseModel_12","PhraseModel_13","PhraseModel_14","PhraseModel_15","PhraseModel_16","PhraseModel_17","PhraseModel_18","PhraseModel_19","PhraseModel_20","PhraseModel_21","PhraseModel_22","PhraseModel_23","PhraseModel_24","PhraseModel_25","PhraseModel_26","PhraseModel_27","PhraseModel_28","PhraseModel_29","PhraseModel_30","PhraseModel_31","PhraseModel_32","PhraseModel_33","PhraseModel_34","PhraseModel_35","PhraseModel_36","PhraseModel_37","PhraseModel_38","PhraseModel_39","PhraseModel_40","PhraseModel_41","PhraseModel_42","PhraseModel_43","PhraseModel_44","PhraseModel_45","PhraseModel_46","PhraseModel_47","PhraseModel_48","PhraseModel_49","PhraseModel_50","PhraseModel_51","PhraseModel_52","PhraseModel_53","PhraseModel_54","PhraseModel_55","PhraseModel_56","PhraseModel_57","PhraseModel_58","PhraseModel_59","PhraseModel_60","PhraseModel_61","PhraseModel_62","PhraseModel_63","PhraseModel_64","PhraseModel_65","PhraseModel_66","PhraseModel_67","PhraseModel_68","PhraseModel_69","PhraseModel_70","PhraseModel_71","PhraseModel_72","PhraseModel_73","PhraseModel_74","PhraseModel_75","PhraseModel_76","PhraseModel_77","PhraseModel_78","PhraseModel_79","PhraseModel_80","PhraseModel_81","PhraseModel_82","PhraseModel_83","PhraseModel_84","PhraseModel_85","PhraseModel_86","PhraseModel_87","PhraseModel_88","PhraseModel_89","PhraseModel_90","PhraseModel_91","PhraseModel_92","PhraseModel_93","PhraseModel_94","PhraseModel_95","PhraseModel_96","PhraseModel_97","PhraseModel_98","PhraseModel_99","Active","Passive","Definite"],"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":7}],"node":{"in_edges":[0],"cat":"NP1","node_hash":"0000000000000007"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":4}],"node":{"in_edges":[1],"cat":"V","node_hash":"0000000000000004"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":8},{"tail":[],"spans":[-1,-1,-1,-1],"feats":[102,1],"rule":9}],"node":{"in_edges":[2,3],"cat":"NP2","node_hash":"0000000000000009"},"edges":[{"tail":[1,2],"spans":[-1,-1,-1,-1],"feats":[],"rule":3}],"node":{"in_edges":[4],"cat":"VP","node_hash":"0000000000000003"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":5}],"node":{"in_edges":[5],"cat":"VPSV","node_hash":"0000000000000005"},"edges":[{"tail":[0,3],"spans":[-1,-1,-1,-1],"feats":[100,1],"rule":2},{"tail":[2,4,0],"spans":[-1,-1,-1,-1],"feats":[101,1],"rule":6}],"node":{"in_edges":[6,7],"cat":"S1","node_hash":"0000000000000006"},"edges":[{"tail":[5],"spans":[-1,-1,-1,-1],"feats":[],"rule":1}],"node":{"in_edges":[8],"cat":"S","node_hash":"0000000000000001"}} +{"rules":[1,"[S] ||| [NP1] [VP] ||| [1] [2] ||| Active=1",2,"[S] ||| [NP2] [VPSV] by [NP1] ||| [1] [2] by [3] ||| Passive=1",3,"[VP] ||| [V] [NP2] ||| [1] [2]",4,"[V] ||| ate ||| ate",5,"[VPSV] ||| was eaten ||| was eaten",6,"[NP1] ||| John ||| John",7,"[NP2] ||| broccoli ||| broccoli",8,"[NP2] ||| the broccoli ||| the broccoli ||| Definite=1",9,"[Goal] ||| [X] ||| [1]"],"features":["PhraseModel_0","PhraseModel_1","PhraseModel_2","PhraseModel_3","PhraseModel_4","PhraseModel_5","PhraseModel_6","PhraseModel_7","PhraseModel_8","PhraseModel_9","PhraseModel_10","PhraseModel_11","PhraseModel_12","PhraseModel_13","PhraseModel_14","PhraseModel_15","PhraseModel_16","PhraseModel_17","PhraseModel_18","PhraseModel_19","PhraseModel_20","PhraseModel_21","PhraseModel_22","PhraseModel_23","PhraseModel_24","PhraseModel_25","PhraseModel_26","PhraseModel_27","PhraseModel_28","PhraseModel_29","PhraseModel_30","PhraseModel_31","PhraseModel_32","PhraseModel_33","PhraseModel_34","PhraseModel_35","PhraseModel_36","PhraseModel_37","PhraseModel_38","PhraseModel_39","PhraseModel_40","PhraseModel_41","PhraseModel_42","PhraseModel_43","PhraseModel_44","PhraseModel_45","PhraseModel_46","PhraseModel_47","PhraseModel_48","PhraseModel_49","PhraseModel_50","PhraseModel_51","PhraseModel_52","PhraseModel_53","PhraseModel_54","PhraseModel_55","PhraseModel_56","PhraseModel_57","PhraseModel_58","PhraseModel_59","PhraseModel_60","PhraseModel_61","PhraseModel_62","PhraseModel_63","PhraseModel_64","PhraseModel_65","PhraseModel_66","PhraseModel_67","PhraseModel_68","PhraseModel_69","PhraseModel_70","PhraseModel_71","PhraseModel_72","PhraseModel_73","PhraseModel_74","PhraseModel_75","PhraseModel_76","PhraseModel_77","PhraseModel_78","PhraseModel_79","PhraseModel_80","PhraseModel_81","PhraseModel_82","PhraseModel_83","PhraseModel_84","PhraseModel_85","PhraseModel_86","PhraseModel_87","PhraseModel_88","PhraseModel_89","PhraseModel_90","PhraseModel_91","PhraseModel_92","PhraseModel_93","PhraseModel_94","PhraseModel_95","PhraseModel_96","PhraseModel_97","PhraseModel_98","PhraseModel_99","Active","Passive","Definite"],"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":6}],"node":{"in_edges":[0],"cat":"NP1","node_hash":"0000000000000006"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":4}],"node":{"in_edges":[1],"cat":"V","node_hash":"0000000000000004"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":7},{"tail":[],"spans":[-1,-1,-1,-1],"feats":[102,1],"rule":8}],"node":{"in_edges":[2,3],"cat":"NP2","node_hash":"0000000000000008"},"edges":[{"tail":[1,2],"spans":[-1,-1,-1,-1],"feats":[],"rule":3}],"node":{"in_edges":[4],"cat":"VP","node_hash":"0000000000000003"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":5}],"node":{"in_edges":[5],"cat":"VPSV","node_hash":"0000000000000005"},"edges":[{"tail":[0,3],"spans":[-1,-1,-1,-1],"feats":[100,1],"rule":1},{"tail":[2,4,0],"spans":[-1,-1,-1,-1],"feats":[101,1],"rule":2}],"node":{"in_edges":[6,7],"cat":"S","node_hash":"0000000000000002"},"edges":[{"tail":[5],"spans":[-1,-1,-1,-1],"feats":[],"rule":9}],"node":{"in_edges":[8],"cat":"Goal","node_hash":"000000000000003D"}} diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc index 58d1957c..5c1b4d4a 100644 --- a/training/utils/grammar_convert.cc +++ b/training/utils/grammar_convert.cc @@ -56,15 +56,22 @@ int GetOrCreateNode(const WordID& lhs, map* lhs2node, Hypergraph* h return node_id - 1; } +void AddDummyGoalNode(Hypergraph* hg) { + static const int kGOAL = -TD::Convert("Goal"); + static TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X] ||| [1]")); + unsigned old_goal_node_idx = hg->nodes_.size() - 1; + HG::Node* goal_node = hg->AddNode(kGOAL); + goal_node->node_hash = goal_node->id_ * 10 + 1; + TailNodeVector tail(1, old_goal_node_idx); + HG::Edge* new_edge = hg->AddEdge(kGOAL_RULE, tail); + hg->ConnectEdgeToHeadNode(new_edge, goal_node); +} + void FilterAndCheckCorrectness(int goal, Hypergraph* hg) { if (goal < 0) { cerr << "Error! [S] not found in grammar!\n"; exit(1); } - if (hg->nodes_[goal].in_edges_.size() != 1) { - cerr << "Error! [S] has more than one rewrite!\n"; - exit(1); - } int old_size = hg->nodes_.size(); hg->TopologicallySortNodesAndEdges(goal); if (hg->nodes_.size() != old_size) { @@ -319,6 +326,7 @@ int main(int argc, char **argv) { if (line.empty()) { int goal = lhs2node[kSTART] - 1; FilterAndCheckCorrectness(goal, &hg); + AddDummyGoalNode(&hg); ProcessHypergraph(w, conf, "", &hg); hg.clear(); lhs2node.clear(); -- cgit v1.2.3 From e027ec43dad0fa2bd9d2d7161ec0355bc1641178 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 27 Apr 2014 21:05:33 +0200 Subject: clean up headers --- decoder/tree_fragment.cc | 2 ++ decoder/tree_fragment.h | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'decoder') diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 4d429f42..696c8601 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -2,6 +2,8 @@ #include +#include "tdict.h" + using namespace std; namespace cdec { diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index 4a704bc4..f9dfa8cc 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -5,8 +5,7 @@ #include #include #include - -#include "tdict.h" +#include namespace cdec { -- cgit v1.2.3 From d1530e378037f310d1759e52b715c266043681ed Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 8 May 2014 12:11:08 -0400 Subject: turn of span filtering --- decoder/hg_intersect.cc | 2 +- decoder/tree2string_translator.cc | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'decoder') diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc index 31a9a1ce..02f5a401 100644 --- a/decoder/hg_intersect.cc +++ b/decoder/hg_intersect.cc @@ -92,7 +92,7 @@ bool Intersect(const Lattice& target, Hypergraph* hg) { return FastLinearIntersect(target, hg); vector rem(hg->edges_.size(), false); - const RuleFilter filter(target, 15); // TODO make configurable + const RuleFilter filter(target, 9999); // TODO make configurable for (unsigned i = 0; i < rem.size(); ++i) rem[i] = filter(*hg->edges_[i].rule_); hg->PruneEdges(rem, true); diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index c353f7ca..fafb0d97 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -128,12 +128,14 @@ namespace std { void AddDummyGoalNode(Hypergraph* hg) { static const int kGOAL = -TD::Convert("Goal"); - static TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X] ||| [1]")); unsigned old_goal_node_idx = hg->nodes_.size() - 1; + int old_goal_cat = hg->nodes_[old_goal_node_idx].cat_; + TRulePtr goal_rule(new TRule("[Goal] ||| [X] ||| [1]")); + goal_rule->f_[0] = old_goal_cat; HG::Node* goal_node = hg->AddNode(kGOAL); goal_node->node_hash = 1; TailNodeVector tail(1, old_goal_node_idx); - HG::Edge* new_edge = hg->AddEdge(kGOAL_RULE, tail); + HG::Edge* new_edge = hg->AddEdge(goal_rule, tail); hg->ConnectEdgeToHeadNode(new_edge, goal_node); } -- cgit v1.2.3 From 3b80feb119dc3c42b1d7137943c86f3c0bfd0119 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 8 May 2014 21:20:46 -0400 Subject: better features --- decoder/tree2string_translator.cc | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index fafb0d97..38daeeb5 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -158,7 +158,11 @@ struct Tree2StringTranslatorImpl { } void CreatePassThroughRules(const cdec::TreeFragment& tree) { + static const int kFIDlex = FD::Convert("PassThrough_Lexical"); + static const int kFIDabs = FD::Convert("PassThrough_Abstract"); + static const int kFIDmix = FD::Convert("PassThrough_Mix"); static const int kFID = FD::Convert("PassThrough"); + static unordered_map pntfid; root.resize(root.size() + 1); root.back().reset(new Tree2StringGrammarNode); ++remove_grammars; @@ -167,14 +171,24 @@ struct Tree2StringTranslatorImpl { vector rhse, rhsf; int ntc = 0; int lhs = -(prod.lhs & cdec::ALL_MASK); + int &ntfid = pntfid[lhs]; + if (!ntfid) { + ostringstream fos; + fos << "PassThrough:" << TD::Convert(-lhs); + ntfid = FD::Convert(fos.str()); + } + bool has_lex = false; + bool has_nt = false; os << '(' << TD::Convert(-lhs); for (auto& sym : prod.rhs) { os << ' '; if (cdec::IsTerminal(sym)) { + has_lex = true; os << TD::Convert(sym); rhse.push_back(sym); rhsf.push_back(sym); } else { + has_nt = true; unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK; os << '[' << TD::Convert(id) << ']'; rhsf.push_back(-id); @@ -192,7 +206,12 @@ struct Tree2StringTranslatorImpl { cur = &cur->next[sym]; TRulePtr rule(new TRule(rhse, rhsf, lhs)); rule->ComputeArity(); + rule->scores_.set_value(ntfid, 1.0); rule->scores_.set_value(kFID, 1.0); + if (has_lex && has_nt) + rule->scores_.set_value(kFIDmix, 1.0); + else if (has_lex) rule->scores_.set_value(kFIDlex, 1.0); + else if (has_nt) rule->scores_.set_value(kFIDabs, 1.0); cur->rules.push_back(rule); } } -- cgit v1.2.3 From 57d87fb8aa4bdda78b02afbf43b332d662e802f4 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 8 May 2014 21:24:35 -0400 Subject: missing header --- decoder/tree_fragment.h | 1 + 1 file changed, 1 insertion(+) (limited to 'decoder') diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index f9dfa8cc..79722b5a 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace cdec { -- cgit v1.2.3 From f5cc1e175e11a59ec1297d4d823f0d59820bf346 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 9 May 2014 00:01:40 -0400 Subject: remove fixed bug warning --- decoder/tree2string_translator.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 38daeeb5..5d7aa5e2 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -21,6 +21,8 @@ struct Tree2StringGrammarNode { vector rules; }; +// this needs to be rewritten so it is fast and checks errors well +// use a lexer probably void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { string line; while(getline(*in, line)) { @@ -36,10 +38,8 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { ostringstream os; int lhs = -(rule_src.root & cdec::ALL_MASK); // build source RHS for SCFG projection - // TODO - this is buggy - it will generate a well-formed SCFG rule - // so it will not generate source strings correctly - // it will, however, generate target translations appropriately vector frhs; + // we traverse the rule_src in left to right, DFS order for (auto sym : rule_src) { //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; cur = &cur->next[sym]; @@ -48,7 +48,7 @@ void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { frhs.push_back(-nt); } else if (cdec::IsTerminal(sym)) { frhs.push_back(sym); - } + } // else internal NT, nothing to do } os << '[' << TD::Convert(-lhs) << "] |||"; for (auto x : frhs) { -- cgit v1.2.3 From 58d5516904c7b529486e1000c1047ac7bf91ed1d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 May 2014 17:46:20 -0400 Subject: stub for t2t translator --- decoder/decoder.cc | 8 +++++--- decoder/translator.h | 3 ++- decoder/tree2string_translator.cc | 12 +++++++----- 3 files changed, 14 insertions(+), 9 deletions(-) (limited to 'decoder') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 41f36822..6783cad0 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -490,8 +490,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } formalism = LowercaseString(str("formalism",conf)); - if (formalism != "t2s" && formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 't2s', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; + if (formalism != "t2s" && formalism != "t2t" && formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 't2s', 't2t', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } @@ -627,7 +627,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream if (formalism == "scfg") translator.reset(new SCFGTranslator(conf)); else if (formalism == "t2s") - translator.reset(new Tree2StringTranslator(conf)); + translator.reset(new Tree2StringTranslator(conf, false)); + else if (formalism == "t2t") + translator.reset(new Tree2StringTranslator(conf, true)); else if (formalism == "fst") translator.reset(new FSTTranslator(conf)); else if (formalism == "pb") diff --git a/decoder/translator.h b/decoder/translator.h index 72b2f0b0..ba218a0b 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -101,7 +101,8 @@ class RescoreTranslator : public Translator { class Tree2StringTranslatorImpl; class Tree2StringTranslator : public Translator { public: - Tree2StringTranslator(const boost::program_options::variables_map& conf); + Tree2StringTranslator(const boost::program_options::variables_map& conf, + bool has_multiple_states); virtual std::string GetDecoderType() const; protected: bool TranslateImpl(const std::string& src, diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 5d7aa5e2..101ed21c 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -23,7 +23,7 @@ struct Tree2StringGrammarNode { // this needs to be rewritten so it is fast and checks errors well // use a lexer probably -void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { +void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { string line; while(getline(*in, line)) { size_t pos = line.find("|||"); @@ -143,7 +143,8 @@ struct Tree2StringTranslatorImpl { vector> root; bool add_pass_through_rules; unsigned remove_grammars; - Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) : + Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf, + bool has_multiple_states) : add_pass_through_rules(conf.count("add_pass_through_rules")) { if (conf.count("grammar")) { const vector gf = conf["grammar"].as>(); @@ -152,7 +153,7 @@ struct Tree2StringTranslatorImpl { for (auto& f : gf) { ReadFile rf(f); root[gc].reset(new Tree2StringGrammarNode); - ReadTree2StringGrammar(rf.stream(), &*root[gc++]); + ReadTree2StringGrammar(rf.stream(), &*root[gc++], has_multiple_states); } } } @@ -357,8 +358,9 @@ struct Tree2StringTranslatorImpl { } }; -Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf) : - pimpl_(new Tree2StringTranslatorImpl(conf)) {} +Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf, + bool has_multiple_states) : + pimpl_(new Tree2StringTranslatorImpl(conf, has_multiple_states)) {} bool Tree2StringTranslator::TranslateImpl(const string& input, SentenceMetadata* smeta, -- cgit v1.2.3 From 0026ae2bf5417401af9cf3c1dd357def5ae23e0a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 May 2014 17:53:53 -0400 Subject: check for duplicates when creating pass through rules --- decoder/tree2string_translator.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 101ed21c..8d624820 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -167,9 +167,8 @@ struct Tree2StringTranslatorImpl { root.resize(root.size() + 1); root.back().reset(new Tree2StringGrammarNode); ++remove_grammars; + unordered_set,boost::hash>> unique_rule_check; for (auto& prod : tree.nodes) { - ostringstream os; - vector rhse, rhsf; int ntc = 0; int lhs = -(prod.lhs & cdec::ALL_MASK); int &ntfid = pntfid[lhs]; @@ -178,8 +177,16 @@ struct Tree2StringTranslatorImpl { fos << "PassThrough:" << TD::Convert(-lhs); ntfid = FD::Convert(fos.str()); } + + // check for duplicate rule in tree + vector key = prod.rhs; + key.push_back(prod.lhs); + if (!unique_rule_check.insert(key).second) continue; + bool has_lex = false; bool has_nt = false; + vector rhse, rhsf; + ostringstream os; os << '(' << TD::Convert(-lhs); for (auto& sym : prod.rhs) { os << ' '; -- cgit v1.2.3 From fdaf46f6df031eb024507c39d7cf920219c6eb5e Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 May 2014 19:04:31 -0400 Subject: fix unique check --- decoder/tree2string_translator.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 8d624820..7b37887e 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -167,7 +167,7 @@ struct Tree2StringTranslatorImpl { root.resize(root.size() + 1); root.back().reset(new Tree2StringGrammarNode); ++remove_grammars; - unordered_set,boost::hash>> unique_rule_check; + unordered_set,boost::hash>> unique_rule_check; for (auto& prod : tree.nodes) { int ntc = 0; int lhs = -(prod.lhs & cdec::ALL_MASK); @@ -179,9 +179,8 @@ struct Tree2StringTranslatorImpl { } // check for duplicate rule in tree - vector key = prod.rhs; + vector key; key.push_back(prod.lhs); - if (!unique_rule_check.insert(key).second) continue; bool has_lex = false; bool has_nt = false; @@ -195,16 +194,19 @@ struct Tree2StringTranslatorImpl { os << TD::Convert(sym); rhse.push_back(sym); rhsf.push_back(sym); + key.push_back(sym); } else { has_nt = true; unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK; os << '[' << TD::Convert(id) << ']'; rhsf.push_back(-id); rhse.push_back(-ntc); + key.push_back(-id); ++ntc; } } os << ')'; + if (!unique_rule_check.insert(key).second) continue; cdec::TreeFragment rule_src(os.str(), true); Tree2StringGrammarNode* cur = root.back().get(); // do we need all transducer states here??? a list??? no pass through rules??? -- cgit v1.2.3 From 362ece653b71b3772cba24214855c47bb14e8117 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 24 May 2014 02:09:49 -0400 Subject: support per sentence tree-to-string grammars --- decoder/tree2string_translator.cc | 46 +++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 7 deletions(-) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 7b37887e..b5b47d5d 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -5,6 +5,7 @@ #include #include #include +#include "fast_lexical_cast.hpp" #include "tree_fragment.h" #include "translator.h" #include "hg.h" @@ -23,7 +24,7 @@ struct Tree2StringGrammarNode { // this needs to be rewritten so it is fast and checks errors well // use a lexer probably -void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { +static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { string line; while(getline(*in, line)) { size_t pos = line.find("|||"); @@ -142,10 +143,12 @@ void AddDummyGoalNode(Hypergraph* hg) { struct Tree2StringTranslatorImpl { vector> root; bool add_pass_through_rules; + bool has_multiple_states; unsigned remove_grammars; Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf, bool has_multiple_states) : - add_pass_through_rules(conf.count("add_pass_through_rules")) { + add_pass_through_rules(conf.count("add_pass_through_rules")), + has_multiple_states(has_multiple_states) { if (conf.count("grammar")) { const vector gf = conf["grammar"].as>(); root.resize(gf.size()); @@ -158,6 +161,15 @@ struct Tree2StringTranslatorImpl { } } + // loads a per-sentence grammar + void LoadSupplementalGrammar(const string& gfile) { + root.resize(root.size() + 1); + root.back().reset(new Tree2StringGrammarNode); + ++remove_grammars; + ReadFile rf(gfile); + ReadTree2StringGrammar(rf.stream(), root.back().get(), has_multiple_states); + } + void CreatePassThroughRules(const cdec::TreeFragment& tree) { static const int kFIDlex = FD::Convert("PassThrough_Lexical"); static const int kFIDabs = FD::Convert("PassThrough_Abstract"); @@ -227,7 +239,7 @@ struct Tree2StringTranslatorImpl { } void RemoveGrammars() { - assert(remove_grammars < root.size()); + assert(remove_grammars <= root.size()); root.resize(root.size() - remove_grammars); } @@ -235,7 +247,6 @@ struct Tree2StringTranslatorImpl { SentenceMetadata* smeta, const vector& weights, Hypergraph* minus_lm_forest) { - remove_grammars = 0; cdec::TreeFragment input_tree(input, false); if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; @@ -371,6 +382,30 @@ Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::varia bool has_multiple_states) : pimpl_(new Tree2StringTranslatorImpl(conf, has_multiple_states)) {} +void Tree2StringTranslator::ProcessMarkupHintsImpl(const map& kv) { + pimpl_->remove_grammars = 0; + if (kv.find("grammar0") != kv.end()) { + cerr << "SGML tag grammar0 is not expected (order is: grammar, grammar1, grammar2, ...)\n"; + abort(); + } + unsigned gc = 0; + set loaded; + while(true) { + string gkey = "grammar"; + if (gc > 0) gkey += boost::lexical_cast(gc); + ++gc; + map::const_iterator it = kv.find(gkey); + if (it == kv.end()) break; + const string& gfile = it->second; + if (loaded.count(gfile) == 1) { + cerr << "Attempting to load " << gfile << " twice!\n"; + abort(); + } + loaded.insert(gfile); + pimpl_->LoadSupplementalGrammar(gfile); + } +} + bool Tree2StringTranslator::TranslateImpl(const string& input, SentenceMetadata* smeta, const vector& weights, @@ -378,9 +413,6 @@ bool Tree2StringTranslator::TranslateImpl(const string& input, return pimpl_->Translate(input, smeta, weights, minus_lm_forest); } -void Tree2StringTranslator::ProcessMarkupHintsImpl(const map& kv) { -} - void Tree2StringTranslator::SentenceCompleteImpl() { pimpl_->RemoveGrammars(); } -- cgit v1.2.3