diff options
Diffstat (limited to 'decoder/decoder.cc')
| -rw-r--r-- | decoder/decoder.cc | 22 | 
1 files changed, 21 insertions, 1 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 95ff6270..8a03c5c9 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -141,12 +141,13 @@ inline shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose  // and then prune the resulting (rescored) hypergraph. All feature values from previous  // passes are carried over into subsequent passes (where they may have different weights).  struct RescoringPass { -  RescoringPass() : density_prune(), beam_prune() {} +  RescoringPass() : fid_summary(), density_prune(), beam_prune() {}    shared_ptr<ModelSet> models;    shared_ptr<IntersectionConfiguration> inter_conf;    vector<const FeatureFunction*> ffs;    shared_ptr<Weights> w;      // null == use previous weights    vector<double> weight_vector; +  int fid_summary;            // 0 == no summary feature    double density_prune;       // 0 == don't density prune    double beam_prune;          // 0 == don't beam prune  }; @@ -155,6 +156,7 @@ ostream& operator<<(ostream& os, const RescoringPass& rp) {    os << "[num_fn=" << rp.ffs.size();    if (rp.inter_conf) { os << " int_alg=" << *rp.inter_conf; }    if (rp.w) os << " new_weights"; +  if (rp.fid_summary) os << " summary_feature=" << FD::Convert(rp.fid_summary);    if (rp.density_prune) os << " density_prune=" << rp.density_prune;    if (rp.beam_prune) os << " beam_prune=" << rp.beam_prune;    os << ']'; @@ -361,18 +363,21 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream          ("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)")          ("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)")          ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full") +        ("summary_feature", po::value<string>(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z")          ("density_prune", po::value<double>(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")          ("beam_prune", po::value<double>(), "Pass 1 pruning: Prune paths from scored forest, keep paths within exp(alpha>=0)")          ("weights2",po::value<string>(),"Optional pass 2")          ("feature_function2",po::value<vector<string> >()->composing(), "Optional pass 2")          ("intersection_strategy2",po::value<string>()->default_value("cube_pruning"), "Optional pass 2") +        ("summary_feature2", po::value<string>(), "Optional pass 2")          ("density_prune2", po::value<double>(), "Optional pass 2")          ("beam_prune2", po::value<double>(), "Optional pass 2")          ("weights3",po::value<string>(),"Optional pass 3")          ("feature_function3",po::value<vector<string> >()->composing(), "Optional pass 3")          ("intersection_strategy3",po::value<string>()->default_value("cube_pruning"), "Optional pass 3") +        ("summary_feature3", po::value<string>(), "Optional pass 3")          ("density_prune3", po::value<double>(), "Optional pass 3")          ("beam_prune3", po::value<double>(), "Optional pass 3") @@ -559,6 +564,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    for (int pass = 0; pass < MAX_PASSES; ++pass) {      string ws = "weights" + StringSuffixForRescoringPass(pass);      string ff = "feature_function" + StringSuffixForRescoringPass(pass); +    string sf = "summary_feature" + StringSuffixForRescoringPass(pass);      string bp = "beam_prune" + StringSuffixForRescoringPass(pass);      string dp = "density_prune" + StringSuffixForRescoringPass(pass);      bool first_pass_condition = ((pass == 0) && (conf.count(ff) || conf.count(bp) || conf.count(dp))); @@ -583,6 +589,11 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream            if (p->IsStateful()) { has_stateful = true; }          }        } +      if (conf.count(sf)) { +        rp.fid_summary = FD::Convert(conf[sf].as<string>()); +        assert(rp.fid_summary > 0); +        // TODO assert that weights for this pass have coef(fid_summary) == 0.0? +      }        if (conf.count(bp)) { rp.beam_prune = conf[bp].as<double>(); }        if (conf.count(dp)) { rp.density_prune = conf[dp].as<double>(); }        int palg = (has_stateful ? 1 : 0);  // if there are no stateful featueres, default to FULL @@ -794,6 +805,15 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {        cerr << "  " << passtr << " partition     log(Z): " << log(z) << endl;      } +    if (rp.fid_summary) { +      Hypergraph::EdgeProbs posteriors; +      const prob_t z = forest.ComputeEdgePosteriors(1.0, &posteriors); +      if (!SILENT) { cerr << "  " << passtr << " adding summary feature " << FD::Convert(rp.fid_summary) << " log(Z)=" << log(z) << endl; } +      assert(forest.edges_.size() == posteriors.size()); +      for (int i = 0; i < posteriors.size(); ++i) +        forest.edges_[i].feature_values_.set_value(rp.fid_summary, log(posteriors[i] / z)); +    } +      string fullbp = "beam_prune" + StringSuffixForRescoringPass(pass);      string fulldp = "density_prune" + StringSuffixForRescoringPass(pass);      maybe_prune(forest,conf,fullbp.c_str(),fulldp.c_str(),passtr,srclen);  | 
