summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/Makefile.am3
-rw-r--r--decoder/rule_lexer.ll42
-rw-r--r--decoder/translator.h16
-rw-r--r--decoder/tree2string_translator.cc87
-rw-r--r--decoder/tree_fragment.cc115
-rw-r--r--decoder/tree_fragment.h59
-rw-r--r--decoder/trule.h5
7 files changed, 315 insertions, 12 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index c0371081..0b61c7cd 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -130,6 +130,7 @@ libcdec_a_SOURCES = \
ffset.cc \
forest_writer.cc \
fst_translator.cc \
+ tree2string_translator.cc \
grammar.cc \
hg.cc \
hg_intersect.cc \
@@ -142,6 +143,8 @@ libcdec_a_SOURCES = \
lattice.cc \
lexalign.cc \
lextrans.cc \
+ tree_fragment.cc \
+ tree_fragment.h \
maxtrans_blunsom.cc \
phrasebased_translator.cc \
phrasetable_fst.cc \
diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll
index c6a85919..05963d05 100644
--- a/decoder/rule_lexer.ll
+++ b/decoder/rule_lexer.ll
@@ -12,6 +12,7 @@
#include "fdict.h"
#include "trule.h"
#include "verbose.h"
+#include "tree_fragment.h"
int lex_line = 0;
std::istream* scfglex_stream = NULL;
@@ -51,6 +52,7 @@ int scfglex_src_nts[MAX_ARITY];
// float scfglex_nt_size_vars[MAX_ARITY];
std::stack<TRulePtr> ctf_rule_stack;
unsigned int ctf_level = 0;
+boost::shared_ptr<cdec::TreeFragment> scfglex_tree;
#define MAX_ALS 2000
AlignmentPoint scfglex_als[MAX_ALS];
@@ -120,7 +122,7 @@ void check_and_update_ctf_stack(const TRulePtr& rp) {
REAL [\-+]?[0-9]+(\.[0-9]*([eE][-+]*[0-9]+)?)?|inf|[\-+]inf
NT [^\t \[\],]+
-%x LHS_END SRC TRG FEATS FEATVAL ALIGNS
+%x LHS_END SRC TRG FEATS FEATVAL ALIGNS TREE
%%
<INITIAL>[ \t] {
@@ -205,7 +207,13 @@ NT [^\t \[\],]+
++scfglex_src_rhs_size;
}
<SRC>[ \t]+ { ; }
-
+<TREE>[^|\n]+ {
+ if (yyleng > 0) {
+ int len = yyleng;
+ while(len > 1 && yytext[len - 1] != ')') { --len; }
+ scfglex_tree.reset(new cdec::TreeFragment(std::string(yytext, len), true));
+ }
+ }
<TRG>\|\|\| {
BEGIN(FEATS);
}
@@ -216,7 +224,7 @@ NT [^\t \[\],]+
}
<TRG>[ \t]+ { ; }
-<TRG,FEATS,ALIGNS>\n {
+<TRG,FEATS,ALIGNS,TREE>\n {
if (scfglex_src_arity != scfglex_trg_arity) {
std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": LHS and RHS arity mismatch!\n";
abort();
@@ -224,18 +232,25 @@ NT [^\t \[\],]+
// const bool ignore_grammar_features = false;
// if (ignore_grammar_features) scfglex_num_feats = 0;
TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity, scfglex_als, scfglex_num_als));
- check_and_update_ctf_stack(rp);
- TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top());
+ if (scfglex_tree) {
+ if (scfglex_tree->frontier_sites != rp->Arity()) {
+ std::cerr << "Arity mismatch with tree annotation: " << *scfglex_tree << std::endl;
+ abort();
+ }
+ rp->tree_structure.swap(scfglex_tree);
+ }
+ check_and_update_ctf_stack(rp);
+ TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top());
rule_callback(rp, ctf_level, coarse_rp, rule_callback_extra);
- ctf_rule_stack.push(rp);
+ ctf_rule_stack.push(rp);
// std::cerr << rp->AsString() << std::endl;
num_rules++;
- lex_line++;
- if (!SILENT) {
- if (num_rules % 50000 == 0) { std::cerr << '.' << std::flush; fl = true; }
- if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; }
- }
- ctf_level = 0;
+ lex_line++;
+ if (!SILENT) {
+ if (num_rules % 50000 == 0) { std::cerr << '.' << std::flush; fl = true; }
+ if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; }
+ }
+ ctf_level = 0;
BEGIN(INITIAL);
}
@@ -253,6 +268,9 @@ NT [^\t \[\],]+
<FEATS>\|\|\| {
BEGIN(ALIGNS);
}
+<ALIGNS>\|\|\|[ \t]* {
+ BEGIN(TREE);
+ }
<FEATVAL>{REAL} {
scfglex_feat_vals[scfglex_num_feats] = strtod(yytext, NULL);
++scfglex_num_feats;
diff --git a/decoder/translator.h b/decoder/translator.h
index c0800e84..72b2f0b0 100644
--- a/decoder/translator.h
+++ b/decoder/translator.h
@@ -98,4 +98,20 @@ class RescoreTranslator : public Translator {
boost::shared_ptr<RescoreTranslatorImpl> pimpl_;
};
+class Tree2StringTranslatorImpl;
+class Tree2StringTranslator : public Translator {
+ public:
+ Tree2StringTranslator(const boost::program_options::variables_map& conf);
+ virtual std::string GetDecoderType() const;
+ protected:
+ bool TranslateImpl(const std::string& src,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest);
+ void ProcessMarkupHintsImpl(const std::map<std::string, std::string>& kv);
+ void SentenceCompleteImpl();
+ private:
+ boost::shared_ptr<Tree2StringTranslatorImpl> pimpl_;
+};
+
#endif
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
new file mode 100644
index 00000000..ac9c0d74
--- /dev/null
+++ b/decoder/tree2string_translator.cc
@@ -0,0 +1,87 @@
+#include <algorithm>
+#include <vector>
+#include <boost/functional/hash.hpp>
+#include <unordered_map>
+#include "tree_fragment.h"
+#include "translator.h"
+#include "hg.h"
+#include "sentence_metadata.h"
+#include "filelib.h"
+#include "stringlib.h"
+#include "tdict.h"
+#include "verbose.h"
+
+using namespace std;
+
+// root: S
+// A implication: (S [A] *INCOMPLETE*
+// B implication: (S [A] [B] *INCOMPLETE*
+// *0* implication: (S _[A] [B])
+// a implication: (S (A a *INCOMPLETE* [B])
+// a implication: (S (A a a *INCOMPLETE* [B])
+// *0* implication: (S (A a a) _[B])
+// D implication: (S (A a a) (B [D] *INCOMPLETE*)
+// *0* implication: (S (A a a) (B _[D]))
+// d implication: (S (A a a) (B (D d *INCOMPLETE*))
+// *0* implication: (S (A a a) (B (D d)))
+// --there are no further outgoing links possible--
+
+// root: S
+// A implication: (S [A] *INCOMPLETE*
+// B implication: (S [A] [B] *INCOMPLETE*
+// *0* implication: (S _[A] [B])
+// *0* implication: (S [A] _[B])
+// b implication: (S [A] (B b *INCOMPLETE*))
+struct Tree2StringGrammarNode {
+ map<unsigned, Tree2StringGrammarNode> next;
+ string rules;
+};
+
+void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGrammarNode>* proots) {
+ unordered_map<unsigned, Tree2StringGrammarNode>& roots = *proots;
+ string line;
+ while(getline(*in, line)) {
+ size_t pos = line.find("|||");
+ assert(pos != string::npos);
+ assert(pos > 3);
+ if (line[pos - 1] == ' ') --pos;
+ cdec::TreeFragment rule_src(line.substr(0, pos), true);
+ }
+}
+
+struct Tree2StringTranslatorImpl {
+ unordered_map<unsigned, Tree2StringGrammarNode> roots; // root['S'] gives rule network for S rules
+ Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) {
+ ReadFile rf(conf["grammar"].as<vector<string>>()[0]);
+ ReadTree2StringGrammar(rf.stream(), &roots);
+ }
+ bool Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ cdec::TreeFragment input_tree(input, false);
+ cerr << "Tree2StringTranslatorImpl: please implement this!\n";
+ return false;
+ }
+};
+
+Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf) :
+ pimpl_(new Tree2StringTranslatorImpl(conf)) {}
+
+bool Tree2StringTranslator::TranslateImpl(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ return pimpl_->Translate(input, smeta, weights, minus_lm_forest);
+}
+
+void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) {
+}
+
+void Tree2StringTranslator::SentenceCompleteImpl() {
+}
+
+std::string Tree2StringTranslator::GetDecoderType() const {
+ return "tree2string";
+}
+
diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc
new file mode 100644
index 00000000..d5c30f58
--- /dev/null
+++ b/decoder/tree_fragment.cc
@@ -0,0 +1,115 @@
+#include "tree_fragment.h"
+
+#include <cassert>
+
+using namespace std;
+
+namespace cdec {
+
+TreeFragment::TreeFragment(const string& tree, bool allow_frontier_sites) {
+ int bal = 0;
+ const unsigned len = tree.size();
+ unsigned cur = 0;
+ unsigned open = 0, close = 0;
+ for (auto& c : tree) {
+ ++cur;
+ if (c == '(') { ++open; ++bal; }
+ else if (c == ')') {
+ ++close; --bal;
+ if (bal < 1 && cur != len) {
+ cerr << "Badly formed tree detected at column " << cur << " in:\n" << tree << endl;
+ abort();
+ }
+ }
+ }
+ nodes.resize(open);
+ unsigned cp = 0, symp = 0, np = 0;
+ ParseRec(tree, allow_frontier_sites, cp, symp, np, &cp, &symp, &np);
+ root = nodes.back().lhs;
+ //cerr << "ROOT: " << TD::Convert(root & ALL_MASK) << endl;
+ //DebugRec(open - 1, &cerr); cerr << "\n";
+}
+
+void TreeFragment::DebugRec(unsigned cur, ostream* out) const {
+ *out << '(' << TD::Convert(nodes[cur].lhs & ALL_MASK);
+ for (auto& x : nodes[cur].rhs) {
+ *out << ' ';
+ if (IsFrontier(x)) {
+ *out << '[' << TD::Convert(x & ALL_MASK) << ']';
+ } else if (IsInternalNT(x)) {
+ DebugRec(x & ALL_MASK, out);
+ } else { // must be terminal
+ *out << TD::Convert(x);
+ }
+ }
+ *out << ')';
+}
+
+// cp is the character index in the tree
+// np keeps track of the nodes (nonterminals) that have been built
+// symp keeps track of the terminal symbols that have been built
+void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp) {
+ if (tree[cp] != '(') {
+ cerr << "Expected ( at " << cp << endl;
+ abort();
+ }
+ const unsigned i = symp;
+ vector<unsigned> rhs; // w | 0 = terminal, w | NT_BIT, index | FRONTIER_BIT
+ ++cp;
+ while(tree[cp] == ' ') { ++cp; }
+ const unsigned nt_start = cp;
+ while(tree[cp] != ' ' && tree[cp] != '(' && tree[cp] != ')') { ++cp; }
+ const unsigned nt_end = cp;
+ while(tree[cp] == ' ') { ++cp; }
+ while (tree[cp] != ')') {
+ if (tree[cp] == '(') {
+ // recursively call parser to deal with constituent
+ ParseRec(tree, afs, cp, symp, np, &cp, &symp, &np);
+ unsigned ind = np - 1;
+ rhs.push_back(ind | NT_BIT);
+ } else { // deal with terminal / nonterminal substitution
+ ++symp;
+ assert(tree[cp] != ' ');
+ const unsigned t_start = cp;
+ while(tree[cp] != ' ' && tree[cp] != ')' && tree[cp] != '(') { ++cp; }
+ const unsigned t_end = cp;
+ while(tree[cp] == ' ') { ++cp; }
+ // TODO: add a terminal symbol to the current edge
+ const bool is_terminal = tree[t_start] != '[' || (t_end - t_start < 3 || tree[t_end - 1] != ']');
+ if (is_terminal) {
+ const unsigned term = TD::Convert(tree.substr(t_start, t_end - t_start));
+ rhs.push_back(term);
+ // cerr << "T='" << TD::Convert(term) << "'\n";
+ ++terminals;
+ } else { // frontier site (NT but no recursion)
+ const unsigned nt = TD::Convert(tree.substr(t_start + 1, t_end - t_start - 2)) | FRONTIER_BIT;
+ rhs.push_back(nt);
+ ++frontier_sites;
+ // cerr << "FRONT-NT=[" << TD::Convert(nt & ALL_MASK) << "]\n";
+ if (!afs) {
+ cerr << "Frontier sites not allowed in input: " << tree << endl;
+ abort();
+ }
+ }
+ }
+ } // continuent has completed, cp is at ), build node
+ const unsigned j = symp; // span from (i,j)
+ // add an internal non-terminal symbol
+ const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | NT_BIT;
+ nodes[np] = TreeFragmentProduction(nt, rhs);
+ //cerr << np << " production(" << i << "," << j << ")= " << TD::Convert(nt & ALL_MASK) << " -->";
+ //for (auto& x : rhs) {
+ // cerr << ' ';
+ // if (IsFrontier(x)) cerr << '*';
+ // if (IsInternalNT(x)) cerr << TD::Convert(nodes[x & ALL_MASK].lhs & ALL_MASK); else
+ // cerr << TD::Convert(x & ALL_MASK);
+ //}
+ //cerr << "\n "; DebugRec(np,&cerr); cerr << endl;
+ ++cp;
+ while(tree[cp] == ' ' && cp < tree.size()) { ++cp; }
+ *pcp = cp;
+ *pnp = np + 1;
+ *psymp = symp;
+}
+
+}
diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h
new file mode 100644
index 00000000..83cd1c1e
--- /dev/null
+++ b/decoder/tree_fragment.h
@@ -0,0 +1,59 @@
+#ifndef TREE_FRAGMENT
+#define TREE_FRAGMENT
+
+#include <iostream>
+#include <vector>
+#include <string>
+
+#include "tdict.h"
+
+namespace cdec {
+
+static const unsigned NT_BIT = 0x40000000u;
+static const unsigned FRONTIER_BIT = 0x80000000u;
+static const unsigned ALL_MASK = 0x0FFFFFFFu;
+
+inline bool IsInternalNT(unsigned x) {
+ return (x & NT_BIT);
+}
+
+inline bool IsFrontier(unsigned x) {
+ return (x & FRONTIER_BIT);
+}
+
+struct TreeFragmentProduction {
+ TreeFragmentProduction() {}
+ TreeFragmentProduction(int nttype, const std::vector<unsigned>& r) : lhs(nttype), rhs(r) {}
+ unsigned lhs;
+ std::vector<unsigned> rhs;
+};
+
+// this data structure represents a tree or forest
+// productions can have mixtures of terminals and nonterminal symbols and non-terminal frontier sites
+class TreeFragment {
+ public:
+ TreeFragment() : frontier_sites(), terminals() {}
+ // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar))))
+ explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false);
+ void DebugRec(unsigned cur, std::ostream* out) const;
+ private:
+ // cp is the character index in the tree
+ // np keeps track of the nodes (nonterminals) that have been built
+ // symp keeps track of the terminal symbols that have been built
+ void ParseRec(const std::string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp);
+ public:
+ unsigned root;
+ unsigned char frontier_sites;
+ unsigned short terminals;
+
+ std::vector<TreeFragmentProduction> nodes;
+};
+
+inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) {
+ x.DebugRec(x.nodes.size() - 1, &os);
+ return os;
+}
+
+}
+
+#endif
diff --git a/decoder/trule.h b/decoder/trule.h
index 6a33d052..e9a10bea 100644
--- a/decoder/trule.h
+++ b/decoder/trule.h
@@ -15,6 +15,8 @@
class TRule;
typedef boost::shared_ptr<TRule> TRulePtr;
+namespace cdec { struct TreeFragment; }
+
struct AlignmentPoint {
AlignmentPoint() : s_(), t_() {}
AlignmentPoint(int s, int t) : s_(s), t_(t) {}
@@ -159,6 +161,9 @@ class TRule {
// only for coarse-to-fine decoding
boost::shared_ptr<std::vector<TRulePtr> > fine_rules_;
+ // optional, shows internal structure of TSG rules
+ boost::shared_ptr<cdec::TreeFragment> tree_structure;
+
private:
TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {}
bool SanityCheck() const;