From 72acccb0da4a54369f32217c1618527956adacac Mon Sep 17 00:00:00 2001 From: redpony Date: Tue, 5 Oct 2010 17:33:45 +0000 Subject: add PrepareForInput to ff interface, make sgml markup available to feature functions git-svn-id: https://ws10smt.googlecode.com/svn/trunk@669 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/decoder.cc | 5 ++++- decoder/ff.cc | 6 ++++++ decoder/ff.h | 8 ++++++++ decoder/scfg_translator.cc | 14 +++++++------- decoder/sentence_metadata.h | 10 ++++++++++ 5 files changed, 35 insertions(+), 8 deletions(-) (limited to 'decoder') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 537fdffa..b975a5fc 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -662,9 +662,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { //FIXME: should get the avg. or max source length of the input lattice (like Lattice::dist_(start,end)); but this is only used to scale beam parameters (optionally) anyway so fidelity isn't important. const bool has_ref = ref.size() > 0; SentenceMetadata smeta(sent_id, ref); + smeta.sgml_.swap(sgml); o->NotifyDecodingStart(smeta); Hypergraph forest; // -LM forest - translator->ProcessMarkupHints(sgml); + translator->ProcessMarkupHints(smeta.sgml_); Timer t("Translation"); const bool translation_successful = translator->Translate(to_translate, &smeta, feature_weights, &forest); @@ -696,6 +697,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { Timer t("prelm rescoring"); forest.Reweight(prelm_feature_weights); Hypergraph prelm_forest; + prelm_models->PrepareForInput(smeta); ApplyModelSet(forest, smeta, *prelm_models, @@ -713,6 +715,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { bool has_late_models = !late_models->empty(); if (has_late_models) { Timer t("Forest rescoring:"); + late_models->PrepareForInput(smeta); forest.Reweight(feature_weights); Hypergraph lm_forest; ApplyModelSet(forest, diff --git a/decoder/ff.cc b/decoder/ff.cc index a32c0dcb..1258bc79 100644 --- a/decoder/ff.cc +++ b/decoder/ff.cc @@ -13,6 +13,7 @@ using namespace std; FeatureFunction::~FeatureFunction() {} +void FeatureFunction::PrepareForInput(const SentenceMetadata&) {} void FeatureFunction::FinalTraversalFeatures(const void* /* ant_state */, SparseVector* /* features */) const { @@ -163,6 +164,11 @@ ModelSet::ModelSet(const vector& w, const vector } } +void ModelSet::PrepareForInput(const SentenceMetadata& smeta) { + for (int i = 0; i < models_.size(); ++i) + const_cast(models_[i])->PrepareForInput(smeta); +} + void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, const Hypergraph& /* hg */, const FFStates& node_states, diff --git a/decoder/ff.h b/decoder/ff.h index 904b9eb8..e470f9a9 100644 --- a/decoder/ff.h +++ b/decoder/ff.h @@ -52,6 +52,10 @@ public: // stateless feature that doesn't depend on source span: override and return true. then your feature can be precomputed over rules. virtual bool rule_feature() const { return false; } + // called once, per input, before any feature calls to TraversalFeatures, etc. + // used to initialize sentence-specific data structures + virtual void PrepareForInput(const SentenceMetadata& smeta); + //OVERRIDE THIS: virtual Features features() const { return single_feature(FD::Convert(name_)); } // returns the number of bytes of context that this feature function will @@ -274,6 +278,10 @@ class ModelSet { Hypergraph::Edge* edge, SentenceMetadata const& smeta) const; + // this is called once before any feature functions apply to a hypergraph + // it can be used to initialize sentence-specific data structures + void PrepareForInput(const SentenceMetadata& smeta); + bool empty() const { return models_.empty(); } bool stateless() const { return !state_size_; } diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 60123e6f..afe796a5 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -34,14 +34,14 @@ struct SCFGTranslatorImpl { if(conf.count("grammar")){ vector gfiles = conf["grammar"].as >(); for (int i = 0; i < gfiles.size(); ++i) { - if (!SILENT) cerr << "Reading SCFG grammar from " << gfiles[i] << endl; - TextGrammar* g = new TextGrammar(gfiles[i]); - g->SetMaxSpan(max_span_limit); - g->SetGrammarName(gfiles[i]); - grammars.push_back(GrammarPtr(g)); - } + if (!SILENT) cerr << "Reading SCFG grammar from " << gfiles[i] << endl; + TextGrammar* g = new TextGrammar(gfiles[i]); + g->SetMaxSpan(max_span_limit); + g->SetGrammarName(gfiles[i]); + grammars.push_back(GrammarPtr(g)); + } + if (!SILENT) cerr << endl; } - cerr << std::endl; if (conf.count("scfg_extra_glue_grammar")) { GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as()); g->SetGrammarName("ExtraGlueGrammar"); diff --git a/decoder/sentence_metadata.h b/decoder/sentence_metadata.h index c9a78dd3..eab9f15d 100644 --- a/decoder/sentence_metadata.h +++ b/decoder/sentence_metadata.h @@ -1,11 +1,14 @@ #ifndef _SENTENCE_METADATA_H_ #define _SENTENCE_METADATA_H_ +#include +#include #include #include "lattice.h" #include "scorer.h" struct SentenceMetadata { + friend class DecoderImpl; SentenceMetadata(int id, const Lattice& ref) : sent_id_(id), src_len_(-1), @@ -42,7 +45,14 @@ struct SentenceMetadata { const DocScorer& GetDocScorer() const { return *ds; } double GetDocLen() const {return doc_len;} + std::string GetSGMLValue(const std::string& key) const { + std::map::const_iterator it = sgml_.find(key); + if (it == sgml_.end()) return ""; + return it->second; + } + private: + std::map sgml_; const int sent_id_; // the following should be set, if possible, by the Translator int src_len_; -- cgit v1.2.3