diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/Makefile.am | 2 | ||||
| -rw-r--r-- | decoder/decoder.cc | 6 | ||||
| -rw-r--r-- | decoder/earley_composer.cc | 38 | ||||
| -rw-r--r-- | decoder/hg.cc | 63 | ||||
| -rw-r--r-- | decoder/hg.h | 17 | ||||
| -rw-r--r-- | decoder/hg_remove_eps.cc | 91 | ||||
| -rw-r--r-- | decoder/hg_remove_eps.h | 13 | ||||
| -rw-r--r-- | decoder/rescore_translator.cc | 57 | ||||
| -rw-r--r-- | decoder/translator.h | 13 | ||||
| -rw-r--r-- | decoder/trule.cc | 4 | 
10 files changed, 218 insertions, 86 deletions
| diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 00d01e53..0a792549 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -37,9 +37,11 @@ libcdec_a_SOURCES = \    fst_translator.cc \    csplit.cc \    translator.cc \ +  rescore_translator.cc \    scfg_translator.cc \    hg.cc \    hg_io.cc \ +  hg_remove_eps.cc \    decoder.cc \    hg_intersect.cc \    hg_sampler.cc \ diff --git a/decoder/decoder.cc b/decoder/decoder.cc index ad4e9e07..a6f7b1ce 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -527,8 +527,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    }    formalism = LowercaseString(str("formalism",conf)); -  if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") { -    cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n"; +  if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { +    cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n";      cerr << dcmdline_options << endl;      exit(1);    } @@ -675,6 +675,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream      translator.reset(new LexicalTrans(conf));    else if (formalism == "lexalign")      translator.reset(new LexicalAlign(conf)); +  else if (formalism == "rescore") +    translator.reset(new RescoreTranslator(conf));    else if (formalism == "tagger")      translator.reset(new Tagger(conf));    else diff --git a/decoder/earley_composer.cc b/decoder/earley_composer.cc index d265d954..efce70a6 100644 --- a/decoder/earley_composer.cc +++ b/decoder/earley_composer.cc @@ -16,6 +16,7 @@  #include "sparse_vector.h"  #include "tdict.h"  #include "hg.h" +#include "hg_remove_eps.h"  using namespace std;  using namespace std::tr1; @@ -48,6 +49,27 @@ static void InitializeConstants() {  }  //////////////////////////////////////////////////////////// +TRulePtr CreateBinaryRule(int lhs, int rhs1, int rhs2) { +  TRule* r = new TRule(*kX1X2); +  r->lhs_ = lhs; +  r->f_[0] = rhs1; +  r->f_[1] = rhs2; +  return TRulePtr(r); +} + +TRulePtr CreateUnaryRule(int lhs, int rhs1) { +  TRule* r = new TRule(*kX1); +  r->lhs_ = lhs; +  r->f_[0] = rhs1; +  return TRulePtr(r); +} + +TRulePtr CreateEpsilonRule(int lhs) { +  TRule* r = new TRule(*kEPSRule); +  r->lhs_ = lhs; +  return TRulePtr(r); +} +  class EGrammarNode {    friend bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest);    friend void AddGrammarRule(const string& r, map<WordID, EGrammarNode>* g); @@ -356,7 +378,7 @@ class EarleyComposerImpl {      }      if (goal_node) {        forest->PruneUnreachable(goal_node->id_); -      forest->EpsilonRemove(kEPS); +      RemoveEpsilons(forest, kEPS);      }      FreeAll();      return goal_node; @@ -557,24 +579,30 @@ class EarleyComposerImpl {      }      Hypergraph::Node*& head_node = edge2node[edge];      if (!head_node) -      head_node = hg->AddNode(kPHRASE); +      head_node = hg->AddNode(edge->cat);      if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_0_ && edge->IsPassive()) {        assert(goal_node == NULL || goal_node == head_node);        goal_node = head_node;      } +    int rhs1 = 0; +    int rhs2 = 0;      Hypergraph::TailNodeVector tail;      SparseVector<double> extra;      if (edge->IsCreatedByPredict()) {        // extra.set_value(FD::Convert("predict"), 1);      } else if (edge->IsCreatedByScan()) {        tail.push_back(edge2node[edge->active_parent]->id_); +      rhs1 = edge->active_parent->cat;        if (tps) {          tail.push_back(tps->id_); +        rhs2 = kPHRASE;        }        //extra.set_value(FD::Convert("scan"), 1);      } else if (edge->IsCreatedByComplete()) {        tail.push_back(edge2node[edge->active_parent]->id_); +      rhs1 = edge->active_parent->cat;        tail.push_back(edge2node[edge->passive_parent]->id_); +      rhs2 = edge->passive_parent->cat;        //extra.set_value(FD::Convert("complete"), 1);      } else {        assert(!"unexpected edge type!"); @@ -592,11 +620,11 @@ class EarleyComposerImpl {  #endif      Hypergraph::Edge* hg_edge = NULL;      if (tail.size() == 0) { -      hg_edge = hg->AddEdge(kEPSRule, tail); +      hg_edge = hg->AddEdge(CreateEpsilonRule(edge->cat), tail);      } else if (tail.size() == 1) { -      hg_edge = hg->AddEdge(kX1, tail); +      hg_edge = hg->AddEdge(CreateUnaryRule(edge->cat, rhs1), tail);      } else if (tail.size() == 2) { -      hg_edge = hg->AddEdge(kX1X2, tail); +      hg_edge = hg->AddEdge(CreateBinaryRule(edge->cat, rhs1, rhs2), tail);      }      if (edge->features)        hg_edge->feature_values_ += *edge->features; diff --git a/decoder/hg.cc b/decoder/hg.cc index dd272221..7240a8ab 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -605,69 +605,6 @@ void Hypergraph::TopologicallySortNodesAndEdges(int goal_index,  #endif  } -TRulePtr Hypergraph::kEPSRule; -TRulePtr Hypergraph::kUnaryRule; - -void Hypergraph::EpsilonRemove(WordID eps) { -  if (!kEPSRule) { -    kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>")); -    kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); -  } -  vector<bool> kill(edges_.size(), false); -  for (unsigned i = 0; i < edges_.size(); ++i) { -    const Edge& edge = edges_[i]; -    if (edge.tail_nodes_.empty() && -        edge.rule_->f_.size() == 1 && -        edge.rule_->f_[0] == eps) { -      kill[i] = true; -      if (!edge.feature_values_.empty()) { -        Node& node = nodes_[edge.head_node_]; -        if (node.in_edges_.size() != 1) { -          cerr << "[WARNING] <eps> edge with features going into non-empty node - can't promote\n"; -          // this *probably* means that there are multiple derivations of the -          // same sequence via different paths through the input forest -          // this needs to be investigated and fixed -        } else { -          for (unsigned j = 0; j < node.out_edges_.size(); ++j) -            edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_; -          // cerr << "PROMOTED " << edge.feature_values_ << endl; -	} -      } -    } -  } -  bool created_eps = false; -  PruneEdges(kill); -  for (unsigned i = 0; i < nodes_.size(); ++i) { -    const Node& node = nodes_[i]; -    if (node.in_edges_.empty()) { -      for (unsigned j = 0; j < node.out_edges_.size(); ++j) { -        Edge& edge = edges_[node.out_edges_[j]]; -        if (edge.rule_->Arity() == 2) { -          assert(edge.rule_->f_.size() == 2); -          assert(edge.rule_->e_.size() == 2); -          edge.rule_ = kUnaryRule; -          unsigned cur = node.id_; -          int t = -1; -          assert(edge.tail_nodes_.size() == 2); -          for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; } -          assert(t != -1); -          edge.tail_nodes_.resize(1); -          edge.tail_nodes_[0] = t; -        } else { -          edge.rule_ = kEPSRule; -          edge.rule_->f_[0] = eps; -          edge.rule_->e_[0] = eps; -          edge.tail_nodes_.clear(); -          created_eps = true; -        } -      } -    } -  } -  vector<bool> k2(edges_.size(), false); -  PruneEdges(k2); -  if (created_eps) EpsilonRemove(eps); -} -  struct EdgeWeightSorter {    const Hypergraph& hg;    EdgeWeightSorter(const Hypergraph& h) : hg(h) {} diff --git a/decoder/hg.h b/decoder/hg.h index 91d25f01..591e98ce 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -148,7 +148,7 @@ public:      void show(std::ostream &o,unsigned mask=SPAN|RULE) const {        o<<'{';        if (mask&CATEGORY) -        o<<TD::Convert(rule_->GetLHS()); +        o<< '[' << TD::Convert(-rule_->GetLHS()) << ']';        if (mask&PREV_SPAN)          o<<'<'<<prev_i_<<','<<prev_j_<<'>';        if (mask&SPAN) @@ -156,9 +156,9 @@ public:        if (mask&PROB)          o<<" p="<<edge_prob_;        if (mask&FEATURES) -        o<<" "<<feature_values_; +        o<<' '<<feature_values_;        if (mask&RULE) -        o<<rule_->AsString(mask&RULE_LHS); +        o<<' '<<rule_->AsString(mask&RULE_LHS);        if (USE_INFO_EDGE) {          std::string const& i=info();          if (mask&&!i.empty()) o << " |||"<<i; // remember, the initial space is expected as part of i @@ -384,14 +384,6 @@ public:    // compute the total number of paths in the forest    double NumberOfPaths() const; -  // BEWARE. this assumes that the source and target language -  // strings are identical and that there are no loops. -  // It assumes a bunch of other things about where the -  // epsilons will be.  It tries to assert failure if you -  // break these assumptions, but it may not. -  // TODO - make this work -  void EpsilonRemove(WordID eps); -    // multiple the weights vector by the edge feature vector    // (inner product) to set the edge probabilities    template <class V> @@ -535,9 +527,6 @@ public:  private:    Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {} - -  static TRulePtr kEPSRule; -  static TRulePtr kUnaryRule;  }; diff --git a/decoder/hg_remove_eps.cc b/decoder/hg_remove_eps.cc new file mode 100644 index 00000000..050c4876 --- /dev/null +++ b/decoder/hg_remove_eps.cc @@ -0,0 +1,91 @@ +#include "hg_remove_eps.h" + +#include <cassert> + +#include "trule.h" +#include "hg.h" + +using namespace std; + +namespace { +  TRulePtr kEPSRule; +  TRulePtr kUnaryRule; + +  TRulePtr CreateUnaryRule(int lhs, int rhs) { +    if (!kUnaryRule) kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); +    TRule* r = new TRule(*kUnaryRule); +    assert(lhs < 0); +    assert(rhs < 0); +    r->lhs_ = lhs; +    r->f_[0] = rhs; +    return TRulePtr(r); +  } + +  TRulePtr CreateEpsilonRule(int lhs, WordID eps) { +    if (!kEPSRule) kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>")); +    TRule* r = new TRule(*kEPSRule); +    r->lhs_ = lhs; +    assert(lhs < 0); +    assert(eps > 0); +    r->e_[0] = eps; +    r->f_[0] = eps; +    return TRulePtr(r); +  } +} + +void RemoveEpsilons(Hypergraph* g, WordID eps) { +  vector<bool> kill(g->edges_.size(), false); +  for (unsigned i = 0; i < g->edges_.size(); ++i) { +    const Hypergraph::Edge& edge = g->edges_[i]; +    if (edge.tail_nodes_.empty() && +        edge.rule_->f_.size() == 1 && +        edge.rule_->f_[0] == eps) { +      kill[i] = true; +      if (!edge.feature_values_.empty()) { +        Hypergraph::Node& node = g->nodes_[edge.head_node_]; +        if (node.in_edges_.size() != 1) { +          cerr << "[WARNING] <eps> edge with features going into non-empty node - can't promote\n"; +          // this *probably* means that there are multiple derivations of the +          // same sequence via different paths through the input forest +          // this needs to be investigated and fixed +        } else { +          for (unsigned j = 0; j < node.out_edges_.size(); ++j) +            g->edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_; +          // cerr << "PROMOTED " << edge.feature_values_ << endl; +	} +      } +    } +  } +  bool created_eps = false; +  g->PruneEdges(kill); +  for (unsigned i = 0; i < g->nodes_.size(); ++i) { +    const Hypergraph::Node& node = g->nodes_[i]; +    if (node.in_edges_.empty()) { +      for (unsigned j = 0; j < node.out_edges_.size(); ++j) { +        Hypergraph::Edge& edge = g->edges_[node.out_edges_[j]]; +        const int lhs = edge.rule_->lhs_; +        if (edge.rule_->Arity() == 2) { +          assert(edge.rule_->f_.size() == 2); +          assert(edge.rule_->e_.size() == 2); +          unsigned cur = node.id_; +          int t = -1; +          assert(edge.tail_nodes_.size() == 2); +          int rhs = 0; +          for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; rhs = edge.rule_->f_[i]; } +          assert(t != -1); +          edge.tail_nodes_.resize(1); +          edge.tail_nodes_[0] = t; +          edge.rule_ = CreateUnaryRule(lhs, rhs); +        } else { +          edge.rule_ = CreateEpsilonRule(lhs, eps); +          edge.tail_nodes_.clear(); +          created_eps = true; +        } +      } +    } +  } +  vector<bool> k2(g->edges_.size(), false); +  g->PruneEdges(k2); +  if (created_eps) RemoveEpsilons(g, eps); +} + diff --git a/decoder/hg_remove_eps.h b/decoder/hg_remove_eps.h new file mode 100644 index 00000000..82f06039 --- /dev/null +++ b/decoder/hg_remove_eps.h @@ -0,0 +1,13 @@ +#ifndef _HG_REMOVE_EPS_H_ +#define _HG_REMOVE_EPS_H_ + +#include "wordid.h" +class Hypergraph; + +// This is not a complete implementation of the general algorithm for +// doing this. It makes a few weird assumptions, for example, that +// if some nonterminal X rewrites as eps, then that is the only thing +// that it rewrites as. This needs to be fixed for the general case! +void RemoveEpsilons(Hypergraph* g, WordID eps); + +#endif diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc new file mode 100644 index 00000000..5c417393 --- /dev/null +++ b/decoder/rescore_translator.cc @@ -0,0 +1,57 @@ +#include "translator.h" + +#include <sstream> +#include <boost/shared_ptr.hpp> + +#include "sentence_metadata.h" +#include "hg.h" +#include "hg_io.h" +#include "tdict.h" + +using namespace std; + +struct RescoreTranslatorImpl { +  RescoreTranslatorImpl(const boost::program_options::variables_map& conf) : +      goal_sym(conf["goal"].as<string>()), +      kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")), +      kGOAL(TD::Convert("Goal") * -1) { +  } + +  bool Translate(const string& input, +                 const vector<double>& weights, +                 Hypergraph* forest) { +    if (input.find("{\"rules\"") == 0) { +      istringstream is(input); +      Hypergraph src_cfg_hg; +      if (!HypergraphIO::ReadFromJSON(&is, forest)) { +        cerr << "Parse error while reading HG from JSON.\n"; +        abort(); +      } +    } else { +      cerr << "Can only read HG input from JSON: use training/grammar_convert\n"; +      abort(); +    } +    Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); +    Hypergraph::Node* goal = forest->AddNode(kGOAL); +    Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); +    forest->ConnectEdgeToHeadNode(hg_edge, goal); +    forest->Reweight(weights); +    return true; +  } + +  const string goal_sym; +  const TRulePtr kGOAL_RULE; +  const WordID kGOAL; +}; + +RescoreTranslator::RescoreTranslator(const boost::program_options::variables_map& conf) : +  pimpl_(new RescoreTranslatorImpl(conf)) {} + +bool RescoreTranslator::TranslateImpl(const string& input, +                              SentenceMetadata* smeta, +                              const vector<double>& weights, +                              Hypergraph* minus_lm_forest) { +  smeta->SetSourceLength(0);  // don't know how to compute this +  return pimpl_->Translate(input, weights, minus_lm_forest); +} + diff --git a/decoder/translator.h b/decoder/translator.h index fc2bb760..c0800e84 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -85,4 +85,17 @@ class FSTTranslator : public Translator {    boost::shared_ptr<FSTTranslatorImpl> pimpl_;  }; +class RescoreTranslatorImpl; +class RescoreTranslator : public Translator { + public: +  RescoreTranslator(const boost::program_options::variables_map& conf); + private: +  bool TranslateImpl(const std::string& src, +                 SentenceMetadata* smeta, +                 const std::vector<double>& weights, +                 Hypergraph* minus_lm_forest); + private: +  boost::shared_ptr<RescoreTranslatorImpl> pimpl_; +}; +  #endif diff --git a/decoder/trule.cc b/decoder/trule.cc index 187a003d..896f9f3d 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -237,9 +237,9 @@ void TRule::ComputeArity() {  string TRule::AsString(bool verbose) const {    ostringstream os;    int idx = 0; -  if (lhs_ && verbose) { +  if (lhs_) {      os << '[' << TD::Convert(lhs_ * -1) << "] |||"; -  } +  } else { os << "NOLHS |||"; }    for (unsigned i = 0; i < f_.size(); ++i) {      const WordID& w = f_[i];      if (w < 0) { | 
