From df7cae6928285c7902cd1f8a5244c9ffcc5ae499 Mon Sep 17 00:00:00 2001 From: graehl Date: Sat, 7 Aug 2010 03:41:32 +0000 Subject: apply fsa models (so far only by bottom up) in cdec git-svn-id: https://ws10smt.googlecode.com/svn/trunk@487 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/Makefile.am | 1 + decoder/apply_fsa_models.cc | 102 ++++++++++++++++++++++++++++++++++++++++++++ decoder/apply_fsa_models.h | 29 ++++++++++++- decoder/apply_models.h | 6 +++ decoder/cdec.cc | 29 ++++++++++--- decoder/feature_vector.h | 4 ++ decoder/stringlib.h | 12 ++++++ 7 files changed, 176 insertions(+), 7 deletions(-) create mode 100755 decoder/apply_fsa_models.cc (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 189e28b0..88d6d17a 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -44,6 +44,7 @@ rule_lexer.cc: rule_lexer.l noinst_LIBRARIES = libcdec.a libcdec_a_SOURCES = \ + apply_fsa_models.cc \ rule_lexer.cc \ fst_translator.cc \ csplit.cc \ diff --git a/decoder/apply_fsa_models.cc b/decoder/apply_fsa_models.cc new file mode 100755 index 00000000..27773b0d --- /dev/null +++ b/decoder/apply_fsa_models.cc @@ -0,0 +1,102 @@ +#include "apply_fsa_models.h" +#include "hg.h" +#include "ff_fsa_dynamic.h" +#include "feature_vector.h" +#include "stringlib.h" +#include "apply_models.h" +#include +#include + +using namespace std; + +struct ApplyFsa { + ApplyFsa(const Hypergraph& ih, + const SentenceMetadata& smeta, + const FsaFeatureFunction& fsa, + DenseWeightVector const& weights, + ApplyFsaBy const& cfg, + Hypergraph* oh) + :ih(ih),smeta(smeta),fsa(fsa),weights(weights),cfg(cfg),oh(oh) + { +// sparse_to_dense(weight_vector,&weights); + Init(); + } + void Init() { + ApplyBottomUp(); + //TODO: implement l->r + } + void ApplyBottomUp() { + assert(cfg.IsBottomUp()); + vector ffs; + ModelSet models(weights, ffs); + IntersectionConfiguration i(cfg.BottomUpAlgorithm(),cfg.pop_limit); + ApplyModelSet(ih,smeta,models,i,oh); + } +private: + const Hypergraph& ih; + const SentenceMetadata& smeta; + const FsaFeatureFunction& fsa; +// WeightVector weight_vector; + DenseWeightVector weights; + ApplyFsaBy cfg; + Hypergraph* oh; +}; + + +void ApplyFsaModels(const Hypergraph& ih, + const SentenceMetadata& smeta, + const FsaFeatureFunction& fsa, + DenseWeightVector const& weight_vector, + ApplyFsaBy const& cfg, + Hypergraph* oh) +{ + ApplyFsa a(ih,smeta,fsa,weight_vector,cfg,oh); +} + + +namespace { +char const* anames[]={ + "BU_CUBE", + "BU_FULL", + "EARLEY", + 0 +}; +} + +//TODO: named enum type in boost? + +std::string ApplyFsaBy::name() const { + return anames[algorithm]; +} + +std::string ApplyFsaBy::all_names() { + std::ostringstream o; + for (int i=0;i=0); + assert (i +#include "feature_vector.h" struct FsaFeatureFunction; struct Hypergraph; struct SentenceMetadata; +struct ApplyFsaBy { + enum { + BU_CUBE, + BU_FULL, + EARLEY, + N_ALGORITHMS + }; + int pop_limit; // only applies to BU_FULL so far + bool IsBottomUp() const { + return algorithm==BU_FULL || algorithm==BU_CUBE; + } + int BottomUpAlgorithm() const; + int algorithm; + std::string name() const; + friend inline std::ostream &operator << (std::ostream &o,ApplyFsaBy const& c) { + return o << c.name(); + } + explicit ApplyFsaBy(int alg, int poplimit=200); + ApplyFsaBy(std::string const& name, int poplimit=200); + ApplyFsaBy(const ApplyFsaBy &o) : algorithm(o.algorithm) { } + static std::string all_names(); // space separated +}; + + void ApplyFsaModels(const Hypergraph& in, const SentenceMetadata& smeta, const FsaFeatureFunction& fsa, + DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this) + ApplyFsaBy const& cfg, Hypergraph* out); #endif diff --git a/decoder/apply_models.h b/decoder/apply_models.h index 61a5b8f7..81fa068e 100644 --- a/decoder/apply_models.h +++ b/decoder/apply_models.h @@ -8,6 +8,12 @@ struct SentenceMetadata; struct exhaustive_t {}; struct IntersectionConfiguration { +enum { + FULL, + CUBE, + N_ALGORITHMS +}; + const int algorithm; // 0 = full intersection, 1 = cube pruning const int pop_limit; // max number of pops off the heap at each node IntersectionConfiguration(int alg, int k) : algorithm(alg), pop_limit(k) {} diff --git a/decoder/cdec.cc b/decoder/cdec.cc index a7c99307..72f0b95e 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -25,6 +25,7 @@ #include "weights.h" #include "tdict.h" #include "ff.h" +#include "ff_fsa_dynamic.h" #include "ff_factory.h" #include "hg_intersect.h" #include "apply_models.h" @@ -119,7 +120,8 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)") ("no_freeze_feature_set,Z", "Do not freeze feature set after reading feature weights file") ("feature_function,F",po::value >()->composing(), "Additional feature function(s) (-L for list)") - ("fsa_feature_function",po::value >()->composing(), "Additional FSA feature function(s) (-L for list)") + ("fsa_feature_function,A",po::value >()->composing(), "Additional FSA feature function(s) (-L for list)") + ("apply_fsa_by",po::value()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names() ("list_feature_functions,L","List available feature functions") ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") ("k_best,k",po::value(),"Extract the k best derivations") @@ -147,7 +149,7 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails") ("beam_prune", po::value(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)") ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") - ("promise_power",po::value()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams. 0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit). note: for the same poplimit, this gives more search error unless very close to 0 (recommend disabled; even 0.01 is slightly worse than 0) which is a bad sign and suggests this isn't doing a good job; further it's slightly slower to LM cube rescore with 0.01 compared to 0, as well as giving (very insignificantly) lower BLEU. TODO: test under more conditions, or try idea with different formula, or prob. cube beams.") + ("promise_power",po::value()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams. 0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit). note: for the same pop_limit, this gives more search error unless very close to 0 (recommend disabled; even 0.01 is slightly worse than 0) which is a bad sign and suggests this isn't doing a good job; further it's slightly slower to LM cube rescore with 0.01 compared to 0, as well as giving (very insignificantly) lower BLEU. TODO: test under more conditions, or try idea with different formula, or prob. cube beams.") ("lexalign_use_null", "Support source-side null words in lexical translation") ("tagger_tagset,t", po::value(), "(Tagger) file containing tag set") ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") @@ -519,7 +521,8 @@ int main(int argc, char** argv) { palg = 0; cerr << "Using full intersection (no pruning).\n"; } - const IntersectionConfiguration inter_conf(palg, conf["cubepruning_pop_limit"].as()); + int pop_limit=conf["cubepruning_pop_limit"].as(); + const IntersectionConfiguration inter_conf(palg, pop_limit); const int sample_max_trans = conf.count("max_translation_sample") ? conf["max_translation_sample"].as() : 0; @@ -619,7 +622,7 @@ int main(int argc, char** argv) { inter_conf, // this is now reduced to exhaustive if all are stateless &prelm_forest); forest.swap(prelm_forest); - forest.Reweight(prelm_feature_weights); + forest.Reweight(prelm_feature_weights); //FIXME: why the reweighting? here and below. maybe in case we already had a featval for that id and ApplyModelSet only adds prob, doesn't recompute it? forest_stats(forest," prelm forest",show_tree_structure,show_features,prelm_feature_weights,oracle.show_derivation); } @@ -642,8 +645,20 @@ int main(int argc, char** argv) { maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen); - vector trans; - ViterbiESentence(forest, &trans); + + + if (!fsa_ffs.empty()) { + Timer t("Target FSA rescoring:"); + if (!has_late_models) + forest.Reweight(feature_weights); + Hypergraph fsa_forest; + assert(fsa_ffs.size()==1); + ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit); + ApplyFsaModels(forest,smeta,*fsa_ffs[0],feature_weights,cfg,&fsa_forest); + forest.swap(fsa_forest); + forest.Reweight(feature_weights); + forest_stats(forest," +FSA forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); + } /*Oracle Rescoring*/ @@ -687,6 +702,8 @@ int main(int argc, char** argv) { cout << HypergraphIO::AsPLF(forest, false) << endl; } else { if (!graphviz && !has_ref && !joshua_viz) { + vector trans; + ViterbiESentence(forest, &trans); cout << TD::GetString(trans) << endl << flush; } if (joshua_viz) { diff --git a/decoder/feature_vector.h b/decoder/feature_vector.h index 1b272506..be378a6a 100755 --- a/decoder/feature_vector.h +++ b/decoder/feature_vector.h @@ -11,4 +11,8 @@ typedef SparseVector FeatureVector; typedef SparseVector WeightVector; typedef std::vector DenseWeightVector; +inline void sparse_to_dense(WeightVector const& wv,DenseWeightVector *dv) { + wv.init_vector(dv); +} + #endif diff --git a/decoder/stringlib.h b/decoder/stringlib.h index b3097bd1..53e6fe50 100644 --- a/decoder/stringlib.h +++ b/decoder/stringlib.h @@ -18,6 +18,18 @@ #include #include #include +#include + +struct toupperc { + inline char operator()(char c) const { + return std::toupper(c); + } +}; + +inline std::string toupper(std::string s) { + std::transform(s.begin(),s.end(),s.begin(),toupperc()); + return s; +} template inline bool match_begin(Istr bstr,Istr estr,Isubstr bsub,Isubstr esub) -- cgit v1.2.3