summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-09-19 19:52:07 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-09-19 19:52:07 +0000
commitb95e56adb2197ccc8ff60b4015c020011009d14c (patch)
tree89acb95777c0eeca7fc9243507ad100fe4ad7e82 /decoder
parent3e9de500d461fa9813295664b847d217f1e8f865 (diff)
big refactor
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@649 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder')
-rw-r--r--decoder/cdec.cc870
-rw-r--r--decoder/cdec_ff.cc6
-rw-r--r--decoder/decoder.cc736
-rw-r--r--decoder/decoder.h16
-rw-r--r--decoder/ff_factory.h2
5 files changed, 731 insertions, 899 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index 56e103aa..97ec6798 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -1,888 +1,28 @@
#include <iostream>
-#include <fstream>
-#include <tr1/unordered_map>
-#include <tr1/unordered_set>
-#include <boost/shared_ptr.hpp>
-#include <boost/program_options.hpp>
-#include <boost/program_options/variables_map.hpp>
-
-#include "decoder.h"
-#include "oracle_bleu.h"
-#include "timing_stats.h"
-#include "translator.h"
-#include "phrasebased_translator.h"
-#include "aligner.h"
-#include "stringlib.h"
-#include "forest_writer.h"
-#include "hg_io.h"
#include "filelib.h"
-#include "sampler.h"
-#include "sparse_vector.h"
-#include "tagger.h"
-#include "lextrans.h"
-#include "lexalign.h"
-#include "csplit.h"
-#include "weights.h"
-#include "tdict.h"
-#include "ff.h"
-#include "ff_fsa_dynamic.h"
-#include "ff_factory.h"
-#include "hg_intersect.h"
-#include "apply_models.h"
-#include "viterbi.h"
-#include "kbest.h"
-#include "inside_outside.h"
-#include "exp_semiring.h"
-#include "sentence_metadata.h"
-#include "scorer.h"
-#include "apply_fsa_models.h"
-#include "program_options.h"
-#include "cfg_options.h"
-
-CFGOptions cfg_options;
+#include "decoder.h"
using namespace std;
-using namespace std::tr1;
-using boost::shared_ptr;
-namespace po = boost::program_options;
-
-vector<string> cfg_files;
-bool show_config=false;
-bool show_weights=false;
-bool verbose_feature_functions=true;
-
-// some globals ...
-boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;
-static const double kMINUS_EPSILON = -1e-6; // don't be too strict
-
-namespace Hack { void MaxTrans(const Hypergraph& in, int beam_size); }
-namespace NgramCache { void Clear(); }
-
-void ShowBanner() {
- cerr << "cdec v1.0 (c) 2009-2010 by Chris Dyer\n";
-}
-
-void ParseTranslatorInputLattice(const string& line, string* input, Lattice* ref) {
- string sref;
- ParseTranslatorInput(line, input, &sref);
- if (sref.size() > 0) {
- assert(ref);
- LatticeTools::ConvertTextOrPLF(sref, ref);
- }
-}
-
-void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {
- for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it)
- trg->set_value(it->first, it->second);
-}
-
-inline string str(char const* name,po::variables_map const& conf) {
- return conf[name].as<string>();
-}
-
-shared_ptr<FeatureFunction> make_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") {
- string ff, param;
- SplitCommandAndParam(ffp, &ff, &param);
- cerr << pre << "feature: " << ff;
- if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n";
- else cerr << " (no config parameters)\n";
- shared_ptr<FeatureFunction> pf = ff_registry.Create(ff, param);
- if (!pf) exit(1);
- int nbyte=pf->NumBytesContext();
- if (verbose_feature_functions)
- cerr<<"State is "<<nbyte<<" bytes for "<<pre<<"feature "<<ffp<<endl;
- return pf;
-}
-
-shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") {
- string ff, param;
- SplitCommandAndParam(ffp, &ff, &param);
- cerr << "FSA Feature: " << ff;
- if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n";
- else cerr << " (no config parameters)\n";
- shared_ptr<FsaFeatureFunction> pf = fsa_ff_registry.Create(ff, param);
- if (!pf) exit(1);
- if (verbose_feature_functions)
- cerr<<"State is "<<pf->state_bytes()<<" bytes for "<<pre<<"feature "<<ffp<<endl;
- return pf;
-}
-
-// print just the --long_opt names suitable for bash compgen
-void print_options(std::ostream &out,po::options_description const& opts) {
- typedef std::vector< shared_ptr<po::option_description> > Ds;
- Ds const& ds=opts.options();
- out << '"';
- for (unsigned i=0;i<ds.size();++i) {
- if (i) out<<' ';
- out<<"--"<<ds[i]->long_name();
- }
- out << '"';
-}
-
-void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* confp) {
- po::variables_map &conf=*confp;
- po::options_description opts("Configuration options");
- opts.add_options()
- ("formalism,f",po::value<string>(),"Decoding formalism; values include SCFG, FST, PB, LexTrans (lexical translation model, also disc training), CSplit (compound splitting), Tagger (sequence labeling), LexAlign (alignment only, or EM training)")
- ("input,i",po::value<string>()->default_value("-"),"Source file")
- ("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
- ("weights,w",po::value<string>(),"Feature weights file")
- ("prelm_weights",po::value<string>(),"Feature weights file for prelm_beam_prune. Requires --weights.")
- ("prelm_copy_weights","use --weights as value for --prelm_weights.")
- ("prelm_feature_function",po::value<vector<string> >()->composing(),"Additional feature functions for prelm pass only (in addition to the 0-state subset of feature_function")
- ("keep_prelm_cube_order","DEPRECATED (always enabled). when forest rescoring with final models, use the edge ordering from the prelm pruning features*weights. only meaningful if --prelm_weights given. UNTESTED but assume that cube pruning gives a sensible result, and that 'good' (as tuned for bleu w/ prelm features) edges come first.")
- ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)")
- ("no_freeze_feature_set,Z", "Do not freeze feature set after reading feature weights file")
- ("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)")
- ("fsa_feature_function,A",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)")
- ("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names()
- ("list_feature_functions,L","List available feature functions")
- ("add_pass_through_rules,P","Add rules to translate OOV words as themselves")
- ("k_best,k",po::value<int>(),"Extract the k best derivations")
- ("unique_k_best,r", "Unique k-best translation list")
- ("aligner,a", "Run as a word/phrase aligner (src & ref required)")
- ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Intersection strategy for incorporating finite-state features; values include Cube_pruning, Full")
- ("cubepruning_pop_limit,K",po::value<int>()->default_value(200), "Max number of pops from the candidate heap at each node")
- ("goal",po::value<string>()->default_value("S"),"Goal symbol (SCFG & FST)")
- ("scfg_extra_glue_grammar", po::value<string>(), "Extra glue grammar file (Glue grammars apply when i=0 but have no other span restrictions)")
- ("scfg_no_hiero_glue_grammar,n", "No Hiero glue grammar (nb. by default the SCFG decoder adds Hiero glue rules)")
- ("scfg_default_nt,d",po::value<string>()->default_value("X"),"Default non-terminal symbol in SCFG")
- ("scfg_max_span_limit,S",po::value<int>()->default_value(10),"Maximum non-terminal span limit (except \"glue\" grammar)")
- ("show_config", po::bool_switch(&show_config), "show contents of loaded -c config files.")
- ("show_weights", po::bool_switch(&show_weights), "show effective feature weights")
- ("show_joshua_visualization,J", "Produce output compatible with the Joshua visualization tools")
- ("show_tree_structure", "Show the Viterbi derivation structure")
- ("show_expected_length", "Show the expected translation length under the model")
- ("show_partition,z", "Compute and show the partition (inside score)")
- ("show_cfg_search_space", "Show the search space as a CFG")
- ("show_features","Show the feature vector for the viterbi translation")
- ("prelm_density_prune", po::value<double>(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
- ("density_prune", po::value<double>(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
- ("prelm_beam_prune", po::value<double>(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)")
- ("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)")
- ("ctf_beam_widen", po::value<double>()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found")
- ("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse")
- ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails")
- ("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)")
- ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices")
- ("promise_power",po::value<double>()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams. 0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit). note: for the same pop_limit, this gives more search error unless very close to 0 (recommend disabled; even 0.01 is slightly worse than 0) which is a bad sign and suggests this isn't doing a good job; further it's slightly slower to LM cube rescore with 0.01 compared to 0, as well as giving (very insignificantly) lower BLEU. TODO: test under more conditions, or try idea with different formula, or prob. cube beams.")
- ("lexalign_use_null", "Support source-side null words in lexical translation")
- ("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set")
- ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format")
- ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
- ("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file")
- ("graphviz","Show (constrained) translation forest in GraphViz format")
- ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
- ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
- ("pb_max_distortion,D", po::value<int>()->default_value(4), "Phrase-based decoder: maximum distortion")
- ("cll_gradient,G","Compute conditional log-likelihood gradient and write to STDOUT (src & ref required)")
- ("crf_uniform_empirical", "If there are multple references use (i.e., lattice) a uniform distribution rather than posterior weighting a la EM")
- ("get_oracle_forest,o", "Calculate rescored hypregraph using approximate BLEU scoring of rules")
- ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
- ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
- ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
- ("forest_output,O",po::value<string>(),"Directory to write forests to")
- ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc.");
- ob.AddOptions(&opts);
- po::options_description cfgo(cfg_options.description());
- cfg_options.AddOptions(&cfgo);
- po::options_description clo("Command line options");
- clo.add_options()
- ("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority")
- ("help,h", "Print this help message and exit")
- ("usage,u", po::value<string>(), "Describe a feature function type")
- ("compgen", "Print just option names suitable for bash command line completion builtin 'compgen'")
- ;
-
- po::options_description dconfig_options, dcmdline_options;
- dconfig_options.add(opts).add(cfgo);
- //add(opts).add(cfgo)
- dcmdline_options.add(dconfig_options).add(clo);
- argv_minus_to_underscore(argc,argv);
- po::store(parse_command_line(argc, argv, dcmdline_options), conf);
- if (conf.count("compgen")) {
- print_options(cout,dcmdline_options);
- cout << endl;
- exit(0);
- }
- ShowBanner();
- if (conf.count("show_config")) // special handling needed because we only want to notify() once.
- show_config=true;
- if (conf.count("config")) {
- typedef vector<string> Cs;
- Cs cs=conf["config"].as<Cs>();
- for (int i=0;i<cs.size();++i) {
- string cfg=cs[i];
- cerr << "Configuration file: " << cfg << endl;
- ReadFile conff(cfg);
- po::store(po::parse_config_file(*conff, dconfig_options), conf);
- }
- }
- po::notify(conf);
- if (show_config && !cfg_files.empty()) {
- cerr<< "\nConfig files:\n\n";
- for (int i=0;i<cfg_files.size();++i) {
- string cfg=cfg_files[i];
- cerr << "Configuration file: " << cfg << endl;
- CopyFile(cfg,cerr);
- cerr << "(end config "<<cfg<<"\n\n";
- }
- cerr <<"Command line:";
- for (int i=0;i<argc;++i)
- cerr<<" "<<argv[i];
- cerr << "\n\n";
- }
-
-
- if (conf.count("list_feature_functions")) {
- cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n";
- ff_registry.DisplayList();
- cerr << "Available FSA feature functions (specify with --fsa_feature_function):\n";
- fsa_ff_registry.DisplayList();
- cerr << endl;
- exit(1);
- }
-
- if (conf.count("usage")) {
- ff_usage(str("usage",conf));
- exit(0);
- }
- if (conf.count("help")) {
- cout << dcmdline_options << endl;
- exit(0);
- }
- if (conf.count("help") || conf.count("formalism") == 0) {
- cerr << dcmdline_options << endl;
- exit(1);
- }
-
- const string formalism = LowercaseString(str("formalism",conf));
- if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") {
- cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n";
- cerr << dcmdline_options << endl;
- exit(1);
- }
-
-}
-
-// TODO move out of cdec into some sampling decoder file
-void SampleRecurse(const Hypergraph& hg, const vector<SampleSet<prob_t> >& ss, int n, vector<WordID>* out) {
- const SampleSet<prob_t>& s = ss[n];
- int i = rng->SelectSample(s);
- const Hypergraph::Edge& edge = hg.edges_[hg.nodes_[n].in_edges_[i]];
- vector<vector<WordID> > ants(edge.tail_nodes_.size());
- for (int j = 0; j < ants.size(); ++j)
- SampleRecurse(hg, ss, edge.tail_nodes_[j], &ants[j]);
-
- vector<const vector<WordID>*> pants(ants.size());
- for (int j = 0; j < ants.size(); ++j) pants[j] = &ants[j];
- edge.rule_->ESubstitute(pants, out);
-}
-
-struct SampleSort {
- bool operator()(const pair<int,string>& a, const pair<int,string>& b) const {
- return a.first > b.first;
- }
-};
-
-// TODO move out of cdec into some sampling decoder file
-void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) {
- unordered_map<string, int, boost::hash<string> > m;
- hg->PushWeightsToGoal();
- const int num_nodes = hg->nodes_.size();
- vector<SampleSet<prob_t> > ss(num_nodes);
- for (int i = 0; i < num_nodes; ++i) {
- SampleSet<prob_t>& s = ss[i];
- const vector<int>& in_edges = hg->nodes_[i].in_edges_;
- for (int j = 0; j < in_edges.size(); ++j) {
- s.add(hg->edges_[in_edges[j]].edge_prob_);
- }
- }
- for (int i = 0; i < samples; ++i) {
- vector<WordID> yield;
- SampleRecurse(*hg, ss, hg->nodes_.size() - 1, &yield);
- const string trans = TD::GetString(yield);
- ++m[trans];
- }
- vector<pair<int, string> > dist;
- for (unordered_map<string, int, boost::hash<string> >::iterator i = m.begin();
- i != m.end(); ++i) {
- dist.push_back(make_pair(i->second, i->first));
- }
- sort(dist.begin(), dist.end(), SampleSort());
- if (k) {
- for (int i = 0; i < k; ++i)
- cout << dist[i].first << " ||| " << dist[i].second << endl;
- } else {
- cout << dist[0].second << endl;
- }
-}
-
-
-
-struct ELengthWeightFunction {
- double operator()(const Hypergraph::Edge& e) const {
- return e.rule_->ELength() - e.rule_->Arity();
- }
-};
-
-
-struct TRPHash {
- size_t operator()(const TRulePtr& o) const { return reinterpret_cast<size_t>(o.get()); }
-};
-static void ExtractRulesDedupe(const Hypergraph& hg, ostream* os) {
- static unordered_set<TRulePtr, TRPHash> written;
- for (int i = 0; i < hg.edges_.size(); ++i) {
- const TRulePtr& rule = hg.edges_[i].rule_;
- if (written.insert(rule).second) {
- (*os) << rule->AsString() << endl;
- }
- }
-}
void register_feature_functions();
-bool beam_param(po::variables_map const& conf,string const& name,double *val,bool scale_srclen=false,double srclen=1)
-{
- if (conf.count(name)) {
- *val=conf[name].as<double>()*(scale_srclen?srclen:1);
- return true;
- }
- return false;
-}
-
-bool prelm_weights_string(po::variables_map const& conf,string &s)
-{
- if (conf.count("prelm_weights")) {
- s=str("prelm_weights",conf);
- return true;
- }
- if (conf.count("prelm_copy_weights")) {
- s=str("weights",conf);
- return true;
- }
- return false;
-}
-
-
-void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,WeightVector *weights=0,bool show_deriv=false) {
- cerr << viterbi_stats(forest,name,true,show_tree,show_deriv);
- if (show_features) {
- cerr << name<<" features: ";
-/* Hypergraph::Edge const* best=forest.ViterbiGoalEdge();
- if (!best)
- cerr << name<<" has no goal edge.";
- else
- cerr<<best->feature_values_;
-*/
- cerr << ViterbiFeatures(forest,weights);
- cerr << endl;
- }
-}
-
-void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,DenseWeightVector const& feature_weights,bool sd=false) {
- WeightVector fw(feature_weights);
- forest_stats(forest,name,show_tree,show_features,&fw,sd);
-}
-
-
-void maybe_prune(Hypergraph &forest,po::variables_map const& conf,string nbeam,string ndensity,string forestname,double srclen) {
- double beam_prune=0,density_prune=0;
- bool use_beam_prune=beam_param(conf,nbeam,&beam_prune,conf.count("scale_prune_srclen"),srclen);
- bool use_density_prune=beam_param(conf,ndensity,&density_prune);
- if (use_beam_prune || use_density_prune) {
- double presize=forest.edges_.size();
- vector<bool> preserve_mask,*pm=0;
- if (conf.count("csplit_preserve_full_word")) {
- preserve_mask.resize(forest.edges_.size());
- preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true;
- pm=&preserve_mask;
- }
- forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1,conf["promise_power"].as<double>());
- if (!forestname.empty()) forestname=" "+forestname;
- forest_stats(forest," Pruned "+forestname+" forest",false,false,0,false);
- cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl;
- }
-}
-
-void show_models(po::variables_map const& conf,ModelSet &ms,char const* header) {
- cerr<<header<<": ";
- ms.show_features(cerr,cerr,conf.count("warn_0_weight"));
-}
-
-template <class V>
-bool store_conf(po::variables_map const& conf,std::string const& name,V *v) {
- if (conf.count(name)) {
- *v=conf[name].as<V>();
- return true;
- }
- return false;
-}
-
-
int main(int argc, char** argv) {
register_feature_functions();
- po::variables_map conf;
- OracleBleu oracle;
-
- InitCommandLine(argc, argv, oracle, &conf);
- const bool write_gradient = conf.count("cll_gradient");
- const bool feature_expectations = conf.count("feature_expectations");
- if (write_gradient && feature_expectations) {
- cerr << "You can only specify --gradient or --feature_expectations, not both!\n";
- exit(1);
- }
- const bool output_training_vector = (write_gradient || feature_expectations);
-
- boost::shared_ptr<Translator> translator;
- const string formalism = LowercaseString(str("formalism",conf));
- const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word");
- if (csplit_preserve_full_word &&
- (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")||conf.count("prelm_beam_prune")||conf.count("prelm_density_prune")))) {
- cerr << "--csplit_preserve_full_word should only be "
- << "used with csplit AND --*_prune!\n";
- exit(1);
- }
- const bool csplit_output_plf = conf.count("csplit_output_plf");
- if (csplit_output_plf && formalism != "csplit") {
- cerr << "--csplit_output_plf should only be used with csplit!\n";
- exit(1);
- }
+ Decoder decoder(argc, argv);
- const string input = str("input",conf);
+ const string input = decoder.GetConf()["input"].as<string>();
cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl;
ReadFile in_read(input);
istream *in = in_read.stream();
assert(*in);
- // load feature weights (and possibly freeze feature set)
- vector<double> feature_weights,prelm_feature_weights;
- Weights w,prelm_w;
- bool has_prelm_models = false;
- if (conf.count("weights")) {
- w.InitFromFile(str("weights",conf));
- feature_weights.resize(FD::NumFeats());
- w.InitVector(&feature_weights);
- string plmw;
- if (prelm_weights_string(conf,plmw)) {
- has_prelm_models = true;
- prelm_w.InitFromFile(plmw);
- prelm_feature_weights.resize(FD::NumFeats());
- prelm_w.InitVector(&prelm_feature_weights);
- if (show_weights)
- cerr << "prelm_weights: " << WeightVector(prelm_feature_weights)<<endl;
- }
- if (show_weights)
- cerr << "+LM weights: " << WeightVector(feature_weights)<<endl;
- }
- bool warn0=conf.count("warn_0_weight");
- bool freeze=!conf.count("no_freeze_feature_set");
- bool early_freeze=freeze && !warn0;
- bool late_freeze=freeze && warn0;
- if (early_freeze) {
- cerr << "Freezing feature set (use --no_freeze_feature_set or --warn_0_weight to prevent)." << endl;
- FD::Freeze(); // this means we can't see the feature names of not-weighted features
- }
-
- // set up translation back end
- if (formalism == "scfg")
- translator.reset(new SCFGTranslator(conf));
- else if (formalism == "fst")
- translator.reset(new FSTTranslator(conf));
- else if (formalism == "pb")
- translator.reset(new PhraseBasedTranslator(conf));
- else if (formalism == "csplit")
- translator.reset(new CompoundSplit(conf));
- else if (formalism == "lextrans")
- translator.reset(new LexicalTrans(conf));
- else if (formalism == "lexalign")
- translator.reset(new LexicalAlign(conf));
- else if (formalism == "tagger")
- translator.reset(new Tagger(conf));
- else
- assert(!"error");
-
- // set up additional scoring features
- vector<shared_ptr<FeatureFunction> > pffs,prelm_only_ffs;
- vector<const FeatureFunction*> late_ffs,prelm_ffs;
- if (conf.count("feature_function") > 0) {
- vector<string> add_ffs;
-// const vector<string>& add_ffs = conf["feature_function"].as<vector<string> >();
- store_conf(conf,"feature_function",&add_ffs);
- for (int i = 0; i < add_ffs.size(); ++i) {
- pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions));
- FeatureFunction const* p=pffs.back().get();
- late_ffs.push_back(p);
- if (has_prelm_models) {
- if (p->NumBytesContext()==0)
- prelm_ffs.push_back(p);
- else
- cerr << "Excluding stateful feature from prelm pruning: "<<add_ffs[i]<<endl;
- }
- }
- }
- if (conf.count("prelm_feature_function") > 0) {
- vector<string> add_ffs;
- store_conf(conf,"prelm_feature_function",&add_ffs);
-// const vector<string>& add_ffs = conf["prelm_feature_function"].as<vector<string> >();
- for (int i = 0; i < add_ffs.size(); ++i) {
- prelm_only_ffs.push_back(make_ff(add_ffs[i],verbose_feature_functions,"prelm-only "));
- prelm_ffs.push_back(prelm_only_ffs.back().get());
- }
- }
-
- vector<shared_ptr<FsaFeatureFunction> > fsa_ffs;
- vector<string> fsa_names;
- store_conf(conf,"fsa_feature_function",&fsa_names);
- for (int i=0;i<fsa_names.size();++i)
- fsa_ffs.push_back(make_fsa_ff(fsa_names[i],verbose_feature_functions,"FSA "));
- if (fsa_ffs.size()>1) {
- //FIXME: support N fsa ffs.
- cerr<<"Only the first fsa FF will be used (FIXME).\n";
- fsa_ffs.resize(1);
- }
- if (!fsa_ffs.empty()) {
- cerr<<"FSA: ";
- show_all_features(fsa_ffs,feature_weights,cerr,cerr,true,true);
- }
-
- if (late_freeze) {
- cerr << "Late freezing feature set (use --no_freeze_feature_set to prevent)." << endl;
- FD::Freeze(); // this means we can't see the feature names of not-weighted features
- }
-
- if (has_prelm_models)
- cerr << "prelm rescoring with "<<prelm_ffs.size()<<" 0-state feature functions. +LM pass will use "<<late_ffs.size()<<" features (not counting rule features)."<<endl;
-
- ModelSet late_models(feature_weights, late_ffs);
- show_models(conf,late_models,"late ");
- ModelSet prelm_models(prelm_feature_weights, prelm_ffs);
- if (has_prelm_models)
- show_models(conf,prelm_models,"prelm ");
-
- int palg = 1;
- if (LowercaseString(str("intersection_strategy",conf)) == "full") {
- palg = 0;
- cerr << "Using full intersection (no pruning).\n";
- }
- int pop_limit=conf["cubepruning_pop_limit"].as<int>();
- const IntersectionConfiguration inter_conf(palg, pop_limit);
-
- const int sample_max_trans = conf.count("max_translation_sample") ?
- conf["max_translation_sample"].as<int>() : 0;
- if (sample_max_trans)
- rng.reset(new RandomNumberGenerator<boost::mt19937>);
- const bool aligner_mode = conf.count("aligner");
- const bool minimal_forests = conf.count("minimal_forests");
- const bool graphviz = conf.count("graphviz");
- const bool joshua_viz = conf.count("show_joshua_visualization");
- const bool encode_b64 = str("vector_format",conf) == "b64";
- const bool kbest = conf.count("k_best");
- const bool unique_kbest = conf.count("unique_k_best");
- const bool crf_uniform_empirical = conf.count("crf_uniform_empirical");
- const bool get_oracle_forest = conf.count("get_oracle_forest");
-
- cfg_options.Validate();
- if (get_oracle_forest)
- oracle.UseConf(conf);
-
- shared_ptr<WriteFile> extract_file;
- if (conf.count("extract_rules"))
- extract_file.reset(new WriteFile(str("extract_rules",conf)));
-
- int combine_size = conf["combine_size"].as<int>();
- if (combine_size < 1) combine_size = 1;
-
- SparseVector<prob_t> acc_vec; // accumulate gradient
- double acc_obj = 0; // accumulate objective
- int g_count = 0; // number of gradient pieces computed
- int sent_id = -1; // line counter
-
+ string buf;
while(*in) {
- NgramCache::Clear(); // clear ngram cache for remote LM (if used)
- Timer::Summarize();
- ++sent_id;
- string buf;
getline(*in, buf);
if (buf.empty()) continue;
- map<string, string> sgml;
- ProcessAndStripSGML(&buf, &sgml);
- if (sgml.find("id") != sgml.end())
- sent_id = atoi(sgml["id"].c_str());
-
- cerr << "\nINPUT: ";
- if (buf.size() < 100)
- cerr << buf << endl;
- else {
- size_t x = buf.rfind(" ", 100);
- if (x == string::npos) x = 100;
- cerr << buf.substr(0, x) << " ..." << endl;
- }
- cerr << " id = " << sent_id << endl;
- string to_translate;
- Lattice ref;
- ParseTranslatorInputLattice(buf, &to_translate, &ref);
- const unsigned srclen=NTokens(to_translate,' ');
-//FIXME: should get the avg. or max source length of the input lattice (like Lattice::dist_(start,end)); but this is only used to scale beam parameters (optionally) anyway so fidelity isn't important.
- const bool has_ref = ref.size() > 0;
- SentenceMetadata smeta(sent_id, ref);
- const bool hadoop_counters = (write_gradient);
- Hypergraph forest; // -LM forest
- translator->ProcessMarkupHints(sgml);
- Timer t("Translation");
- const bool translation_successful =
- translator->Translate(to_translate, &smeta, feature_weights, &forest);
- //TODO: modify translator to incorporate all 0-state model scores immediately?
- translator->SentenceComplete();
- if (!translation_successful) {
- cerr << " NO PARSE FOUND.\n";
- if (hadoop_counters)
- cerr << "reporter:counter:UserCounters,FParseFailed,1" << endl;
- cout << endl << flush;
- continue;
- }
- const bool show_tree_structure=conf.count("show_tree_structure");
- const bool show_features=conf.count("show_features");
- forest_stats(forest," -LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
- if (conf.count("show_expected_length")) {
- const PRPair<double, double> res =
- Inside<PRPair<double, double>,
- PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest);
- cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl;
- }
- if (conf.count("show_partition")) {
- const prob_t z = Inside<prob_t, EdgeProb>(forest);
- cerr << " -LM partition log(Z): " << log(z) << endl;
- }
- if (extract_file)
- ExtractRulesDedupe(forest, extract_file->stream());
-
- if (has_prelm_models) {
- Timer t("prelm rescoring");
- forest.Reweight(prelm_feature_weights);
- Hypergraph prelm_forest;
- ApplyModelSet(forest,
- smeta,
- prelm_models,
- inter_conf, // this is now reduced to exhaustive if all are stateless
- &prelm_forest);
- forest.swap(prelm_forest);
- forest.Reweight(prelm_feature_weights); //FIXME: why the reweighting? here and below. maybe in case we already had a featval for that id and ApplyModelSet only adds prob, doesn't recompute it?
- forest_stats(forest," prelm forest",show_tree_structure,show_features,prelm_feature_weights,oracle.show_derivation);
- }
-
- maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen);
-
- cfg_options.maybe_output_source(forest);
-
- bool has_late_models = !late_models.empty();
- if (has_late_models) {
- Timer t("Forest rescoring:");
- forest.Reweight(feature_weights);
- Hypergraph lm_forest;
- ApplyModelSet(forest,
- smeta,
- late_models,
- inter_conf,
- &lm_forest);
- forest.swap(lm_forest);
- forest.Reweight(feature_weights);
- forest_stats(forest," +LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
- }
-
- maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen);
-
- HgCFG hgcfg(forest);
- cfg_options.prepare(hgcfg);
- if (!fsa_ffs.empty()) {
- Timer t("Target FSA rescoring:");
- if (!has_late_models)
- forest.Reweight(feature_weights);
- Hypergraph fsa_forest;
- assert(fsa_ffs.size()==1);
- ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit);
- cerr << "FSA rescoring with "<<cfg<<" "<<fsa_ffs[0]->describe()<<endl;
- ApplyFsaModels(hgcfg,smeta,*fsa_ffs[0],feature_weights,cfg,&fsa_forest);
- forest.swap(fsa_forest);
- forest.Reweight(feature_weights);
- forest_stats(forest," +FSA forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
- }
-
- /*Oracle Rescoring*/
- if(get_oracle_forest) {
- Oracle o=oracle.ComputeOracle(smeta,&forest,FeatureVector(feature_weights),10,conf["forest_output"].as<std::string>());
- cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
- cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
- o.hope.Print(cerr," +Oracle BLEU");
- o.fear.Print(cerr," -Oracle BLEU");
- //Add 1-best translation (trans) to psuedo-doc vectors
- oracle.IncludeLastScore(&cerr);
- }
-
-
- if (conf.count("forest_output") && !has_ref) {
- ForestWriter writer(str("forest_output",conf), sent_id);
- if (FileExists(writer.fname_)) {
- cerr << " Unioning...\n";
- Hypergraph new_hg;
- {
- ReadFile rf(writer.fname_);
- bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
- assert(succeeded);
- }
- new_hg.Union(forest);
- bool succeeded = writer.Write(new_hg, minimal_forests);
- assert(succeeded);
- } else {
- bool succeeded = writer.Write(forest, minimal_forests);
- assert(succeeded);
- }
- }
-
- if (sample_max_trans) {
- MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0);
- } else {
- if (kbest) {
- //TODO: does this work properly?
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-");
- } else if (csplit_output_plf) {
- cout << HypergraphIO::AsPLF(forest, false) << endl;
- } else {
- if (!graphviz && !has_ref && !joshua_viz) {
- vector<WordID> trans;
- ViterbiESentence(forest, &trans);
- cout << TD::GetString(trans) << endl << flush;
- }
- if (joshua_viz) {
- cout << sent_id << " ||| " << JoshuaVisualizationString(forest) << " ||| 1.0 ||| " << -1.0 << endl << flush;
- }
- }
- }
-
- const int max_trans_beam_size = conf.count("max_translation_beam") ?
- conf["max_translation_beam"].as<int>() : 0;
- if (max_trans_beam_size) {
- Hack::MaxTrans(forest, max_trans_beam_size);
- continue;
- }
-
- if (graphviz && !has_ref) forest.PrintGraphviz();
-
- // the following are only used if write_gradient is true!
- SparseVector<prob_t> full_exp, ref_exp, gradient;
- double log_z = 0, log_ref_z = 0;
- if (write_gradient) {
- const prob_t z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &full_exp);
- log_z = log(z);
- full_exp /= z;
- }
- if (conf.count("show_cfg_search_space"))
- HypergraphIO::WriteAsCFG(forest);
- if (has_ref) {
- if (HG::Intersect(ref, &forest)) {
- cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
- cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl;
- if (crf_uniform_empirical) {
- cerr << " USING UNIFORM WEIGHTS\n";
- for (int i = 0; i < forest.edges_.size(); ++i)
- forest.edges_[i].edge_prob_=prob_t::One();
- } else {
- forest.Reweight(feature_weights);
- cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
- }
- if (hadoop_counters)
- cerr << "reporter:counter:UserCounters,SentencePairsParsed,1" << endl;
- if (conf.count("show_partition")) {
- const prob_t z = Inside<prob_t, EdgeProb>(forest);
- cerr << " Contst. partition log(Z): " << log(z) << endl;
- }
- //DumpKBest(sent_id, forest, 1000);
- if (conf.count("forest_output")) {
- ForestWriter writer(str("forest_output",conf), sent_id);
- if (FileExists(writer.fname_)) {
- cerr << " Unioning...\n";
- Hypergraph new_hg;
- {
- ReadFile rf(writer.fname_);
- bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
- assert(succeeded);
- }
- new_hg.Union(forest);
- bool succeeded = writer.Write(new_hg, minimal_forests);
- assert(succeeded);
- } else {
- bool succeeded = writer.Write(forest, minimal_forests);
- assert(succeeded);
- }
- }
- if (aligner_mode && !output_training_vector)
- AlignerTools::WriteAlignment(smeta.GetSourceLattice(), smeta.GetReference(), forest, &cout);
- if (write_gradient) {
- const prob_t ref_z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp);
- ref_exp /= ref_z;
- if (crf_uniform_empirical) {
- log_ref_z = ref_exp.dot(feature_weights);
- } else {
- log_ref_z = log(ref_z);
- }
- //cerr << " MODEL LOG Z: " << log_z << endl;
- //cerr << " EMPIRICAL LOG Z: " << log_ref_z << endl;
- if ((log_z - log_ref_z) < kMINUS_EPSILON) {
- cerr << "DIFF. ERR! log_z < log_ref_z: " << log_z << " " << log_ref_z << endl;
- exit(1);
- }
- assert(!isnan(log_ref_z));
- ref_exp -= full_exp;
- acc_vec += ref_exp;
- acc_obj += (log_z - log_ref_z);
- }
- if (feature_expectations) {
- const prob_t z =
- InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp);
- ref_exp /= z;
- acc_obj += log(z);
- acc_vec += ref_exp;
- }
-
- if (output_training_vector) {
- acc_vec.erase(0);
- ++g_count;
- if (g_count % combine_size == 0) {
- if (encode_b64) {
- cout << "0\t";
- SparseVector<double> dav; ConvertSV(acc_vec, &dav);
- B64::Encode(acc_obj, dav, &cout);
- cout << endl << flush;
- } else {
- cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush;
- }
- acc_vec.clear();
- acc_obj = 0;
- }
- }
- if (conf.count("graphviz")) forest.PrintGraphviz();
- } else {
- cerr << " REFERENCE UNREACHABLE.\n";
- if (write_gradient) {
- if (hadoop_counters)
- cerr << "reporter:counter:UserCounters,EFParseFailed,1" << endl;
- cout << endl << flush;
- }
- }
- }
- }
- if (output_training_vector && !acc_vec.empty()) {
- if (encode_b64) {
- cout << "0\t";
- SparseVector<double> dav; ConvertSV(acc_vec, &dav);
- B64::Encode(acc_obj, dav, &cout);
- cout << endl << flush;
- } else {
- cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush;
- }
+ decoder.Decode(buf);
}
- exit(0); // maybe this will save some destruction overhead. or g++ without cxx_atexit needed?
return 0;
}
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 98d4711f..84ba19fa 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -13,6 +13,12 @@
#include "ff_register.h"
void register_feature_functions() {
+ static bool registered = false;
+ if (registered) {
+ assert(!"register_feature_functions() called twice!");
+ }
+ registered = true;
+
//TODO: these are worthless example target FSA ffs. remove later
RegisterFsaImpl<SameFirstLetter>(true);
RegisterFsaImpl<LongerThanPrev>(true);
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 196ffa46..f90e6bc0 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -1,25 +1,94 @@
#include "decoder.h"
+#include <tr1/unordered_map>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
-#include "ff_factory.h"
-#include "cfg_options.h"
+#include "sampler.h"
#include "stringlib.h"
+#include "weights.h"
+#include "filelib.h"
+#include "fdict.h"
+#include "timing_stats.h"
+
+#include "translator.h"
+#include "phrasebased_translator.h"
+#include "tagger.h"
+#include "lextrans.h"
+#include "lexalign.h"
+#include "csplit.h"
+
+#include "lattice.h"
#include "hg.h"
+#include "sentence_metadata.h"
+#include "hg_intersect.h"
+
+#include "apply_fsa_models.h"
+#include "oracle_bleu.h"
+#include "apply_models.h"
+#include "ff.h"
+#include "ff_factory.h"
+#include "cfg_options.h"
+#include "viterbi.h"
+#include "kbest.h"
+#include "inside_outside.h"
+#include "exp_semiring.h"
+#include "sentence_metadata.h"
+#include "hg_cfg.h"
+
+#include "forest_writer.h" // TODO this section should probably be handled by an Observer
+#include "hg_io.h"
+#include "aligner.h"
+static const double kMINUS_EPSILON = -1e-6; // don't be too strict
using namespace std;
+using namespace std::tr1;
using boost::shared_ptr;
namespace po = boost::program_options;
+static bool verbose_feature_functions=true;
+
+namespace Hack { void MaxTrans(const Hypergraph& in, int beam_size); }
+namespace NgramCache { void Clear(); }
+
+void DecoderObserver::NotifySourceParseFailure(const SentenceMetadata&) {}
+void DecoderObserver::NotifyTranslationForest(const SentenceMetadata&, Hypergraph*) {}
+void DecoderObserver::NotifyAlignmentFailure(const SentenceMetadata&) {}
+void DecoderObserver::NotifyAlignmentForest(const SentenceMetadata&, Hypergraph*) {}
+void DecoderObserver::NotifyDecodingComplete(const SentenceMetadata&) {}
+
+struct ELengthWeightFunction {
+ double operator()(const Hypergraph::Edge& e) const {
+ return e.rule_->ELength() - e.rule_->Arity();
+ }
+};
inline void ShowBanner() {
cerr << "cdec v1.0 (c) 2009-2010 by Chris Dyer\n";
}
+inline void show_models(po::variables_map const& conf,ModelSet &ms,char const* header) {
+ cerr<<header<<": ";
+ ms.show_features(cerr,cerr,conf.count("warn_0_weight"));
+}
+
inline string str(char const* name,po::variables_map const& conf) {
return conf[name].as<string>();
}
+inline bool prelm_weights_string(po::variables_map const& conf,string &s) {
+ if (conf.count("prelm_weights")) {
+ s=str("prelm_weights",conf);
+ return true;
+ }
+ if (conf.count("prelm_copy_weights")) {
+ s=str("weights",conf);
+ return true;
+ }
+ return false;
+}
+
+
+
// print just the --long_opt names suitable for bash compgen
inline void print_options(std::ostream &out,po::options_description const& opts) {
typedef std::vector< shared_ptr<po::option_description> > Ds;
@@ -32,22 +101,214 @@ inline void print_options(std::ostream &out,po::options_description const& opts)
out << '"';
}
+template <class V>
+inline bool store_conf(po::variables_map const& conf,std::string const& name,V *v) {
+ if (conf.count(name)) {
+ *v=conf[name].as<V>();
+ return true;
+ }
+ return false;
+}
+
+inline shared_ptr<FeatureFunction> make_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") {
+ string ff, param;
+ SplitCommandAndParam(ffp, &ff, &param);
+ cerr << pre << "feature: " << ff;
+ if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n";
+ else cerr << " (no config parameters)\n";
+ shared_ptr<FeatureFunction> pf = ff_registry.Create(ff, param);
+ if (!pf) exit(1);
+ int nbyte=pf->NumBytesContext();
+ if (verbose_feature_functions)
+ cerr<<"State is "<<nbyte<<" bytes for "<<pre<<"feature "<<ffp<<endl;
+ return pf;
+}
+
+inline shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") {
+ string ff, param;
+ SplitCommandAndParam(ffp, &ff, &param);
+ cerr << "FSA Feature: " << ff;
+ if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n";
+ else cerr << " (no config parameters)\n";
+ shared_ptr<FsaFeatureFunction> pf = fsa_ff_registry.Create(ff, param);
+ if (!pf) exit(1);
+ if (verbose_feature_functions)
+ cerr<<"State is "<<pf->state_bytes()<<" bytes for "<<pre<<"feature "<<ffp<<endl;
+ return pf;
+}
+
struct DecoderImpl {
- DecoderImpl(int argc, char** argv, istream* cfg);
- bool Decode(const string& input) {
- return false;
+ DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg);
+ ~DecoderImpl();
+ bool Decode(const string& input, DecoderObserver*);
+ void SetWeights(const vector<double>& weights) {
}
- bool DecodeProduceHypergraph(const string& input, Hypergraph* hg) {
+
+ void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,WeightVector *weights=0,bool show_deriv=false) {
+ cerr << viterbi_stats(forest,name,true,show_tree,show_deriv);
+ if (show_features) {
+ cerr << name<<" features: ";
+/* Hypergraph::Edge const* best=forest.ViterbiGoalEdge();
+ if (!best)
+ cerr << name<<" has no goal edge.";
+ else
+ cerr<<best->feature_values_;
+*/
+ cerr << ViterbiFeatures(forest,weights);
+ cerr << endl;
+ }
+ }
+
+ void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,DenseWeightVector const& feature_weights, bool sd=false) {
+ WeightVector fw(feature_weights);
+ forest_stats(forest,name,show_tree,show_features,&fw,sd);
+ }
+
+ bool beam_param(po::variables_map const& conf,string const& name,double *val,bool scale_srclen=false,double srclen=1) {
+ if (conf.count(name)) {
+ *val=conf[name].as<double>()*(scale_srclen?srclen:1);
+ return true;
+ }
return false;
}
- void SetWeights(const vector<double>& weights) {
+
+ void maybe_prune(Hypergraph &forest,po::variables_map const& conf,string nbeam,string ndensity,string forestname,double srclen) {
+ double beam_prune=0,density_prune=0;
+ bool use_beam_prune=beam_param(conf,nbeam,&beam_prune,conf.count("scale_prune_srclen"),srclen);
+ bool use_density_prune=beam_param(conf,ndensity,&density_prune);
+ if (use_beam_prune || use_density_prune) {
+ double presize=forest.edges_.size();
+ vector<bool> preserve_mask,*pm=0;
+ if (conf.count("csplit_preserve_full_word")) {
+ preserve_mask.resize(forest.edges_.size());
+ preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true;
+ pm=&preserve_mask;
+ }
+ forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1,conf["promise_power"].as<double>());
+ if (!forestname.empty()) forestname=" "+forestname;
+ forest_stats(forest," Pruned "+forestname+" forest",false,false,0,false);
+ cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl;
+ }
+ }
+
+ void SampleRecurse(const Hypergraph& hg, const vector<SampleSet<prob_t> >& ss, int n, vector<WordID>* out) {
+ const SampleSet<prob_t>& s = ss[n];
+ int i = rng->SelectSample(s);
+ const Hypergraph::Edge& edge = hg.edges_[hg.nodes_[n].in_edges_[i]];
+ vector<vector<WordID> > ants(edge.tail_nodes_.size());
+ for (int j = 0; j < ants.size(); ++j)
+ SampleRecurse(hg, ss, edge.tail_nodes_[j], &ants[j]);
+
+ vector<const vector<WordID>*> pants(ants.size());
+ for (int j = 0; j < ants.size(); ++j) pants[j] = &ants[j];
+ edge.rule_->ESubstitute(pants, out);
+ }
+
+ struct SampleSort {
+ bool operator()(const pair<int,string>& a, const pair<int,string>& b) const {
+ return a.first > b.first;
+ }
+ };
+
+ // TODO this should be handled by an Observer
+ void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) {
+ unordered_map<string, int, boost::hash<string> > m;
+ hg->PushWeightsToGoal();
+ const int num_nodes = hg->nodes_.size();
+ vector<SampleSet<prob_t> > ss(num_nodes);
+ for (int i = 0; i < num_nodes; ++i) {
+ SampleSet<prob_t>& s = ss[i];
+ const vector<int>& in_edges = hg->nodes_[i].in_edges_;
+ for (int j = 0; j < in_edges.size(); ++j) {
+ s.add(hg->edges_[in_edges[j]].edge_prob_);
+ }
+ }
+ for (int i = 0; i < samples; ++i) {
+ vector<WordID> yield;
+ SampleRecurse(*hg, ss, hg->nodes_.size() - 1, &yield);
+ const string trans = TD::GetString(yield);
+ ++m[trans];
+ }
+ vector<pair<int, string> > dist;
+ for (unordered_map<string, int, boost::hash<string> >::iterator i = m.begin();
+ i != m.end(); ++i) {
+ dist.push_back(make_pair(i->second, i->first));
+ }
+ sort(dist.begin(), dist.end(), SampleSort());
+ if (k) {
+ for (int i = 0; i < k; ++i)
+ cout << dist[i].first << " ||| " << dist[i].second << endl;
+ } else {
+ cout << dist[0].second << endl;
+ }
}
- po::variables_map conf;
+ void ParseTranslatorInputLattice(const string& line, string* input, Lattice* ref) {
+ string sref;
+ ParseTranslatorInput(line, input, &sref);
+ if (sref.size() > 0) {
+ assert(ref);
+ LatticeTools::ConvertTextOrPLF(sref, ref);
+ }
+ }
+
+ po::variables_map& conf;
+ OracleBleu oracle;
CFGOptions cfg_options;
+ string formalism;
+ shared_ptr<Translator> translator;
+ vector<double> feature_weights,prelm_feature_weights;
+ Weights w,prelm_w;
+ vector<shared_ptr<FeatureFunction> > pffs,prelm_only_ffs;
+ vector<const FeatureFunction*> late_ffs,prelm_ffs;
+ vector<shared_ptr<FsaFeatureFunction> > fsa_ffs;
+ vector<string> fsa_names;
+ ModelSet* late_models, *prelm_models;
+ IntersectionConfiguration* inter_conf;
+ shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;
+ int sample_max_trans;
+ bool aligner_mode;
+ bool minimal_forests;
+ bool graphviz;
+ bool joshua_viz;
+ bool encode_b64;
+ bool kbest;
+ bool unique_kbest;
+ bool crf_uniform_empirical;
+ bool get_oracle_forest;
+ shared_ptr<WriteFile> extract_file;
+ int combine_size;
+ int sent_id;
+ SparseVector<prob_t> acc_vec; // accumulate gradient
+ double acc_obj; // accumulate objective
+ int g_count; // number of gradient pieces computed
+ bool has_prelm_models;
+ int pop_limit;
+ bool csplit_output_plf;
+ bool write_gradient; // TODO Observer
+ bool feature_expectations; // TODO Observer
+ bool output_training_vector; // TODO Observer
+
+ static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {
+ for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it)
+ trg->set_value(it->first, it->second);
+ }
};
-DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) {
+DecoderImpl::~DecoderImpl() {
+ if (output_training_vector && !acc_vec.empty()) {
+ if (encode_b64) {
+ cout << "0\t";
+ SparseVector<double> dav; ConvertSV(acc_vec, &dav);
+ B64::Encode(acc_obj, dav, &cout);
+ cout << endl << flush;
+ } else {
+ cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush;
+ }
+ }
+}
+
+DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg) : conf(conf) {
if (cfg) { if (argc || argv) { cerr << "DecoderImpl() can only take a file or command line options, not both\n"; exit(1); } }
bool show_config;
bool show_weights;
@@ -64,7 +325,7 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) {
("prelm_feature_function",po::value<vector<string> >()->composing(),"Additional feature functions for prelm pass only (in addition to the 0-state subset of feature_function")
("keep_prelm_cube_order","DEPRECATED (always enabled). when forest rescoring with final models, use the edge ordering from the prelm pruning features*weights. only meaningful if --prelm_weights given. UNTESTED but assume that cube pruning gives a sensible result, and that 'good' (as tuned for bleu w/ prelm features) edges come first.")
("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)")
- ("no_freeze_feature_set,Z", "Do not freeze feature set after reading feature weights file")
+ ("freeze_feature_set,Z", "Freeze feature set after reading feature weights file")
("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)")
("fsa_feature_function,A",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)")
("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names()
@@ -142,7 +403,7 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) {
}
if (conf.count("show_config")) // special handling needed because we only want to notify() once.
show_config=true;
- if (conf.count("config") || cfg) {
+ if (conf.count("config") && !cfg) {
typedef vector<string> Cs;
Cs cs=conf["config"].as<Cs>();
for (int i=0;i<cs.size();++i) {
@@ -151,8 +412,8 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) {
ReadFile conff(cfg);
po::store(po::parse_config_file(*conff, dconfig_options), conf);
}
- if (cfg) po::store(po::parse_config_file(*cfg, dconfig_options), conf);
}
+ if (cfg) po::store(po::parse_config_file(*cfg, dconfig_options), conf);
po::notify(conf);
if (show_config && !cfg_files.empty()) {
cerr<< "\nConfig files:\n\n";
@@ -171,9 +432,9 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) {
if (conf.count("list_feature_functions")) {
cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n";
- // ff_registry.DisplayList(); //TODO
+ ff_registry.DisplayList(); //TODO
cerr << "Available FSA feature functions (specify with --fsa_feature_function):\n";
- // fsa_ff_registry.DisplayList(); // TODO
+ fsa_ff_registry.DisplayList(); // TODO
cerr << endl;
exit(1);
}
@@ -191,38 +452,449 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) {
exit(1);
}
- const string formalism = LowercaseString(str("formalism",conf));
+ formalism = LowercaseString(str("formalism",conf));
if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") {
cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n";
cerr << dcmdline_options << endl;
exit(1);
}
-}
-Decoder::Decoder(istream* cfg) {
- pimpl_.reset(new DecoderImpl(0,0,cfg));
-}
+ write_gradient = conf.count("cll_gradient");
+ feature_expectations = conf.count("feature_expectations");
+ if (write_gradient && feature_expectations) {
+ cerr << "You can only specify --gradient or --feature_expectations, not both!\n";
+ exit(1);
+ }
+ output_training_vector = (write_gradient || feature_expectations);
-Decoder::Decoder(int argc, char** argv) {
- pimpl_.reset(new DecoderImpl(argc, argv, 0));
-}
+ const string formalism = LowercaseString(str("formalism",conf));
+ const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word");
+ if (csplit_preserve_full_word &&
+ (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")||conf.count("prelm_beam_prune")||conf.count("prelm_density_prune")))) {
+ cerr << "--csplit_preserve_full_word should only be "
+ << "used with csplit AND --*_prune!\n";
+ exit(1);
+ }
+ csplit_output_plf = conf.count("csplit_output_plf");
+ if (csplit_output_plf && formalism != "csplit") {
+ cerr << "--csplit_output_plf should only be used with csplit!\n";
+ exit(1);
+ }
-Decoder::~Decoder() {}
+ // load feature weights (and possibly freeze feature set)
+ has_prelm_models = false;
+ if (conf.count("weights")) {
+ w.InitFromFile(str("weights",conf));
+ feature_weights.resize(FD::NumFeats());
+ w.InitVector(&feature_weights);
+ string plmw;
+ if (prelm_weights_string(conf,plmw)) {
+ has_prelm_models = true;
+ prelm_w.InitFromFile(plmw);
+ prelm_feature_weights.resize(FD::NumFeats());
+ prelm_w.InitVector(&prelm_feature_weights);
+ if (show_weights)
+ cerr << "prelm_weights: " << WeightVector(prelm_feature_weights)<<endl;
+ }
+ if (show_weights)
+ cerr << "+LM weights: " << WeightVector(feature_weights)<<endl;
+ }
+ bool warn0=conf.count("warn_0_weight");
+ bool freeze=conf.count("freeze_feature_set");
+ bool early_freeze=freeze && !warn0;
+ bool late_freeze=freeze && warn0;
+ if (early_freeze) {
+ cerr << "Freezing feature set" << endl;
+ FD::Freeze(); // this means we can't see the feature names of not-weighted features
+ }
-bool Decoder::Decode(const string& input) {
- return pimpl_->Decode(input);
-}
+ // set up translation back end
+ if (formalism == "scfg")
+ translator.reset(new SCFGTranslator(conf));
+ else if (formalism == "fst")
+ translator.reset(new FSTTranslator(conf));
+ else if (formalism == "pb")
+ translator.reset(new PhraseBasedTranslator(conf));
+ else if (formalism == "csplit")
+ translator.reset(new CompoundSplit(conf));
+ else if (formalism == "lextrans")
+ translator.reset(new LexicalTrans(conf));
+ else if (formalism == "lexalign")
+ translator.reset(new LexicalAlign(conf));
+ else if (formalism == "tagger")
+ translator.reset(new Tagger(conf));
+ else
+ assert(!"error");
+
+ // set up additional scoring features
+ if (conf.count("feature_function") > 0) {
+ vector<string> add_ffs;
+// const vector<string>& add_ffs = conf["feature_function"].as<vector<string> >();
+ store_conf(conf,"feature_function",&add_ffs);
+ for (int i = 0; i < add_ffs.size(); ++i) {
+ pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions));
+ FeatureFunction const* p=pffs.back().get();
+ late_ffs.push_back(p);
+ if (has_prelm_models) {
+ if (p->NumBytesContext()==0)
+ prelm_ffs.push_back(p);
+ else
+ cerr << "Excluding stateful feature from prelm pruning: "<<add_ffs[i]<<endl;
+ }
+ }
+ }
+ if (conf.count("prelm_feature_function") > 0) {
+ vector<string> add_ffs;
+ store_conf(conf,"prelm_feature_function",&add_ffs);
+// const vector<string>& add_ffs = conf["prelm_feature_function"].as<vector<string> >();
+ for (int i = 0; i < add_ffs.size(); ++i) {
+ prelm_only_ffs.push_back(make_ff(add_ffs[i],verbose_feature_functions,"prelm-only "));
+ prelm_ffs.push_back(prelm_only_ffs.back().get());
+ }
+ }
-bool Decoder::DecodeProduceHypergraph(const string& input, Hypergraph* hg) {
- return pimpl_->DecodeProduceHypergraph(input, hg);
+ store_conf(conf,"fsa_feature_function",&fsa_names);
+ for (int i=0;i<fsa_names.size();++i)
+ fsa_ffs.push_back(make_fsa_ff(fsa_names[i],verbose_feature_functions,"FSA "));
+ if (fsa_ffs.size()>1) {
+ //FIXME: support N fsa ffs.
+ cerr<<"Only the first fsa FF will be used (FIXME).\n";
+ fsa_ffs.resize(1);
+ }
+ if (!fsa_ffs.empty()) {
+ cerr<<"FSA: ";
+ show_all_features(fsa_ffs,feature_weights,cerr,cerr,true,true);
+ }
+
+ if (late_freeze) {
+ cerr << "Late freezing feature set (use --no_freeze_feature_set to prevent)." << endl;
+ FD::Freeze(); // this means we can't see the feature names of not-weighted features
+ }
+
+ if (has_prelm_models)
+ cerr << "prelm rescoring with "<<prelm_ffs.size()<<" 0-state feature functions. +LM pass will use "<<late_ffs.size()<<" features (not counting rule features)."<<endl;
+
+ late_models = new ModelSet(feature_weights, late_ffs);
+ show_models(conf,*late_models,"late ");
+ prelm_models = new ModelSet(prelm_feature_weights, prelm_ffs);
+ if (has_prelm_models)
+ show_models(conf,*prelm_models,"prelm ");
+
+ int palg = 1;
+ if (LowercaseString(str("intersection_strategy",conf)) == "full") {
+ palg = 0;
+ cerr << "Using full intersection (no pruning).\n";
+ }
+ pop_limit=conf["cubepruning_pop_limit"].as<int>();
+ inter_conf = new IntersectionConfiguration(palg, pop_limit);
+
+ sample_max_trans = conf.count("max_translation_sample") ?
+ conf["max_translation_sample"].as<int>() : 0;
+ if (sample_max_trans)
+ rng.reset(new RandomNumberGenerator<boost::mt19937>);
+ aligner_mode = conf.count("aligner");
+ minimal_forests = conf.count("minimal_forests");
+ graphviz = conf.count("graphviz");
+ joshua_viz = conf.count("show_joshua_visualization");
+ encode_b64 = str("vector_format",conf) == "b64";
+ kbest = conf.count("k_best");
+ unique_kbest = conf.count("unique_k_best");
+ crf_uniform_empirical = conf.count("crf_uniform_empirical");
+ get_oracle_forest = conf.count("get_oracle_forest");
+
+ cfg_options.Validate();
+
+ if (conf.count("extract_rules"))
+ extract_file.reset(new WriteFile(str("extract_rules",conf)));
+
+ combine_size = conf["combine_size"].as<int>();
+ if (combine_size < 1) combine_size = 1;
+ sent_id = -1;
+ acc_obj = 0; // accumulate objective
+ g_count = 0; // number of gradient pieces computed
}
-void Decoder::SetWeights(const vector<double>& weights) {
- pimpl_->SetWeights(weights);
+Decoder::Decoder(istream* cfg) { pimpl_.reset(new DecoderImpl(conf,0,0,cfg)); }
+Decoder::Decoder(int argc, char** argv) { pimpl_.reset(new DecoderImpl(conf,argc, argv, 0)); }
+Decoder::~Decoder() {}
+bool Decoder::Decode(const string& input, DecoderObserver* o) {
+ bool del = false;
+ if (!o) { o = new DecoderObserver; del = true; }
+ const bool res = pimpl_->Decode(input, o);
+ if (del) delete o;
+ return res;
}
+void Decoder::SetWeights(const vector<double>& weights) { pimpl_->SetWeights(weights); }
+
+
+bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
+ string buf = input;
+ NgramCache::Clear(); // clear ngram cache for remote LM (if used)
+ Timer::Summarize();
+ ++sent_id;
+ map<string, string> sgml;
+ ProcessAndStripSGML(&buf, &sgml);
+ if (sgml.find("id") != sgml.end())
+ sent_id = atoi(sgml["id"].c_str());
+
+ cerr << "\nINPUT: ";
+ if (buf.size() < 100)
+ cerr << buf << endl;
+ else {
+ size_t x = buf.rfind(" ", 100);
+ if (x == string::npos) x = 100;
+ cerr << buf.substr(0, x) << " ..." << endl;
+ }
+ cerr << " id = " << sent_id << endl;
+ string to_translate;
+ Lattice ref;
+ ParseTranslatorInputLattice(buf, &to_translate, &ref);
+ const unsigned srclen=NTokens(to_translate,' ');
+//FIXME: should get the avg. or max source length of the input lattice (like Lattice::dist_(start,end)); but this is only used to scale beam parameters (optionally) anyway so fidelity isn't important.
+ const bool has_ref = ref.size() > 0;
+ SentenceMetadata smeta(sent_id, ref);
+ Hypergraph forest; // -LM forest
+ translator->ProcessMarkupHints(sgml);
+ Timer t("Translation");
+ const bool translation_successful =
+ translator->Translate(to_translate, &smeta, feature_weights, &forest);
+ //TODO: modify translator to incorporate all 0-state model scores immediately?
+ translator->SentenceComplete();
+
+ if (!translation_successful) {
+ cerr << " NO PARSE FOUND.\n";
+ o->NotifySourceParseFailure(smeta);
+ o->NotifyDecodingComplete(smeta);
+ return false;
+ }
+
+ const bool show_tree_structure=conf.count("show_tree_structure");
+ const bool show_features=conf.count("show_features");
+ forest_stats(forest," -LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
+ if (conf.count("show_expected_length")) {
+ const PRPair<double, double> res =
+ Inside<PRPair<double, double>,
+ PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest);
+ cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl;
+ }
+ if (conf.count("show_partition")) {
+ const prob_t z = Inside<prob_t, EdgeProb>(forest);
+ cerr << " -LM partition log(Z): " << log(z) << endl;
+ }
+
+ if (has_prelm_models) {
+ Timer t("prelm rescoring");
+ forest.Reweight(prelm_feature_weights);
+ Hypergraph prelm_forest;
+ ApplyModelSet(forest,
+ smeta,
+ *prelm_models,
+ *inter_conf, // this is now reduced to exhaustive if all are stateless
+ &prelm_forest);
+ forest.swap(prelm_forest);
+ forest.Reweight(prelm_feature_weights); //FIXME: why the reweighting? here and below. maybe in case we already had a featval for that id and ApplyModelSet only adds prob, doesn't recompute it?
+ forest_stats(forest," prelm forest",show_tree_structure,show_features,prelm_feature_weights,oracle.show_derivation);
+ }
+
+ maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen);
+
+ cfg_options.maybe_output_source(forest);
+
+ bool has_late_models = !late_models->empty();
+ if (has_late_models) {
+ Timer t("Forest rescoring:");
+ forest.Reweight(feature_weights);
+ Hypergraph lm_forest;
+ ApplyModelSet(forest,
+ smeta,
+ *late_models,
+ *inter_conf,
+ &lm_forest);
+ forest.swap(lm_forest);
+ forest.Reweight(feature_weights);
+ forest_stats(forest," +LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
+ }
+
+ maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen);
+
+ HgCFG hgcfg(forest);
+ cfg_options.prepare(hgcfg);
+
+ if (!fsa_ffs.empty()) {
+ Timer t("Target FSA rescoring:");
+ if (!has_late_models)
+ forest.Reweight(feature_weights);
+ Hypergraph fsa_forest;
+ assert(fsa_ffs.size()==1);
+ ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit);
+ cerr << "FSA rescoring with "<<cfg<<" "<<fsa_ffs[0]->describe()<<endl;
+ ApplyFsaModels(hgcfg,smeta,*fsa_ffs[0],feature_weights,cfg,&fsa_forest);
+ forest.swap(fsa_forest);
+ forest.Reweight(feature_weights);
+ forest_stats(forest," +FSA forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
+ }
+
+ // Oracle Rescoring
+ if(get_oracle_forest) {
+ Oracle oc=oracle.ComputeOracle(smeta,&forest,FeatureVector(feature_weights),10,conf["forest_output"].as<std::string>());
+ cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
+ cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
+ oc.hope.Print(cerr," +Oracle BLEU");
+ oc.fear.Print(cerr," -Oracle BLEU");
+ //Add 1-best translation (trans) to psuedo-doc vectors
+ oracle.IncludeLastScore(&cerr);
+ }
+
+ // TODO I think this should probably be handled by an Observer
+ if (conf.count("forest_output") && !has_ref) {
+ ForestWriter writer(str("forest_output",conf), sent_id);
+ if (FileExists(writer.fname_)) {
+ cerr << " Unioning...\n";
+ Hypergraph new_hg;
+ {
+ ReadFile rf(writer.fname_);
+ bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
+ assert(succeeded);
+ }
+ new_hg.Union(forest);
+ bool succeeded = writer.Write(new_hg, minimal_forests);
+ assert(succeeded);
+ } else {
+ bool succeeded = writer.Write(forest, minimal_forests);
+ assert(succeeded);
+ }
+ }
+
+ // TODO I think this should probably be handled by an Observer
+ if (sample_max_trans) {
+ MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0);
+ } else {
+ if (kbest) {
+ //TODO: does this work properly?
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-");
+ } else if (csplit_output_plf) {
+ cout << HypergraphIO::AsPLF(forest, false) << endl;
+ } else {
+ if (!graphviz && !has_ref && !joshua_viz) {
+ vector<WordID> trans;
+ ViterbiESentence(forest, &trans);
+ cout << TD::GetString(trans) << endl << flush;
+ }
+ if (joshua_viz) {
+ cout << sent_id << " ||| " << JoshuaVisualizationString(forest) << " ||| 1.0 ||| " << -1.0 << endl << flush;
+ }
+ }
+ }
+
+ // TODO this should be handled by an Observer
+ const int max_trans_beam_size = conf.count("max_translation_beam") ?
+ conf["max_translation_beam"].as<int>() : 0;
+ if (max_trans_beam_size) {
+ Hack::MaxTrans(forest, max_trans_beam_size);
+ return true;
+ }
+
+ // TODO this should be handled by an Observer
+ if (graphviz && !has_ref) forest.PrintGraphviz();
-void InitCommandLine(int argc, char** argv, po::variables_map* confp) {
- po::variables_map &conf=*confp;
+ // the following are only used if write_gradient is true!
+ SparseVector<prob_t> full_exp, ref_exp, gradient;
+ double log_z = 0, log_ref_z = 0;
+ if (write_gradient) {
+ const prob_t z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &full_exp);
+ log_z = log(z);
+ full_exp /= z;
+ }
+ if (conf.count("show_cfg_search_space"))
+ HypergraphIO::WriteAsCFG(forest);
+ if (has_ref) {
+ if (HG::Intersect(ref, &forest)) {
+ cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
+ cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl;
+ if (crf_uniform_empirical) {
+ cerr << " USING UNIFORM WEIGHTS\n";
+ for (int i = 0; i < forest.edges_.size(); ++i)
+ forest.edges_[i].edge_prob_=prob_t::One();
+ } else {
+ forest.Reweight(feature_weights);
+ cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
+ }
+ if (conf.count("show_partition")) {
+ const prob_t z = Inside<prob_t, EdgeProb>(forest);
+ cerr << " Contst. partition log(Z): " << log(z) << endl;
+ }
+ if (conf.count("forest_output")) {
+ ForestWriter writer(str("forest_output",conf), sent_id);
+ if (FileExists(writer.fname_)) {
+ cerr << " Unioning...\n";
+ Hypergraph new_hg;
+ {
+ ReadFile rf(writer.fname_);
+ bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
+ assert(succeeded);
+ }
+ new_hg.Union(forest);
+ bool succeeded = writer.Write(new_hg, minimal_forests);
+ assert(succeeded);
+ } else {
+ bool succeeded = writer.Write(forest, minimal_forests);
+ assert(succeeded);
+ }
+ }
+ if (aligner_mode && !output_training_vector)
+ AlignerTools::WriteAlignment(smeta.GetSourceLattice(), smeta.GetReference(), forest, &cout);
+ if (write_gradient) {
+ const prob_t ref_z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp);
+ ref_exp /= ref_z;
+ if (crf_uniform_empirical) {
+ log_ref_z = ref_exp.dot(feature_weights);
+ } else {
+ log_ref_z = log(ref_z);
+ }
+ //cerr << " MODEL LOG Z: " << log_z << endl;
+ //cerr << " EMPIRICAL LOG Z: " << log_ref_z << endl;
+ if ((log_z - log_ref_z) < kMINUS_EPSILON) {
+ cerr << "DIFF. ERR! log_z < log_ref_z: " << log_z << " " << log_ref_z << endl;
+ exit(1);
+ }
+ assert(!isnan(log_ref_z));
+ ref_exp -= full_exp;
+ acc_vec += ref_exp;
+ acc_obj += (log_z - log_ref_z);
+ }
+ if (feature_expectations) {
+ const prob_t z =
+ InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp);
+ ref_exp /= z;
+ acc_obj += log(z);
+ acc_vec += ref_exp;
+ }
+ if (output_training_vector) {
+ acc_vec.erase(0);
+ ++g_count;
+ if (g_count % combine_size == 0) {
+ if (encode_b64) {
+ cout << "0\t";
+ SparseVector<double> dav; ConvertSV(acc_vec, &dav);
+ B64::Encode(acc_obj, dav, &cout);
+ cout << endl << flush;
+ } else {
+ cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush;
+ }
+ acc_vec.clear();
+ acc_obj = 0;
+ }
+ }
+ if (conf.count("graphviz")) forest.PrintGraphviz();
+ } else {
+ cerr << " REFERENCE UNREACHABLE.\n";
+ if (write_gradient) {
+ cout << endl << flush;
+ }
+ }
+ }
+ return true;
}
+
diff --git a/decoder/decoder.h b/decoder/decoder.h
index ae1b9bb0..5dd6e1aa 100644
--- a/decoder/decoder.h
+++ b/decoder/decoder.h
@@ -5,17 +5,29 @@
#include <string>
#include <vector>
#include <boost/shared_ptr.hpp>
+#include <boost/program_options/variables_map.hpp>
+class SentenceMetadata;
struct Hypergraph;
struct DecoderImpl;
+
+struct DecoderObserver {
+ virtual void NotifySourceParseFailure(const SentenceMetadata& smeta);
+ virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg);
+ virtual void NotifyAlignmentFailure(const SentenceMetadata& semta);
+ virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg);
+ virtual void NotifyDecodingComplete(const SentenceMetadata& smeta);
+};
+
struct Decoder {
Decoder(int argc, char** argv);
Decoder(std::istream* config_file);
- bool Decode(const std::string& input);
- bool DecodeProduceHypergraph(const std::string& input, Hypergraph* hg);
+ bool Decode(const std::string& input, DecoderObserver* observer = NULL);
void SetWeights(const std::vector<double>& weights);
~Decoder();
+ const boost::program_options::variables_map& GetConf() const { return conf; }
private:
+ boost::program_options::variables_map conf;
boost::shared_ptr<DecoderImpl> pimpl_;
};
diff --git a/decoder/ff_factory.h b/decoder/ff_factory.h
index 5eb68c8b..92334396 100644
--- a/decoder/ff_factory.h
+++ b/decoder/ff_factory.h
@@ -20,6 +20,8 @@
#include <boost/shared_ptr.hpp>
+#include "ff_fsa_dynamic.h"
+
class FeatureFunction;
class FsaFeatureFunction;