From 471b69d38404452e5150d1955d4dc96744cbbeda Mon Sep 17 00:00:00 2001 From: graehl Date: Thu, 8 Jul 2010 23:27:05 +0000 Subject: feature functions support e.g. --usage=LanguageModel, report feature ids, warn about duplicate ids, 0-expand weight vector for no segfault, --warn_0_weight, and know their own names to simplify registration git-svn-id: https://ws10smt.googlecode.com/svn/trunk@192 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/cdec.cc | 126 ++++++++++++++++++++++++++++++++------------------ decoder/cdec_ff.cc | 9 ++-- decoder/ff.cc | 64 ++++++++++++++++++++++++- decoder/ff.h | 30 ++++++++++-- decoder/ff_factory.cc | 1 + decoder/ff_lm.cc | 4 ++ decoder/ff_lm.h | 1 + 7 files changed, 181 insertions(+), 54 deletions(-) (limited to 'decoder') diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 54e24792..919751a2 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -56,7 +56,28 @@ void ConvertSV(const SparseVector& src, SparseVector* trg) { trg->set_value(it->first, it->second); } -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + +inline string str(char const* name,po::variables_map const& conf) { + return conf[name].as(); +} + +shared_ptr make_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") { + string ff, param; + SplitCommandAndParam(ffp, &ff, ¶m); + cerr << "Feature: " << ff; + if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; + else cerr << " (no config parameters)\n"; + shared_ptr pf = global_ff_registry->Create(ff, param); + if (!pf) + exit(1); + int nbyte=pf->NumBytesContext(); + if (verbose_feature_functions) + cerr<<"State is "<(),"Decoding formalism; values include SCFG, FST, PB, LexTrans (lexical translation model, also disc training), CSplit (compound splitting), Tagger (sequence labeling), LexAlign (alignment only, or EM training)") @@ -65,8 +86,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("weights,w",po::value(),"Feature weights file") ("prelm_weights",po::value(),"Feature weights file for prelm_beam_prune. Requires --weights.") ("prelm_copy_weights","use --weights as value for --prelm_weights.") + ("prelm_feature_function",po::value >()->composing(),"Additional feature functions for prelm pass only (in addition to the 0-state subset of feature_function") ("keep_prelm_cube_order","when forest rescoring with final models, use the edge ordering from the prelm pruning features*weights. only meaningful if --prelm_weights given. UNTESTED but assume that cube pruning gives a sensible result, and that 'good' (as tuned for bleu w/ prelm features) edges come first.") - + ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)") ("no_freeze_feature_set,Z", "Do not freeze feature set after reading feature weights file") ("feature_function,F",po::value >()->composing(), "Additional feature function(s) (-L for list)") ("list_feature_functions,L","List available feature functions") @@ -111,33 +133,44 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description clo("Command line options"); clo.add_options() ("config,c", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); + ("help,h", "Print this help message and exit") + ("usage", po::value(), "Describe a feature function type") + ; + po::options_description dconfig_options, dcmdline_options; dconfig_options.add(opts); dcmdline_options.add(opts).add(clo); - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - const string cfg = (*conf)["config"].as(); + po::store(parse_command_line(argc, argv, dcmdline_options), conf); + if (conf.count("config")) { + const string cfg = str("config",conf); cerr << "Configuration file: " << cfg << endl; ifstream config(cfg.c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); + po::store(po::parse_config_file(config, dconfig_options), conf); } - po::notify(*conf); + po::notify(conf); - if (conf->count("list_feature_functions")) { + if (conf.count("list_feature_functions")) { cerr << "Available feature functions (specify with -F):\n"; global_ff_registry->DisplayList(); cerr << endl; exit(1); } - if (conf->count("help") || conf->count("formalism") == 0) { + if (conf.count("usage")) { + cout<usage(str("usage",conf),true,true)<()); + const string formalism = LowercaseString(str("formalism",conf)); if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") { cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n"; cerr << dcmdline_options << endl; @@ -256,18 +289,17 @@ bool beam_param(po::variables_map const& conf,string const& name,double *val,boo bool prelm_weights_string(po::variables_map const& conf,string &s) { if (conf.count("prelm_weights")) { - s=conf["prelm_weights"].as(); + s=str("prelm_weights",conf); return true; } if (conf.count("prelm_copy_weights")) { - s=conf["weights"].as(); + s=str("weights",conf); return true; } return false; } - void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,FeatureWeights *weights=0) { cerr << viterbi_stats(forest,name,true,show_tree); if (show_features) { @@ -305,6 +337,10 @@ void maybe_prune(Hypergraph &forest,po::variables_map const& conf,string nbeam,s } } +void show_models(po::variables_map const& conf,ModelSet &ms,char const* header) { + cerr< translator; - const string formalism = LowercaseString(conf["formalism"].as()); + const string formalism = LowercaseString(str("formalism",conf)); const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word"); if (csplit_preserve_full_word && (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")||conf.count("prelm_beam_prune")||conf.count("prelm_density_prune")))) { @@ -341,7 +377,7 @@ int main(int argc, char** argv) { Weights w,prelm_w; bool has_prelm_models = false; if (conf.count("weights")) { - w.InitFromFile(conf["weights"].as()); + w.InitFromFile(str("weights",conf)); feature_weights.resize(FD::NumFeats()); w.InitVector(&feature_weights); string plmw; @@ -350,13 +386,9 @@ int main(int argc, char** argv) { prelm_w.InitFromFile(plmw); prelm_feature_weights.resize(FD::NumFeats()); prelm_w.InitVector(&prelm_feature_weights); - cerr << "prelm_weights: " << FeatureVector(prelm_feature_weights)< > pffs; + vector > pffs,prelm_only_ffs; vector late_ffs,prelm_ffs; if (conf.count("feature_function") > 0) { const vector& add_ffs = conf["feature_function"].as >(); for (int i = 0; i < add_ffs.size(); ++i) { - string ff, param; - SplitCommandAndParam(add_ffs[i], &ff, ¶m); - cerr << "Feature: " << ff; - if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; - else cerr << " (no config parameters)\n"; - shared_ptr pff = global_ff_registry->Create(ff, param); - FeatureFunction const* p=pff.get(); - if (!p) { exit(1); } - // TODO check that multiple features aren't trying to set the same fid - pffs.push_back(pff); + pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions)); + FeatureFunction const* p=pffs.back().get(); late_ffs.push_back(p); - int nbyte=p->NumBytesContext(); - if (verbose_feature_functions) - cerr<<"State is "<NumBytesContext()==0) prelm_ffs.push_back(p); else - cerr << "Excluding stateful feature from prelm pruning: "< 0) { + const vector& add_ffs = conf["prelm_feature_function"].as >(); + for (int i = 0; i < add_ffs.size(); ++i) { + prelm_only_ffs.push_back(make_ff(add_ffs[i],verbose_feature_functions,"prelm-only ")); + prelm_ffs.push_back(prelm_only_ffs.back().get()); + } + } + if (has_prelm_models) cerr << "prelm rescoring with "<()) == "full") { + if (LowercaseString(str("intersection_strategy",conf)) == "full") { palg = 0; cerr << "Using full intersection (no pruning).\n"; } @@ -426,17 +463,17 @@ int main(int argc, char** argv) { const bool minimal_forests = conf.count("minimal_forests"); const bool graphviz = conf.count("graphviz"); const bool joshua_viz = conf.count("show_joshua_visualization"); - const bool encode_b64 = conf["vector_format"].as() == "b64"; + const bool encode_b64 = str("vector_format",conf) == "b64"; const bool kbest = conf.count("k_best"); const bool unique_kbest = conf.count("unique_k_best"); const bool crf_uniform_empirical = conf.count("crf_uniform_empirical"); shared_ptr extract_file; if (conf.count("extract_rules")) - extract_file.reset(new WriteFile(conf["extract_rules"].as())); + extract_file.reset(new WriteFile(str("extract_rules",conf))); int combine_size = conf["combine_size"].as(); if (combine_size < 1) combine_size = 1; - const string input = conf["input"].as(); + const string input = str("input",conf); cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl; ReadFile in_read(input); istream *in = in_read.stream(); @@ -506,7 +543,6 @@ int main(int argc, char** argv) { ExtractRulesDedupe(forest, extract_file->stream()); if (has_prelm_models) { - ModelSet prelm_models(prelm_feature_weights, prelm_ffs); Timer t("prelm rescoring"); forest.Reweight(prelm_feature_weights); forest.SortInEdgesByEdgeWeights(); @@ -544,7 +580,7 @@ int main(int argc, char** argv) { maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen); if (conf.count("forest_output") && !has_ref) { - ForestWriter writer(conf["forest_output"].as(), sent_id); + ForestWriter writer(str("forest_output",conf), sent_id); if (FileExists(writer.fname_)) { cerr << " Unioning...\n"; Hypergraph new_hg; @@ -621,7 +657,7 @@ int main(int argc, char** argv) { } //DumpKBest(sent_id, forest, 1000); if (conf.count("forest_output")) { - ForestWriter writer(conf["forest_output"].as(), sent_id); + ForestWriter writer(str("forest_output",conf), sent_id); if (FileExists(writer.fname_)) { cerr << " Unioning...\n"; Hypergraph new_hg; diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 8cf2f2fd..077956a8 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -12,13 +12,14 @@ boost::shared_ptr global_ff_registry; void register_feature_functions() { global_ff_registry->Register(new FFFactory); - //TODO: define usage(false,false) for each of the below + + //TODO: use for all features the new Register which requires usage(...) #ifdef HAVE_RANDLM global_ff_registry->Register("RandLM", new FFFactory); #endif - global_ff_registry->Register("WordPenalty", new FFFactory); - global_ff_registry->Register("SourceWordPenalty", new FFFactory); - global_ff_registry->Register("ArityPenalty", new FFFactory); + global_ff_registry->Register(new FFFactory); + global_ff_registry->Register(new FFFactory); + global_ff_registry->Register(new FFFactory); global_ff_registry->Register("RuleShape", new FFFactory); global_ff_registry->Register("RelativeSentencePosition", new FFFactory); global_ff_registry->Register("Model2BinaryFeatures", new FFFactory); diff --git a/decoder/ff.cc b/decoder/ff.cc index 73dbbdc9..3f433dfb 100644 --- a/decoder/ff.cc +++ b/decoder/ff.cc @@ -29,6 +29,55 @@ string FeatureFunction::usage_helper(std::string const& name,std::string const& return r; } +FeatureFunction::Features FeatureFunction::single_feature(WordID feat) { + return Features(1,feat); +} + +FeatureFunction::Features ModelSet::all_features(std::ostream *warn) { + typedef FeatureFunction::Features FFS; + FFS ffs; +#define WARNFF(x) do { if (warn) { *warn << "WARNING: "<< x ; *warn< FFM; + FFM ff_from; + for (unsigned i=0;i= weights_.size()) + weights_.resize(fid+1); + pair i_new=ff_from.insert(FFM::value_type(fid,ffname)); + if (i_new.second) + ffs.push_back(fid); + else { + WARNFF(ffname<<" models["<second); + } + } + } + return ffs; +#undef WARNFF +} + +void ModelSet::show_features(std::ostream &out,std::ostream &warn,bool warn_zero_wt) +{ + typedef FeatureFunction::Features FFS; + FFS ffs=all_features(&warn); + out << "Weight Feature\n"; + for (unsigned i=0;i& ant_states, @@ -75,12 +133,16 @@ void SourceWordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, ArityPenalty::ArityPenalty(const std::string& /* param */) : value_(-1.0 / log(10)) { string fname = "Arity_X"; - for (int i = 0; i < 10; ++i) { + for (int i = 0; i < N_ARITIES; ++i) { fname[6]=i + '0'; fids_[i] = FD::Convert(fname); } } +FeatureFunction::Features ArityPenalty::features() const { + return Features(&fids_[0],&fids_[N_ARITIES]); +} + void ArityPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, const std::vector& ant_states, diff --git a/decoder/ff.h b/decoder/ff.h index c6c9cf8f..6f8b8626 100644 --- a/decoder/ff.h +++ b/decoder/ff.h @@ -15,6 +15,7 @@ class FeatureFunction; // see definition below // FinalTraversalFeatures(...) class FeatureFunction { public: + std::string name; // set by FF factory using usage() FeatureFunction() : state_size_() {} explicit FeatureFunction(int state_size) : state_size_(state_size) {} virtual ~FeatureFunction(); @@ -24,12 +25,14 @@ class FeatureFunction { return usage_helper("FIXME_feature_needs_name","[no parameters]","[no documentation yet]",show_params,show_details); } - static std::string usage_helper(std::string const& name,std::string const& params,std::string const& details,bool show_params,bool show_details); + typedef std::vector Features; // set of features ids +protected: + static std::string usage_helper(std::string const& name,std::string const& params,std::string const& details,bool show_params,bool show_details); + static Features single_feature(WordID feat); public: - typedef std::vector Features; - virtual Features features() { return Features(); } + virtual Features features() const { return Features(); } // returns the number of bytes of context that this feature function will // (maximally) use. By default, 0 ("stateless" models in Hiero/Joshua). // NOTE: this value is fixed for the instance of your class, you cannot @@ -87,7 +90,11 @@ public: // add value_ class WordPenalty : public FeatureFunction { public: + Features features() const; WordPenalty(const std::string& param); + static std::string usage(bool p,bool d) { + return usage_helper("WordPenalty","","number of target words (local feature)",p,d); + } protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, @@ -102,7 +109,11 @@ class WordPenalty : public FeatureFunction { class SourceWordPenalty : public FeatureFunction { public: + Features features() const; SourceWordPenalty(const std::string& param); + static std::string usage(bool p,bool d) { + return usage_helper("SourceWordPenalty","","number of source words (local feature, and meaningless except when input has non-constant number of source words, e.g. segmentation/morphology/speech recognition lattice)",p,d); + } protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, @@ -117,7 +128,12 @@ class SourceWordPenalty : public FeatureFunction { class ArityPenalty : public FeatureFunction { public: + Features features() const; ArityPenalty(const std::string& param); + static std::string usage(bool p,bool d) { + return usage_helper("ArityPenalty","","Indicator feature Arity_N=1 for rule of arity N (local feature)",p,d); + } + protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, @@ -126,7 +142,10 @@ class ArityPenalty : public FeatureFunction { SparseVector* estimated_features, void* context) const; private: - int fids_[10]; + enum {N_ARITIES=10}; + + + int fids_[N_ARITIES]; const double value_; }; @@ -153,6 +172,9 @@ class ModelSet { Hypergraph::Edge* edge) const; bool empty() const { return models_.empty(); } + + FeatureFunction::Features all_features(std::ostream *warnings=0); // this will warn about duplicate features as well (one function overwrites the feature of another). also resizes weights_ so it is large enough to hold the (0) weight for the largest reported feature id + void show_features(std::ostream &out,std::ostream &warn,bool warn_zero_wt=true); //show features and weights private: std::vector models_; std::vector weights_; diff --git a/decoder/ff_factory.cc b/decoder/ff_factory.cc index d66cd883..fe733ca5 100644 --- a/decoder/ff_factory.cc +++ b/decoder/ff_factory.cc @@ -28,6 +28,7 @@ shared_ptr FFRegistry::Create(const string& ffname, const strin cerr << "I don't know how to create feature " << ffname << endl; } else { res = it->second->Create(param); + res->name=ffname; } return res; } diff --git a/decoder/ff_lm.cc b/decoder/ff_lm.cc index 9e6f02b7..0590fa7e 100644 --- a/decoder/ff_lm.cc +++ b/decoder/ff_lm.cc @@ -532,6 +532,10 @@ LanguageModel::LanguageModel(const string& param) { SetStateSize(LanguageModelImpl::OrderToStateSize(order)); } +FeatureFunction::Features LanguageModel::features() const { + return single_feature(fid_); +} + LanguageModel::~LanguageModel() { delete pimpl_; } diff --git a/decoder/ff_lm.h b/decoder/ff_lm.h index 5ea41068..935e283c 100644 --- a/decoder/ff_lm.h +++ b/decoder/ff_lm.h @@ -19,6 +19,7 @@ class LanguageModel : public FeatureFunction { SparseVector* features) const; std::string DebugStateToString(const void* state) const; static std::string usage(bool param,bool verbose); + Features features() const; protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, -- cgit v1.2.3