diff options
-rw-r--r-- | decoder/cdec.cc | 870 | ||||
-rw-r--r-- | decoder/cdec_ff.cc | 6 | ||||
-rw-r--r-- | decoder/decoder.cc | 736 | ||||
-rw-r--r-- | decoder/decoder.h | 16 | ||||
-rw-r--r-- | decoder/ff_factory.h | 2 |
5 files changed, 731 insertions, 899 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 56e103aa..97ec6798 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -1,888 +1,28 @@ #include <iostream> -#include <fstream> -#include <tr1/unordered_map> -#include <tr1/unordered_set> -#include <boost/shared_ptr.hpp> -#include <boost/program_options.hpp> -#include <boost/program_options/variables_map.hpp> - -#include "decoder.h" -#include "oracle_bleu.h" -#include "timing_stats.h" -#include "translator.h" -#include "phrasebased_translator.h" -#include "aligner.h" -#include "stringlib.h" -#include "forest_writer.h" -#include "hg_io.h" #include "filelib.h" -#include "sampler.h" -#include "sparse_vector.h" -#include "tagger.h" -#include "lextrans.h" -#include "lexalign.h" -#include "csplit.h" -#include "weights.h" -#include "tdict.h" -#include "ff.h" -#include "ff_fsa_dynamic.h" -#include "ff_factory.h" -#include "hg_intersect.h" -#include "apply_models.h" -#include "viterbi.h" -#include "kbest.h" -#include "inside_outside.h" -#include "exp_semiring.h" -#include "sentence_metadata.h" -#include "scorer.h" -#include "apply_fsa_models.h" -#include "program_options.h" -#include "cfg_options.h" - -CFGOptions cfg_options; +#include "decoder.h" using namespace std; -using namespace std::tr1; -using boost::shared_ptr; -namespace po = boost::program_options; - -vector<string> cfg_files; -bool show_config=false; -bool show_weights=false; -bool verbose_feature_functions=true; - -// some globals ... -boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng; -static const double kMINUS_EPSILON = -1e-6; // don't be too strict - -namespace Hack { void MaxTrans(const Hypergraph& in, int beam_size); } -namespace NgramCache { void Clear(); } - -void ShowBanner() { - cerr << "cdec v1.0 (c) 2009-2010 by Chris Dyer\n"; -} - -void ParseTranslatorInputLattice(const string& line, string* input, Lattice* ref) { - string sref; - ParseTranslatorInput(line, input, &sref); - if (sref.size() > 0) { - assert(ref); - LatticeTools::ConvertTextOrPLF(sref, ref); - } -} - -void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) { - for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it) - trg->set_value(it->first, it->second); -} - -inline string str(char const* name,po::variables_map const& conf) { - return conf[name].as<string>(); -} - -shared_ptr<FeatureFunction> make_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") { - string ff, param; - SplitCommandAndParam(ffp, &ff, ¶m); - cerr << pre << "feature: " << ff; - if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; - else cerr << " (no config parameters)\n"; - shared_ptr<FeatureFunction> pf = ff_registry.Create(ff, param); - if (!pf) exit(1); - int nbyte=pf->NumBytesContext(); - if (verbose_feature_functions) - cerr<<"State is "<<nbyte<<" bytes for "<<pre<<"feature "<<ffp<<endl; - return pf; -} - -shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") { - string ff, param; - SplitCommandAndParam(ffp, &ff, ¶m); - cerr << "FSA Feature: " << ff; - if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; - else cerr << " (no config parameters)\n"; - shared_ptr<FsaFeatureFunction> pf = fsa_ff_registry.Create(ff, param); - if (!pf) exit(1); - if (verbose_feature_functions) - cerr<<"State is "<<pf->state_bytes()<<" bytes for "<<pre<<"feature "<<ffp<<endl; - return pf; -} - -// print just the --long_opt names suitable for bash compgen -void print_options(std::ostream &out,po::options_description const& opts) { - typedef std::vector< shared_ptr<po::option_description> > Ds; - Ds const& ds=opts.options(); - out << '"'; - for (unsigned i=0;i<ds.size();++i) { - if (i) out<<' '; - out<<"--"<<ds[i]->long_name(); - } - out << '"'; -} - -void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* confp) { - po::variables_map &conf=*confp; - po::options_description opts("Configuration options"); - opts.add_options() - ("formalism,f",po::value<string>(),"Decoding formalism; values include SCFG, FST, PB, LexTrans (lexical translation model, also disc training), CSplit (compound splitting), Tagger (sequence labeling), LexAlign (alignment only, or EM training)") - ("input,i",po::value<string>()->default_value("-"),"Source file") - ("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") - ("weights,w",po::value<string>(),"Feature weights file") - ("prelm_weights",po::value<string>(),"Feature weights file for prelm_beam_prune. Requires --weights.") - ("prelm_copy_weights","use --weights as value for --prelm_weights.") - ("prelm_feature_function",po::value<vector<string> >()->composing(),"Additional feature functions for prelm pass only (in addition to the 0-state subset of feature_function") - ("keep_prelm_cube_order","DEPRECATED (always enabled). when forest rescoring with final models, use the edge ordering from the prelm pruning features*weights. only meaningful if --prelm_weights given. UNTESTED but assume that cube pruning gives a sensible result, and that 'good' (as tuned for bleu w/ prelm features) edges come first.") - ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)") - ("no_freeze_feature_set,Z", "Do not freeze feature set after reading feature weights file") - ("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)") - ("fsa_feature_function,A",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)") - ("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names() - ("list_feature_functions,L","List available feature functions") - ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") - ("k_best,k",po::value<int>(),"Extract the k best derivations") - ("unique_k_best,r", "Unique k-best translation list") - ("aligner,a", "Run as a word/phrase aligner (src & ref required)") - ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Intersection strategy for incorporating finite-state features; values include Cube_pruning, Full") - ("cubepruning_pop_limit,K",po::value<int>()->default_value(200), "Max number of pops from the candidate heap at each node") - ("goal",po::value<string>()->default_value("S"),"Goal symbol (SCFG & FST)") - ("scfg_extra_glue_grammar", po::value<string>(), "Extra glue grammar file (Glue grammars apply when i=0 but have no other span restrictions)") - ("scfg_no_hiero_glue_grammar,n", "No Hiero glue grammar (nb. by default the SCFG decoder adds Hiero glue rules)") - ("scfg_default_nt,d",po::value<string>()->default_value("X"),"Default non-terminal symbol in SCFG") - ("scfg_max_span_limit,S",po::value<int>()->default_value(10),"Maximum non-terminal span limit (except \"glue\" grammar)") - ("show_config", po::bool_switch(&show_config), "show contents of loaded -c config files.") - ("show_weights", po::bool_switch(&show_weights), "show effective feature weights") - ("show_joshua_visualization,J", "Produce output compatible with the Joshua visualization tools") - ("show_tree_structure", "Show the Viterbi derivation structure") - ("show_expected_length", "Show the expected translation length under the model") - ("show_partition,z", "Compute and show the partition (inside score)") - ("show_cfg_search_space", "Show the search space as a CFG") - ("show_features","Show the feature vector for the viterbi translation") - ("prelm_density_prune", po::value<double>(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") - ("density_prune", po::value<double>(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") - ("prelm_beam_prune", po::value<double>(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)") - ("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") - ("ctf_beam_widen", po::value<double>()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") - ("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse") - ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails") - ("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)") - ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") - ("promise_power",po::value<double>()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams. 0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit). note: for the same pop_limit, this gives more search error unless very close to 0 (recommend disabled; even 0.01 is slightly worse than 0) which is a bad sign and suggests this isn't doing a good job; further it's slightly slower to LM cube rescore with 0.01 compared to 0, as well as giving (very insignificantly) lower BLEU. TODO: test under more conditions, or try idea with different formula, or prob. cube beams.") - ("lexalign_use_null", "Support source-side null words in lexical translation") - ("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set") - ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") - ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") - ("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file") - ("graphviz","Show (constrained) translation forest in GraphViz format") - ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart") - ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart") - ("pb_max_distortion,D", po::value<int>()->default_value(4), "Phrase-based decoder: maximum distortion") - ("cll_gradient,G","Compute conditional log-likelihood gradient and write to STDOUT (src & ref required)") - ("crf_uniform_empirical", "If there are multple references use (i.e., lattice) a uniform distribution rather than posterior weighting a la EM") - ("get_oracle_forest,o", "Calculate rescored hypregraph using approximate BLEU scoring of rules") - ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") - ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") - ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") - ("forest_output,O",po::value<string>(),"Directory to write forests to") - ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc."); - ob.AddOptions(&opts); - po::options_description cfgo(cfg_options.description()); - cfg_options.AddOptions(&cfgo); - po::options_description clo("Command line options"); - clo.add_options() - ("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority") - ("help,h", "Print this help message and exit") - ("usage,u", po::value<string>(), "Describe a feature function type") - ("compgen", "Print just option names suitable for bash command line completion builtin 'compgen'") - ; - - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts).add(cfgo); - //add(opts).add(cfgo) - dcmdline_options.add(dconfig_options).add(clo); - argv_minus_to_underscore(argc,argv); - po::store(parse_command_line(argc, argv, dcmdline_options), conf); - if (conf.count("compgen")) { - print_options(cout,dcmdline_options); - cout << endl; - exit(0); - } - ShowBanner(); - if (conf.count("show_config")) // special handling needed because we only want to notify() once. - show_config=true; - if (conf.count("config")) { - typedef vector<string> Cs; - Cs cs=conf["config"].as<Cs>(); - for (int i=0;i<cs.size();++i) { - string cfg=cs[i]; - cerr << "Configuration file: " << cfg << endl; - ReadFile conff(cfg); - po::store(po::parse_config_file(*conff, dconfig_options), conf); - } - } - po::notify(conf); - if (show_config && !cfg_files.empty()) { - cerr<< "\nConfig files:\n\n"; - for (int i=0;i<cfg_files.size();++i) { - string cfg=cfg_files[i]; - cerr << "Configuration file: " << cfg << endl; - CopyFile(cfg,cerr); - cerr << "(end config "<<cfg<<"\n\n"; - } - cerr <<"Command line:"; - for (int i=0;i<argc;++i) - cerr<<" "<<argv[i]; - cerr << "\n\n"; - } - - - if (conf.count("list_feature_functions")) { - cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n"; - ff_registry.DisplayList(); - cerr << "Available FSA feature functions (specify with --fsa_feature_function):\n"; - fsa_ff_registry.DisplayList(); - cerr << endl; - exit(1); - } - - if (conf.count("usage")) { - ff_usage(str("usage",conf)); - exit(0); - } - if (conf.count("help")) { - cout << dcmdline_options << endl; - exit(0); - } - if (conf.count("help") || conf.count("formalism") == 0) { - cerr << dcmdline_options << endl; - exit(1); - } - - const string formalism = LowercaseString(str("formalism",conf)); - if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n"; - cerr << dcmdline_options << endl; - exit(1); - } - -} - -// TODO move out of cdec into some sampling decoder file -void SampleRecurse(const Hypergraph& hg, const vector<SampleSet<prob_t> >& ss, int n, vector<WordID>* out) { - const SampleSet<prob_t>& s = ss[n]; - int i = rng->SelectSample(s); - const Hypergraph::Edge& edge = hg.edges_[hg.nodes_[n].in_edges_[i]]; - vector<vector<WordID> > ants(edge.tail_nodes_.size()); - for (int j = 0; j < ants.size(); ++j) - SampleRecurse(hg, ss, edge.tail_nodes_[j], &ants[j]); - - vector<const vector<WordID>*> pants(ants.size()); - for (int j = 0; j < ants.size(); ++j) pants[j] = &ants[j]; - edge.rule_->ESubstitute(pants, out); -} - -struct SampleSort { - bool operator()(const pair<int,string>& a, const pair<int,string>& b) const { - return a.first > b.first; - } -}; - -// TODO move out of cdec into some sampling decoder file -void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) { - unordered_map<string, int, boost::hash<string> > m; - hg->PushWeightsToGoal(); - const int num_nodes = hg->nodes_.size(); - vector<SampleSet<prob_t> > ss(num_nodes); - for (int i = 0; i < num_nodes; ++i) { - SampleSet<prob_t>& s = ss[i]; - const vector<int>& in_edges = hg->nodes_[i].in_edges_; - for (int j = 0; j < in_edges.size(); ++j) { - s.add(hg->edges_[in_edges[j]].edge_prob_); - } - } - for (int i = 0; i < samples; ++i) { - vector<WordID> yield; - SampleRecurse(*hg, ss, hg->nodes_.size() - 1, &yield); - const string trans = TD::GetString(yield); - ++m[trans]; - } - vector<pair<int, string> > dist; - for (unordered_map<string, int, boost::hash<string> >::iterator i = m.begin(); - i != m.end(); ++i) { - dist.push_back(make_pair(i->second, i->first)); - } - sort(dist.begin(), dist.end(), SampleSort()); - if (k) { - for (int i = 0; i < k; ++i) - cout << dist[i].first << " ||| " << dist[i].second << endl; - } else { - cout << dist[0].second << endl; - } -} - - - -struct ELengthWeightFunction { - double operator()(const Hypergraph::Edge& e) const { - return e.rule_->ELength() - e.rule_->Arity(); - } -}; - - -struct TRPHash { - size_t operator()(const TRulePtr& o) const { return reinterpret_cast<size_t>(o.get()); } -}; -static void ExtractRulesDedupe(const Hypergraph& hg, ostream* os) { - static unordered_set<TRulePtr, TRPHash> written; - for (int i = 0; i < hg.edges_.size(); ++i) { - const TRulePtr& rule = hg.edges_[i].rule_; - if (written.insert(rule).second) { - (*os) << rule->AsString() << endl; - } - } -} void register_feature_functions(); -bool beam_param(po::variables_map const& conf,string const& name,double *val,bool scale_srclen=false,double srclen=1) -{ - if (conf.count(name)) { - *val=conf[name].as<double>()*(scale_srclen?srclen:1); - return true; - } - return false; -} - -bool prelm_weights_string(po::variables_map const& conf,string &s) -{ - if (conf.count("prelm_weights")) { - s=str("prelm_weights",conf); - return true; - } - if (conf.count("prelm_copy_weights")) { - s=str("weights",conf); - return true; - } - return false; -} - - -void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,WeightVector *weights=0,bool show_deriv=false) { - cerr << viterbi_stats(forest,name,true,show_tree,show_deriv); - if (show_features) { - cerr << name<<" features: "; -/* Hypergraph::Edge const* best=forest.ViterbiGoalEdge(); - if (!best) - cerr << name<<" has no goal edge."; - else - cerr<<best->feature_values_; -*/ - cerr << ViterbiFeatures(forest,weights); - cerr << endl; - } -} - -void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,DenseWeightVector const& feature_weights,bool sd=false) { - WeightVector fw(feature_weights); - forest_stats(forest,name,show_tree,show_features,&fw,sd); -} - - -void maybe_prune(Hypergraph &forest,po::variables_map const& conf,string nbeam,string ndensity,string forestname,double srclen) { - double beam_prune=0,density_prune=0; - bool use_beam_prune=beam_param(conf,nbeam,&beam_prune,conf.count("scale_prune_srclen"),srclen); - bool use_density_prune=beam_param(conf,ndensity,&density_prune); - if (use_beam_prune || use_density_prune) { - double presize=forest.edges_.size(); - vector<bool> preserve_mask,*pm=0; - if (conf.count("csplit_preserve_full_word")) { - preserve_mask.resize(forest.edges_.size()); - preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true; - pm=&preserve_mask; - } - forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1,conf["promise_power"].as<double>()); - if (!forestname.empty()) forestname=" "+forestname; - forest_stats(forest," Pruned "+forestname+" forest",false,false,0,false); - cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl; - } -} - -void show_models(po::variables_map const& conf,ModelSet &ms,char const* header) { - cerr<<header<<": "; - ms.show_features(cerr,cerr,conf.count("warn_0_weight")); -} - -template <class V> -bool store_conf(po::variables_map const& conf,std::string const& name,V *v) { - if (conf.count(name)) { - *v=conf[name].as<V>(); - return true; - } - return false; -} - - int main(int argc, char** argv) { register_feature_functions(); - po::variables_map conf; - OracleBleu oracle; - - InitCommandLine(argc, argv, oracle, &conf); - const bool write_gradient = conf.count("cll_gradient"); - const bool feature_expectations = conf.count("feature_expectations"); - if (write_gradient && feature_expectations) { - cerr << "You can only specify --gradient or --feature_expectations, not both!\n"; - exit(1); - } - const bool output_training_vector = (write_gradient || feature_expectations); - - boost::shared_ptr<Translator> translator; - const string formalism = LowercaseString(str("formalism",conf)); - const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word"); - if (csplit_preserve_full_word && - (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")||conf.count("prelm_beam_prune")||conf.count("prelm_density_prune")))) { - cerr << "--csplit_preserve_full_word should only be " - << "used with csplit AND --*_prune!\n"; - exit(1); - } - const bool csplit_output_plf = conf.count("csplit_output_plf"); - if (csplit_output_plf && formalism != "csplit") { - cerr << "--csplit_output_plf should only be used with csplit!\n"; - exit(1); - } + Decoder decoder(argc, argv); - const string input = str("input",conf); + const string input = decoder.GetConf()["input"].as<string>(); cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl; ReadFile in_read(input); istream *in = in_read.stream(); assert(*in); - // load feature weights (and possibly freeze feature set) - vector<double> feature_weights,prelm_feature_weights; - Weights w,prelm_w; - bool has_prelm_models = false; - if (conf.count("weights")) { - w.InitFromFile(str("weights",conf)); - feature_weights.resize(FD::NumFeats()); - w.InitVector(&feature_weights); - string plmw; - if (prelm_weights_string(conf,plmw)) { - has_prelm_models = true; - prelm_w.InitFromFile(plmw); - prelm_feature_weights.resize(FD::NumFeats()); - prelm_w.InitVector(&prelm_feature_weights); - if (show_weights) - cerr << "prelm_weights: " << WeightVector(prelm_feature_weights)<<endl; - } - if (show_weights) - cerr << "+LM weights: " << WeightVector(feature_weights)<<endl; - } - bool warn0=conf.count("warn_0_weight"); - bool freeze=!conf.count("no_freeze_feature_set"); - bool early_freeze=freeze && !warn0; - bool late_freeze=freeze && warn0; - if (early_freeze) { - cerr << "Freezing feature set (use --no_freeze_feature_set or --warn_0_weight to prevent)." << endl; - FD::Freeze(); // this means we can't see the feature names of not-weighted features - } - - // set up translation back end - if (formalism == "scfg") - translator.reset(new SCFGTranslator(conf)); - else if (formalism == "fst") - translator.reset(new FSTTranslator(conf)); - else if (formalism == "pb") - translator.reset(new PhraseBasedTranslator(conf)); - else if (formalism == "csplit") - translator.reset(new CompoundSplit(conf)); - else if (formalism == "lextrans") - translator.reset(new LexicalTrans(conf)); - else if (formalism == "lexalign") - translator.reset(new LexicalAlign(conf)); - else if (formalism == "tagger") - translator.reset(new Tagger(conf)); - else - assert(!"error"); - - // set up additional scoring features - vector<shared_ptr<FeatureFunction> > pffs,prelm_only_ffs; - vector<const FeatureFunction*> late_ffs,prelm_ffs; - if (conf.count("feature_function") > 0) { - vector<string> add_ffs; -// const vector<string>& add_ffs = conf["feature_function"].as<vector<string> >(); - store_conf(conf,"feature_function",&add_ffs); - for (int i = 0; i < add_ffs.size(); ++i) { - pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions)); - FeatureFunction const* p=pffs.back().get(); - late_ffs.push_back(p); - if (has_prelm_models) { - if (p->NumBytesContext()==0) - prelm_ffs.push_back(p); - else - cerr << "Excluding stateful feature from prelm pruning: "<<add_ffs[i]<<endl; - } - } - } - if (conf.count("prelm_feature_function") > 0) { - vector<string> add_ffs; - store_conf(conf,"prelm_feature_function",&add_ffs); -// const vector<string>& add_ffs = conf["prelm_feature_function"].as<vector<string> >(); - for (int i = 0; i < add_ffs.size(); ++i) { - prelm_only_ffs.push_back(make_ff(add_ffs[i],verbose_feature_functions,"prelm-only ")); - prelm_ffs.push_back(prelm_only_ffs.back().get()); - } - } - - vector<shared_ptr<FsaFeatureFunction> > fsa_ffs; - vector<string> fsa_names; - store_conf(conf,"fsa_feature_function",&fsa_names); - for (int i=0;i<fsa_names.size();++i) - fsa_ffs.push_back(make_fsa_ff(fsa_names[i],verbose_feature_functions,"FSA ")); - if (fsa_ffs.size()>1) { - //FIXME: support N fsa ffs. - cerr<<"Only the first fsa FF will be used (FIXME).\n"; - fsa_ffs.resize(1); - } - if (!fsa_ffs.empty()) { - cerr<<"FSA: "; - show_all_features(fsa_ffs,feature_weights,cerr,cerr,true,true); - } - - if (late_freeze) { - cerr << "Late freezing feature set (use --no_freeze_feature_set to prevent)." << endl; - FD::Freeze(); // this means we can't see the feature names of not-weighted features - } - - if (has_prelm_models) - cerr << "prelm rescoring with "<<prelm_ffs.size()<<" 0-state feature functions. +LM pass will use "<<late_ffs.size()<<" features (not counting rule features)."<<endl; - - ModelSet late_models(feature_weights, late_ffs); - show_models(conf,late_models,"late "); - ModelSet prelm_models(prelm_feature_weights, prelm_ffs); - if (has_prelm_models) - show_models(conf,prelm_models,"prelm "); - - int palg = 1; - if (LowercaseString(str("intersection_strategy",conf)) == "full") { - palg = 0; - cerr << "Using full intersection (no pruning).\n"; - } - int pop_limit=conf["cubepruning_pop_limit"].as<int>(); - const IntersectionConfiguration inter_conf(palg, pop_limit); - - const int sample_max_trans = conf.count("max_translation_sample") ? - conf["max_translation_sample"].as<int>() : 0; - if (sample_max_trans) - rng.reset(new RandomNumberGenerator<boost::mt19937>); - const bool aligner_mode = conf.count("aligner"); - const bool minimal_forests = conf.count("minimal_forests"); - const bool graphviz = conf.count("graphviz"); - const bool joshua_viz = conf.count("show_joshua_visualization"); - const bool encode_b64 = str("vector_format",conf) == "b64"; - const bool kbest = conf.count("k_best"); - const bool unique_kbest = conf.count("unique_k_best"); - const bool crf_uniform_empirical = conf.count("crf_uniform_empirical"); - const bool get_oracle_forest = conf.count("get_oracle_forest"); - - cfg_options.Validate(); - if (get_oracle_forest) - oracle.UseConf(conf); - - shared_ptr<WriteFile> extract_file; - if (conf.count("extract_rules")) - extract_file.reset(new WriteFile(str("extract_rules",conf))); - - int combine_size = conf["combine_size"].as<int>(); - if (combine_size < 1) combine_size = 1; - - SparseVector<prob_t> acc_vec; // accumulate gradient - double acc_obj = 0; // accumulate objective - int g_count = 0; // number of gradient pieces computed - int sent_id = -1; // line counter - + string buf; while(*in) { - NgramCache::Clear(); // clear ngram cache for remote LM (if used) - Timer::Summarize(); - ++sent_id; - string buf; getline(*in, buf); if (buf.empty()) continue; - map<string, string> sgml; - ProcessAndStripSGML(&buf, &sgml); - if (sgml.find("id") != sgml.end()) - sent_id = atoi(sgml["id"].c_str()); - - cerr << "\nINPUT: "; - if (buf.size() < 100) - cerr << buf << endl; - else { - size_t x = buf.rfind(" ", 100); - if (x == string::npos) x = 100; - cerr << buf.substr(0, x) << " ..." << endl; - } - cerr << " id = " << sent_id << endl; - string to_translate; - Lattice ref; - ParseTranslatorInputLattice(buf, &to_translate, &ref); - const unsigned srclen=NTokens(to_translate,' '); -//FIXME: should get the avg. or max source length of the input lattice (like Lattice::dist_(start,end)); but this is only used to scale beam parameters (optionally) anyway so fidelity isn't important. - const bool has_ref = ref.size() > 0; - SentenceMetadata smeta(sent_id, ref); - const bool hadoop_counters = (write_gradient); - Hypergraph forest; // -LM forest - translator->ProcessMarkupHints(sgml); - Timer t("Translation"); - const bool translation_successful = - translator->Translate(to_translate, &smeta, feature_weights, &forest); - //TODO: modify translator to incorporate all 0-state model scores immediately? - translator->SentenceComplete(); - if (!translation_successful) { - cerr << " NO PARSE FOUND.\n"; - if (hadoop_counters) - cerr << "reporter:counter:UserCounters,FParseFailed,1" << endl; - cout << endl << flush; - continue; - } - const bool show_tree_structure=conf.count("show_tree_structure"); - const bool show_features=conf.count("show_features"); - forest_stats(forest," -LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); - if (conf.count("show_expected_length")) { - const PRPair<double, double> res = - Inside<PRPair<double, double>, - PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest); - cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl; - } - if (conf.count("show_partition")) { - const prob_t z = Inside<prob_t, EdgeProb>(forest); - cerr << " -LM partition log(Z): " << log(z) << endl; - } - if (extract_file) - ExtractRulesDedupe(forest, extract_file->stream()); - - if (has_prelm_models) { - Timer t("prelm rescoring"); - forest.Reweight(prelm_feature_weights); - Hypergraph prelm_forest; - ApplyModelSet(forest, - smeta, - prelm_models, - inter_conf, // this is now reduced to exhaustive if all are stateless - &prelm_forest); - forest.swap(prelm_forest); - forest.Reweight(prelm_feature_weights); //FIXME: why the reweighting? here and below. maybe in case we already had a featval for that id and ApplyModelSet only adds prob, doesn't recompute it? - forest_stats(forest," prelm forest",show_tree_structure,show_features,prelm_feature_weights,oracle.show_derivation); - } - - maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen); - - cfg_options.maybe_output_source(forest); - - bool has_late_models = !late_models.empty(); - if (has_late_models) { - Timer t("Forest rescoring:"); - forest.Reweight(feature_weights); - Hypergraph lm_forest; - ApplyModelSet(forest, - smeta, - late_models, - inter_conf, - &lm_forest); - forest.swap(lm_forest); - forest.Reweight(feature_weights); - forest_stats(forest," +LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); - } - - maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen); - - HgCFG hgcfg(forest); - cfg_options.prepare(hgcfg); - if (!fsa_ffs.empty()) { - Timer t("Target FSA rescoring:"); - if (!has_late_models) - forest.Reweight(feature_weights); - Hypergraph fsa_forest; - assert(fsa_ffs.size()==1); - ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit); - cerr << "FSA rescoring with "<<cfg<<" "<<fsa_ffs[0]->describe()<<endl; - ApplyFsaModels(hgcfg,smeta,*fsa_ffs[0],feature_weights,cfg,&fsa_forest); - forest.swap(fsa_forest); - forest.Reweight(feature_weights); - forest_stats(forest," +FSA forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); - } - - /*Oracle Rescoring*/ - if(get_oracle_forest) { - Oracle o=oracle.ComputeOracle(smeta,&forest,FeatureVector(feature_weights),10,conf["forest_output"].as<std::string>()); - cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; - cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; - o.hope.Print(cerr," +Oracle BLEU"); - o.fear.Print(cerr," -Oracle BLEU"); - //Add 1-best translation (trans) to psuedo-doc vectors - oracle.IncludeLastScore(&cerr); - } - - - if (conf.count("forest_output") && !has_ref) { - ForestWriter writer(str("forest_output",conf), sent_id); - if (FileExists(writer.fname_)) { - cerr << " Unioning...\n"; - Hypergraph new_hg; - { - ReadFile rf(writer.fname_); - bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); - assert(succeeded); - } - new_hg.Union(forest); - bool succeeded = writer.Write(new_hg, minimal_forests); - assert(succeeded); - } else { - bool succeeded = writer.Write(forest, minimal_forests); - assert(succeeded); - } - } - - if (sample_max_trans) { - MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0); - } else { - if (kbest) { - //TODO: does this work properly? - oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-"); - } else if (csplit_output_plf) { - cout << HypergraphIO::AsPLF(forest, false) << endl; - } else { - if (!graphviz && !has_ref && !joshua_viz) { - vector<WordID> trans; - ViterbiESentence(forest, &trans); - cout << TD::GetString(trans) << endl << flush; - } - if (joshua_viz) { - cout << sent_id << " ||| " << JoshuaVisualizationString(forest) << " ||| 1.0 ||| " << -1.0 << endl << flush; - } - } - } - - const int max_trans_beam_size = conf.count("max_translation_beam") ? - conf["max_translation_beam"].as<int>() : 0; - if (max_trans_beam_size) { - Hack::MaxTrans(forest, max_trans_beam_size); - continue; - } - - if (graphviz && !has_ref) forest.PrintGraphviz(); - - // the following are only used if write_gradient is true! - SparseVector<prob_t> full_exp, ref_exp, gradient; - double log_z = 0, log_ref_z = 0; - if (write_gradient) { - const prob_t z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &full_exp); - log_z = log(z); - full_exp /= z; - } - if (conf.count("show_cfg_search_space")) - HypergraphIO::WriteAsCFG(forest); - if (has_ref) { - if (HG::Intersect(ref, &forest)) { - cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; - cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl; - if (crf_uniform_empirical) { - cerr << " USING UNIFORM WEIGHTS\n"; - for (int i = 0; i < forest.edges_.size(); ++i) - forest.edges_[i].edge_prob_=prob_t::One(); - } else { - forest.Reweight(feature_weights); - cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; - } - if (hadoop_counters) - cerr << "reporter:counter:UserCounters,SentencePairsParsed,1" << endl; - if (conf.count("show_partition")) { - const prob_t z = Inside<prob_t, EdgeProb>(forest); - cerr << " Contst. partition log(Z): " << log(z) << endl; - } - //DumpKBest(sent_id, forest, 1000); - if (conf.count("forest_output")) { - ForestWriter writer(str("forest_output",conf), sent_id); - if (FileExists(writer.fname_)) { - cerr << " Unioning...\n"; - Hypergraph new_hg; - { - ReadFile rf(writer.fname_); - bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); - assert(succeeded); - } - new_hg.Union(forest); - bool succeeded = writer.Write(new_hg, minimal_forests); - assert(succeeded); - } else { - bool succeeded = writer.Write(forest, minimal_forests); - assert(succeeded); - } - } - if (aligner_mode && !output_training_vector) - AlignerTools::WriteAlignment(smeta.GetSourceLattice(), smeta.GetReference(), forest, &cout); - if (write_gradient) { - const prob_t ref_z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp); - ref_exp /= ref_z; - if (crf_uniform_empirical) { - log_ref_z = ref_exp.dot(feature_weights); - } else { - log_ref_z = log(ref_z); - } - //cerr << " MODEL LOG Z: " << log_z << endl; - //cerr << " EMPIRICAL LOG Z: " << log_ref_z << endl; - if ((log_z - log_ref_z) < kMINUS_EPSILON) { - cerr << "DIFF. ERR! log_z < log_ref_z: " << log_z << " " << log_ref_z << endl; - exit(1); - } - assert(!isnan(log_ref_z)); - ref_exp -= full_exp; - acc_vec += ref_exp; - acc_obj += (log_z - log_ref_z); - } - if (feature_expectations) { - const prob_t z = - InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp); - ref_exp /= z; - acc_obj += log(z); - acc_vec += ref_exp; - } - - if (output_training_vector) { - acc_vec.erase(0); - ++g_count; - if (g_count % combine_size == 0) { - if (encode_b64) { - cout << "0\t"; - SparseVector<double> dav; ConvertSV(acc_vec, &dav); - B64::Encode(acc_obj, dav, &cout); - cout << endl << flush; - } else { - cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush; - } - acc_vec.clear(); - acc_obj = 0; - } - } - if (conf.count("graphviz")) forest.PrintGraphviz(); - } else { - cerr << " REFERENCE UNREACHABLE.\n"; - if (write_gradient) { - if (hadoop_counters) - cerr << "reporter:counter:UserCounters,EFParseFailed,1" << endl; - cout << endl << flush; - } - } - } - } - if (output_training_vector && !acc_vec.empty()) { - if (encode_b64) { - cout << "0\t"; - SparseVector<double> dav; ConvertSV(acc_vec, &dav); - B64::Encode(acc_obj, dav, &cout); - cout << endl << flush; - } else { - cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush; - } + decoder.Decode(buf); } - exit(0); // maybe this will save some destruction overhead. or g++ without cxx_atexit needed? return 0; } diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 98d4711f..84ba19fa 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -13,6 +13,12 @@ #include "ff_register.h" void register_feature_functions() { + static bool registered = false; + if (registered) { + assert(!"register_feature_functions() called twice!"); + } + registered = true; + //TODO: these are worthless example target FSA ffs. remove later RegisterFsaImpl<SameFirstLetter>(true); RegisterFsaImpl<LongerThanPrev>(true); diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 196ffa46..f90e6bc0 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -1,25 +1,94 @@ #include "decoder.h" +#include <tr1/unordered_map> #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> -#include "ff_factory.h" -#include "cfg_options.h" +#include "sampler.h" #include "stringlib.h" +#include "weights.h" +#include "filelib.h" +#include "fdict.h" +#include "timing_stats.h" + +#include "translator.h" +#include "phrasebased_translator.h" +#include "tagger.h" +#include "lextrans.h" +#include "lexalign.h" +#include "csplit.h" + +#include "lattice.h" #include "hg.h" +#include "sentence_metadata.h" +#include "hg_intersect.h" + +#include "apply_fsa_models.h" +#include "oracle_bleu.h" +#include "apply_models.h" +#include "ff.h" +#include "ff_factory.h" +#include "cfg_options.h" +#include "viterbi.h" +#include "kbest.h" +#include "inside_outside.h" +#include "exp_semiring.h" +#include "sentence_metadata.h" +#include "hg_cfg.h" + +#include "forest_writer.h" // TODO this section should probably be handled by an Observer +#include "hg_io.h" +#include "aligner.h" +static const double kMINUS_EPSILON = -1e-6; // don't be too strict using namespace std; +using namespace std::tr1; using boost::shared_ptr; namespace po = boost::program_options; +static bool verbose_feature_functions=true; + +namespace Hack { void MaxTrans(const Hypergraph& in, int beam_size); } +namespace NgramCache { void Clear(); } + +void DecoderObserver::NotifySourceParseFailure(const SentenceMetadata&) {} +void DecoderObserver::NotifyTranslationForest(const SentenceMetadata&, Hypergraph*) {} +void DecoderObserver::NotifyAlignmentFailure(const SentenceMetadata&) {} +void DecoderObserver::NotifyAlignmentForest(const SentenceMetadata&, Hypergraph*) {} +void DecoderObserver::NotifyDecodingComplete(const SentenceMetadata&) {} + +struct ELengthWeightFunction { + double operator()(const Hypergraph::Edge& e) const { + return e.rule_->ELength() - e.rule_->Arity(); + } +}; inline void ShowBanner() { cerr << "cdec v1.0 (c) 2009-2010 by Chris Dyer\n"; } +inline void show_models(po::variables_map const& conf,ModelSet &ms,char const* header) { + cerr<<header<<": "; + ms.show_features(cerr,cerr,conf.count("warn_0_weight")); +} + inline string str(char const* name,po::variables_map const& conf) { return conf[name].as<string>(); } +inline bool prelm_weights_string(po::variables_map const& conf,string &s) { + if (conf.count("prelm_weights")) { + s=str("prelm_weights",conf); + return true; + } + if (conf.count("prelm_copy_weights")) { + s=str("weights",conf); + return true; + } + return false; +} + + + // print just the --long_opt names suitable for bash compgen inline void print_options(std::ostream &out,po::options_description const& opts) { typedef std::vector< shared_ptr<po::option_description> > Ds; @@ -32,22 +101,214 @@ inline void print_options(std::ostream &out,po::options_description const& opts) out << '"'; } +template <class V> +inline bool store_conf(po::variables_map const& conf,std::string const& name,V *v) { + if (conf.count(name)) { + *v=conf[name].as<V>(); + return true; + } + return false; +} + +inline shared_ptr<FeatureFunction> make_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") { + string ff, param; + SplitCommandAndParam(ffp, &ff, ¶m); + cerr << pre << "feature: " << ff; + if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; + else cerr << " (no config parameters)\n"; + shared_ptr<FeatureFunction> pf = ff_registry.Create(ff, param); + if (!pf) exit(1); + int nbyte=pf->NumBytesContext(); + if (verbose_feature_functions) + cerr<<"State is "<<nbyte<<" bytes for "<<pre<<"feature "<<ffp<<endl; + return pf; +} + +inline shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") { + string ff, param; + SplitCommandAndParam(ffp, &ff, ¶m); + cerr << "FSA Feature: " << ff; + if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; + else cerr << " (no config parameters)\n"; + shared_ptr<FsaFeatureFunction> pf = fsa_ff_registry.Create(ff, param); + if (!pf) exit(1); + if (verbose_feature_functions) + cerr<<"State is "<<pf->state_bytes()<<" bytes for "<<pre<<"feature "<<ffp<<endl; + return pf; +} + struct DecoderImpl { - DecoderImpl(int argc, char** argv, istream* cfg); - bool Decode(const string& input) { - return false; + DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg); + ~DecoderImpl(); + bool Decode(const string& input, DecoderObserver*); + void SetWeights(const vector<double>& weights) { } - bool DecodeProduceHypergraph(const string& input, Hypergraph* hg) { + + void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,WeightVector *weights=0,bool show_deriv=false) { + cerr << viterbi_stats(forest,name,true,show_tree,show_deriv); + if (show_features) { + cerr << name<<" features: "; +/* Hypergraph::Edge const* best=forest.ViterbiGoalEdge(); + if (!best) + cerr << name<<" has no goal edge."; + else + cerr<<best->feature_values_; +*/ + cerr << ViterbiFeatures(forest,weights); + cerr << endl; + } + } + + void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,DenseWeightVector const& feature_weights, bool sd=false) { + WeightVector fw(feature_weights); + forest_stats(forest,name,show_tree,show_features,&fw,sd); + } + + bool beam_param(po::variables_map const& conf,string const& name,double *val,bool scale_srclen=false,double srclen=1) { + if (conf.count(name)) { + *val=conf[name].as<double>()*(scale_srclen?srclen:1); + return true; + } return false; } - void SetWeights(const vector<double>& weights) { + + void maybe_prune(Hypergraph &forest,po::variables_map const& conf,string nbeam,string ndensity,string forestname,double srclen) { + double beam_prune=0,density_prune=0; + bool use_beam_prune=beam_param(conf,nbeam,&beam_prune,conf.count("scale_prune_srclen"),srclen); + bool use_density_prune=beam_param(conf,ndensity,&density_prune); + if (use_beam_prune || use_density_prune) { + double presize=forest.edges_.size(); + vector<bool> preserve_mask,*pm=0; + if (conf.count("csplit_preserve_full_word")) { + preserve_mask.resize(forest.edges_.size()); + preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true; + pm=&preserve_mask; + } + forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1,conf["promise_power"].as<double>()); + if (!forestname.empty()) forestname=" "+forestname; + forest_stats(forest," Pruned "+forestname+" forest",false,false,0,false); + cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl; + } + } + + void SampleRecurse(const Hypergraph& hg, const vector<SampleSet<prob_t> >& ss, int n, vector<WordID>* out) { + const SampleSet<prob_t>& s = ss[n]; + int i = rng->SelectSample(s); + const Hypergraph::Edge& edge = hg.edges_[hg.nodes_[n].in_edges_[i]]; + vector<vector<WordID> > ants(edge.tail_nodes_.size()); + for (int j = 0; j < ants.size(); ++j) + SampleRecurse(hg, ss, edge.tail_nodes_[j], &ants[j]); + + vector<const vector<WordID>*> pants(ants.size()); + for (int j = 0; j < ants.size(); ++j) pants[j] = &ants[j]; + edge.rule_->ESubstitute(pants, out); + } + + struct SampleSort { + bool operator()(const pair<int,string>& a, const pair<int,string>& b) const { + return a.first > b.first; + } + }; + + // TODO this should be handled by an Observer + void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) { + unordered_map<string, int, boost::hash<string> > m; + hg->PushWeightsToGoal(); + const int num_nodes = hg->nodes_.size(); + vector<SampleSet<prob_t> > ss(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + SampleSet<prob_t>& s = ss[i]; + const vector<int>& in_edges = hg->nodes_[i].in_edges_; + for (int j = 0; j < in_edges.size(); ++j) { + s.add(hg->edges_[in_edges[j]].edge_prob_); + } + } + for (int i = 0; i < samples; ++i) { + vector<WordID> yield; + SampleRecurse(*hg, ss, hg->nodes_.size() - 1, &yield); + const string trans = TD::GetString(yield); + ++m[trans]; + } + vector<pair<int, string> > dist; + for (unordered_map<string, int, boost::hash<string> >::iterator i = m.begin(); + i != m.end(); ++i) { + dist.push_back(make_pair(i->second, i->first)); + } + sort(dist.begin(), dist.end(), SampleSort()); + if (k) { + for (int i = 0; i < k; ++i) + cout << dist[i].first << " ||| " << dist[i].second << endl; + } else { + cout << dist[0].second << endl; + } } - po::variables_map conf; + void ParseTranslatorInputLattice(const string& line, string* input, Lattice* ref) { + string sref; + ParseTranslatorInput(line, input, &sref); + if (sref.size() > 0) { + assert(ref); + LatticeTools::ConvertTextOrPLF(sref, ref); + } + } + + po::variables_map& conf; + OracleBleu oracle; CFGOptions cfg_options; + string formalism; + shared_ptr<Translator> translator; + vector<double> feature_weights,prelm_feature_weights; + Weights w,prelm_w; + vector<shared_ptr<FeatureFunction> > pffs,prelm_only_ffs; + vector<const FeatureFunction*> late_ffs,prelm_ffs; + vector<shared_ptr<FsaFeatureFunction> > fsa_ffs; + vector<string> fsa_names; + ModelSet* late_models, *prelm_models; + IntersectionConfiguration* inter_conf; + shared_ptr<RandomNumberGenerator<boost::mt19937> > rng; + int sample_max_trans; + bool aligner_mode; + bool minimal_forests; + bool graphviz; + bool joshua_viz; + bool encode_b64; + bool kbest; + bool unique_kbest; + bool crf_uniform_empirical; + bool get_oracle_forest; + shared_ptr<WriteFile> extract_file; + int combine_size; + int sent_id; + SparseVector<prob_t> acc_vec; // accumulate gradient + double acc_obj; // accumulate objective + int g_count; // number of gradient pieces computed + bool has_prelm_models; + int pop_limit; + bool csplit_output_plf; + bool write_gradient; // TODO Observer + bool feature_expectations; // TODO Observer + bool output_training_vector; // TODO Observer + + static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) { + for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it) + trg->set_value(it->first, it->second); + } }; -DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) { +DecoderImpl::~DecoderImpl() { + if (output_training_vector && !acc_vec.empty()) { + if (encode_b64) { + cout << "0\t"; + SparseVector<double> dav; ConvertSV(acc_vec, &dav); + B64::Encode(acc_obj, dav, &cout); + cout << endl << flush; + } else { + cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush; + } + } +} + +DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg) : conf(conf) { if (cfg) { if (argc || argv) { cerr << "DecoderImpl() can only take a file or command line options, not both\n"; exit(1); } } bool show_config; bool show_weights; @@ -64,7 +325,7 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) { ("prelm_feature_function",po::value<vector<string> >()->composing(),"Additional feature functions for prelm pass only (in addition to the 0-state subset of feature_function") ("keep_prelm_cube_order","DEPRECATED (always enabled). when forest rescoring with final models, use the edge ordering from the prelm pruning features*weights. only meaningful if --prelm_weights given. UNTESTED but assume that cube pruning gives a sensible result, and that 'good' (as tuned for bleu w/ prelm features) edges come first.") ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)") - ("no_freeze_feature_set,Z", "Do not freeze feature set after reading feature weights file") + ("freeze_feature_set,Z", "Freeze feature set after reading feature weights file") ("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)") ("fsa_feature_function,A",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)") ("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names() @@ -142,7 +403,7 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) { } if (conf.count("show_config")) // special handling needed because we only want to notify() once. show_config=true; - if (conf.count("config") || cfg) { + if (conf.count("config") && !cfg) { typedef vector<string> Cs; Cs cs=conf["config"].as<Cs>(); for (int i=0;i<cs.size();++i) { @@ -151,8 +412,8 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) { ReadFile conff(cfg); po::store(po::parse_config_file(*conff, dconfig_options), conf); } - if (cfg) po::store(po::parse_config_file(*cfg, dconfig_options), conf); } + if (cfg) po::store(po::parse_config_file(*cfg, dconfig_options), conf); po::notify(conf); if (show_config && !cfg_files.empty()) { cerr<< "\nConfig files:\n\n"; @@ -171,9 +432,9 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) { if (conf.count("list_feature_functions")) { cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n"; - // ff_registry.DisplayList(); //TODO + ff_registry.DisplayList(); //TODO cerr << "Available FSA feature functions (specify with --fsa_feature_function):\n"; - // fsa_ff_registry.DisplayList(); // TODO + fsa_ff_registry.DisplayList(); // TODO cerr << endl; exit(1); } @@ -191,38 +452,449 @@ DecoderImpl::DecoderImpl(int argc, char** argv, istream* cfg) { exit(1); } - const string formalism = LowercaseString(str("formalism",conf)); + formalism = LowercaseString(str("formalism",conf)); if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") { cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } -} -Decoder::Decoder(istream* cfg) { - pimpl_.reset(new DecoderImpl(0,0,cfg)); -} + write_gradient = conf.count("cll_gradient"); + feature_expectations = conf.count("feature_expectations"); + if (write_gradient && feature_expectations) { + cerr << "You can only specify --gradient or --feature_expectations, not both!\n"; + exit(1); + } + output_training_vector = (write_gradient || feature_expectations); -Decoder::Decoder(int argc, char** argv) { - pimpl_.reset(new DecoderImpl(argc, argv, 0)); -} + const string formalism = LowercaseString(str("formalism",conf)); + const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word"); + if (csplit_preserve_full_word && + (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")||conf.count("prelm_beam_prune")||conf.count("prelm_density_prune")))) { + cerr << "--csplit_preserve_full_word should only be " + << "used with csplit AND --*_prune!\n"; + exit(1); + } + csplit_output_plf = conf.count("csplit_output_plf"); + if (csplit_output_plf && formalism != "csplit") { + cerr << "--csplit_output_plf should only be used with csplit!\n"; + exit(1); + } -Decoder::~Decoder() {} + // load feature weights (and possibly freeze feature set) + has_prelm_models = false; + if (conf.count("weights")) { + w.InitFromFile(str("weights",conf)); + feature_weights.resize(FD::NumFeats()); + w.InitVector(&feature_weights); + string plmw; + if (prelm_weights_string(conf,plmw)) { + has_prelm_models = true; + prelm_w.InitFromFile(plmw); + prelm_feature_weights.resize(FD::NumFeats()); + prelm_w.InitVector(&prelm_feature_weights); + if (show_weights) + cerr << "prelm_weights: " << WeightVector(prelm_feature_weights)<<endl; + } + if (show_weights) + cerr << "+LM weights: " << WeightVector(feature_weights)<<endl; + } + bool warn0=conf.count("warn_0_weight"); + bool freeze=conf.count("freeze_feature_set"); + bool early_freeze=freeze && !warn0; + bool late_freeze=freeze && warn0; + if (early_freeze) { + cerr << "Freezing feature set" << endl; + FD::Freeze(); // this means we can't see the feature names of not-weighted features + } -bool Decoder::Decode(const string& input) { - return pimpl_->Decode(input); -} + // set up translation back end + if (formalism == "scfg") + translator.reset(new SCFGTranslator(conf)); + else if (formalism == "fst") + translator.reset(new FSTTranslator(conf)); + else if (formalism == "pb") + translator.reset(new PhraseBasedTranslator(conf)); + else if (formalism == "csplit") + translator.reset(new CompoundSplit(conf)); + else if (formalism == "lextrans") + translator.reset(new LexicalTrans(conf)); + else if (formalism == "lexalign") + translator.reset(new LexicalAlign(conf)); + else if (formalism == "tagger") + translator.reset(new Tagger(conf)); + else + assert(!"error"); + + // set up additional scoring features + if (conf.count("feature_function") > 0) { + vector<string> add_ffs; +// const vector<string>& add_ffs = conf["feature_function"].as<vector<string> >(); + store_conf(conf,"feature_function",&add_ffs); + for (int i = 0; i < add_ffs.size(); ++i) { + pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions)); + FeatureFunction const* p=pffs.back().get(); + late_ffs.push_back(p); + if (has_prelm_models) { + if (p->NumBytesContext()==0) + prelm_ffs.push_back(p); + else + cerr << "Excluding stateful feature from prelm pruning: "<<add_ffs[i]<<endl; + } + } + } + if (conf.count("prelm_feature_function") > 0) { + vector<string> add_ffs; + store_conf(conf,"prelm_feature_function",&add_ffs); +// const vector<string>& add_ffs = conf["prelm_feature_function"].as<vector<string> >(); + for (int i = 0; i < add_ffs.size(); ++i) { + prelm_only_ffs.push_back(make_ff(add_ffs[i],verbose_feature_functions,"prelm-only ")); + prelm_ffs.push_back(prelm_only_ffs.back().get()); + } + } -bool Decoder::DecodeProduceHypergraph(const string& input, Hypergraph* hg) { - return pimpl_->DecodeProduceHypergraph(input, hg); + store_conf(conf,"fsa_feature_function",&fsa_names); + for (int i=0;i<fsa_names.size();++i) + fsa_ffs.push_back(make_fsa_ff(fsa_names[i],verbose_feature_functions,"FSA ")); + if (fsa_ffs.size()>1) { + //FIXME: support N fsa ffs. + cerr<<"Only the first fsa FF will be used (FIXME).\n"; + fsa_ffs.resize(1); + } + if (!fsa_ffs.empty()) { + cerr<<"FSA: "; + show_all_features(fsa_ffs,feature_weights,cerr,cerr,true,true); + } + + if (late_freeze) { + cerr << "Late freezing feature set (use --no_freeze_feature_set to prevent)." << endl; + FD::Freeze(); // this means we can't see the feature names of not-weighted features + } + + if (has_prelm_models) + cerr << "prelm rescoring with "<<prelm_ffs.size()<<" 0-state feature functions. +LM pass will use "<<late_ffs.size()<<" features (not counting rule features)."<<endl; + + late_models = new ModelSet(feature_weights, late_ffs); + show_models(conf,*late_models,"late "); + prelm_models = new ModelSet(prelm_feature_weights, prelm_ffs); + if (has_prelm_models) + show_models(conf,*prelm_models,"prelm "); + + int palg = 1; + if (LowercaseString(str("intersection_strategy",conf)) == "full") { + palg = 0; + cerr << "Using full intersection (no pruning).\n"; + } + pop_limit=conf["cubepruning_pop_limit"].as<int>(); + inter_conf = new IntersectionConfiguration(palg, pop_limit); + + sample_max_trans = conf.count("max_translation_sample") ? + conf["max_translation_sample"].as<int>() : 0; + if (sample_max_trans) + rng.reset(new RandomNumberGenerator<boost::mt19937>); + aligner_mode = conf.count("aligner"); + minimal_forests = conf.count("minimal_forests"); + graphviz = conf.count("graphviz"); + joshua_viz = conf.count("show_joshua_visualization"); + encode_b64 = str("vector_format",conf) == "b64"; + kbest = conf.count("k_best"); + unique_kbest = conf.count("unique_k_best"); + crf_uniform_empirical = conf.count("crf_uniform_empirical"); + get_oracle_forest = conf.count("get_oracle_forest"); + + cfg_options.Validate(); + + if (conf.count("extract_rules")) + extract_file.reset(new WriteFile(str("extract_rules",conf))); + + combine_size = conf["combine_size"].as<int>(); + if (combine_size < 1) combine_size = 1; + sent_id = -1; + acc_obj = 0; // accumulate objective + g_count = 0; // number of gradient pieces computed } -void Decoder::SetWeights(const vector<double>& weights) { - pimpl_->SetWeights(weights); +Decoder::Decoder(istream* cfg) { pimpl_.reset(new DecoderImpl(conf,0,0,cfg)); } +Decoder::Decoder(int argc, char** argv) { pimpl_.reset(new DecoderImpl(conf,argc, argv, 0)); } +Decoder::~Decoder() {} +bool Decoder::Decode(const string& input, DecoderObserver* o) { + bool del = false; + if (!o) { o = new DecoderObserver; del = true; } + const bool res = pimpl_->Decode(input, o); + if (del) delete o; + return res; } +void Decoder::SetWeights(const vector<double>& weights) { pimpl_->SetWeights(weights); } + + +bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { + string buf = input; + NgramCache::Clear(); // clear ngram cache for remote LM (if used) + Timer::Summarize(); + ++sent_id; + map<string, string> sgml; + ProcessAndStripSGML(&buf, &sgml); + if (sgml.find("id") != sgml.end()) + sent_id = atoi(sgml["id"].c_str()); + + cerr << "\nINPUT: "; + if (buf.size() < 100) + cerr << buf << endl; + else { + size_t x = buf.rfind(" ", 100); + if (x == string::npos) x = 100; + cerr << buf.substr(0, x) << " ..." << endl; + } + cerr << " id = " << sent_id << endl; + string to_translate; + Lattice ref; + ParseTranslatorInputLattice(buf, &to_translate, &ref); + const unsigned srclen=NTokens(to_translate,' '); +//FIXME: should get the avg. or max source length of the input lattice (like Lattice::dist_(start,end)); but this is only used to scale beam parameters (optionally) anyway so fidelity isn't important. + const bool has_ref = ref.size() > 0; + SentenceMetadata smeta(sent_id, ref); + Hypergraph forest; // -LM forest + translator->ProcessMarkupHints(sgml); + Timer t("Translation"); + const bool translation_successful = + translator->Translate(to_translate, &smeta, feature_weights, &forest); + //TODO: modify translator to incorporate all 0-state model scores immediately? + translator->SentenceComplete(); + + if (!translation_successful) { + cerr << " NO PARSE FOUND.\n"; + o->NotifySourceParseFailure(smeta); + o->NotifyDecodingComplete(smeta); + return false; + } + + const bool show_tree_structure=conf.count("show_tree_structure"); + const bool show_features=conf.count("show_features"); + forest_stats(forest," -LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); + if (conf.count("show_expected_length")) { + const PRPair<double, double> res = + Inside<PRPair<double, double>, + PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest); + cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl; + } + if (conf.count("show_partition")) { + const prob_t z = Inside<prob_t, EdgeProb>(forest); + cerr << " -LM partition log(Z): " << log(z) << endl; + } + + if (has_prelm_models) { + Timer t("prelm rescoring"); + forest.Reweight(prelm_feature_weights); + Hypergraph prelm_forest; + ApplyModelSet(forest, + smeta, + *prelm_models, + *inter_conf, // this is now reduced to exhaustive if all are stateless + &prelm_forest); + forest.swap(prelm_forest); + forest.Reweight(prelm_feature_weights); //FIXME: why the reweighting? here and below. maybe in case we already had a featval for that id and ApplyModelSet only adds prob, doesn't recompute it? + forest_stats(forest," prelm forest",show_tree_structure,show_features,prelm_feature_weights,oracle.show_derivation); + } + + maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen); + + cfg_options.maybe_output_source(forest); + + bool has_late_models = !late_models->empty(); + if (has_late_models) { + Timer t("Forest rescoring:"); + forest.Reweight(feature_weights); + Hypergraph lm_forest; + ApplyModelSet(forest, + smeta, + *late_models, + *inter_conf, + &lm_forest); + forest.swap(lm_forest); + forest.Reweight(feature_weights); + forest_stats(forest," +LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); + } + + maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen); + + HgCFG hgcfg(forest); + cfg_options.prepare(hgcfg); + + if (!fsa_ffs.empty()) { + Timer t("Target FSA rescoring:"); + if (!has_late_models) + forest.Reweight(feature_weights); + Hypergraph fsa_forest; + assert(fsa_ffs.size()==1); + ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit); + cerr << "FSA rescoring with "<<cfg<<" "<<fsa_ffs[0]->describe()<<endl; + ApplyFsaModels(hgcfg,smeta,*fsa_ffs[0],feature_weights,cfg,&fsa_forest); + forest.swap(fsa_forest); + forest.Reweight(feature_weights); + forest_stats(forest," +FSA forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); + } + + // Oracle Rescoring + if(get_oracle_forest) { + Oracle oc=oracle.ComputeOracle(smeta,&forest,FeatureVector(feature_weights),10,conf["forest_output"].as<std::string>()); + cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; + oc.hope.Print(cerr," +Oracle BLEU"); + oc.fear.Print(cerr," -Oracle BLEU"); + //Add 1-best translation (trans) to psuedo-doc vectors + oracle.IncludeLastScore(&cerr); + } + + // TODO I think this should probably be handled by an Observer + if (conf.count("forest_output") && !has_ref) { + ForestWriter writer(str("forest_output",conf), sent_id); + if (FileExists(writer.fname_)) { + cerr << " Unioning...\n"; + Hypergraph new_hg; + { + ReadFile rf(writer.fname_); + bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); + assert(succeeded); + } + new_hg.Union(forest); + bool succeeded = writer.Write(new_hg, minimal_forests); + assert(succeeded); + } else { + bool succeeded = writer.Write(forest, minimal_forests); + assert(succeeded); + } + } + + // TODO I think this should probably be handled by an Observer + if (sample_max_trans) { + MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0); + } else { + if (kbest) { + //TODO: does this work properly? + oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-"); + } else if (csplit_output_plf) { + cout << HypergraphIO::AsPLF(forest, false) << endl; + } else { + if (!graphviz && !has_ref && !joshua_viz) { + vector<WordID> trans; + ViterbiESentence(forest, &trans); + cout << TD::GetString(trans) << endl << flush; + } + if (joshua_viz) { + cout << sent_id << " ||| " << JoshuaVisualizationString(forest) << " ||| 1.0 ||| " << -1.0 << endl << flush; + } + } + } + + // TODO this should be handled by an Observer + const int max_trans_beam_size = conf.count("max_translation_beam") ? + conf["max_translation_beam"].as<int>() : 0; + if (max_trans_beam_size) { + Hack::MaxTrans(forest, max_trans_beam_size); + return true; + } + + // TODO this should be handled by an Observer + if (graphviz && !has_ref) forest.PrintGraphviz(); -void InitCommandLine(int argc, char** argv, po::variables_map* confp) { - po::variables_map &conf=*confp; + // the following are only used if write_gradient is true! + SparseVector<prob_t> full_exp, ref_exp, gradient; + double log_z = 0, log_ref_z = 0; + if (write_gradient) { + const prob_t z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &full_exp); + log_z = log(z); + full_exp /= z; + } + if (conf.count("show_cfg_search_space")) + HypergraphIO::WriteAsCFG(forest); + if (has_ref) { + if (HG::Intersect(ref, &forest)) { + cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl; + if (crf_uniform_empirical) { + cerr << " USING UNIFORM WEIGHTS\n"; + for (int i = 0; i < forest.edges_.size(); ++i) + forest.edges_[i].edge_prob_=prob_t::One(); + } else { + forest.Reweight(feature_weights); + cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; + } + if (conf.count("show_partition")) { + const prob_t z = Inside<prob_t, EdgeProb>(forest); + cerr << " Contst. partition log(Z): " << log(z) << endl; + } + if (conf.count("forest_output")) { + ForestWriter writer(str("forest_output",conf), sent_id); + if (FileExists(writer.fname_)) { + cerr << " Unioning...\n"; + Hypergraph new_hg; + { + ReadFile rf(writer.fname_); + bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); + assert(succeeded); + } + new_hg.Union(forest); + bool succeeded = writer.Write(new_hg, minimal_forests); + assert(succeeded); + } else { + bool succeeded = writer.Write(forest, minimal_forests); + assert(succeeded); + } + } + if (aligner_mode && !output_training_vector) + AlignerTools::WriteAlignment(smeta.GetSourceLattice(), smeta.GetReference(), forest, &cout); + if (write_gradient) { + const prob_t ref_z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp); + ref_exp /= ref_z; + if (crf_uniform_empirical) { + log_ref_z = ref_exp.dot(feature_weights); + } else { + log_ref_z = log(ref_z); + } + //cerr << " MODEL LOG Z: " << log_z << endl; + //cerr << " EMPIRICAL LOG Z: " << log_ref_z << endl; + if ((log_z - log_ref_z) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_z < log_ref_z: " << log_z << " " << log_ref_z << endl; + exit(1); + } + assert(!isnan(log_ref_z)); + ref_exp -= full_exp; + acc_vec += ref_exp; + acc_obj += (log_z - log_ref_z); + } + if (feature_expectations) { + const prob_t z = + InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp); + ref_exp /= z; + acc_obj += log(z); + acc_vec += ref_exp; + } + if (output_training_vector) { + acc_vec.erase(0); + ++g_count; + if (g_count % combine_size == 0) { + if (encode_b64) { + cout << "0\t"; + SparseVector<double> dav; ConvertSV(acc_vec, &dav); + B64::Encode(acc_obj, dav, &cout); + cout << endl << flush; + } else { + cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush; + } + acc_vec.clear(); + acc_obj = 0; + } + } + if (conf.count("graphviz")) forest.PrintGraphviz(); + } else { + cerr << " REFERENCE UNREACHABLE.\n"; + if (write_gradient) { + cout << endl << flush; + } + } + } + return true; } + diff --git a/decoder/decoder.h b/decoder/decoder.h index ae1b9bb0..5dd6e1aa 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -5,17 +5,29 @@ #include <string> #include <vector> #include <boost/shared_ptr.hpp> +#include <boost/program_options/variables_map.hpp> +class SentenceMetadata; struct Hypergraph; struct DecoderImpl; + +struct DecoderObserver { + virtual void NotifySourceParseFailure(const SentenceMetadata& smeta); + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg); + virtual void NotifyAlignmentFailure(const SentenceMetadata& semta); + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg); + virtual void NotifyDecodingComplete(const SentenceMetadata& smeta); +}; + struct Decoder { Decoder(int argc, char** argv); Decoder(std::istream* config_file); - bool Decode(const std::string& input); - bool DecodeProduceHypergraph(const std::string& input, Hypergraph* hg); + bool Decode(const std::string& input, DecoderObserver* observer = NULL); void SetWeights(const std::vector<double>& weights); ~Decoder(); + const boost::program_options::variables_map& GetConf() const { return conf; } private: + boost::program_options::variables_map conf; boost::shared_ptr<DecoderImpl> pimpl_; }; diff --git a/decoder/ff_factory.h b/decoder/ff_factory.h index 5eb68c8b..92334396 100644 --- a/decoder/ff_factory.h +++ b/decoder/ff_factory.h @@ -20,6 +20,8 @@ #include <boost/shared_ptr.hpp> +#include "ff_fsa_dynamic.h" + class FeatureFunction; class FsaFeatureFunction; |