From 10a668822715cee024a7e7391c62caa8e078e840 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 12 Mar 2014 02:30:26 -0400 Subject: add support for internal tree structure on SCFG rules --- decoder/rule_lexer.ll | 42 +++++++++++++------ decoder/tree2string_translator.cc | 87 +++++++++++++++++++++++++++++++++++++++ decoder/tree_fragment.h | 59 ++++++++++++++++++++++++++ 3 files changed, 176 insertions(+), 12 deletions(-) create mode 100644 decoder/tree2string_translator.cc create mode 100644 decoder/tree_fragment.h (limited to 'decoder') diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll index c6a85919..05963d05 100644 --- a/decoder/rule_lexer.ll +++ b/decoder/rule_lexer.ll @@ -12,6 +12,7 @@ #include "fdict.h" #include "trule.h" #include "verbose.h" +#include "tree_fragment.h" int lex_line = 0; std::istream* scfglex_stream = NULL; @@ -51,6 +52,7 @@ int scfglex_src_nts[MAX_ARITY]; // float scfglex_nt_size_vars[MAX_ARITY]; std::stack ctf_rule_stack; unsigned int ctf_level = 0; +boost::shared_ptr scfglex_tree; #define MAX_ALS 2000 AlignmentPoint scfglex_als[MAX_ALS]; @@ -120,7 +122,7 @@ void check_and_update_ctf_stack(const TRulePtr& rp) { REAL [\-+]?[0-9]+(\.[0-9]*([eE][-+]*[0-9]+)?)?|inf|[\-+]inf NT [^\t \[\],]+ -%x LHS_END SRC TRG FEATS FEATVAL ALIGNS +%x LHS_END SRC TRG FEATS FEATVAL ALIGNS TREE %% [ \t] { @@ -205,7 +207,13 @@ NT [^\t \[\],]+ ++scfglex_src_rhs_size; } [ \t]+ { ; } - +[^|\n]+ { + if (yyleng > 0) { + int len = yyleng; + while(len > 1 && yytext[len - 1] != ')') { --len; } + scfglex_tree.reset(new cdec::TreeFragment(std::string(yytext, len), true)); + } + } \|\|\| { BEGIN(FEATS); } @@ -216,7 +224,7 @@ NT [^\t \[\],]+ } [ \t]+ { ; } -\n { +\n { if (scfglex_src_arity != scfglex_trg_arity) { std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": LHS and RHS arity mismatch!\n"; abort(); @@ -224,18 +232,25 @@ NT [^\t \[\],]+ // const bool ignore_grammar_features = false; // if (ignore_grammar_features) scfglex_num_feats = 0; TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity, scfglex_als, scfglex_num_als)); - check_and_update_ctf_stack(rp); - TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top()); + if (scfglex_tree) { + if (scfglex_tree->frontier_sites != rp->Arity()) { + std::cerr << "Arity mismatch with tree annotation: " << *scfglex_tree << std::endl; + abort(); + } + rp->tree_structure.swap(scfglex_tree); + } + check_and_update_ctf_stack(rp); + TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top()); rule_callback(rp, ctf_level, coarse_rp, rule_callback_extra); - ctf_rule_stack.push(rp); + ctf_rule_stack.push(rp); // std::cerr << rp->AsString() << std::endl; num_rules++; - lex_line++; - if (!SILENT) { - if (num_rules % 50000 == 0) { std::cerr << '.' << std::flush; fl = true; } - if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; } - } - ctf_level = 0; + lex_line++; + if (!SILENT) { + if (num_rules % 50000 == 0) { std::cerr << '.' << std::flush; fl = true; } + if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; } + } + ctf_level = 0; BEGIN(INITIAL); } @@ -253,6 +268,9 @@ NT [^\t \[\],]+ \|\|\| { BEGIN(ALIGNS); } +\|\|\|[ \t]* { + BEGIN(TREE); + } {REAL} { scfglex_feat_vals[scfglex_num_feats] = strtod(yytext, NULL); ++scfglex_num_feats; diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc new file mode 100644 index 00000000..4ccc54e2 --- /dev/null +++ b/decoder/tree2string_translator.cc @@ -0,0 +1,87 @@ +#include +#include +#include +#include +#include "tree_fragment.h" +#include "translator.h" +#include "hg.h" +#include "sentence_metadata.h" +#include "filelib.h" +#include "stringlib.h" +#include "tdict.h" +#include "verbose.h" + +using namespace std; + +// root: S +// A implication: (S [A] *INCOMPLETE* +// B implication: (S [A] [B] *INCOMPLETE* +// *0* implication: (S _[A] [B]) +// a implication: (S (A a *INCOMPLETE* [B]) +// a implication: (S (A a a *INCOMPLETE* [B]) +// *0* implication: (S (A a a) _[B]) +// D implication: (S (A a a) (B [D] *INCOMPLETE*) +// *0* implication: (S (A a a) (B _[D])) +// d implication: (S (A a a) (B (D d *INCOMPLETE*)) +// *0* implication: (S (A a a) (B (D d))) +// --there are no further outgoing links possible-- + +// root: S +// A implication: (S [A] *INCOMPLETE* +// B implication: (S [A] [B] *INCOMPLETE* +// *0* implication: (S _[A] [B]) +// *0* implication: (S [A] _[B]) +// b implication: (S [A] (B b *INCOMPLETE*)) +struct Tree2StringGrammarNode { + unordered_map next; + string rules; +}; + +void ReadTree2StringGrammar(istream* in, unordered_map* proots) { + unordered_map& roots = *proots; + string line; + while(getline(*in, line)) { + size_t pos = line.find("|||"); + assert(pos != string::npos); + assert(pos > 3); + if (line[pos - 1] == ' ') --pos; + cdec::TreeFragment rule_src(line.substr(0, pos), true); + } +} + +struct Tree2StringTranslatorImpl { + unordered_map roots; // root['S'] gives rule network for S rules + Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) { + ReadFile rf(conf["grammar"].as>()[0]); + ReadTree2StringGrammar(rf.stream(), &roots); + } + bool Translate(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* minus_lm_forest) { + cdec::TreeFragment input_tree(input, false); + cerr << "Tree2StringTranslatorImpl: please implement this!\n"; + return false; + } +}; + +Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new Tree2StringTranslatorImpl(conf)) {} + +bool Tree2StringTranslator::TranslateImpl(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* minus_lm_forest) { + return pimpl_->Translate(input, smeta, weights, minus_lm_forest); +} + +void Tree2StringTranslator::ProcessMarkupHintsImpl(const map& kv) { +} + +void Tree2StringTranslator::SentenceCompleteImpl() { +} + +std::string Tree2StringTranslator::GetDecoderType() const { + return "tree2string"; +} + diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h new file mode 100644 index 00000000..83cd1c1e --- /dev/null +++ b/decoder/tree_fragment.h @@ -0,0 +1,59 @@ +#ifndef TREE_FRAGMENT +#define TREE_FRAGMENT + +#include +#include +#include + +#include "tdict.h" + +namespace cdec { + +static const unsigned NT_BIT = 0x40000000u; +static const unsigned FRONTIER_BIT = 0x80000000u; +static const unsigned ALL_MASK = 0x0FFFFFFFu; + +inline bool IsInternalNT(unsigned x) { + return (x & NT_BIT); +} + +inline bool IsFrontier(unsigned x) { + return (x & FRONTIER_BIT); +} + +struct TreeFragmentProduction { + TreeFragmentProduction() {} + TreeFragmentProduction(int nttype, const std::vector& r) : lhs(nttype), rhs(r) {} + unsigned lhs; + std::vector rhs; +}; + +// this data structure represents a tree or forest +// productions can have mixtures of terminals and nonterminal symbols and non-terminal frontier sites +class TreeFragment { + public: + TreeFragment() : frontier_sites(), terminals() {} + // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar)))) + explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false); + void DebugRec(unsigned cur, std::ostream* out) const; + private: + // cp is the character index in the tree + // np keeps track of the nodes (nonterminals) that have been built + // symp keeps track of the terminal symbols that have been built + void ParseRec(const std::string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp); + public: + unsigned root; + unsigned char frontier_sites; + unsigned short terminals; + + std::vector nodes; +}; + +inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) { + x.DebugRec(x.nodes.size() - 1, &os); + return os; +} + +} + +#endif -- cgit v1.2.3 From 70ef91b22ee4abc5e50c15c4eb08121739af2bfd Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 12 Mar 2014 02:31:37 -0400 Subject: tree_fragment stuff --- decoder/tree_fragment.cc | 115 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 decoder/tree_fragment.cc (limited to 'decoder') diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc new file mode 100644 index 00000000..d5c30f58 --- /dev/null +++ b/decoder/tree_fragment.cc @@ -0,0 +1,115 @@ +#include "tree_fragment.h" + +#include + +using namespace std; + +namespace cdec { + +TreeFragment::TreeFragment(const string& tree, bool allow_frontier_sites) { + int bal = 0; + const unsigned len = tree.size(); + unsigned cur = 0; + unsigned open = 0, close = 0; + for (auto& c : tree) { + ++cur; + if (c == '(') { ++open; ++bal; } + else if (c == ')') { + ++close; --bal; + if (bal < 1 && cur != len) { + cerr << "Badly formed tree detected at column " << cur << " in:\n" << tree << endl; + abort(); + } + } + } + nodes.resize(open); + unsigned cp = 0, symp = 0, np = 0; + ParseRec(tree, allow_frontier_sites, cp, symp, np, &cp, &symp, &np); + root = nodes.back().lhs; + //cerr << "ROOT: " << TD::Convert(root & ALL_MASK) << endl; + //DebugRec(open - 1, &cerr); cerr << "\n"; +} + +void TreeFragment::DebugRec(unsigned cur, ostream* out) const { + *out << '(' << TD::Convert(nodes[cur].lhs & ALL_MASK); + for (auto& x : nodes[cur].rhs) { + *out << ' '; + if (IsFrontier(x)) { + *out << '[' << TD::Convert(x & ALL_MASK) << ']'; + } else if (IsInternalNT(x)) { + DebugRec(x & ALL_MASK, out); + } else { // must be terminal + *out << TD::Convert(x); + } + } + *out << ')'; +} + +// cp is the character index in the tree +// np keeps track of the nodes (nonterminals) that have been built +// symp keeps track of the terminal symbols that have been built +void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp) { + if (tree[cp] != '(') { + cerr << "Expected ( at " << cp << endl; + abort(); + } + const unsigned i = symp; + vector rhs; // w | 0 = terminal, w | NT_BIT, index | FRONTIER_BIT + ++cp; + while(tree[cp] == ' ') { ++cp; } + const unsigned nt_start = cp; + while(tree[cp] != ' ' && tree[cp] != '(' && tree[cp] != ')') { ++cp; } + const unsigned nt_end = cp; + while(tree[cp] == ' ') { ++cp; } + while (tree[cp] != ')') { + if (tree[cp] == '(') { + // recursively call parser to deal with constituent + ParseRec(tree, afs, cp, symp, np, &cp, &symp, &np); + unsigned ind = np - 1; + rhs.push_back(ind | NT_BIT); + } else { // deal with terminal / nonterminal substitution + ++symp; + assert(tree[cp] != ' '); + const unsigned t_start = cp; + while(tree[cp] != ' ' && tree[cp] != ')' && tree[cp] != '(') { ++cp; } + const unsigned t_end = cp; + while(tree[cp] == ' ') { ++cp; } + // TODO: add a terminal symbol to the current edge + const bool is_terminal = tree[t_start] != '[' || (t_end - t_start < 3 || tree[t_end - 1] != ']'); + if (is_terminal) { + const unsigned term = TD::Convert(tree.substr(t_start, t_end - t_start)); + rhs.push_back(term); + // cerr << "T='" << TD::Convert(term) << "'\n"; + ++terminals; + } else { // frontier site (NT but no recursion) + const unsigned nt = TD::Convert(tree.substr(t_start + 1, t_end - t_start - 2)) | FRONTIER_BIT; + rhs.push_back(nt); + ++frontier_sites; + // cerr << "FRONT-NT=[" << TD::Convert(nt & ALL_MASK) << "]\n"; + if (!afs) { + cerr << "Frontier sites not allowed in input: " << tree << endl; + abort(); + } + } + } + } // continuent has completed, cp is at ), build node + const unsigned j = symp; // span from (i,j) + // add an internal non-terminal symbol + const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | NT_BIT; + nodes[np] = TreeFragmentProduction(nt, rhs); + //cerr << np << " production(" << i << "," << j << ")= " << TD::Convert(nt & ALL_MASK) << " -->"; + //for (auto& x : rhs) { + // cerr << ' '; + // if (IsFrontier(x)) cerr << '*'; + // if (IsInternalNT(x)) cerr << TD::Convert(nodes[x & ALL_MASK].lhs & ALL_MASK); else + // cerr << TD::Convert(x & ALL_MASK); + //} + //cerr << "\n "; DebugRec(np,&cerr); cerr << endl; + ++cp; + while(tree[cp] == ' ' && cp < tree.size()) { ++cp; } + *pcp = cp; + *pnp = np + 1; + *psymp = symp; +} + +} -- cgit v1.2.3 From 80f465a250e7fcfc5dd476d04e39a43ef0c909a3 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 13 Mar 2014 22:52:18 -0400 Subject: missing commit --- decoder/trule.h | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'decoder') diff --git a/decoder/trule.h b/decoder/trule.h index 6a33d052..e9a10bea 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -15,6 +15,8 @@ class TRule; typedef boost::shared_ptr TRulePtr; +namespace cdec { struct TreeFragment; } + struct AlignmentPoint { AlignmentPoint() : s_(), t_() {} AlignmentPoint(int s, int t) : s_(s), t_(t) {} @@ -159,6 +161,9 @@ class TRule { // only for coarse-to-fine decoding boost::shared_ptr > fine_rules_; + // optional, shows internal structure of TSG rules + boost::shared_ptr tree_structure; + private: TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {} bool SanityCheck() const; -- cgit v1.2.3 From 3b4b66479ac80177393eb7953c3bcc268f3d9551 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 13 Mar 2014 22:59:01 -0400 Subject: missing makefile --- decoder/Makefile.am | 3 +++ 1 file changed, 3 insertions(+) (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index b735756d..c41cd7f9 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -129,6 +129,7 @@ libcdec_a_SOURCES = \ ffset.cc \ forest_writer.cc \ fst_translator.cc \ + tree2string_translator.cc \ grammar.cc \ hg.cc \ hg_intersect.cc \ @@ -141,6 +142,8 @@ libcdec_a_SOURCES = \ lattice.cc \ lexalign.cc \ lextrans.cc \ + tree_fragment.cc \ + tree_fragment.h \ maxtrans_blunsom.cc \ phrasebased_translator.cc \ phrasetable_fst.cc \ -- cgit v1.2.3 From 4d653c9f1769855c3b8a835922b5ab56a92bd94b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 13 Mar 2014 23:03:51 -0400 Subject: missing commit --- decoder/translator.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'decoder') diff --git a/decoder/translator.h b/decoder/translator.h index c0800e84..72b2f0b0 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -98,4 +98,20 @@ class RescoreTranslator : public Translator { boost::shared_ptr pimpl_; }; +class Tree2StringTranslatorImpl; +class Tree2StringTranslator : public Translator { + public: + Tree2StringTranslator(const boost::program_options::variables_map& conf); + virtual std::string GetDecoderType() const; + protected: + bool TranslateImpl(const std::string& src, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest); + void ProcessMarkupHintsImpl(const std::map& kv); + void SentenceCompleteImpl(); + private: + boost::shared_ptr pimpl_; +}; + #endif -- cgit v1.2.3 From cc87bfed0697583b7c11243913254dde3c0047d4 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 13 Mar 2014 23:06:50 -0400 Subject: possible gcc comp error --- decoder/tree2string_translator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'decoder') diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 4ccc54e2..ac9c0d74 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -33,7 +33,7 @@ using namespace std; // *0* implication: (S [A] _[B]) // b implication: (S [A] (B b *INCOMPLETE*)) struct Tree2StringGrammarNode { - unordered_map next; + map next; string rules; }; -- cgit v1.2.3