diff options
| -rw-r--r-- | decoder/decoder.cc | 5 | ||||
| -rw-r--r-- | decoder/ff.cc | 6 | ||||
| -rw-r--r-- | decoder/ff.h | 8 | ||||
| -rw-r--r-- | decoder/scfg_translator.cc | 14 | ||||
| -rw-r--r-- | decoder/sentence_metadata.h | 10 | 
5 files changed, 35 insertions, 8 deletions
| diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 537fdffa..b975a5fc 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -662,9 +662,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {  //FIXME: should get the avg. or max source length of the input lattice (like Lattice::dist_(start,end)); but this is only used to scale beam parameters (optionally) anyway so fidelity isn't important.    const bool has_ref = ref.size() > 0;    SentenceMetadata smeta(sent_id, ref); +  smeta.sgml_.swap(sgml);    o->NotifyDecodingStart(smeta);    Hypergraph forest;          // -LM forest -  translator->ProcessMarkupHints(sgml); +  translator->ProcessMarkupHints(smeta.sgml_);    Timer t("Translation");    const bool translation_successful =      translator->Translate(to_translate, &smeta, feature_weights, &forest); @@ -696,6 +697,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {      Timer t("prelm rescoring");      forest.Reweight(prelm_feature_weights);      Hypergraph prelm_forest; +    prelm_models->PrepareForInput(smeta);      ApplyModelSet(forest,                    smeta,                    *prelm_models, @@ -713,6 +715,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {    bool has_late_models = !late_models->empty();    if (has_late_models) {      Timer t("Forest rescoring:"); +    late_models->PrepareForInput(smeta);      forest.Reweight(feature_weights);      Hypergraph lm_forest;      ApplyModelSet(forest, diff --git a/decoder/ff.cc b/decoder/ff.cc index a32c0dcb..1258bc79 100644 --- a/decoder/ff.cc +++ b/decoder/ff.cc @@ -13,6 +13,7 @@ using namespace std;  FeatureFunction::~FeatureFunction() {} +void FeatureFunction::PrepareForInput(const SentenceMetadata&) {}  void FeatureFunction::FinalTraversalFeatures(const void* /* ant_state */,                                               SparseVector<double>* /* features */) const { @@ -163,6 +164,11 @@ ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*>    }  } +void ModelSet::PrepareForInput(const SentenceMetadata& smeta) { +  for (int i = 0; i < models_.size(); ++i) +    const_cast<FeatureFunction*>(models_[i])->PrepareForInput(smeta); +} +  void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta,                                   const Hypergraph& /* hg */,                                   const FFStates& node_states, diff --git a/decoder/ff.h b/decoder/ff.h index 904b9eb8..e470f9a9 100644 --- a/decoder/ff.h +++ b/decoder/ff.h @@ -52,6 +52,10 @@ public:    // stateless feature that doesn't depend on source span: override and return true.  then your feature can be precomputed over rules.    virtual bool rule_feature() const { return false; } +  // called once, per input, before any feature calls to TraversalFeatures, etc. +  // used to initialize sentence-specific data structures +  virtual void PrepareForInput(const SentenceMetadata& smeta); +    //OVERRIDE THIS:    virtual Features features() const { return single_feature(FD::Convert(name_)); }    // returns the number of bytes of context that this feature function will @@ -274,6 +278,10 @@ class ModelSet {                          Hypergraph::Edge* edge,                          SentenceMetadata const& smeta) const; +  // this is called once before any feature functions apply to a hypergraph +  // it can be used to initialize sentence-specific data structures +  void PrepareForInput(const SentenceMetadata& smeta); +    bool empty() const { return models_.empty(); }    bool stateless() const { return !state_size_; } diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 60123e6f..afe796a5 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -34,14 +34,14 @@ struct SCFGTranslatorImpl {      if(conf.count("grammar")){        vector<string> gfiles = conf["grammar"].as<vector<string> >();        for (int i = 0; i < gfiles.size(); ++i) { -    	  if (!SILENT) cerr << "Reading SCFG grammar from " << gfiles[i] << endl; -    	  TextGrammar* g = new TextGrammar(gfiles[i]); -    	  g->SetMaxSpan(max_span_limit); -    	  g->SetGrammarName(gfiles[i]); -    	  grammars.push_back(GrammarPtr(g)); -	    } +        if (!SILENT) cerr << "Reading SCFG grammar from " << gfiles[i] << endl; +        TextGrammar* g = new TextGrammar(gfiles[i]); +        g->SetMaxSpan(max_span_limit); +        g->SetGrammarName(gfiles[i]); +        grammars.push_back(GrammarPtr(g)); +      } +      if (!SILENT) cerr << endl;      } -    cerr << std::endl;      if (conf.count("scfg_extra_glue_grammar")) {        GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>());        g->SetGrammarName("ExtraGlueGrammar"); diff --git a/decoder/sentence_metadata.h b/decoder/sentence_metadata.h index c9a78dd3..eab9f15d 100644 --- a/decoder/sentence_metadata.h +++ b/decoder/sentence_metadata.h @@ -1,11 +1,14 @@  #ifndef _SENTENCE_METADATA_H_  #define _SENTENCE_METADATA_H_ +#include <string> +#include <map>  #include <cassert>  #include "lattice.h"  #include "scorer.h"  struct SentenceMetadata { +  friend class DecoderImpl;    SentenceMetadata(int id, const Lattice& ref) :      sent_id_(id),      src_len_(-1), @@ -42,7 +45,14 @@ struct SentenceMetadata {    const DocScorer& GetDocScorer() const { return *ds; }    double GetDocLen() const {return doc_len;} +  std::string GetSGMLValue(const std::string& key) const { +    std::map<std::string, std::string>::const_iterator it = sgml_.find(key); +    if (it == sgml_.end()) return ""; +    return it->second; +  } +   private: +  std::map<std::string, std::string> sgml_;    const int sent_id_;    // the following should be set, if possible, by the Translator    int src_len_; | 
