summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2014-02-16 01:11:59 -0500
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2014-02-16 01:11:59 -0500
commit039fc3fa3a60137fc9f61c3e1505c9bef89fe4da (patch)
tree7689b2b68103e6e83b86e85eac4d7fb7432b352c /decoder
parent32523a40609d89268669925d5678f185a7729caa (diff)
new rule shape features
Diffstat (limited to 'decoder')
-rw-r--r--decoder/cdec_ff.cc1
-rw-r--r--decoder/ff_ruleshape.cc138
-rw-r--r--decoder/ff_ruleshape.h46
3 files changed, 185 insertions, 0 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index b2541722..0411908f 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -58,6 +58,7 @@ void register_feature_functions() {
ff_registry.Register("KLanguageModel", new KLanguageModelFactory());
ff_registry.Register("NonLatinCount", new FFFactory<NonLatinCount>);
ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>);
+ ff_registry.Register("RuleShape2", new FFFactory<RuleShapeFeatures2>);
ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>);
ff_registry.Register("LexNullJump", new FFFactory<LexNullJump>);
ff_registry.Register("NewJump", new FFFactory<NewJump>);
diff --git a/decoder/ff_ruleshape.cc b/decoder/ff_ruleshape.cc
index 7bb548c4..35b41c46 100644
--- a/decoder/ff_ruleshape.cc
+++ b/decoder/ff_ruleshape.cc
@@ -1,5 +1,8 @@
#include "ff_ruleshape.h"
+#include "filelib.h"
+#include "stringlib.h"
+#include "verbose.h"
#include "trule.h"
#include "hg.h"
#include "fdict.h"
@@ -104,3 +107,138 @@ void RuleShapeFeatures::TraversalFeaturesImpl(const SentenceMetadata& /* smeta *
features->set_value(cur->fid_, 1.0);
}
+namespace {
+void ParseRSArgs(string const& in, string* emapfile, string* fmapfile, unsigned *pfxsize) {
+ vector<string> const& argv=SplitOnWhitespace(in);
+ *emapfile = "";
+ *fmapfile = "";
+ *pfxsize = 0;
+#define RSSPEC_NEXTARG if (i==argv.end()) { \
+ cerr << "Missing argument for "<<*last<<". "; goto usage; \
+ } else { ++i; }
+
+ for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) {
+ string const& s=*i;
+ if (s[0]=='-') {
+ if (s.size()>2) goto fail;
+ switch (s[1]) {
+ case 'e':
+ if (emapfile->size() > 0) { cerr << "Multiple -e specifications!\n"; abort(); }
+ RSSPEC_NEXTARG; *emapfile=*i;
+ break;
+ case 'f':
+ if (fmapfile->size() > 0) { cerr << "Multiple -f specifications!\n"; abort(); }
+ RSSPEC_NEXTARG; *fmapfile=*i;
+ break;
+ case 'p':
+ RSSPEC_NEXTARG; *pfxsize=atoi(i->c_str());
+ break;
+#undef RSSPEC_NEXTARG
+ default:
+ fail:
+ cerr<<"Unknown RuleShape2 option "<<s<<" ; ";
+ goto usage;
+ }
+ } else {
+ cerr << "RuleShape2 bad argument!\n";
+ abort();
+ }
+ }
+ return;
+usage:
+ cerr << "Bad parameters for RuleShape2\n";
+ abort();
+}
+
+inline void AddWordToClassMapping_(vector<WordID>* pv, unsigned f, unsigned t, unsigned pfx_size) {
+ if (pfx_size) {
+ const string& ts = TD::Convert(t);
+ if (pfx_size < ts.size())
+ t = TD::Convert(ts.substr(0, pfx_size));
+ }
+ if (f >= pv->size())
+ pv->resize((f + 1) * 1.2);
+ (*pv)[f] = t;
+}
+}
+
+RuleShapeFeatures2::~RuleShapeFeatures2() {}
+
+RuleShapeFeatures2::RuleShapeFeatures2(const string& param) : kNT(TD::Convert("NT")), kUNK(TD::Convert("<unk>")) {
+ string emap;
+ string fmap;
+ unsigned pfxsize = 0;
+ ParseRSArgs(param, &emap, &fmap, &pfxsize);
+ has_src_ = fmap.size();
+ has_trg_ = emap.size();
+ if (has_trg_) LoadWordClasses(emap, pfxsize, &e2class_);
+ if (has_src_) LoadWordClasses(fmap, pfxsize, &f2class_);
+ if (!has_trg_ && !has_src_) {
+ cerr << "RuleShapeFeatures2 requires [-e trg_map.gz] or [-f src_map.gz] or both, and optional [-p pfxsize]\n";
+ abort();
+ }
+}
+
+void RuleShapeFeatures2::LoadWordClasses(const string& file, const unsigned pfx_size, vector<WordID>* pv) {
+ ReadFile rf(file);
+ istream& in = *rf.stream();
+ string line;
+ vector<WordID> dummy;
+ int lc = 0;
+ if (!SILENT)
+ cerr << " Loading word classes from " << file << " ...\n";
+ AddWordToClassMapping_(pv, TD::Convert("<s>"), TD::Convert("<s>"), 0);
+ AddWordToClassMapping_(pv, TD::Convert("</s>"), TD::Convert("</s>"), 0);
+ while(getline(in, line)) {
+ dummy.clear();
+ TD::ConvertSentence(line, &dummy);
+ ++lc;
+ if (dummy.size() != 2 && dummy.size() != 3) {
+ cerr << " Class map file expects: CLASS WORD [freq]\n";
+ cerr << " Format error in " << file << ", line " << lc << ": " << line << endl;
+ abort();
+ }
+ AddWordToClassMapping_(pv, dummy[1], dummy[0], pfx_size);
+ }
+ if (!SILENT)
+ cerr << " Loaded word " << lc << " mapping rules.\n";
+}
+
+void RuleShapeFeatures2::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& /* ant_contexts */,
+ SparseVector<double>* features,
+ SparseVector<double>* /* estimated_features */,
+ void* /* context */) const {
+ const vector<int>& f = edge.rule_->f();
+ const vector<int>& e = edge.rule_->e();
+ Node* fid = &fidtree_;
+ if (has_src_) {
+ for (unsigned i = 0; i < f.size(); ++i)
+ fid = &fid->next_[MapF(f[i])];
+ }
+ if (has_trg_) {
+ for (unsigned i = 0; i < e.size(); ++i)
+ fid = &fid->next_[MapE(e[i])];
+ }
+ if (!fid->fid_) {
+ ostringstream os;
+ os << "RS:";
+ if (has_src_) {
+ for (unsigned i = 0; i < f.size(); ++i) {
+ if (i) os << '_';
+ os << TD::Convert(MapF(f[i]));
+ }
+ if (has_trg_) os << "__";
+ }
+ if (has_trg_) {
+ for (unsigned i = 0; i < e.size(); ++i) {
+ if (i) os << '_';
+ os << TD::Convert(MapE(e[i]));
+ }
+ }
+ fid->fid_ = FD::Convert(os.str());
+ }
+ features->set_value(fid->fid_, 1);
+}
+
diff --git a/decoder/ff_ruleshape.h b/decoder/ff_ruleshape.h
index 9f20faf3..488cfd84 100644
--- a/decoder/ff_ruleshape.h
+++ b/decoder/ff_ruleshape.h
@@ -2,6 +2,7 @@
#define _FF_RULESHAPE_H_
#include <vector>
+#include <map>
#include "ff.h"
class RuleShapeFeatures : public FeatureFunction {
@@ -28,4 +29,49 @@ class RuleShapeFeatures : public FeatureFunction {
}
};
+class RuleShapeFeatures2 : public FeatureFunction {
+ public:
+ ~RuleShapeFeatures2();
+ RuleShapeFeatures2(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const HG::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ private:
+ struct Node {
+ int fid_;
+ Node() : fid_() {}
+ std::map<WordID, Node> next_;
+ };
+ mutable Node fidtree_;
+
+ inline WordID MapE(WordID w) const {
+ if (w <= 0) return kNT;
+ unsigned res = 0;
+ if (w < e2class_.size()) res = e2class_[w];
+ if (!res) res = kUNK;
+ return res;
+ }
+
+ inline WordID MapF(WordID w) const {
+ if (w <= 0) return kNT;
+ unsigned res = 0;
+ if (w < f2class_.size()) res = f2class_[w];
+ if (!res) res = kUNK;
+ return res;
+ }
+
+ // prfx_size=0 => use full word classes otherwise truncate to specified length
+ void LoadWordClasses(const std::string& fname, unsigned pfxsize, std::vector<WordID>* pv);
+ const WordID kNT;
+ const WordID kUNK;
+ std::vector<WordID> e2class_;
+ std::vector<WordID> f2class_;
+ bool has_src_;
+ bool has_trg_;
+};
+
#endif