diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-09-19 19:52:07 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-09-19 19:52:07 +0000 |
commit | ad307cc72c80f9fc81affe8723682b685f60914f (patch) | |
tree | 2fb3d095da2e1fc097cda8895c318d9caee5f7ce /decoder/decoder.cc | |
parent | 31180b4b707aba7c18a78670846ae3d2f3ff0b09 (diff) |
big refactor
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@649 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder/decoder.cc')
-rw-r--r-- | decoder/decoder.cc | 736 |
1 files changed, 704 insertions, 32 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 196ffa46..f90e6bc0 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -1,25 +1,94 @@ #include "decoder.h" +#include <tr1/unordered_map> #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> -#include "ff_factory.h" -#include "cfg_options.h" +#include "sampler.h" #include "stringlib.h" +#include "weights.h" +#include "filelib.h" +#include "fdict.h" +#include "timing_stats.h" + +#include "translator.h" +#include "phrasebased_translator.h" +#include "tagger.h" +#include "lextrans.h" +#include "lexalign.h" +#include "csplit.h" + +#include "lattice.h" #include "hg.h" +#include "sentence_metadata.h" +#include "hg_intersect.h" + +#include "apply_fsa_models.h" +#include "oracle_bleu.h" +#include "apply_models.h" +#include "ff.h" +#include "ff_factory.h" +#include "cfg_options.h" +#include "viterbi.h" +#include "kbest.h" +#include "inside_outside.h" +#include "exp_semiring.h" +#include "sentence_metadata.h" +#include "hg_cfg.h" + +#include "forest_writer.h" // TODO this section should probably be handled by an Observer +#include "hg_io.h" +#include "aligner.h" +static const double kMINUS_EPSILON = -1e-6; // don't be too strict using namespace std; +using namespace std::tr1; using boost::shared_ptr; namespace po = boost::program_options; +static bool verbose_feature_functions=true; + +namespace Hack { void MaxTrans(const Hypergraph& in, int beam_size); } +namespace NgramCache { void Clear(); } + +void DecoderObserver::NotifySourceParseFailure(const SentenceMetadata&) {} +void DecoderObserver::NotifyTranslationForest(const SentenceMetadata&, Hypergraph*) {} +void DecoderObserver::NotifyAlignmentFailure(const SentenceMetadata&) {} +void DecoderObserver::NotifyAlignmentForest(const SentenceMetadata&, Hypergraph*) {} +void DecoderObserver::NotifyDecodingComplete(const SentenceMetadata&) {} + +struct ELengthWeightFunction { + double operator()(const Hypergraph::Edge& e) const { + return e.rule_->ELength() - e.rule_->Arity(); + } +}; inline void ShowBanner() { cerr << "cdec v1.0 (c) 2009-2010 by Chris Dyer\n"; } +inline void show_models(po::variables_map const& conf,ModelSet &ms,char const* header) { + cerr<<header<<": "; + ms.show_features(cerr,cerr,conf.count("warn_0_weight")); +} + inline string str(char const* name,po::variables_map const& conf) { return conf[name].as<string>(); } +inline bool prelm_weights_string(po::variables_map const& conf,string &s) { + if (conf.count("prelm_weights")) { + s=str("prelm_weights",conf); + return true; + } + if (conf.count("prelm_copy_weights")) { + s=str("weights",conf); + return true; + } + return false; +} + + + // print just the --long_opt names suitable for bash compgen inline void print_options(std::ostream &out,po::options_description const& opts) { typedef std::vector< shared_ptr<po::option_description> > Ds; @@ -32,22 +101,214 @@ inline void print_options(std::ostream &out,po::options_description const& opts) out << '"'; } +template <class V> +inline bool store_conf(po::variables_map const& conf,std::string const& name,V *v) { + if (conf.count(name)) { + *v=conf[name].as<V>(); + return true; + } + return false; +} + +inline shared_ptr<FeatureFunction> make_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") { + string ff, param; + SplitCommandAndParam(ffp, &ff, ¶m); + cerr << pre << "feature: " << ff; + if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; + else cerr << " (no config parameters)\n"; + shared_ptr<FeatureFunction> pf = ff_registry.Create(ff, param); + if (!pf) exit(1); + int nbyte=pf->NumBytesContext(); + if (verbose_feature_functions) + cerr<<"State is "<<nbyte<<" bytes for "<<pre<<"feature "<<ffp<<endl; + return pf; +} + +inline shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") { + string ff, param; + SplitCommandAndParam(ffp, &ff, ¶m); + cerr << "FSA Feature: " << ff; + if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; + else cerr << " (no config parameters)\n"; + shared_ptr<FsaFeatureFunction> pf = fsa_ff_registry.Create(ff, param); + if (!pf) exit(1); + if (verbose_feature_functions) + cerr<<"State is "<<pf->state_bytes()<<" bytes for "<<pre<<"feature "<<ffp<<endl; + return pf; +} + struct DecoderImpl { - DecoderImpl(int argc, char** argv, istream* cfg); - bool Decode(const string& input) { - return false; + DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg); + ~DecoderImpl(); + bool Decode(const string& input, DecoderObserver*); + void SetWeights(const vector<double>& weights) { } - bool DecodeProduceHypergraph(const string& input, Hypergraph* hg) { + + void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,WeightVector *weights=0,bool show_deriv=false) { + cerr << viterbi_stats(forest,name,true,show_tree,show_deriv); + if (show_features) { + cerr << name<<" features: "; +/* Hypergraph::Edge const* best=forest.ViterbiGoalEdge(); + if (!best) + cerr << name<<" has no goal edge."; + else + cerr<<best->feature_values_; +*/ + cerr << ViterbiFeatures(forest,weights); + cerr << endl; + } + } + + void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,DenseWeightVector const& feature_weights, bool sd=false) { + WeightVector fw(feature_weights); + forest_stats(forest,name,show_tree,show_features,&fw,sd); + } + + bool beam_param(po::variables_map const& conf,string const& name,double *val,bool scale_srclen=false,double srclen=1) { + if (conf.count(name)) { + *val=conf[name].as<double>()*(scale_srclen?srclen:1); + return true; + } return false; } - void SetWeights(const vector<double>& weights) { + + void maybe_prune(Hypergraph &forest,po::variables_map const& conf,string nbeam,string ndensity,string forestname,double srclen) { + double beam_prune=0,density_prune=0; + bool use_beam_prune=beam_param(conf,nbeam,&beam_prune,conf.count("scale_prune_srclen"),srclen); + bool use_density_prune=beam_param(conf,ndensity,&density_prune); + if (use_beam_prune || use_density_prune) { + double presize=forest.edges_.size(); + vector<bool> preserve_mask,*pm=0; + if (conf.count("csplit_preserve_full_word")) { + preserve_mask.resize(forest.edges_.size()); + preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true; + pm=&preserve_mask; + } + forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1,conf["promise_power"].as<double>()); + if (!forestname.empty()) forestname=" "+forestname; + forest_stats(forest," Pruned "+forestname+" forest",false,false,0,false); + cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl; + } + } + + void SampleRecurse(const Hypergraph& hg, const vector<SampleSet<prob_t> >& ss, int n, vector<WordID>* out) { + const SampleSet<prob_t>& s = ss[n]; + int i = rng->SelectSample(s); + const Hypergraph::Edge& edge = hg.edges_[hg.nodes_[n].in_edges_[i]]; + vector<vector<WordID> > ants(edge.tail_nodes_.size()); + for (int j = 0; j < ants.size(); ++j) + SampleRecurse(hg, ss, edge.tail_nodes_[j], &ants[j]); + + vector<const vector<WordID>*> pants(ants.size()); + for (int j = 0; j < ants.size(); ++j) pants[j] = &ants[j]; + edge.rule_->ESubstitute(pants, out); + } + + struct SampleSort { + bool operator()(const pair<int,string>& a, const pair<int,string>& b) const { + return a.first > b.first; + } + }; + + // TODO this should be handled by an Observer + void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) { + unordered_map<string, int, boost::hash<string> > m; + hg->PushWeightsToGoal(); + const int num_nodes = hg->nodes_.size(); + vector<SampleSet<prob_t> > ss(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + SampleSet<prob_t>& s = ss[i]; + const vector<int>& in_edges = hg->nodes_[i].in_edges_; + for (int j = 0; j < in_edges.size(); ++j) { + s.add(hg->edges_[in_edges[j]].edge_prob_); + } + } + for (int i = 0; i < samples; ++i) { + vector<WordID> yield; + SampleRecurse(*hg, ss, hg->nodes_.size() - 1, &yield); + const string trans = TD::GetString(yield); + ++m[trans]; + } + vector<pair<int, string> > dist; + for (unordered_map<string, int, boost::hash<string> >::iterator i = m.begin(); + i != m.end(); ++i) { + dist.push_back(make_pair(i->second, i->first)); + } + sort(dist.begin(), dist.end(), SampleSort()); + if (k) { + for (int i = 0; i < k; ++i) + cout << dist[i].first << " ||| " << dist[i].second << endl; + } else { + cout << dist[0].second << endl; + } } - po::variables_map conf; + void ParseTranslatorInputLattice(const string& line, string* input, Lattice* ref) { + string sref; + ParseTranslatorInput(line, input, &sref); + if (sref.size() > 0) { + assert(ref); + LatticeTools::ConvertTextOrPLF(sref, ref); + } + } + + po::variables_map& conf; + OracleBleu oracle; CFGOptions cfg_options; + string formalism; + shared_ptr<Translator> translator; + vector<double> feature_weights,prelm_feature_weights; + Weights w,prelm_w; + vector<shared_ptr<FeatureFunction> > pffs,prelm_only_ffs; + vector<const FeatureFunction*> late_ffs,prelm_ffs; + vector<shared_ptr<FsaFeatureFunction> > fsa_ffs; + vector<string> fsa_names; + ModelSet* late_models, *prelm_models; + IntersectionConfiguration* inter_conf; + shared_ptr<RandomNumberGenerator<boost::mt19937> > rng; + int sample_max_trans; + bool aligner_mode; + bool minimal_forests; + bool graphviz; + bool joshua_viz; + bool encode_b64; + bool kbest; + bool unique_kbest; + bool crf_uniform_empirical; + bool get_oracle_forest; + shared_ptr<WriteFile> extract_file; + int combine_size; + int sent_id; + SparseVector<prob_t> acc_vec; // accumulate gradient + double acc_obj; // accumulate objective + int g_count; // number of gradient pieces computed + bool has_prelm_models; + int pop_limit; + bool csplit_output_plf; + bool write_gradient; // TODO Observer + bool feature_expectations; // TODO Observer + bool output_training_vector; // TODO Observer + + static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) { + for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it) + trg->set_value(it->first, it->second); + } }; -DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) { +DecoderImpl::~DecoderImpl() { + if (output_training_vector && !acc_vec.empty()) { + if (encode_b64) { + cout << "0\t"; + SparseVector<double> dav; ConvertSV(acc_vec, &dav); + B64::Encode(acc_obj, dav, &cout); + cout << endl << flush; + } else { + cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush; + } + } +} + +DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg) : conf(conf) { if (cfg) { if (argc || argv) { cerr << "DecoderImpl() can only take a file or command line options, not both\n"; exit(1); } } bool show_config; bool show_weights; @@ -64,7 +325,7 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) { ("prelm_feature_function",po::value<vector<string> >()->composing(),"Additional feature functions for prelm pass only (in addition to the 0-state subset of feature_function") ("keep_prelm_cube_order","DEPRECATED (always enabled). when forest rescoring with final models, use the edge ordering from the prelm pruning features*weights. only meaningful if --prelm_weights given. UNTESTED but assume that cube pruning gives a sensible result, and that 'good' (as tuned for bleu w/ prelm features) edges come first.") ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)") - ("no_freeze_feature_set,Z", "Do not freeze feature set after reading feature weights file") + ("freeze_feature_set,Z", "Freeze feature set after reading feature weights file") ("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)") ("fsa_feature_function,A",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)") ("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names() @@ -142,7 +403,7 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) { } if (conf.count("show_config")) // special handling needed because we only want to notify() once. show_config=true; - if (conf.count("config") || cfg) { + if (conf.count("config") && !cfg) { typedef vector<string> Cs; Cs cs=conf["config"].as<Cs>(); for (int i=0;i<cs.size();++i) { @@ -151,8 +412,8 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) { ReadFile conff(cfg); po::store(po::parse_config_file(*conff, dconfig_options), conf); } - if (cfg) po::store(po::parse_config_file(*cfg, dconfig_options), conf); } + if (cfg) po::store(po::parse_config_file(*cfg, dconfig_options), conf); po::notify(conf); if (show_config && !cfg_files.empty()) { cerr<< "\nConfig files:\n\n"; @@ -171,9 +432,9 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) { if (conf.count("list_feature_functions")) { cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n"; - // ff_registry.DisplayList(); //TODO + ff_registry.DisplayList(); //TODO cerr << "Available FSA feature functions (specify with --fsa_feature_function):\n"; - // fsa_ff_registry.DisplayList(); // TODO + fsa_ff_registry.DisplayList(); // TODO cerr << endl; exit(1); } @@ -191,38 +452,449 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) { exit(1); } - const string formalism = LowercaseString(str("formalism",conf)); + formalism = LowercaseString(str("formalism",conf)); if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") { cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } -} -Decoder::Decoder(istream* cfg) { - pimpl_.reset(new DecoderImpl(0,0,cfg)); -} + write_gradient = conf.count("cll_gradient"); + feature_expectations = conf.count("feature_expectations"); + if (write_gradient && feature_expectations) { + cerr << "You can only specify --gradient or --feature_expectations, not both!\n"; + exit(1); + } + output_training_vector = (write_gradient || feature_expectations); -Decoder::Decoder(int argc, char** argv) { - pimpl_.reset(new DecoderImpl(argc, argv, 0)); -} + const string formalism = LowercaseString(str("formalism",conf)); + const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word"); + if (csplit_preserve_full_word && + (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")||conf.count("prelm_beam_prune")||conf.count("prelm_density_prune")))) { + cerr << "--csplit_preserve_full_word should only be " + << "used with csplit AND --*_prune!\n"; + exit(1); + } + csplit_output_plf = conf.count("csplit_output_plf"); + if (csplit_output_plf && formalism != "csplit") { + cerr << "--csplit_output_plf should only be used with csplit!\n"; + exit(1); + } -Decoder::~Decoder() {} + // load feature weights (and possibly freeze feature set) + has_prelm_models = false; + if (conf.count("weights")) { + w.InitFromFile(str("weights",conf)); + feature_weights.resize(FD::NumFeats()); + w.InitVector(&feature_weights); + string plmw; + if (prelm_weights_string(conf,plmw)) { + has_prelm_models = true; + prelm_w.InitFromFile(plmw); + prelm_feature_weights.resize(FD::NumFeats()); + prelm_w.InitVector(&prelm_feature_weights); + if (show_weights) + cerr << "prelm_weights: " << WeightVector(prelm_feature_weights)<<endl; + } + if (show_weights) + cerr << "+LM weights: " << WeightVector(feature_weights)<<endl; + } + bool warn0=conf.count("warn_0_weight"); + bool freeze=conf.count("freeze_feature_set"); + bool early_freeze=freeze && !warn0; + bool late_freeze=freeze && warn0; + if (early_freeze) { + cerr << "Freezing feature set" << endl; + FD::Freeze(); // this means we can't see the feature names of not-weighted features + } -bool Decoder::Decode(const string& input) { - return pimpl_->Decode(input); -} + // set up translation back end + if (formalism == "scfg") + translator.reset(new SCFGTranslator(conf)); + else if (formalism == "fst") + translator.reset(new FSTTranslator(conf)); + else if (formalism == "pb") + translator.reset(new PhraseBasedTranslator(conf)); + else if (formalism == "csplit") + translator.reset(new CompoundSplit(conf)); + else if (formalism == "lextrans") + translator.reset(new LexicalTrans(conf)); + else if (formalism == "lexalign") + translator.reset(new LexicalAlign(conf)); + else if (formalism == "tagger") + translator.reset(new Tagger(conf)); + else + assert(!"error"); + + // set up additional scoring features + if (conf.count("feature_function") > 0) { + vector<string> add_ffs; +// const vector<string>& add_ffs = conf["feature_function"].as<vector<string> >(); + store_conf(conf,"feature_function",&add_ffs); + for (int i = 0; i < add_ffs.size(); ++i) { + pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions)); + FeatureFunction const* p=pffs.back().get(); + late_ffs.push_back(p); + if (has_prelm_models) { + if (p->NumBytesContext()==0) + prelm_ffs.push_back(p); + else + cerr << "Excluding stateful feature from prelm pruning: "<<add_ffs[i]<<endl; + } + } + } + if (conf.count("prelm_feature_function") > 0) { + vector<string> add_ffs; + store_conf(conf,"prelm_feature_function",&add_ffs); +// const vector<string>& add_ffs = conf["prelm_feature_function"].as<vector<string> >(); + for (int i = 0; i < add_ffs.size(); ++i) { + prelm_only_ffs.push_back(make_ff(add_ffs[i],verbose_feature_functions,"prelm-only ")); + prelm_ffs.push_back(prelm_only_ffs.back().get()); + } + } -bool Decoder::DecodeProduceHypergraph(const string& input, Hypergraph* hg) { - return pimpl_->DecodeProduceHypergraph(input, hg); + store_conf(conf,"fsa_feature_function",&fsa_names); + for (int i=0;i<fsa_names.size();++i) + fsa_ffs.push_back(make_fsa_ff(fsa_names[i],verbose_feature_functions,"FSA ")); + if (fsa_ffs.size()>1) { + //FIXME: support N fsa ffs. + cerr<<"Only the first fsa FF will be used (FIXME).\n"; + fsa_ffs.resize(1); + } + if (!fsa_ffs.empty()) { + cerr<<"FSA: "; + show_all_features(fsa_ffs,feature_weights,cerr,cerr,true,true); + } + + if (late_freeze) { + cerr << "Late freezing feature set (use --no_freeze_feature_set to prevent)." << endl; + FD::Freeze(); // this means we can't see the feature names of not-weighted features + } + + if (has_prelm_models) + cerr << "prelm rescoring with "<<prelm_ffs.size()<<" 0-state feature functions. +LM pass will use "<<late_ffs.size()<<" features (not counting rule features)."<<endl; + + late_models = new ModelSet(feature_weights, late_ffs); + show_models(conf,*late_models,"late "); + prelm_models = new ModelSet(prelm_feature_weights, prelm_ffs); + if (has_prelm_models) + show_models(conf,*prelm_models,"prelm "); + + int palg = 1; + if (LowercaseString(str("intersection_strategy",conf)) == "full") { + palg = 0; + cerr << "Using full intersection (no pruning).\n"; + } + pop_limit=conf["cubepruning_pop_limit"].as<int>(); + inter_conf = new IntersectionConfiguration(palg, pop_limit); + + sample_max_trans = conf.count("max_translation_sample") ? + conf["max_translation_sample"].as<int>() : 0; + if (sample_max_trans) + rng.reset(new RandomNumberGenerator<boost::mt19937>); + aligner_mode = conf.count("aligner"); + minimal_forests = conf.count("minimal_forests"); + graphviz = conf.count("graphviz"); + joshua_viz = conf.count("show_joshua_visualization"); + encode_b64 = str("vector_format",conf) == "b64"; + kbest = conf.count("k_best"); + unique_kbest = conf.count("unique_k_best"); + crf_uniform_empirical = conf.count("crf_uniform_empirical"); + get_oracle_forest = conf.count("get_oracle_forest"); + + cfg_options.Validate(); + + if (conf.count("extract_rules")) + extract_file.reset(new WriteFile(str("extract_rules",conf))); + + combine_size = conf["combine_size"].as<int>(); + if (combine_size < 1) combine_size = 1; + sent_id = -1; + acc_obj = 0; // accumulate objective + g_count = 0; // number of gradient pieces computed } -void Decoder::SetWeights(const vector<double>& weights) { - pimpl_->SetWeights(weights); +Decoder::Decoder(istream* cfg) { pimpl_.reset(new DecoderImpl(conf,0,0,cfg)); } +Decoder::Decoder(int argc, char** argv) { pimpl_.reset(new DecoderImpl(conf,argc, argv, 0)); } +Decoder::~Decoder() {} +bool Decoder::Decode(const string& input, DecoderObserver* o) { + bool del = false; + if (!o) { o = new DecoderObserver; del = true; } + const bool res = pimpl_->Decode(input, o); + if (del) delete o; + return res; } +void Decoder::SetWeights(const vector<double>& weights) { pimpl_->SetWeights(weights); } + + +bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { + string buf = input; + NgramCache::Clear(); // clear ngram cache for remote LM (if used) + Timer::Summarize(); + ++sent_id; + map<string, string> sgml; + ProcessAndStripSGML(&buf, &sgml); + if (sgml.find("id") != sgml.end()) + sent_id = atoi(sgml["id"].c_str()); + + cerr << "\nINPUT: "; + if (buf.size() < 100) + cerr << buf << endl; + else { + size_t x = buf.rfind(" ", 100); + if (x == string::npos) x = 100; + cerr << buf.substr(0, x) << " ..." << endl; + } + cerr << " id = " << sent_id << endl; + string to_translate; + Lattice ref; + ParseTranslatorInputLattice(buf, &to_translate, &ref); + const unsigned srclen=NTokens(to_translate,' '); +//FIXME: should get the avg. or max source length of the input lattice (like Lattice::dist_(start,end)); but this is only used to scale beam parameters (optionally) anyway so fidelity isn't important. + const bool has_ref = ref.size() > 0; + SentenceMetadata smeta(sent_id, ref); + Hypergraph forest; // -LM forest + translator->ProcessMarkupHints(sgml); + Timer t("Translation"); + const bool translation_successful = + translator->Translate(to_translate, &smeta, feature_weights, &forest); + //TODO: modify translator to incorporate all 0-state model scores immediately? + translator->SentenceComplete(); + + if (!translation_successful) { + cerr << " NO PARSE FOUND.\n"; + o->NotifySourceParseFailure(smeta); + o->NotifyDecodingComplete(smeta); + return false; + } + + const bool show_tree_structure=conf.count("show_tree_structure"); + const bool show_features=conf.count("show_features"); + forest_stats(forest," -LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); + if (conf.count("show_expected_length")) { + const PRPair<double, double> res = + Inside<PRPair<double, double>, + PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest); + cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl; + } + if (conf.count("show_partition")) { + const prob_t z = Inside<prob_t, EdgeProb>(forest); + cerr << " -LM partition log(Z): " << log(z) << endl; + } + + if (has_prelm_models) { + Timer t("prelm rescoring"); + forest.Reweight(prelm_feature_weights); + Hypergraph prelm_forest; + ApplyModelSet(forest, + smeta, + *prelm_models, + *inter_conf, // this is now reduced to exhaustive if all are stateless + &prelm_forest); + forest.swap(prelm_forest); + forest.Reweight(prelm_feature_weights); //FIXME: why the reweighting? here and below. maybe in case we already had a featval for that id and ApplyModelSet only adds prob, doesn't recompute it? + forest_stats(forest," prelm forest",show_tree_structure,show_features,prelm_feature_weights,oracle.show_derivation); + } + + maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen); + + cfg_options.maybe_output_source(forest); + + bool has_late_models = !late_models->empty(); + if (has_late_models) { + Timer t("Forest rescoring:"); + forest.Reweight(feature_weights); + Hypergraph lm_forest; + ApplyModelSet(forest, + smeta, + *late_models, + *inter_conf, + &lm_forest); + forest.swap(lm_forest); + forest.Reweight(feature_weights); + forest_stats(forest," +LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); + } + + maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen); + + HgCFG hgcfg(forest); + cfg_options.prepare(hgcfg); + + if (!fsa_ffs.empty()) { + Timer t("Target FSA rescoring:"); + if (!has_late_models) + forest.Reweight(feature_weights); + Hypergraph fsa_forest; + assert(fsa_ffs.size()==1); + ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit); + cerr << "FSA rescoring with "<<cfg<<" "<<fsa_ffs[0]->describe()<<endl; + ApplyFsaModels(hgcfg,smeta,*fsa_ffs[0],feature_weights,cfg,&fsa_forest); + forest.swap(fsa_forest); + forest.Reweight(feature_weights); + forest_stats(forest," +FSA forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); + } + + // Oracle Rescoring + if(get_oracle_forest) { + Oracle oc=oracle.ComputeOracle(smeta,&forest,FeatureVector(feature_weights),10,conf["forest_output"].as<std::string>()); + cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; + oc.hope.Print(cerr," +Oracle BLEU"); + oc.fear.Print(cerr," -Oracle BLEU"); + //Add 1-best translation (trans) to psuedo-doc vectors + oracle.IncludeLastScore(&cerr); + } + + // TODO I think this should probably be handled by an Observer + if (conf.count("forest_output") && !has_ref) { + ForestWriter writer(str("forest_output",conf), sent_id); + if (FileExists(writer.fname_)) { + cerr << " Unioning...\n"; + Hypergraph new_hg; + { + ReadFile rf(writer.fname_); + bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); + assert(succeeded); + } + new_hg.Union(forest); + bool succeeded = writer.Write(new_hg, minimal_forests); + assert(succeeded); + } else { + bool succeeded = writer.Write(forest, minimal_forests); + assert(succeeded); + } + } + + // TODO I think this should probably be handled by an Observer + if (sample_max_trans) { + MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0); + } else { + if (kbest) { + //TODO: does this work properly? + oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-"); + } else if (csplit_output_plf) { + cout << HypergraphIO::AsPLF(forest, false) << endl; + } else { + if (!graphviz && !has_ref && !joshua_viz) { + vector<WordID> trans; + ViterbiESentence(forest, &trans); + cout << TD::GetString(trans) << endl << flush; + } + if (joshua_viz) { + cout << sent_id << " ||| " << JoshuaVisualizationString(forest) << " ||| 1.0 ||| " << -1.0 << endl << flush; + } + } + } + + // TODO this should be handled by an Observer + const int max_trans_beam_size = conf.count("max_translation_beam") ? + conf["max_translation_beam"].as<int>() : 0; + if (max_trans_beam_size) { + Hack::MaxTrans(forest, max_trans_beam_size); + return true; + } + + // TODO this should be handled by an Observer + if (graphviz && !has_ref) forest.PrintGraphviz(); -void InitCommandLine(int argc, char** argv, po::variables_map* confp) { - po::variables_map &conf=*confp; + // the following are only used if write_gradient is true! + SparseVector<prob_t> full_exp, ref_exp, gradient; + double log_z = 0, log_ref_z = 0; + if (write_gradient) { + const prob_t z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &full_exp); + log_z = log(z); + full_exp /= z; + } + if (conf.count("show_cfg_search_space")) + HypergraphIO::WriteAsCFG(forest); + if (has_ref) { + if (HG::Intersect(ref, &forest)) { + cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl; + if (crf_uniform_empirical) { + cerr << " USING UNIFORM WEIGHTS\n"; + for (int i = 0; i < forest.edges_.size(); ++i) + forest.edges_[i].edge_prob_=prob_t::One(); + } else { + forest.Reweight(feature_weights); + cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; + } + if (conf.count("show_partition")) { + const prob_t z = Inside<prob_t, EdgeProb>(forest); + cerr << " Contst. partition log(Z): " << log(z) << endl; + } + if (conf.count("forest_output")) { + ForestWriter writer(str("forest_output",conf), sent_id); + if (FileExists(writer.fname_)) { + cerr << " Unioning...\n"; + Hypergraph new_hg; + { + ReadFile rf(writer.fname_); + bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); + assert(succeeded); + } + new_hg.Union(forest); + bool succeeded = writer.Write(new_hg, minimal_forests); + assert(succeeded); + } else { + bool succeeded = writer.Write(forest, minimal_forests); + assert(succeeded); + } + } + if (aligner_mode && !output_training_vector) + AlignerTools::WriteAlignment(smeta.GetSourceLattice(), smeta.GetReference(), forest, &cout); + if (write_gradient) { + const prob_t ref_z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp); + ref_exp /= ref_z; + if (crf_uniform_empirical) { + log_ref_z = ref_exp.dot(feature_weights); + } else { + log_ref_z = log(ref_z); + } + //cerr << " MODEL LOG Z: " << log_z << endl; + //cerr << " EMPIRICAL LOG Z: " << log_ref_z << endl; + if ((log_z - log_ref_z) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_z < log_ref_z: " << log_z << " " << log_ref_z << endl; + exit(1); + } + assert(!isnan(log_ref_z)); + ref_exp -= full_exp; + acc_vec += ref_exp; + acc_obj += (log_z - log_ref_z); + } + if (feature_expectations) { + const prob_t z = + InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp); + ref_exp /= z; + acc_obj += log(z); + acc_vec += ref_exp; + } + if (output_training_vector) { + acc_vec.erase(0); + ++g_count; + if (g_count % combine_size == 0) { + if (encode_b64) { + cout << "0\t"; + SparseVector<double> dav; ConvertSV(acc_vec, &dav); + B64::Encode(acc_obj, dav, &cout); + cout << endl << flush; + } else { + cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush; + } + acc_vec.clear(); + acc_obj = 0; + } + } + if (conf.count("graphviz")) forest.PrintGraphviz(); + } else { + cerr << " REFERENCE UNREACHABLE.\n"; + if (write_gradient) { + cout << endl << flush; + } + } + } + return true; } + |