diff options
Diffstat (limited to 'decoder/cdec.cc')
-rw-r--r-- | decoder/cdec.cc | 51 |
1 files changed, 43 insertions, 8 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 876dee18..e896a484 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -34,6 +34,7 @@ #include "exp_semiring.h" #include "sentence_metadata.h" #include "../vest/scorer.h" +#include "apply_fsa_models.h" using namespace std; using namespace std::tr1; @@ -69,15 +70,27 @@ shared_ptr<FeatureFunction> make_ff(string const& ffp,bool verbose_feature_funct cerr << "Feature: " << ff; if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; else cerr << " (no config parameters)\n"; - shared_ptr<FeatureFunction> pf = global_ff_registry->Create(ff, param); - if (!pf) - exit(1); + shared_ptr<FeatureFunction> pf = ff_registry.Create(ff, param); + if (!pf) exit(1); int nbyte=pf->NumBytesContext(); if (verbose_feature_functions) cerr<<"State is "<<nbyte<<" bytes for "<<pre<<"feature "<<ffp<<endl; return pf; } +shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") { + string ff, param; + SplitCommandAndParam(ffp, &ff, ¶m); + cerr << "FSA Feature: " << ff; + if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; + else cerr << " (no config parameters)\n"; + shared_ptr<FsaFeatureFunction> pf = fsa_ff_registry.Create(ff, param); + if (!pf) exit(1); + if (verbose_feature_functions) + cerr<<"State is "<<pf->state_bytes()<<" bytes for "<<pre<<"feature "<<ffp<<endl; + return pf; +} + // print just the --long_opt names suitable for bash compgen void print_options(std::ostream &out,po::options_description const& opts) { typedef std::vector< shared_ptr<po::option_description> > Ds; @@ -106,6 +119,7 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)") ("no_freeze_feature_set,Z", "Do not freeze feature set after reading feature weights file") ("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)") + ("fsa_feature_function",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)") ("list_feature_functions,L","List available feature functions") ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") ("k_best,k",po::value<int>(),"Extract the k best derivations") @@ -185,13 +199,15 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c if (conf.count("list_feature_functions")) { cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n"; - global_ff_registry->DisplayList(); + ff_registry.DisplayList(); + cerr << "Available feature functions (specify with --fsa_feature_function):\n"; + fsa_ff_registry.DisplayList(); cerr << endl; exit(1); } if (conf.count("usage")) { - cout<<global_ff_registry->usage(str("usage",conf),true,true)<<endl; + cout<<ff_registry.usage(str("usage",conf),true,true)<<endl; exit(0); } if (conf.count("help")) { @@ -358,8 +374,17 @@ void show_models(po::variables_map const& conf,ModelSet &ms,char const* header) } +template <class V> +bool store_conf(po::variables_map const& conf,std::string const& name,V *v) { + if (conf.count(name)) { + *v=conf[name].as<V>(); + return true; + } + return false; +} + + int main(int argc, char** argv) { - global_ff_registry.reset(new FFRegistry); register_feature_functions(); po::variables_map conf; OracleBleu oracle; @@ -441,7 +466,6 @@ int main(int argc, char** argv) { // set up additional scoring features vector<shared_ptr<FeatureFunction> > pffs,prelm_only_ffs; - vector<const FeatureFunction*> late_ffs,prelm_ffs; if (conf.count("feature_function") > 0) { const vector<string>& add_ffs = conf["feature_function"].as<vector<string> >(); @@ -454,7 +478,7 @@ int main(int argc, char** argv) { prelm_ffs.push_back(p); else cerr << "Excluding stateful feature from prelm pruning: "<<add_ffs[i]<<endl; -} + } } } if (conf.count("prelm_feature_function") > 0) { @@ -465,6 +489,17 @@ int main(int argc, char** argv) { } } + vector<shared_ptr<FsaFeatureFunction> > fsa_ffs; + vector<string> fsa_names; + store_conf(conf,"fsa_feature_function",&fsa_names); + if (fsa_ffs.size()>1) { + //FIXME: support N fsa ffs. + cerr<<"Only the first fsa FF will be used (FIXME).\n"; + fsa_names.resize(1); + for (int i=0;i<fsa_names.size();++i) + fsa_ffs.push_back(make_fsa_ff(fsa_names[i],verbose_feature_functions,"FSA ")); + } + if (late_freeze) { cerr << "Late freezing feature set (use --no_freeze_feature_set to prevent)." << endl; FD::Freeze(); // this means we can't see the feature names of not-weighted features |