diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/Makefile.am | 2 | ||||
| -rw-r--r-- | decoder/decoder.cc | 16 | ||||
| -rw-r--r-- | decoder/decoder.h | 6 | ||||
| -rw-r--r-- | decoder/earley_composer.cc | 38 | ||||
| -rw-r--r-- | decoder/ff_ngrams.cc | 85 | ||||
| -rw-r--r-- | decoder/ff_ngrams.h | 2 | ||||
| -rw-r--r-- | decoder/ff_tagger.cc | 17 | ||||
| -rw-r--r-- | decoder/hg.cc | 63 | ||||
| -rw-r--r-- | decoder/hg.h | 17 | ||||
| -rw-r--r-- | decoder/hg_io.cc | 1 | ||||
| -rw-r--r-- | decoder/hg_remove_eps.cc | 91 | ||||
| -rw-r--r-- | decoder/hg_remove_eps.h | 13 | ||||
| -rw-r--r-- | decoder/inside_outside.h | 4 | ||||
| -rw-r--r-- | decoder/rescore_translator.cc | 58 | ||||
| -rw-r--r-- | decoder/scfg_translator.cc | 70 | ||||
| -rw-r--r-- | decoder/tagger.cc | 1 | ||||
| -rw-r--r-- | decoder/translator.h | 17 | ||||
| -rw-r--r-- | decoder/trule.cc | 4 | 
18 files changed, 347 insertions, 158 deletions
| diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 00d01e53..0a792549 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -37,9 +37,11 @@ libcdec_a_SOURCES = \    fst_translator.cc \    csplit.cc \    translator.cc \ +  rescore_translator.cc \    scfg_translator.cc \    hg.cc \    hg_io.cc \ +  hg_remove_eps.cc \    decoder.cc \    hg_intersect.cc \    hg_sampler.cc \ diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 333f0fb6..a6f7b1ce 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -527,8 +527,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    }    formalism = LowercaseString(str("formalism",conf)); -  if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") { -    cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n"; +  if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { +    cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n";      cerr << dcmdline_options << endl;      exit(1);    } @@ -675,6 +675,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream      translator.reset(new LexicalTrans(conf));    else if (formalism == "lexalign")      translator.reset(new LexicalAlign(conf)); +  else if (formalism == "rescore") +    translator.reset(new RescoreTranslator(conf));    else if (formalism == "tagger")      translator.reset(new Tagger(conf));    else @@ -743,16 +745,14 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) {  }  vector<weight_t>& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); }  const vector<weight_t>& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); } -void Decoder::SetSupplementalGrammar(const std::string& grammar_string) { -  assert(pimpl_->translator->GetDecoderType() == "SCFG"); -  static_cast<SCFGTranslator&>(*pimpl_->translator).SetSupplementalGrammar(grammar_string); +void Decoder::AddSupplementalGrammar(GrammarPtr gp) { +  static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammar(gp);  } -void Decoder::SetSentenceGrammarFromString(const std::string& grammar_str) { +void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string) {    assert(pimpl_->translator->GetDecoderType() == "SCFG"); -  static_cast<SCFGTranslator&>(*pimpl_->translator).SetSentenceGrammarFromString(grammar_str); +  static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string);  } -  bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {    string buf = input;    NgramCache::Clear();   // clear ngram cache for remote LM (if used) diff --git a/decoder/decoder.h b/decoder/decoder.h index 6b2f7b16..bef2ff5e 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -37,6 +37,8 @@ struct DecoderObserver {    virtual void NotifyDecodingComplete(const SentenceMetadata& smeta);  }; +struct Grammar;  // TODO once the decoder interface is cleaned up, +                 // this should be somewhere else  struct Decoder {    Decoder(int argc, char** argv);    Decoder(std::istream* config_file); @@ -54,8 +56,8 @@ struct Decoder {    // add grammar rules (currently only supported by SCFG decoders)    // that will be used on subsequent calls to Decode. rules should be in standard    // text format. This function does NOT read from a file. -  void SetSupplementalGrammar(const std::string& grammar); -  void SetSentenceGrammarFromString(const std::string& grammar_str); +  void AddSupplementalGrammar(boost::shared_ptr<Grammar> gp); +  void AddSupplementalGrammarFromString(const std::string& grammar_string);   private:    boost::program_options::variables_map conf;    boost::shared_ptr<DecoderImpl> pimpl_; diff --git a/decoder/earley_composer.cc b/decoder/earley_composer.cc index d265d954..efce70a6 100644 --- a/decoder/earley_composer.cc +++ b/decoder/earley_composer.cc @@ -16,6 +16,7 @@  #include "sparse_vector.h"  #include "tdict.h"  #include "hg.h" +#include "hg_remove_eps.h"  using namespace std;  using namespace std::tr1; @@ -48,6 +49,27 @@ static void InitializeConstants() {  }  //////////////////////////////////////////////////////////// +TRulePtr CreateBinaryRule(int lhs, int rhs1, int rhs2) { +  TRule* r = new TRule(*kX1X2); +  r->lhs_ = lhs; +  r->f_[0] = rhs1; +  r->f_[1] = rhs2; +  return TRulePtr(r); +} + +TRulePtr CreateUnaryRule(int lhs, int rhs1) { +  TRule* r = new TRule(*kX1); +  r->lhs_ = lhs; +  r->f_[0] = rhs1; +  return TRulePtr(r); +} + +TRulePtr CreateEpsilonRule(int lhs) { +  TRule* r = new TRule(*kEPSRule); +  r->lhs_ = lhs; +  return TRulePtr(r); +} +  class EGrammarNode {    friend bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest);    friend void AddGrammarRule(const string& r, map<WordID, EGrammarNode>* g); @@ -356,7 +378,7 @@ class EarleyComposerImpl {      }      if (goal_node) {        forest->PruneUnreachable(goal_node->id_); -      forest->EpsilonRemove(kEPS); +      RemoveEpsilons(forest, kEPS);      }      FreeAll();      return goal_node; @@ -557,24 +579,30 @@ class EarleyComposerImpl {      }      Hypergraph::Node*& head_node = edge2node[edge];      if (!head_node) -      head_node = hg->AddNode(kPHRASE); +      head_node = hg->AddNode(edge->cat);      if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_0_ && edge->IsPassive()) {        assert(goal_node == NULL || goal_node == head_node);        goal_node = head_node;      } +    int rhs1 = 0; +    int rhs2 = 0;      Hypergraph::TailNodeVector tail;      SparseVector<double> extra;      if (edge->IsCreatedByPredict()) {        // extra.set_value(FD::Convert("predict"), 1);      } else if (edge->IsCreatedByScan()) {        tail.push_back(edge2node[edge->active_parent]->id_); +      rhs1 = edge->active_parent->cat;        if (tps) {          tail.push_back(tps->id_); +        rhs2 = kPHRASE;        }        //extra.set_value(FD::Convert("scan"), 1);      } else if (edge->IsCreatedByComplete()) {        tail.push_back(edge2node[edge->active_parent]->id_); +      rhs1 = edge->active_parent->cat;        tail.push_back(edge2node[edge->passive_parent]->id_); +      rhs2 = edge->passive_parent->cat;        //extra.set_value(FD::Convert("complete"), 1);      } else {        assert(!"unexpected edge type!"); @@ -592,11 +620,11 @@ class EarleyComposerImpl {  #endif      Hypergraph::Edge* hg_edge = NULL;      if (tail.size() == 0) { -      hg_edge = hg->AddEdge(kEPSRule, tail); +      hg_edge = hg->AddEdge(CreateEpsilonRule(edge->cat), tail);      } else if (tail.size() == 1) { -      hg_edge = hg->AddEdge(kX1, tail); +      hg_edge = hg->AddEdge(CreateUnaryRule(edge->cat, rhs1), tail);      } else if (tail.size() == 2) { -      hg_edge = hg->AddEdge(kX1X2, tail); +      hg_edge = hg->AddEdge(CreateBinaryRule(edge->cat, rhs1, rhs2), tail);      }      if (edge->features)        hg_edge->feature_values_ += *edge->features; diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc index d6d79f5e..9c13fdbb 100644 --- a/decoder/ff_ngrams.cc +++ b/decoder/ff_ngrams.cc @@ -48,6 +48,9 @@ struct State {  namespace {    string Escape(const string& x) { +    if (x.find('=') == string::npos && x.find(';') == string::npos) { +      return x; +    }      string y = x;      for (int i = 0; i < y.size(); ++i) {        if (y[i] == '=') y[i]='_'; @@ -57,10 +60,17 @@ namespace {    }  } -static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order) { +static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector<string>& prefixes, string& target_separator) {    vector<string> const& argv=SplitOnWhitespace(in);    *explicit_markers = false;    *order = 3; +  prefixes.push_back("NOT-USED"); +  prefixes.push_back("U:"); // default unigram prefix +  prefixes.push_back("B:"); // default bigram prefix +  prefixes.push_back("T:"); // ...etc +  prefixes.push_back("4:"); // ...etc +  prefixes.push_back("5:"); // max allowed! +  target_separator = "_";  #define LMSPEC_NEXTARG if (i==argv.end()) {            \      cerr << "Missing argument for "<<*last<<". "; goto usage; \      } else { ++i; } @@ -73,6 +83,30 @@ static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order)        case 'x':          *explicit_markers = true;          break; +      case 'U': +	LMSPEC_NEXTARG; +	prefixes[1] = *i; +	break; +      case 'B': +	LMSPEC_NEXTARG; +	prefixes[2] = *i; +	break; +      case 'T': +	LMSPEC_NEXTARG; +	prefixes[3] = *i; +	break; +      case '4': +	LMSPEC_NEXTARG; +	prefixes[4] = *i; +	break; +      case '5': +	LMSPEC_NEXTARG; +	prefixes[5] = *i; +	break; +      case 'S': +	LMSPEC_NEXTARG; +	target_separator = *i; +	break;        case 'o':          LMSPEC_NEXTARG; *order=atoi((*i).c_str());          break; @@ -86,7 +120,29 @@ static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order)    }    return true;  usage: -  cerr << "NgramFeatures is incorrect!\n"; +  cerr << "Wrong parameters for NgramFeatures.\n\n" + +       << "NgramFeatures Usage: \n"			      +       << " feature_function=NgramFeatures filename.lm [-x] [-o <order>] \n" +       << " [-U <unigram-prefix>] [-B <bigram-prefix>][-T <trigram-prefix>]\n" +       << " [-4 <4-gram-prefix>] [-5 <5-gram-prefix>] [-S <separator>]\n\n"  +     +       << "Defaults: \n" +       << "  <order>          = 3\n"  +       << "  <unigram-prefix> = U:\n" +       << "  <bigram-prefix>  = B:\n" +       << "  <trigram-prefix> = T:\n" +       << "  <4-gram-prefix>  = 4:\n" +       << "  <5-gram-prefix>  = 5:\n" +       << "  <separator>      = _\n" +       << "  -x (i.e. explicit sos/eos markers) is turned off\n\n" + +       << "Example configuration: \n" +       << "  feature_function=NgramFeatures -o 3 -T tri: -S |\n\n" + +       << "Example feature instantiation: \n" +       << "  tri:a|b|c \n\n"; +    return false;  } @@ -158,16 +214,12 @@ class NgramDetectorImpl {        int& fid = ft->fids[curword];        ++n;        if (!fid) { -        const char* code="_UBT456789"; // prefix code (unigram, bigram, etc.)          ostringstream os; -        os << code[n] << ':'; +        os << prefixes_[n];          for (int i = n-1; i >= 0; --i) { -          os << (i != n-1 ? "_" : ""); +          os << (i != n-1 ? target_separator_ : "");            const string& tok = TD::Convert(buf[i]); -          if (tok.find('=') == string::npos) -            os << tok; -          else -            os << Escape(tok); +	  os << Escape(tok);          }          fid = FD::Convert(os.str());        } @@ -297,7 +349,8 @@ class NgramDetectorImpl {    }   public: -  explicit NgramDetectorImpl(bool explicit_markers, unsigned order) : +  explicit NgramDetectorImpl(bool explicit_markers, unsigned order, +			     vector<string>& prefixes, string& target_separator) :        kCDEC_UNK(TD::Convert("<unk>")) ,        add_sos_eos_(!explicit_markers) {      order_ = order; @@ -305,6 +358,8 @@ class NgramDetectorImpl {      unscored_size_offset_ = (order_ - 1) * sizeof(WordID);      is_complete_offset_ = unscored_size_offset_ + 1;      unscored_words_offset_ = is_complete_offset_ + 1; +    prefixes_ = prefixes; +    target_separator_ = target_separator;      // special handling of beginning / ending sentence markers      dummy_state_ = new char[state_size_]; @@ -340,6 +395,8 @@ class NgramDetectorImpl {    char* dummy_state_;    vector<const void*> dummy_ants_;    TRulePtr dummy_rule_; +  vector<string> prefixes_; +  string target_separator_;    struct FidTree {      map<WordID, int> fids;      map<WordID, FidTree> levels; @@ -348,11 +405,13 @@ class NgramDetectorImpl {  };  NgramDetector::NgramDetector(const string& param) { -  string filename, mapfile, featname; +  string filename, mapfile, featname, target_separator; +  vector<string> prefixes;    bool explicit_markers = false;    unsigned order = 3; -  ParseArgs(param, &explicit_markers, &order); -  pimpl_ = new NgramDetectorImpl(explicit_markers, order); +  ParseArgs(param, &explicit_markers, &order, prefixes, target_separator); +  pimpl_ = new NgramDetectorImpl(explicit_markers, order, prefixes,  +				 target_separator);    SetStateSize(pimpl_->ReserveStateSize());  } diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h index 82f61b33..064dbb49 100644 --- a/decoder/ff_ngrams.h +++ b/decoder/ff_ngrams.h @@ -10,7 +10,7 @@  struct NgramDetectorImpl;  class NgramDetector : public FeatureFunction {   public: -  // param = "filename.lm [-o n]" +  // param = "filename.lm [-o <order>] [-U <unigram-prefix>] [-B <bigram-prefix>] [-T <trigram-prefix>] [-4 <4-gram-prefix>] [-5 <5-gram-prefix>] [-S <separator>]    NgramDetector(const std::string& param);    ~NgramDetector();    virtual void FinalTraversalFeatures(const void* context, diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc index 019315a2..fd9210fa 100644 --- a/decoder/ff_tagger.cc +++ b/decoder/ff_tagger.cc @@ -8,6 +8,17 @@  using namespace std; +namespace { +  string Escape(const string& x) { +    string y = x; +    for (int i = 0; i < y.size(); ++i) { +      if (y[i] == '=') y[i]='_'; +      if (y[i] == ';') y[i]='_'; +    } +    return y; +  } +} +  Tagger_BigramIndicator::Tagger_BigramIndicator(const std::string& param) :    FeatureFunction(sizeof(WordID)) {     no_uni_ = (LowercaseString(param) == "no_uni"); @@ -28,7 +39,7 @@ void Tagger_BigramIndicator::FireFeature(const WordID& left,        os << '_';        if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); }      } -    fid = FD::Convert(os.str()); +    fid = FD::Convert(Escape(os.str()));    }    features->set_value(fid, 1.0);  } @@ -90,7 +101,7 @@ void LexicalPairIndicator::FireFeature(WordID src,    if (!fid) {      ostringstream os;      os << name_ << ':' << TD::Convert(src) << ':' << TD::Convert(trg); -    fid = FD::Convert(os.str()); +    fid = FD::Convert(Escape(os.str()));    }    features->set_value(fid, 1.0);  } @@ -127,7 +138,7 @@ void OutputIndicator::FireFeature(WordID trg,      if (escape.count(trg)) trg = escape[trg];      ostringstream os;      os << "T:" << TD::Convert(trg); -    fid = FD::Convert(os.str()); +    fid = FD::Convert(Escape(os.str()));    }    features->set_value(fid, 1.0);  } diff --git a/decoder/hg.cc b/decoder/hg.cc index dd272221..7240a8ab 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -605,69 +605,6 @@ void Hypergraph::TopologicallySortNodesAndEdges(int goal_index,  #endif  } -TRulePtr Hypergraph::kEPSRule; -TRulePtr Hypergraph::kUnaryRule; - -void Hypergraph::EpsilonRemove(WordID eps) { -  if (!kEPSRule) { -    kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>")); -    kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); -  } -  vector<bool> kill(edges_.size(), false); -  for (unsigned i = 0; i < edges_.size(); ++i) { -    const Edge& edge = edges_[i]; -    if (edge.tail_nodes_.empty() && -        edge.rule_->f_.size() == 1 && -        edge.rule_->f_[0] == eps) { -      kill[i] = true; -      if (!edge.feature_values_.empty()) { -        Node& node = nodes_[edge.head_node_]; -        if (node.in_edges_.size() != 1) { -          cerr << "[WARNING] <eps> edge with features going into non-empty node - can't promote\n"; -          // this *probably* means that there are multiple derivations of the -          // same sequence via different paths through the input forest -          // this needs to be investigated and fixed -        } else { -          for (unsigned j = 0; j < node.out_edges_.size(); ++j) -            edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_; -          // cerr << "PROMOTED " << edge.feature_values_ << endl; -	} -      } -    } -  } -  bool created_eps = false; -  PruneEdges(kill); -  for (unsigned i = 0; i < nodes_.size(); ++i) { -    const Node& node = nodes_[i]; -    if (node.in_edges_.empty()) { -      for (unsigned j = 0; j < node.out_edges_.size(); ++j) { -        Edge& edge = edges_[node.out_edges_[j]]; -        if (edge.rule_->Arity() == 2) { -          assert(edge.rule_->f_.size() == 2); -          assert(edge.rule_->e_.size() == 2); -          edge.rule_ = kUnaryRule; -          unsigned cur = node.id_; -          int t = -1; -          assert(edge.tail_nodes_.size() == 2); -          for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; } -          assert(t != -1); -          edge.tail_nodes_.resize(1); -          edge.tail_nodes_[0] = t; -        } else { -          edge.rule_ = kEPSRule; -          edge.rule_->f_[0] = eps; -          edge.rule_->e_[0] = eps; -          edge.tail_nodes_.clear(); -          created_eps = true; -        } -      } -    } -  } -  vector<bool> k2(edges_.size(), false); -  PruneEdges(k2); -  if (created_eps) EpsilonRemove(eps); -} -  struct EdgeWeightSorter {    const Hypergraph& hg;    EdgeWeightSorter(const Hypergraph& h) : hg(h) {} diff --git a/decoder/hg.h b/decoder/hg.h index 91d25f01..591e98ce 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -148,7 +148,7 @@ public:      void show(std::ostream &o,unsigned mask=SPAN|RULE) const {        o<<'{';        if (mask&CATEGORY) -        o<<TD::Convert(rule_->GetLHS()); +        o<< '[' << TD::Convert(-rule_->GetLHS()) << ']';        if (mask&PREV_SPAN)          o<<'<'<<prev_i_<<','<<prev_j_<<'>';        if (mask&SPAN) @@ -156,9 +156,9 @@ public:        if (mask&PROB)          o<<" p="<<edge_prob_;        if (mask&FEATURES) -        o<<" "<<feature_values_; +        o<<' '<<feature_values_;        if (mask&RULE) -        o<<rule_->AsString(mask&RULE_LHS); +        o<<' '<<rule_->AsString(mask&RULE_LHS);        if (USE_INFO_EDGE) {          std::string const& i=info();          if (mask&&!i.empty()) o << " |||"<<i; // remember, the initial space is expected as part of i @@ -384,14 +384,6 @@ public:    // compute the total number of paths in the forest    double NumberOfPaths() const; -  // BEWARE. this assumes that the source and target language -  // strings are identical and that there are no loops. -  // It assumes a bunch of other things about where the -  // epsilons will be.  It tries to assert failure if you -  // break these assumptions, but it may not. -  // TODO - make this work -  void EpsilonRemove(WordID eps); -    // multiple the weights vector by the edge feature vector    // (inner product) to set the edge probabilities    template <class V> @@ -535,9 +527,6 @@ public:  private:    Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {} - -  static TRulePtr kEPSRule; -  static TRulePtr kUnaryRule;  }; diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index bfb2fb80..8bd40387 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -261,6 +261,7 @@ static void WriteRule(const TRule& r, ostream* out) {  }  bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) { +  if (hg.empty()) { *out << "{}\n"; return true; }    map<const TRule*, int> rid;    ostream& o = *out;    rid[NULL] = 0; diff --git a/decoder/hg_remove_eps.cc b/decoder/hg_remove_eps.cc new file mode 100644 index 00000000..050c4876 --- /dev/null +++ b/decoder/hg_remove_eps.cc @@ -0,0 +1,91 @@ +#include "hg_remove_eps.h" + +#include <cassert> + +#include "trule.h" +#include "hg.h" + +using namespace std; + +namespace { +  TRulePtr kEPSRule; +  TRulePtr kUnaryRule; + +  TRulePtr CreateUnaryRule(int lhs, int rhs) { +    if (!kUnaryRule) kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); +    TRule* r = new TRule(*kUnaryRule); +    assert(lhs < 0); +    assert(rhs < 0); +    r->lhs_ = lhs; +    r->f_[0] = rhs; +    return TRulePtr(r); +  } + +  TRulePtr CreateEpsilonRule(int lhs, WordID eps) { +    if (!kEPSRule) kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>")); +    TRule* r = new TRule(*kEPSRule); +    r->lhs_ = lhs; +    assert(lhs < 0); +    assert(eps > 0); +    r->e_[0] = eps; +    r->f_[0] = eps; +    return TRulePtr(r); +  } +} + +void RemoveEpsilons(Hypergraph* g, WordID eps) { +  vector<bool> kill(g->edges_.size(), false); +  for (unsigned i = 0; i < g->edges_.size(); ++i) { +    const Hypergraph::Edge& edge = g->edges_[i]; +    if (edge.tail_nodes_.empty() && +        edge.rule_->f_.size() == 1 && +        edge.rule_->f_[0] == eps) { +      kill[i] = true; +      if (!edge.feature_values_.empty()) { +        Hypergraph::Node& node = g->nodes_[edge.head_node_]; +        if (node.in_edges_.size() != 1) { +          cerr << "[WARNING] <eps> edge with features going into non-empty node - can't promote\n"; +          // this *probably* means that there are multiple derivations of the +          // same sequence via different paths through the input forest +          // this needs to be investigated and fixed +        } else { +          for (unsigned j = 0; j < node.out_edges_.size(); ++j) +            g->edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_; +          // cerr << "PROMOTED " << edge.feature_values_ << endl; +	} +      } +    } +  } +  bool created_eps = false; +  g->PruneEdges(kill); +  for (unsigned i = 0; i < g->nodes_.size(); ++i) { +    const Hypergraph::Node& node = g->nodes_[i]; +    if (node.in_edges_.empty()) { +      for (unsigned j = 0; j < node.out_edges_.size(); ++j) { +        Hypergraph::Edge& edge = g->edges_[node.out_edges_[j]]; +        const int lhs = edge.rule_->lhs_; +        if (edge.rule_->Arity() == 2) { +          assert(edge.rule_->f_.size() == 2); +          assert(edge.rule_->e_.size() == 2); +          unsigned cur = node.id_; +          int t = -1; +          assert(edge.tail_nodes_.size() == 2); +          int rhs = 0; +          for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; rhs = edge.rule_->f_[i]; } +          assert(t != -1); +          edge.tail_nodes_.resize(1); +          edge.tail_nodes_[0] = t; +          edge.rule_ = CreateUnaryRule(lhs, rhs); +        } else { +          edge.rule_ = CreateEpsilonRule(lhs, eps); +          edge.tail_nodes_.clear(); +          created_eps = true; +        } +      } +    } +  } +  vector<bool> k2(g->edges_.size(), false); +  g->PruneEdges(k2); +  if (created_eps) RemoveEpsilons(g, eps); +} + diff --git a/decoder/hg_remove_eps.h b/decoder/hg_remove_eps.h new file mode 100644 index 00000000..82f06039 --- /dev/null +++ b/decoder/hg_remove_eps.h @@ -0,0 +1,13 @@ +#ifndef _HG_REMOVE_EPS_H_ +#define _HG_REMOVE_EPS_H_ + +#include "wordid.h" +class Hypergraph; + +// This is not a complete implementation of the general algorithm for +// doing this. It makes a few weird assumptions, for example, that +// if some nonterminal X rewrites as eps, then that is the only thing +// that it rewrites as. This needs to be fixed for the general case! +void RemoveEpsilons(Hypergraph* g, WordID eps); + +#endif diff --git a/decoder/inside_outside.h b/decoder/inside_outside.h index bb7f9fcc..f73a1d3f 100644 --- a/decoder/inside_outside.h +++ b/decoder/inside_outside.h @@ -41,10 +41,6 @@ WeightType Inside(const Hypergraph& hg,      WeightType* const cur_node_inside_score = &inside_score[i];      Hypergraph::EdgesVector const& in=hg.nodes_[i].in_edges_;      const unsigned num_in_edges = in.size(); -    if (num_in_edges == 0) { -      *cur_node_inside_score = WeightType(1); //FIXME: why not call weight(edge) instead? -      continue; -    }      for (unsigned j = 0; j < num_in_edges; ++j) {        const Hypergraph::Edge& edge = hg.edges_[in[j]];        WeightType score = weight(edge); diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc new file mode 100644 index 00000000..10192f7a --- /dev/null +++ b/decoder/rescore_translator.cc @@ -0,0 +1,58 @@ +#include "translator.h" + +#include <sstream> +#include <boost/shared_ptr.hpp> + +#include "sentence_metadata.h" +#include "hg.h" +#include "hg_io.h" +#include "tdict.h" + +using namespace std; + +struct RescoreTranslatorImpl { +  RescoreTranslatorImpl(const boost::program_options::variables_map& conf) : +      goal_sym(conf["goal"].as<string>()), +      kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")), +      kGOAL(TD::Convert("Goal") * -1) { +  } + +  bool Translate(const string& input, +                 const vector<double>& weights, +                 Hypergraph* forest) { +    if (input == "{}") return false; +    if (input.find("{\"rules\"") == 0) { +      istringstream is(input); +      Hypergraph src_cfg_hg; +      if (!HypergraphIO::ReadFromJSON(&is, forest)) { +        cerr << "Parse error while reading HG from JSON.\n"; +        abort(); +      } +    } else { +      cerr << "Can only read HG input from JSON: use training/grammar_convert\n"; +      abort(); +    } +    Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); +    Hypergraph::Node* goal = forest->AddNode(kGOAL); +    Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); +    forest->ConnectEdgeToHeadNode(hg_edge, goal); +    forest->Reweight(weights); +    return true; +  } + +  const string goal_sym; +  const TRulePtr kGOAL_RULE; +  const WordID kGOAL; +}; + +RescoreTranslator::RescoreTranslator(const boost::program_options::variables_map& conf) : +  pimpl_(new RescoreTranslatorImpl(conf)) {} + +bool RescoreTranslator::TranslateImpl(const string& input, +                              SentenceMetadata* smeta, +                              const vector<double>& weights, +                              Hypergraph* minus_lm_forest) { +  smeta->SetSourceLength(0);  // don't know how to compute this +  return pimpl_->Translate(input, weights, minus_lm_forest); +} + diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 185f979a..a978cfc2 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -20,7 +20,6 @@  #define reverse_foreach BOOST_REVERSE_FOREACH  using namespace std; -static bool usingSentenceGrammar = false;  static bool printGrammarsUsed = false;  struct SCFGTranslatorImpl { @@ -91,31 +90,31 @@ struct SCFGTranslatorImpl {    bool show_tree_structure_;    unsigned int ctf_iterations_;    vector<GrammarPtr> grammars; -  GrammarPtr sup_grammar_; +  set<GrammarPtr> sup_grammars_; -  struct Equals { Equals(const GrammarPtr& v) : v_(v) {} -                  bool operator()(const GrammarPtr& x) const { return x == v_; } const GrammarPtr& v_; }; +  struct ContainedIn { +    ContainedIn(const set<GrammarPtr>& gs) : gs_(gs) {} +    bool operator()(const GrammarPtr& x) const { return gs_.find(x) != gs_.end(); } +    const set<GrammarPtr>& gs_; +  }; -  void SetSupplementalGrammar(const std::string& grammar_string) { -    grammars.erase(remove_if(grammars.begin(), grammars.end(), Equals(sup_grammar_)), grammars.end()); +  void AddSupplementalGrammarFromString(const std::string& grammar_string) { +    grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end());      istringstream in(grammar_string); -    sup_grammar_.reset(new TextGrammar(&in)); -    grammars.push_back(sup_grammar_); +    TextGrammar* sent_grammar = new TextGrammar(&in); +    sent_grammar->SetMaxSpan(max_span_limit); +    sent_grammar->SetGrammarName("SupFromString"); +    AddSupplementalGrammar(GrammarPtr(sent_grammar));    } -  struct NameEquals { NameEquals(const string name) : name_(name) {} -                      bool operator()(const GrammarPtr& x) const { return x->GetGrammarName() == name_; } const string name_; }; +  void AddSupplementalGrammar(GrammarPtr gp) { +    sup_grammars_.insert(gp); +    grammars.push_back(gp); +  } -  void SetSentenceGrammarFromString(const std::string& grammar_str) { -    assert(grammar_str != ""); -    if (!SILENT) cerr << "Setting sentence grammar" << endl; -    usingSentenceGrammar = true; -    istringstream in(grammar_str); -    TextGrammar* sent_grammar = new TextGrammar(&in); -    sent_grammar->SetMaxSpan(max_span_limit); -    sent_grammar->SetGrammarName("__psg"); -    grammars.erase(remove_if(grammars.begin(), grammars.end(), NameEquals("__psg")), grammars.end()); -    grammars.push_back(GrammarPtr(sent_grammar)); +  void RemoveSupplementalGrammars() { +    grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end()); +    sup_grammars_.clear();    }    bool Translate(const string& input, @@ -300,35 +299,24 @@ Check for grammar pointer in the sentence markup, for use with sentence specific   */  void SCFGTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) {    map<string,string>::const_iterator it = kv.find("grammar"); - - -  if (it == kv.end()) { -    usingSentenceGrammar= false; -    return; +  if (it != kv.end()) { +    TextGrammar* sentGrammar = new TextGrammar(it->second); +    sentGrammar->SetMaxSpan(pimpl_->max_span_limit); +    sentGrammar->SetGrammarName(it->second); +    pimpl_->AddSupplementalGrammar(GrammarPtr(sentGrammar));    } -  //Create sentence specific grammar from specified file name and load grammar into list of grammars -  usingSentenceGrammar = true; -  TextGrammar* sentGrammar = new TextGrammar(it->second); -  sentGrammar->SetMaxSpan(pimpl_->max_span_limit); -  sentGrammar->SetGrammarName(it->second); -  pimpl_->grammars.push_back(GrammarPtr(sentGrammar)); -  } -void SCFGTranslator::SetSupplementalGrammar(const std::string& grammar) { -  pimpl_->SetSupplementalGrammar(grammar); +void SCFGTranslator::AddSupplementalGrammarFromString(const std::string& grammar) { +  pimpl_->AddSupplementalGrammarFromString(grammar);  } -void SCFGTranslator::SetSentenceGrammarFromString(const std::string& grammar_str) { -  pimpl_->SetSentenceGrammarFromString(grammar_str); +void SCFGTranslator::AddSupplementalGrammar(GrammarPtr grammar) { +  pimpl_->AddSupplementalGrammar(grammar);  }  void SCFGTranslator::SentenceCompleteImpl() { - -  if(usingSentenceGrammar)      // Drop the last sentence grammar from the list of grammars -    { -      pimpl_->grammars.pop_back(); -    } +  pimpl_->RemoveSupplementalGrammars();  }  std::string SCFGTranslator::GetDecoderType() const { diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 54890e85..63e855c8 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -54,6 +54,7 @@ struct TaggerImpl {        const int new_node_id = forest->AddNode(kXCAT)->id_;        for (int k = 0; k < tagset_.size(); ++k) {          TRulePtr rule(TRule::CreateLexicalRule(src, tagset_[k])); +        rule->lhs_ = kXCAT;          Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());          edge->i_ = i;          edge->j_ = i+1; diff --git a/decoder/translator.h b/decoder/translator.h index cfd3b08a..c0800e84 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -58,8 +58,8 @@ class SCFGTranslatorImpl;  class SCFGTranslator : public Translator {   public:    SCFGTranslator(const boost::program_options::variables_map& conf); -  void SetSupplementalGrammar(const std::string& grammar); -  void SetSentenceGrammarFromString(const std::string& grammar); +  void AddSupplementalGrammar(GrammarPtr gp); +  void AddSupplementalGrammarFromString(const std::string& grammar);    virtual std::string GetDecoderType() const;   protected:    bool TranslateImpl(const std::string& src, @@ -85,4 +85,17 @@ class FSTTranslator : public Translator {    boost::shared_ptr<FSTTranslatorImpl> pimpl_;  }; +class RescoreTranslatorImpl; +class RescoreTranslator : public Translator { + public: +  RescoreTranslator(const boost::program_options::variables_map& conf); + private: +  bool TranslateImpl(const std::string& src, +                 SentenceMetadata* smeta, +                 const std::vector<double>& weights, +                 Hypergraph* minus_lm_forest); + private: +  boost::shared_ptr<RescoreTranslatorImpl> pimpl_; +}; +  #endif diff --git a/decoder/trule.cc b/decoder/trule.cc index 187a003d..896f9f3d 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -237,9 +237,9 @@ void TRule::ComputeArity() {  string TRule::AsString(bool verbose) const {    ostringstream os;    int idx = 0; -  if (lhs_ && verbose) { +  if (lhs_) {      os << '[' << TD::Convert(lhs_ * -1) << "] |||"; -  } +  } else { os << "NOLHS |||"; }    for (unsigned i = 0; i < f_.size(); ++i) {      const WordID& w = f_[i];      if (w < 0) { | 
