summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-03-12 02:30:26 -0400
committerChris Dyer <redpony@gmail.com>2014-03-12 02:30:26 -0400
commit10a668822715cee024a7e7391c62caa8e078e840 (patch)
treec6db1a3055f3dd589c3bffd6e54d9bb544b8e1e3
parent284383880f043edb2d67afbe2f64237c466245c1 (diff)
add support for internal tree structure on SCFG rules
-rw-r--r--decoder/rule_lexer.ll42
-rw-r--r--decoder/tree2string_translator.cc87
-rw-r--r--decoder/tree_fragment.h59
3 files changed, 176 insertions, 12 deletions
diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll
index c6a85919..05963d05 100644
--- a/decoder/rule_lexer.ll
+++ b/decoder/rule_lexer.ll
@@ -12,6 +12,7 @@
#include "fdict.h"
#include "trule.h"
#include "verbose.h"
+#include "tree_fragment.h"
int lex_line = 0;
std::istream* scfglex_stream = NULL;
@@ -51,6 +52,7 @@ int scfglex_src_nts[MAX_ARITY];
// float scfglex_nt_size_vars[MAX_ARITY];
std::stack<TRulePtr> ctf_rule_stack;
unsigned int ctf_level = 0;
+boost::shared_ptr<cdec::TreeFragment> scfglex_tree;
#define MAX_ALS 2000
AlignmentPoint scfglex_als[MAX_ALS];
@@ -120,7 +122,7 @@ void check_and_update_ctf_stack(const TRulePtr& rp) {
REAL [\-+]?[0-9]+(\.[0-9]*([eE][-+]*[0-9]+)?)?|inf|[\-+]inf
NT [^\t \[\],]+
-%x LHS_END SRC TRG FEATS FEATVAL ALIGNS
+%x LHS_END SRC TRG FEATS FEATVAL ALIGNS TREE
%%
<INITIAL>[ \t] {
@@ -205,7 +207,13 @@ NT [^\t \[\],]+
++scfglex_src_rhs_size;
}
<SRC>[ \t]+ { ; }
-
+<TREE>[^|\n]+ {
+ if (yyleng > 0) {
+ int len = yyleng;
+ while(len > 1 && yytext[len - 1] != ')') { --len; }
+ scfglex_tree.reset(new cdec::TreeFragment(std::string(yytext, len), true));
+ }
+ }
<TRG>\|\|\| {
BEGIN(FEATS);
}
@@ -216,7 +224,7 @@ NT [^\t \[\],]+
}
<TRG>[ \t]+ { ; }
-<TRG,FEATS,ALIGNS>\n {
+<TRG,FEATS,ALIGNS,TREE>\n {
if (scfglex_src_arity != scfglex_trg_arity) {
std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": LHS and RHS arity mismatch!\n";
abort();
@@ -224,18 +232,25 @@ NT [^\t \[\],]+
// const bool ignore_grammar_features = false;
// if (ignore_grammar_features) scfglex_num_feats = 0;
TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity, scfglex_als, scfglex_num_als));
- check_and_update_ctf_stack(rp);
- TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top());
+ if (scfglex_tree) {
+ if (scfglex_tree->frontier_sites != rp->Arity()) {
+ std::cerr << "Arity mismatch with tree annotation: " << *scfglex_tree << std::endl;
+ abort();
+ }
+ rp->tree_structure.swap(scfglex_tree);
+ }
+ check_and_update_ctf_stack(rp);
+ TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top());
rule_callback(rp, ctf_level, coarse_rp, rule_callback_extra);
- ctf_rule_stack.push(rp);
+ ctf_rule_stack.push(rp);
// std::cerr << rp->AsString() << std::endl;
num_rules++;
- lex_line++;
- if (!SILENT) {
- if (num_rules % 50000 == 0) { std::cerr << '.' << std::flush; fl = true; }
- if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; }
- }
- ctf_level = 0;
+ lex_line++;
+ if (!SILENT) {
+ if (num_rules % 50000 == 0) { std::cerr << '.' << std::flush; fl = true; }
+ if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; }
+ }
+ ctf_level = 0;
BEGIN(INITIAL);
}
@@ -253,6 +268,9 @@ NT [^\t \[\],]+
<FEATS>\|\|\| {
BEGIN(ALIGNS);
}
+<ALIGNS>\|\|\|[ \t]* {
+ BEGIN(TREE);
+ }
<FEATVAL>{REAL} {
scfglex_feat_vals[scfglex_num_feats] = strtod(yytext, NULL);
++scfglex_num_feats;
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
new file mode 100644
index 00000000..4ccc54e2
--- /dev/null
+++ b/decoder/tree2string_translator.cc
@@ -0,0 +1,87 @@
+#include <algorithm>
+#include <vector>
+#include <boost/functional/hash.hpp>
+#include <unordered_map>
+#include "tree_fragment.h"
+#include "translator.h"
+#include "hg.h"
+#include "sentence_metadata.h"
+#include "filelib.h"
+#include "stringlib.h"
+#include "tdict.h"
+#include "verbose.h"
+
+using namespace std;
+
+// root: S
+// A implication: (S [A] *INCOMPLETE*
+// B implication: (S [A] [B] *INCOMPLETE*
+// *0* implication: (S _[A] [B])
+// a implication: (S (A a *INCOMPLETE* [B])
+// a implication: (S (A a a *INCOMPLETE* [B])
+// *0* implication: (S (A a a) _[B])
+// D implication: (S (A a a) (B [D] *INCOMPLETE*)
+// *0* implication: (S (A a a) (B _[D]))
+// d implication: (S (A a a) (B (D d *INCOMPLETE*))
+// *0* implication: (S (A a a) (B (D d)))
+// --there are no further outgoing links possible--
+
+// root: S
+// A implication: (S [A] *INCOMPLETE*
+// B implication: (S [A] [B] *INCOMPLETE*
+// *0* implication: (S _[A] [B])
+// *0* implication: (S [A] _[B])
+// b implication: (S [A] (B b *INCOMPLETE*))
+struct Tree2StringGrammarNode {
+ unordered_map<unsigned, Tree2StringGrammarNode> next;
+ string rules;
+};
+
+void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGrammarNode>* proots) {
+ unordered_map<unsigned, Tree2StringGrammarNode>& roots = *proots;
+ string line;
+ while(getline(*in, line)) {
+ size_t pos = line.find("|||");
+ assert(pos != string::npos);
+ assert(pos > 3);
+ if (line[pos - 1] == ' ') --pos;
+ cdec::TreeFragment rule_src(line.substr(0, pos), true);
+ }
+}
+
+struct Tree2StringTranslatorImpl {
+ unordered_map<unsigned, Tree2StringGrammarNode> roots; // root['S'] gives rule network for S rules
+ Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) {
+ ReadFile rf(conf["grammar"].as<vector<string>>()[0]);
+ ReadTree2StringGrammar(rf.stream(), &roots);
+ }
+ bool Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ cdec::TreeFragment input_tree(input, false);
+ cerr << "Tree2StringTranslatorImpl: please implement this!\n";
+ return false;
+ }
+};
+
+Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf) :
+ pimpl_(new Tree2StringTranslatorImpl(conf)) {}
+
+bool Tree2StringTranslator::TranslateImpl(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ return pimpl_->Translate(input, smeta, weights, minus_lm_forest);
+}
+
+void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) {
+}
+
+void Tree2StringTranslator::SentenceCompleteImpl() {
+}
+
+std::string Tree2StringTranslator::GetDecoderType() const {
+ return "tree2string";
+}
+
diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h
new file mode 100644
index 00000000..83cd1c1e
--- /dev/null
+++ b/decoder/tree_fragment.h
@@ -0,0 +1,59 @@
+#ifndef TREE_FRAGMENT
+#define TREE_FRAGMENT
+
+#include <iostream>
+#include <vector>
+#include <string>
+
+#include "tdict.h"
+
+namespace cdec {
+
+static const unsigned NT_BIT = 0x40000000u;
+static const unsigned FRONTIER_BIT = 0x80000000u;
+static const unsigned ALL_MASK = 0x0FFFFFFFu;
+
+inline bool IsInternalNT(unsigned x) {
+ return (x & NT_BIT);
+}
+
+inline bool IsFrontier(unsigned x) {
+ return (x & FRONTIER_BIT);
+}
+
+struct TreeFragmentProduction {
+ TreeFragmentProduction() {}
+ TreeFragmentProduction(int nttype, const std::vector<unsigned>& r) : lhs(nttype), rhs(r) {}
+ unsigned lhs;
+ std::vector<unsigned> rhs;
+};
+
+// this data structure represents a tree or forest
+// productions can have mixtures of terminals and nonterminal symbols and non-terminal frontier sites
+class TreeFragment {
+ public:
+ TreeFragment() : frontier_sites(), terminals() {}
+ // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar))))
+ explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false);
+ void DebugRec(unsigned cur, std::ostream* out) const;
+ private:
+ // cp is the character index in the tree
+ // np keeps track of the nodes (nonterminals) that have been built
+ // symp keeps track of the terminal symbols that have been built
+ void ParseRec(const std::string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp);
+ public:
+ unsigned root;
+ unsigned char frontier_sites;
+ unsigned short terminals;
+
+ std::vector<TreeFragmentProduction> nodes;
+};
+
+inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) {
+ x.DebugRec(x.nodes.size() - 1, &os);
+ return os;
+}
+
+}
+
+#endif