diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/Makefile.am | 1 | ||||
| -rw-r--r-- | decoder/cdec_ff.cc | 3 | ||||
| -rw-r--r-- | decoder/decoder.cc | 1 | ||||
| -rw-r--r-- | decoder/ff_lexical.h | 128 | ||||
| -rw-r--r-- | decoder/ff_rules.cc | 22 | ||||
| -rw-r--r-- | decoder/ff_rules.h | 13 | ||||
| -rw-r--r-- | decoder/scfg_translator.cc | 31 | 
7 files changed, 152 insertions, 47 deletions
| diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 02e58479..e46a7120 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -50,6 +50,7 @@ libcdec_a_SOURCES = \    ff_external.h \    ff_factory.h \    ff_klm.h \ +	ff_lexical.h \    ff_lm.h \    ff_ngrams.h \    ff_parse_match.h \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 0411908f..7f7e075b 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -24,6 +24,7 @@  #include "ff_charset.h"  #include "ff_wordset.h"  #include "ff_external.h" +#include "ff_lexical.h"  void register_feature_functions() { @@ -39,13 +40,13 @@ void register_feature_functions() {    RegisterFF<SourceWordPenalty>();    RegisterFF<ArityPenalty>();    RegisterFF<BLEUModel>(); +  RegisterFF<LexicalFeatures>();    //TODO: use for all features the new Register which requires static FF::usage(false,false) give name    ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>());    ff_registry.Register("NgramFeatures", new FFFactory<NgramDetector>());    ff_registry.Register("RuleContextFeatures", new FFFactory<RuleContextFeatures>());    ff_registry.Register("RuleIdentityFeatures", new FFFactory<RuleIdentityFeatures>()); -  ff_registry.Register("RuleWordAlignmentFeatures", new FFFactory<RuleWordAlignmentFeatures>());    ff_registry.Register("ParseMatchFeatures", new FFFactory<ParseMatchFeatures>);    ff_registry.Register("SoftSyntaxFeatures", new FFFactory<SoftSyntaxFeatures>);    ff_registry.Register("SoftSyntaxFeaturesMindist", new FFFactory<SoftSyntaxFeaturesMindist>); diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 6783cad0..c384c33f 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -366,6 +366,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream          ("beam_prune3", po::value<double>(), "Optional pass 3")          ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") +        ("add_extra_pass_through_features,Q", po::value<unsigned int>()->default_value(0), "Add PassThrough{1..N} features, capped at N.")          ("k_best,k",po::value<int>(),"Extract the k best derivations")          ("unique_k_best,r", "Unique k-best translation list")          ("aligner,a", "Run as a word/phrase aligner (src & ref required)") diff --git a/decoder/ff_lexical.h b/decoder/ff_lexical.h new file mode 100644 index 00000000..21c85b27 --- /dev/null +++ b/decoder/ff_lexical.h @@ -0,0 +1,128 @@ +#ifndef FF_LEXICAL_H_ +#define FF_LEXICAL_H_ + +#include <vector> +#include <map> +#include "trule.h" +#include "ff.h" +#include "hg.h" +#include "array2d.h" +#include "wordid.h" +#include <sstream> +#include <cassert> +#include <cmath> + +#include "filelib.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "lattice.h" +#include "fdict.h" +#include "verbose.h" +#include "tdict.h" +#include "hg.h" + +using namespace std; + +namespace { +  string Escape(const string& x) { +    string y = x; +    for (int i = 0; i < y.size(); ++i) { +      if (y[i] == '=') y[i]='_'; +      if (y[i] == ';') y[i]='_'; +    } +    return y; +  } +} + +class LexicalFeatures : public FeatureFunction { +public: +	LexicalFeatures(const std::string& param) { +		if (param.empty()) { +			cerr << "LexicalFeatures: using T,D,I\n"; +			T_ = true; I_ = true; D_ = true; +		} else { +			const vector<string> argv = SplitOnWhitespace(param); +			assert(argv.size() == 3); +			T_ = (bool) atoi(argv[0].c_str()); +			I_ = (bool) atoi(argv[1].c_str()); +			D_ = (bool) atoi(argv[2].c_str()); +			cerr << "T=" << T_ << " I=" << I_ << " D=" << D_ << endl; +		} +	}; +	static std::string usage(bool p,bool d) { +	    return usage_helper("LexicalFeatures","[0/1 0/1 0/1]","Sparse lexical word translation indicator features. If arguments are supplied, specify like this: translations insertions deletions",p,d); +	} +protected: +	virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, +			const HG::Edge& edge, +			const std::vector<const void*>& ant_contexts, +			SparseVector<double>* features, +			SparseVector<double>* estimated_features, +			void* context) const; +	virtual void PrepareForInput(const SentenceMetadata& smeta); +private: +	mutable std::map<const TRule*, SparseVector<double> > rule2feats_; +	bool T_; +	bool I_; +	bool D_; +}; + +void LexicalFeatures::PrepareForInput(const SentenceMetadata& smeta) { +  rule2feats_.clear(); //  std::map<const TRule*, SparseVector<double> > +} + +void LexicalFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, +	const HG::Edge& edge, +	const std::vector<const void*>& ant_contexts, +	SparseVector<double>* features, +	SparseVector<double>* estimated_features, +	void* context) const { +	 +	map<const TRule*, SparseVector<double> >::iterator it = rule2feats_.find(edge.rule_.get());	 +	if (it == rule2feats_.end()) { +		const TRule& rule = *edge.rule_; +	    it = rule2feats_.insert(make_pair(&rule, SparseVector<double>())).first; +	    SparseVector<double>& f = it->second; +	    std::vector<bool> sf(edge.rule_->FLength(),false); // stores if source tokens are visited by alignment points +		std::vector<bool> se(edge.rule_->ELength(),false); // stores if target tokens are visited by alignment points +		int fid = 0; +	    // translations +	    for (unsigned i=0;i<rule.a_.size();++i) { +	    	const AlignmentPoint& ap = rule.a_[i]; +	    	sf[ap.s_] = true; // mark index as seen +	    	se[ap.t_] = true; // mark index as seen +	    	ostringstream os; +			os << "LT:" << Escape(TD::Convert(rule.f_[ap.s_])) << ":" << Escape(TD::Convert(rule.e_[ap.t_])); +			fid = FD::Convert(os.str()); +			if (fid <= 0) continue; +			if (T_) +				f.add_value(fid, 1.0); +	    } +	    // word deletions +	    for (unsigned i=0;i<sf.size();++i) { +	    	if (!sf[i] && rule.f_[i] > 0) {// if not visited and is terminal +	    		ostringstream os; +	    		os << "LD:" << Escape(TD::Convert(rule.f_[i])); +	    		fid = FD::Convert(os.str()); +	    		if (fid <= 0) continue; +	    		if (D_) +		    		f.add_value(fid, 1.0); +	    	} +	    } +	    // word insertions +	    for (unsigned i=0;i<se.size();++i) { +	    	if (!se[i] && rule.e_[i] >= 1) {// if not visited and is terminal +	    		ostringstream os; +	    		os << "LI:" << Escape(TD::Convert(rule.e_[i])); +	    		fid = FD::Convert(os.str()); +	    		if (fid <= 0) continue; +	    		if (I_) +		    		f.add_value(fid, 1.0); +	    	} +	    } +	} +	(*features) += it->second; +} + + +#endif diff --git a/decoder/ff_rules.cc b/decoder/ff_rules.cc index 7bccf084..9533caed 100644 --- a/decoder/ff_rules.cc +++ b/decoder/ff_rules.cc @@ -69,28 +69,6 @@ void RuleIdentityFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,    features->add_value(it->second, 1);  } -RuleWordAlignmentFeatures::RuleWordAlignmentFeatures(const std::string& param) { -} - -void RuleWordAlignmentFeatures::PrepareForInput(const SentenceMetadata& smeta) { -} - -void RuleWordAlignmentFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, -                                         const Hypergraph::Edge& edge, -                                         const vector<const void*>& ant_contexts, -                                         SparseVector<double>* features, -                                         SparseVector<double>* estimated_features, -                                         void* context) const { -  const TRule& rule = *edge.rule_; -  ostringstream os; -  vector<AlignmentPoint> als = rule.als();  -  std::vector<AlignmentPoint>::const_iterator xx = als.begin(); -  for (; xx != als.end(); ++xx) { -    os << "WA:" <<  TD::Convert(rule.f_[xx->s_]) << ":" << TD::Convert(rule.e_[xx->t_]); -  } -  features->add_value(FD::Convert(Escape(os.str())), 1); -} -  RuleSourceBigramFeatures::RuleSourceBigramFeatures(const std::string& param) {  } diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h index 324d7a39..f210dc65 100644 --- a/decoder/ff_rules.h +++ b/decoder/ff_rules.h @@ -24,19 +24,6 @@ class RuleIdentityFeatures : public FeatureFunction {    mutable std::map<const TRule*, int> rule2_fid_;  }; -class RuleWordAlignmentFeatures : public FeatureFunction { - public: -  RuleWordAlignmentFeatures(const std::string& param); - protected: -  virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, -                                     const HG::Edge& edge, -                                     const std::vector<const void*>& ant_contexts, -                                     SparseVector<double>* features, -                                     SparseVector<double>* estimated_features, -                                     void* context) const; -  virtual void PrepareForInput(const SentenceMetadata& smeta); -}; -  class RuleSourceBigramFeatures : public FeatureFunction {   public:    RuleSourceBigramFeatures(const std::string& param); diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 88f62769..c3cfcaad 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -28,7 +28,7 @@ struct GlueGrammar : public TextGrammar {  };  struct PassThroughGrammar : public TextGrammar { -  PassThroughGrammar(const Lattice& input, const std::string& cat, const unsigned int ctf_level=0); +  PassThroughGrammar(const Lattice& input, const std::string& cat, const unsigned int ctf_level=0, const unsigned int num_pt_features=0);    virtual bool HasRuleForSpan(int i, int j, int distance) const;  }; @@ -56,7 +56,7 @@ bool GlueGrammar::HasRuleForSpan(int i, int /* j */, int /* distance */) const {    return (i == 0);  } -PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level) { +PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level, const unsigned num_pt_features) {    unordered_set<WordID> ss;    for (int i = 0; i < input.size(); ++i) {      const vector<LatticeArc>& alts = input[i]; @@ -64,14 +64,21 @@ PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat,        const int j = alts[k].dist2next + i;        const string& src = TD::Convert(alts[k].label);        if (ss.count(alts[k].label) == 0) { -        int length = static_cast<int>(log(UTF8StringLen(src)) / log(1.6)) + 1; -        if (length > 6) length = 6; -        string len_feat = "PassThrough_0=1"; -        len_feat[12] += length; -        TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1 " + len_feat)); -        pt->a_.push_back(AlignmentPoint(0,0)); -        AddRule(pt); -        RefineRule(pt, ctf_level); +        if (num_pt_features > 0) { +          int length = static_cast<int>(log(UTF8StringLen(src)) / log(1.6)) + 1; +          if (length > num_pt_features) length = num_pt_features; +          string len_feat = "PassThrough_0=1"; +          len_feat[12] += length; +          TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1 " + len_feat)); +          pt->a_.push_back(AlignmentPoint(0,0)); +          AddRule(pt); +          RefineRule(pt, ctf_level); +        } else { +          TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1 ")); +          pt->a_.push_back(AlignmentPoint(0,0)); +          AddRule(pt); +          RefineRule(pt, ctf_level); +        }          ss.insert(alts[k].label);        }      } @@ -86,6 +93,7 @@ struct SCFGTranslatorImpl {    SCFGTranslatorImpl(const boost::program_options::variables_map& conf) :        max_span_limit(conf["scfg_max_span_limit"].as<int>()),        add_pass_through_rules(conf.count("add_pass_through_rules")), +      num_pt_features(conf["add_extra_pass_through_features"].as<unsigned int>()),        goal(conf["goal"].as<string>()),        default_nt(conf["scfg_default_nt"].as<string>()),        use_ctf_(conf.count("coarse_to_fine_beam_prune")) @@ -140,6 +148,7 @@ struct SCFGTranslatorImpl {    const int max_span_limit;    const bool add_pass_through_rules; +  const unsigned int num_pt_features;    const string goal;    const string default_nt;    const bool use_ctf_; @@ -187,7 +196,7 @@ struct SCFGTranslatorImpl {      smeta->SetSourceLength(lattice.size());      if (add_pass_through_rules){        if (!SILENT) cerr << "Adding pass through grammar" << endl; -      PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_); +      PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_, num_pt_features);        g->SetGrammarName("PassThrough");        glist.push_back(GrammarPtr(g));      } | 
