diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/Makefile.am | 3 | ||||
| -rw-r--r-- | decoder/rule_lexer.ll | 42 | ||||
| -rw-r--r-- | decoder/translator.h | 16 | ||||
| -rw-r--r-- | decoder/tree2string_translator.cc | 87 | ||||
| -rw-r--r-- | decoder/tree_fragment.cc | 115 | ||||
| -rw-r--r-- | decoder/tree_fragment.h | 59 | ||||
| -rw-r--r-- | decoder/trule.h | 5 | 
7 files changed, 315 insertions, 12 deletions
| diff --git a/decoder/Makefile.am b/decoder/Makefile.am index c0371081..0b61c7cd 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -130,6 +130,7 @@ libcdec_a_SOURCES = \    ffset.cc \    forest_writer.cc \    fst_translator.cc \ +  tree2string_translator.cc \    grammar.cc \    hg.cc \    hg_intersect.cc \ @@ -142,6 +143,8 @@ libcdec_a_SOURCES = \    lattice.cc \    lexalign.cc \    lextrans.cc \ +  tree_fragment.cc \ +  tree_fragment.h \    maxtrans_blunsom.cc \    phrasebased_translator.cc \    phrasetable_fst.cc \ diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll index c6a85919..05963d05 100644 --- a/decoder/rule_lexer.ll +++ b/decoder/rule_lexer.ll @@ -12,6 +12,7 @@  #include "fdict.h"  #include "trule.h"  #include "verbose.h" +#include "tree_fragment.h"  int lex_line = 0;  std::istream* scfglex_stream = NULL; @@ -51,6 +52,7 @@ int scfglex_src_nts[MAX_ARITY];  // float scfglex_nt_size_vars[MAX_ARITY];  std::stack<TRulePtr> ctf_rule_stack;  unsigned int ctf_level = 0; +boost::shared_ptr<cdec::TreeFragment> scfglex_tree;  #define MAX_ALS 2000  AlignmentPoint scfglex_als[MAX_ALS]; @@ -120,7 +122,7 @@ void check_and_update_ctf_stack(const TRulePtr& rp) {  REAL [\-+]?[0-9]+(\.[0-9]*([eE][-+]*[0-9]+)?)?|inf|[\-+]inf  NT [^\t \[\],]+ -%x LHS_END SRC TRG FEATS FEATVAL ALIGNS +%x LHS_END SRC TRG FEATS FEATVAL ALIGNS TREE  %%  <INITIAL>[ \t]	{ @@ -205,7 +207,13 @@ NT [^\t \[\],]+  		++scfglex_src_rhs_size;  		}  <SRC>[ \t]+	{ ; } - +<TREE>[^|\n]+	{ +                if (yyleng > 0) { +		  int len = yyleng; +                  while(len > 1 && yytext[len - 1] != ')') { --len; } +                  scfglex_tree.reset(new cdec::TreeFragment(std::string(yytext, len), true)); +                } +		}  <TRG>\|\|\|	{  		BEGIN(FEATS);  		} @@ -216,7 +224,7 @@ NT [^\t \[\],]+  		}  <TRG>[ \t]+	{ ; } -<TRG,FEATS,ALIGNS>\n	{ +<TRG,FEATS,ALIGNS,TREE>\n	{                  if (scfglex_src_arity != scfglex_trg_arity) {                    std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": LHS and RHS arity mismatch!\n";                    abort(); @@ -224,18 +232,25 @@ NT [^\t \[\],]+  		// const bool ignore_grammar_features = false;  		// if (ignore_grammar_features) scfglex_num_feats = 0;  		TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity, scfglex_als, scfglex_num_als)); -    check_and_update_ctf_stack(rp); -    TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top()); +		if (scfglex_tree) { +		  if (scfglex_tree->frontier_sites != rp->Arity()) { +		    std::cerr << "Arity mismatch with tree annotation: " << *scfglex_tree << std::endl; +		    abort(); +		  } +		  rp->tree_structure.swap(scfglex_tree); +		} +		check_and_update_ctf_stack(rp); +		TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top());  		rule_callback(rp, ctf_level, coarse_rp, rule_callback_extra); -    ctf_rule_stack.push(rp); +		ctf_rule_stack.push(rp);  		// std::cerr << rp->AsString() << std::endl;  		num_rules++; -    lex_line++; -    if (!SILENT) { -      if (num_rules %   50000 == 0) { std::cerr << '.' << std::flush; fl = true; } -      if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; } -    } -    ctf_level = 0; +		lex_line++; +		if (!SILENT) { +		  if (num_rules %   50000 == 0) { std::cerr << '.' << std::flush; fl = true; } +		  if (num_rules % 2000000 == 0) { std::cerr << " [" << num_rules << "]\n"; fl = false; } +		} +		ctf_level = 0;  		BEGIN(INITIAL);  		} @@ -253,6 +268,9 @@ NT [^\t \[\],]+  <FEATS>\|\|\|	{  		BEGIN(ALIGNS);  		} +<ALIGNS>\|\|\|[ \t]*	{ +		BEGIN(TREE); +		}  <FEATVAL>{REAL}	{  		scfglex_feat_vals[scfglex_num_feats] = strtod(yytext, NULL);  		++scfglex_num_feats; diff --git a/decoder/translator.h b/decoder/translator.h index c0800e84..72b2f0b0 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -98,4 +98,20 @@ class RescoreTranslator : public Translator {    boost::shared_ptr<RescoreTranslatorImpl> pimpl_;  }; +class Tree2StringTranslatorImpl; +class Tree2StringTranslator : public Translator { + public: +  Tree2StringTranslator(const boost::program_options::variables_map& conf); +  virtual std::string GetDecoderType() const; + protected: +  bool TranslateImpl(const std::string& src, +                 SentenceMetadata* smeta, +                 const std::vector<double>& weights, +                 Hypergraph* minus_lm_forest); +  void ProcessMarkupHintsImpl(const std::map<std::string, std::string>& kv); +  void SentenceCompleteImpl(); + private: +  boost::shared_ptr<Tree2StringTranslatorImpl> pimpl_; +}; +  #endif diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc new file mode 100644 index 00000000..ac9c0d74 --- /dev/null +++ b/decoder/tree2string_translator.cc @@ -0,0 +1,87 @@ +#include <algorithm> +#include <vector> +#include <boost/functional/hash.hpp> +#include <unordered_map> +#include "tree_fragment.h" +#include "translator.h" +#include "hg.h" +#include "sentence_metadata.h" +#include "filelib.h" +#include "stringlib.h" +#include "tdict.h" +#include "verbose.h" + +using namespace std; + +// root: S +// A          implication: (S [A] *INCOMPLETE* +// B          implication: (S [A] [B] *INCOMPLETE* +// *0*        implication: (S _[A] [B]) +// a          implication: (S (A a *INCOMPLETE* [B]) +// a          implication: (S (A a a *INCOMPLETE* [B]) +// *0*        implication: (S (A a a) _[B]) +// D          implication: (S (A a a) (B [D] *INCOMPLETE*) +// *0*        implication: (S (A a a) (B _[D])) +// d          implication: (S (A a a) (B (D d *INCOMPLETE*)) +// *0*        implication: (S (A a a) (B (D d))) +// --there are no further outgoing links possible-- + +// root: S +// A          implication: (S [A] *INCOMPLETE* +// B          implication: (S [A] [B] *INCOMPLETE* +// *0*        implication: (S _[A] [B]) +// *0*        implication: (S [A] _[B]) +// b          implication: (S [A] (B b *INCOMPLETE*)) +struct Tree2StringGrammarNode { +  map<unsigned, Tree2StringGrammarNode> next; +  string rules; +}; + +void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGrammarNode>* proots) { +  unordered_map<unsigned, Tree2StringGrammarNode>& roots = *proots; +  string line; +  while(getline(*in, line)) { +    size_t pos = line.find("|||"); +    assert(pos != string::npos); +    assert(pos > 3); +    if (line[pos - 1] == ' ') --pos; +    cdec::TreeFragment rule_src(line.substr(0, pos), true); +  } +} + +struct Tree2StringTranslatorImpl { +  unordered_map<unsigned, Tree2StringGrammarNode> roots; // root['S'] gives rule network for S rules +  Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) { +    ReadFile rf(conf["grammar"].as<vector<string>>()[0]); +    ReadTree2StringGrammar(rf.stream(), &roots); +  } +  bool Translate(const string& input, +                 SentenceMetadata* smeta, +                 const vector<double>& weights, +                 Hypergraph* minus_lm_forest) { +    cdec::TreeFragment input_tree(input, false); +    cerr << "Tree2StringTranslatorImpl: please implement this!\n"; +    return false; +  } +}; + +Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf) : +  pimpl_(new Tree2StringTranslatorImpl(conf)) {} + +bool Tree2StringTranslator::TranslateImpl(const string& input, +                               SentenceMetadata* smeta, +                               const vector<double>& weights, +                               Hypergraph* minus_lm_forest) { +  return pimpl_->Translate(input, smeta, weights, minus_lm_forest); +} + +void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) { +} + +void Tree2StringTranslator::SentenceCompleteImpl() { +} + +std::string Tree2StringTranslator::GetDecoderType() const { +  return "tree2string"; +} + diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc new file mode 100644 index 00000000..d5c30f58 --- /dev/null +++ b/decoder/tree_fragment.cc @@ -0,0 +1,115 @@ +#include "tree_fragment.h" + +#include <cassert> + +using namespace std; + +namespace cdec { + +TreeFragment::TreeFragment(const string& tree, bool allow_frontier_sites) { +  int bal = 0; +  const unsigned len = tree.size(); +  unsigned cur = 0; +  unsigned open = 0, close = 0; +  for (auto& c : tree) { +    ++cur; +    if (c == '(') { ++open; ++bal; } +    else if (c == ')') { +      ++close; --bal; +      if (bal < 1 && cur != len) { +        cerr << "Badly formed tree detected at column " << cur << " in:\n" << tree << endl; +        abort(); +      } +    } +  } +  nodes.resize(open); +  unsigned cp = 0, symp = 0, np = 0; +  ParseRec(tree, allow_frontier_sites, cp, symp, np, &cp, &symp, &np); +  root = nodes.back().lhs; +  //cerr << "ROOT: " << TD::Convert(root & ALL_MASK) << endl; +  //DebugRec(open - 1, &cerr); cerr << "\n"; +} + +void TreeFragment::DebugRec(unsigned cur, ostream* out) const { +  *out << '(' << TD::Convert(nodes[cur].lhs & ALL_MASK); +  for (auto& x : nodes[cur].rhs) { +    *out << ' '; +    if (IsFrontier(x)) { +      *out << '[' << TD::Convert(x & ALL_MASK) << ']'; +    } else if (IsInternalNT(x)) { +      DebugRec(x & ALL_MASK, out); +    } else { // must be terminal +      *out << TD::Convert(x); +    } +  } +  *out << ')'; +} + +// cp is the character index in the tree +// np keeps track of the nodes (nonterminals) that have been built +// symp keeps track of the terminal symbols that have been built +void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp) { +  if (tree[cp] != '(') { +    cerr << "Expected ( at " << cp << endl; +    abort(); +  } +  const unsigned i = symp; +  vector<unsigned> rhs; // w | 0 = terminal, w | NT_BIT, index | FRONTIER_BIT +  ++cp; +  while(tree[cp] == ' ') { ++cp; } +  const unsigned nt_start = cp; +  while(tree[cp] != ' ' && tree[cp] != '(' && tree[cp] != ')') { ++cp; } +  const unsigned nt_end = cp; +  while(tree[cp] == ' ') { ++cp; } +  while (tree[cp] != ')') { +    if (tree[cp] == '(') { +      // recursively call parser to deal with constituent +      ParseRec(tree, afs, cp, symp, np, &cp, &symp, &np); +      unsigned ind = np - 1; +      rhs.push_back(ind | NT_BIT); +    } else { // deal with terminal / nonterminal substitution +      ++symp; +      assert(tree[cp] != ' '); +      const unsigned t_start = cp; +      while(tree[cp] != ' ' && tree[cp] != ')' && tree[cp] != '(') { ++cp; } +      const unsigned t_end = cp; +      while(tree[cp] == ' ') { ++cp; } +      // TODO: add a terminal symbol to the current edge +      const bool is_terminal = tree[t_start] != '[' || (t_end - t_start < 3 || tree[t_end - 1] != ']'); +      if (is_terminal) { +        const unsigned term = TD::Convert(tree.substr(t_start, t_end - t_start)); +        rhs.push_back(term); +        // cerr << "T='" << TD::Convert(term) << "'\n"; +        ++terminals; +      } else { // frontier site (NT but no recursion) +        const unsigned nt = TD::Convert(tree.substr(t_start + 1, t_end - t_start - 2)) | FRONTIER_BIT; +        rhs.push_back(nt); +        ++frontier_sites; +        // cerr << "FRONT-NT=[" << TD::Convert(nt & ALL_MASK) << "]\n"; +        if (!afs) { +          cerr << "Frontier sites not allowed in input: " << tree << endl; +          abort(); +        } +      }  +    } +  } // continuent has completed, cp is at ), build node +  const unsigned j = symp; // span from (i,j) +  // add an internal non-terminal symbol +  const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | NT_BIT; +  nodes[np] = TreeFragmentProduction(nt, rhs); +  //cerr << np << " production(" << i << "," << j << ")=  " << TD::Convert(nt & ALL_MASK) << " -->"; +  //for (auto& x : rhs) { +  //  cerr << ' '; +  //  if (IsFrontier(x)) cerr << '*'; +  //  if (IsInternalNT(x)) cerr << TD::Convert(nodes[x & ALL_MASK].lhs & ALL_MASK); else +  //    cerr << TD::Convert(x & ALL_MASK); +  //} +  //cerr << "\n   "; DebugRec(np,&cerr); cerr << endl; +  ++cp; +  while(tree[cp] == ' ' && cp < tree.size()) { ++cp; } +  *pcp = cp; +  *pnp = np + 1; +  *psymp = symp; +} + +} diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h new file mode 100644 index 00000000..83cd1c1e --- /dev/null +++ b/decoder/tree_fragment.h @@ -0,0 +1,59 @@ +#ifndef TREE_FRAGMENT +#define TREE_FRAGMENT + +#include <iostream> +#include <vector> +#include <string> + +#include "tdict.h" + +namespace cdec { + +static const unsigned NT_BIT       = 0x40000000u; +static const unsigned FRONTIER_BIT = 0x80000000u; +static const unsigned ALL_MASK     = 0x0FFFFFFFu; + +inline bool IsInternalNT(unsigned x) { +  return (x & NT_BIT); +} + +inline bool IsFrontier(unsigned x) { +  return (x & FRONTIER_BIT); +} + +struct TreeFragmentProduction { +  TreeFragmentProduction() {} +  TreeFragmentProduction(int nttype, const std::vector<unsigned>& r) : lhs(nttype), rhs(r) {} +  unsigned lhs; +  std::vector<unsigned> rhs; +}; + +// this data structure represents a tree or forest +// productions can have mixtures of terminals and nonterminal symbols and non-terminal frontier sites +class TreeFragment { + public: +  TreeFragment() : frontier_sites(), terminals() {} +  // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar)))) +  explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false); +  void DebugRec(unsigned cur, std::ostream* out) const; + private: +  // cp is the character index in the tree +  // np keeps track of the nodes (nonterminals) that have been built +  // symp keeps track of the terminal symbols that have been built +  void ParseRec(const std::string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp); + public: +  unsigned root; +  unsigned char frontier_sites; +  unsigned short terminals; + +  std::vector<TreeFragmentProduction> nodes; +}; + +inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) { +  x.DebugRec(x.nodes.size() - 1, &os); +  return os; +} + +} + +#endif diff --git a/decoder/trule.h b/decoder/trule.h index 6a33d052..e9a10bea 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -15,6 +15,8 @@  class TRule;  typedef boost::shared_ptr<TRule> TRulePtr; +namespace cdec { struct TreeFragment; } +  struct AlignmentPoint {    AlignmentPoint() : s_(), t_() {}    AlignmentPoint(int s, int t) : s_(s), t_(t) {} @@ -159,6 +161,9 @@ class TRule {    // only for coarse-to-fine decoding    boost::shared_ptr<std::vector<TRulePtr> > fine_rules_; +  // optional, shows internal structure of TSG rules +  boost::shared_ptr<cdec::TreeFragment> tree_structure; +   private:    TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {}    bool SanityCheck() const; | 
