summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-09-07 13:57:52 -0400
committerChris Dyer <redpony@gmail.com>2014-09-07 13:57:52 -0400
commitffd0096320770325a8925dd17453d1fdd9375bb9 (patch)
tree200b1a6d1b14853d8ed6acb649ab9add881cc99b /decoder
parent49c105dfc1fc3a0334d03de4d361abf23a6f1898 (diff)
parente6f2dd6892e277d0a868c22f726c4a83c86da016 (diff)
Merge pull request #50 from pks/master
alignment features, PassThroughN features, dtrain update, mira qsub, and pro fix
Diffstat (limited to 'decoder')
-rw-r--r--decoder/Makefile.am1
-rw-r--r--decoder/cdec_ff.cc3
-rw-r--r--decoder/decoder.cc1
-rw-r--r--decoder/ff_lexical.h128
-rw-r--r--decoder/ff_rules.cc22
-rw-r--r--decoder/ff_rules.h13
-rw-r--r--decoder/scfg_translator.cc31
7 files changed, 152 insertions, 47 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 02e58479..e46a7120 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -50,6 +50,7 @@ libcdec_a_SOURCES = \
ff_external.h \
ff_factory.h \
ff_klm.h \
+ ff_lexical.h \
ff_lm.h \
ff_ngrams.h \
ff_parse_match.h \
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 0411908f..7f7e075b 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -24,6 +24,7 @@
#include "ff_charset.h"
#include "ff_wordset.h"
#include "ff_external.h"
+#include "ff_lexical.h"
void register_feature_functions() {
@@ -39,13 +40,13 @@ void register_feature_functions() {
RegisterFF<SourceWordPenalty>();
RegisterFF<ArityPenalty>();
RegisterFF<BLEUModel>();
+ RegisterFF<LexicalFeatures>();
//TODO: use for all features the new Register which requires static FF::usage(false,false) give name
ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>());
ff_registry.Register("NgramFeatures", new FFFactory<NgramDetector>());
ff_registry.Register("RuleContextFeatures", new FFFactory<RuleContextFeatures>());
ff_registry.Register("RuleIdentityFeatures", new FFFactory<RuleIdentityFeatures>());
- ff_registry.Register("RuleWordAlignmentFeatures", new FFFactory<RuleWordAlignmentFeatures>());
ff_registry.Register("ParseMatchFeatures", new FFFactory<ParseMatchFeatures>);
ff_registry.Register("SoftSyntaxFeatures", new FFFactory<SoftSyntaxFeatures>);
ff_registry.Register("SoftSyntaxFeaturesMindist", new FFFactory<SoftSyntaxFeaturesMindist>);
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 6783cad0..c384c33f 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -366,6 +366,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("beam_prune3", po::value<double>(), "Optional pass 3")
("add_pass_through_rules,P","Add rules to translate OOV words as themselves")
+ ("add_extra_pass_through_features,Q", po::value<unsigned int>()->default_value(0), "Add PassThrough{1..N} features, capped at N.")
("k_best,k",po::value<int>(),"Extract the k best derivations")
("unique_k_best,r", "Unique k-best translation list")
("aligner,a", "Run as a word/phrase aligner (src & ref required)")
diff --git a/decoder/ff_lexical.h b/decoder/ff_lexical.h
new file mode 100644
index 00000000..21c85b27
--- /dev/null
+++ b/decoder/ff_lexical.h
@@ -0,0 +1,128 @@
+#ifndef FF_LEXICAL_H_
+#define FF_LEXICAL_H_
+
+#include <vector>
+#include <map>
+#include "trule.h"
+#include "ff.h"
+#include "hg.h"
+#include "array2d.h"
+#include "wordid.h"
+#include <sstream>
+#include <cassert>
+#include <cmath>
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "sentence_metadata.h"
+#include "lattice.h"
+#include "fdict.h"
+#include "verbose.h"
+#include "tdict.h"
+#include "hg.h"
+
+using namespace std;
+
+namespace {
+ string Escape(const string& x) {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
+ }
+ return y;
+ }
+}
+
+class LexicalFeatures : public FeatureFunction {
+public:
+ LexicalFeatures(const std::string& param) {
+ if (param.empty()) {
+ cerr << "LexicalFeatures: using T,D,I\n";
+ T_ = true; I_ = true; D_ = true;
+ } else {
+ const vector<string> argv = SplitOnWhitespace(param);
+ assert(argv.size() == 3);
+ T_ = (bool) atoi(argv[0].c_str());
+ I_ = (bool) atoi(argv[1].c_str());
+ D_ = (bool) atoi(argv[2].c_str());
+ cerr << "T=" << T_ << " I=" << I_ << " D=" << D_ << endl;
+ }
+ };
+ static std::string usage(bool p,bool d) {
+ return usage_helper("LexicalFeatures","[0/1 0/1 0/1]","Sparse lexical word translation indicator features. If arguments are supplied, specify like this: translations insertions deletions",p,d);
+ }
+protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const HG::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+private:
+ mutable std::map<const TRule*, SparseVector<double> > rule2feats_;
+ bool T_;
+ bool I_;
+ bool D_;
+};
+
+void LexicalFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+ rule2feats_.clear(); // std::map<const TRule*, SparseVector<double> >
+}
+
+void LexicalFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const HG::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+
+ map<const TRule*, SparseVector<double> >::iterator it = rule2feats_.find(edge.rule_.get());
+ if (it == rule2feats_.end()) {
+ const TRule& rule = *edge.rule_;
+ it = rule2feats_.insert(make_pair(&rule, SparseVector<double>())).first;
+ SparseVector<double>& f = it->second;
+ std::vector<bool> sf(edge.rule_->FLength(),false); // stores if source tokens are visited by alignment points
+ std::vector<bool> se(edge.rule_->ELength(),false); // stores if target tokens are visited by alignment points
+ int fid = 0;
+ // translations
+ for (unsigned i=0;i<rule.a_.size();++i) {
+ const AlignmentPoint& ap = rule.a_[i];
+ sf[ap.s_] = true; // mark index as seen
+ se[ap.t_] = true; // mark index as seen
+ ostringstream os;
+ os << "LT:" << Escape(TD::Convert(rule.f_[ap.s_])) << ":" << Escape(TD::Convert(rule.e_[ap.t_]));
+ fid = FD::Convert(os.str());
+ if (fid <= 0) continue;
+ if (T_)
+ f.add_value(fid, 1.0);
+ }
+ // word deletions
+ for (unsigned i=0;i<sf.size();++i) {
+ if (!sf[i] && rule.f_[i] > 0) {// if not visited and is terminal
+ ostringstream os;
+ os << "LD:" << Escape(TD::Convert(rule.f_[i]));
+ fid = FD::Convert(os.str());
+ if (fid <= 0) continue;
+ if (D_)
+ f.add_value(fid, 1.0);
+ }
+ }
+ // word insertions
+ for (unsigned i=0;i<se.size();++i) {
+ if (!se[i] && rule.e_[i] >= 1) {// if not visited and is terminal
+ ostringstream os;
+ os << "LI:" << Escape(TD::Convert(rule.e_[i]));
+ fid = FD::Convert(os.str());
+ if (fid <= 0) continue;
+ if (I_)
+ f.add_value(fid, 1.0);
+ }
+ }
+ }
+ (*features) += it->second;
+}
+
+
+#endif
diff --git a/decoder/ff_rules.cc b/decoder/ff_rules.cc
index 7bccf084..9533caed 100644
--- a/decoder/ff_rules.cc
+++ b/decoder/ff_rules.cc
@@ -69,28 +69,6 @@ void RuleIdentityFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
features->add_value(it->second, 1);
}
-RuleWordAlignmentFeatures::RuleWordAlignmentFeatures(const std::string& param) {
-}
-
-void RuleWordAlignmentFeatures::PrepareForInput(const SentenceMetadata& smeta) {
-}
-
-void RuleWordAlignmentFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const {
- const TRule& rule = *edge.rule_;
- ostringstream os;
- vector<AlignmentPoint> als = rule.als();
- std::vector<AlignmentPoint>::const_iterator xx = als.begin();
- for (; xx != als.end(); ++xx) {
- os << "WA:" << TD::Convert(rule.f_[xx->s_]) << ":" << TD::Convert(rule.e_[xx->t_]);
- }
- features->add_value(FD::Convert(Escape(os.str())), 1);
-}
-
RuleSourceBigramFeatures::RuleSourceBigramFeatures(const std::string& param) {
}
diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h
index 324d7a39..f210dc65 100644
--- a/decoder/ff_rules.h
+++ b/decoder/ff_rules.h
@@ -24,19 +24,6 @@ class RuleIdentityFeatures : public FeatureFunction {
mutable std::map<const TRule*, int> rule2_fid_;
};
-class RuleWordAlignmentFeatures : public FeatureFunction {
- public:
- RuleWordAlignmentFeatures(const std::string& param);
- protected:
- virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const HG::Edge& edge,
- const std::vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const;
- virtual void PrepareForInput(const SentenceMetadata& smeta);
-};
-
class RuleSourceBigramFeatures : public FeatureFunction {
public:
RuleSourceBigramFeatures(const std::string& param);
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index 88f62769..c3cfcaad 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -28,7 +28,7 @@ struct GlueGrammar : public TextGrammar {
};
struct PassThroughGrammar : public TextGrammar {
- PassThroughGrammar(const Lattice& input, const std::string& cat, const unsigned int ctf_level=0);
+ PassThroughGrammar(const Lattice& input, const std::string& cat, const unsigned int ctf_level=0, const unsigned int num_pt_features=0);
virtual bool HasRuleForSpan(int i, int j, int distance) const;
};
@@ -56,7 +56,7 @@ bool GlueGrammar::HasRuleForSpan(int i, int /* j */, int /* distance */) const {
return (i == 0);
}
-PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level) {
+PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const unsigned int ctf_level, const unsigned num_pt_features) {
unordered_set<WordID> ss;
for (int i = 0; i < input.size(); ++i) {
const vector<LatticeArc>& alts = input[i];
@@ -64,14 +64,21 @@ PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat,
const int j = alts[k].dist2next + i;
const string& src = TD::Convert(alts[k].label);
if (ss.count(alts[k].label) == 0) {
- int length = static_cast<int>(log(UTF8StringLen(src)) / log(1.6)) + 1;
- if (length > 6) length = 6;
- string len_feat = "PassThrough_0=1";
- len_feat[12] += length;
- TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1 " + len_feat));
- pt->a_.push_back(AlignmentPoint(0,0));
- AddRule(pt);
- RefineRule(pt, ctf_level);
+ if (num_pt_features > 0) {
+ int length = static_cast<int>(log(UTF8StringLen(src)) / log(1.6)) + 1;
+ if (length > num_pt_features) length = num_pt_features;
+ string len_feat = "PassThrough_0=1";
+ len_feat[12] += length;
+ TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1 " + len_feat));
+ pt->a_.push_back(AlignmentPoint(0,0));
+ AddRule(pt);
+ RefineRule(pt, ctf_level);
+ } else {
+ TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1 "));
+ pt->a_.push_back(AlignmentPoint(0,0));
+ AddRule(pt);
+ RefineRule(pt, ctf_level);
+ }
ss.insert(alts[k].label);
}
}
@@ -86,6 +93,7 @@ struct SCFGTranslatorImpl {
SCFGTranslatorImpl(const boost::program_options::variables_map& conf) :
max_span_limit(conf["scfg_max_span_limit"].as<int>()),
add_pass_through_rules(conf.count("add_pass_through_rules")),
+ num_pt_features(conf["add_extra_pass_through_features"].as<unsigned int>()),
goal(conf["goal"].as<string>()),
default_nt(conf["scfg_default_nt"].as<string>()),
use_ctf_(conf.count("coarse_to_fine_beam_prune"))
@@ -140,6 +148,7 @@ struct SCFGTranslatorImpl {
const int max_span_limit;
const bool add_pass_through_rules;
+ const unsigned int num_pt_features;
const string goal;
const string default_nt;
const bool use_ctf_;
@@ -187,7 +196,7 @@ struct SCFGTranslatorImpl {
smeta->SetSourceLength(lattice.size());
if (add_pass_through_rules){
if (!SILENT) cerr << "Adding pass through grammar" << endl;
- PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_);
+ PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_, num_pt_features);
g->SetGrammarName("PassThrough");
glist.push_back(GrammarPtr(g));
}