summaryrefslogtreecommitdiff
path: root/decoder/decoder.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/decoder.cc')
-rw-r--r--decoder/decoder.cc64
1 files changed, 30 insertions, 34 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 25eb2de4..4d4b6245 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -159,8 +159,7 @@ struct RescoringPass {
shared_ptr<ModelSet> models;
shared_ptr<IntersectionConfiguration> inter_conf;
vector<const FeatureFunction*> ffs;
- shared_ptr<Weights> w; // null == use previous weights
- vector<double> weight_vector;
+ shared_ptr<vector<weight_t> > weight_vector;
int fid_summary; // 0 == no summary feature
double density_prune; // 0 == don't density prune
double beam_prune; // 0 == don't beam prune
@@ -169,7 +168,7 @@ struct RescoringPass {
ostream& operator<<(ostream& os, const RescoringPass& rp) {
os << "[num_fn=" << rp.ffs.size();
if (rp.inter_conf) { os << " int_alg=" << *rp.inter_conf; }
- if (rp.w) os << " new_weights";
+ //if (rp.weight_vector.size() > 0) os << " new_weights";
if (rp.fid_summary) os << " summary_feature=" << FD::Convert(rp.fid_summary);
if (rp.density_prune) os << " density_prune=" << rp.density_prune;
if (rp.beam_prune) os << " beam_prune=" << rp.beam_prune;
@@ -181,13 +180,8 @@ struct DecoderImpl {
DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg);
~DecoderImpl();
bool Decode(const string& input, DecoderObserver*);
- void SetWeights(const vector<double>& weights) {
- init_weights = weights;
- for (int i = 0; i < rescoring_passes.size(); ++i) {
- if (rescoring_passes[i].models)
- rescoring_passes[i].models->SetWeights(weights);
- rescoring_passes[i].weight_vector = weights;
- }
+ vector<weight_t>& CurrentWeightVector() {
+ return *rescoring_passes.back().weight_vector;
}
void SetId(int next_sent_id) { sent_id = next_sent_id - 1; }
@@ -300,8 +294,7 @@ struct DecoderImpl {
OracleBleu oracle;
string formalism;
shared_ptr<Translator> translator;
- Weights w_init_weights; // used with initial parse
- vector<double> init_weights; // weights used with initial parse
+ shared_ptr<vector<weight_t> > init_weights; // weights used with initial parse
vector<shared_ptr<FeatureFunction> > pffs;
#ifdef FSA_RESCORING
CFGOptions cfg_options;
@@ -557,13 +550,18 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
exit(1);
}
- // load initial feature weights (and possibly freeze feature set)
- if (conf.count("weights")) {
- w_init_weights.InitFromFile(str("weights",conf));
- w_init_weights.InitVector(&init_weights);
- init_weights.resize(FD::NumFeats());
+ // load perfect hash function for features
+ if (conf.count("cmph_perfect_feature_hash")) {
+ cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as<string>() << " ...\n";
+ FD::EnableHash(conf["cmph_perfect_feature_hash"].as<string>());
+ cerr << " " << FD::NumFeats() << " features in map\n";
}
+ // load initial feature weights (and possibly freeze feature set)
+ init_weights.reset(new vector<weight_t>);
+ if (conf.count("weights"))
+ Weights::InitFromFile(str("weights",conf), init_weights.get());
+
// cube pruning pop-limit: we may want to configure this on a per-pass basis
pop_limit = conf["cubepruning_pop_limit"].as<int>();
@@ -582,9 +580,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
RescoringPass& rp = rescoring_passes.back();
// only configure new weights if pass > 0, otherwise we reuse the initial chart weights
if (nth_pass_condition && conf.count(ws)) {
- rp.w.reset(new Weights);
- rp.w->InitFromFile(str(ws.c_str(), conf));
- rp.w->InitVector(&rp.weight_vector);
+ rp.weight_vector.reset(new vector<weight_t>());
+ Weights::InitFromFile(str(ws.c_str(), conf), rp.weight_vector.get());
}
bool has_stateful = false;
if (conf.count(ff)) {
@@ -624,11 +621,15 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
}
// set up weight vectors since later phases may reuse weights from earlier phases
- const vector<double>* prev = &init_weights;
+ shared_ptr<vector<weight_t> > prev_weights = init_weights;
for (int pass = 0; pass < rescoring_passes.size(); ++pass) {
RescoringPass& rp = rescoring_passes[pass];
- if (!rp.w) { rp.weight_vector = *prev; } else { prev = &rp.weight_vector; }
- rp.models.reset(new ModelSet(rp.weight_vector, rp.ffs));
+ if (!rp.weight_vector) {
+ rp.weight_vector = prev_weights;
+ } else {
+ prev_weights = rp.weight_vector;
+ }
+ rp.models.reset(new ModelSet(*rp.weight_vector, rp.ffs));
string ps = "Pass1 "; ps[4] += pass;
if (!SILENT) show_models(conf,*rp.models,ps.c_str());
}
@@ -650,12 +651,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
FD::Freeze(); // this means we can't see the feature names of not-weighted features
}
- if (conf.count("cmph_perfect_feature_hash")) {
- cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as<string>() << " ...\n";
- FD::EnableHash(conf["cmph_perfect_feature_hash"].as<string>());
- cerr << " " << FD::NumFeats() << " features in map\n";
- }
-
// set up translation back end
if (formalism == "scfg")
translator.reset(new SCFGTranslator(conf));
@@ -685,7 +680,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
}
if (!fsa_ffs.empty()) {
cerr<<"FSA: ";
- show_all_features(fsa_ffs,init_weights,cerr,cerr,true,true);
+ show_all_features(fsa_ffs,*init_weights,cerr,cerr,true,true);
}
#endif
@@ -733,7 +728,8 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) {
if (del) delete o;
return res;
}
-void Decoder::SetWeights(const vector<double>& weights) { pimpl_->SetWeights(weights); }
+vector<weight_t>& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); }
+const vector<weight_t>& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); }
void Decoder::SetSupplementalGrammar(const std::string& grammar_string) {
assert(pimpl_->translator->GetDecoderType() == "SCFG");
static_cast<SCFGTranslator&>(*pimpl_->translator).SetSupplementalGrammar(grammar_string);
@@ -774,7 +770,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
translator->ProcessMarkupHints(smeta.sgml_);
Timer t("Translation");
const bool translation_successful =
- translator->Translate(to_translate, &smeta, init_weights, &forest);
+ translator->Translate(to_translate, &smeta, *init_weights, &forest);
translator->SentenceComplete();
if (!translation_successful) {
@@ -812,7 +808,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
for (int pass = 0; pass < rescoring_passes.size(); ++pass) {
const RescoringPass& rp = rescoring_passes[pass];
- const vector<double>& cur_weights = rp.weight_vector;
+ const vector<weight_t>& cur_weights = *rp.weight_vector;
if (!SILENT) cerr << endl << " RESCORING PASS #" << (pass+1) << " " << rp << endl;
#ifdef FSA_RESCORING
cfg_options.maybe_output_source(forest);
@@ -933,7 +929,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
#endif
}
- const vector<double>& last_weights = (rescoring_passes.empty() ? init_weights : rescoring_passes.back().weight_vector);
+ const vector<double>& last_weights = (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector);
// Oracle Rescoring
if(get_oracle_forest) {