summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/decoder.cc5
-rw-r--r--decoder/ff.cc6
-rw-r--r--decoder/ff.h8
-rw-r--r--decoder/scfg_translator.cc14
-rw-r--r--decoder/sentence_metadata.h10
5 files changed, 35 insertions, 8 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 537fdffa..b975a5fc 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -662,9 +662,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
//FIXME: should get the avg. or max source length of the input lattice (like Lattice::dist_(start,end)); but this is only used to scale beam parameters (optionally) anyway so fidelity isn't important.
const bool has_ref = ref.size() > 0;
SentenceMetadata smeta(sent_id, ref);
+ smeta.sgml_.swap(sgml);
o->NotifyDecodingStart(smeta);
Hypergraph forest; // -LM forest
- translator->ProcessMarkupHints(sgml);
+ translator->ProcessMarkupHints(smeta.sgml_);
Timer t("Translation");
const bool translation_successful =
translator->Translate(to_translate, &smeta, feature_weights, &forest);
@@ -696,6 +697,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Timer t("prelm rescoring");
forest.Reweight(prelm_feature_weights);
Hypergraph prelm_forest;
+ prelm_models->PrepareForInput(smeta);
ApplyModelSet(forest,
smeta,
*prelm_models,
@@ -713,6 +715,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
bool has_late_models = !late_models->empty();
if (has_late_models) {
Timer t("Forest rescoring:");
+ late_models->PrepareForInput(smeta);
forest.Reweight(feature_weights);
Hypergraph lm_forest;
ApplyModelSet(forest,
diff --git a/decoder/ff.cc b/decoder/ff.cc
index a32c0dcb..1258bc79 100644
--- a/decoder/ff.cc
+++ b/decoder/ff.cc
@@ -13,6 +13,7 @@ using namespace std;
FeatureFunction::~FeatureFunction() {}
+void FeatureFunction::PrepareForInput(const SentenceMetadata&) {}
void FeatureFunction::FinalTraversalFeatures(const void* /* ant_state */,
SparseVector<double>* /* features */) const {
@@ -163,6 +164,11 @@ ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*>
}
}
+void ModelSet::PrepareForInput(const SentenceMetadata& smeta) {
+ for (int i = 0; i < models_.size(); ++i)
+ const_cast<FeatureFunction*>(models_[i])->PrepareForInput(smeta);
+}
+
void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta,
const Hypergraph& /* hg */,
const FFStates& node_states,
diff --git a/decoder/ff.h b/decoder/ff.h
index 904b9eb8..e470f9a9 100644
--- a/decoder/ff.h
+++ b/decoder/ff.h
@@ -52,6 +52,10 @@ public:
// stateless feature that doesn't depend on source span: override and return true. then your feature can be precomputed over rules.
virtual bool rule_feature() const { return false; }
+ // called once, per input, before any feature calls to TraversalFeatures, etc.
+ // used to initialize sentence-specific data structures
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+
//OVERRIDE THIS:
virtual Features features() const { return single_feature(FD::Convert(name_)); }
// returns the number of bytes of context that this feature function will
@@ -274,6 +278,10 @@ class ModelSet {
Hypergraph::Edge* edge,
SentenceMetadata const& smeta) const;
+ // this is called once before any feature functions apply to a hypergraph
+ // it can be used to initialize sentence-specific data structures
+ void PrepareForInput(const SentenceMetadata& smeta);
+
bool empty() const { return models_.empty(); }
bool stateless() const { return !state_size_; }
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index 60123e6f..afe796a5 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -34,14 +34,14 @@ struct SCFGTranslatorImpl {
if(conf.count("grammar")){
vector<string> gfiles = conf["grammar"].as<vector<string> >();
for (int i = 0; i < gfiles.size(); ++i) {
- if (!SILENT) cerr << "Reading SCFG grammar from " << gfiles[i] << endl;
- TextGrammar* g = new TextGrammar(gfiles[i]);
- g->SetMaxSpan(max_span_limit);
- g->SetGrammarName(gfiles[i]);
- grammars.push_back(GrammarPtr(g));
- }
+ if (!SILENT) cerr << "Reading SCFG grammar from " << gfiles[i] << endl;
+ TextGrammar* g = new TextGrammar(gfiles[i]);
+ g->SetMaxSpan(max_span_limit);
+ g->SetGrammarName(gfiles[i]);
+ grammars.push_back(GrammarPtr(g));
+ }
+ if (!SILENT) cerr << endl;
}
- cerr << std::endl;
if (conf.count("scfg_extra_glue_grammar")) {
GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>());
g->SetGrammarName("ExtraGlueGrammar");
diff --git a/decoder/sentence_metadata.h b/decoder/sentence_metadata.h
index c9a78dd3..eab9f15d 100644
--- a/decoder/sentence_metadata.h
+++ b/decoder/sentence_metadata.h
@@ -1,11 +1,14 @@
#ifndef _SENTENCE_METADATA_H_
#define _SENTENCE_METADATA_H_
+#include <string>
+#include <map>
#include <cassert>
#include "lattice.h"
#include "scorer.h"
struct SentenceMetadata {
+ friend class DecoderImpl;
SentenceMetadata(int id, const Lattice& ref) :
sent_id_(id),
src_len_(-1),
@@ -42,7 +45,14 @@ struct SentenceMetadata {
const DocScorer& GetDocScorer() const { return *ds; }
double GetDocLen() const {return doc_len;}
+ std::string GetSGMLValue(const std::string& key) const {
+ std::map<std::string, std::string>::const_iterator it = sgml_.find(key);
+ if (it == sgml_.end()) return "";
+ return it->second;
+ }
+
private:
+ std::map<std::string, std::string> sgml_;
const int sent_id_;
// the following should be set, if possible, by the Translator
int src_len_;