From 039fc3fa3a60137fc9f61c3e1505c9bef89fe4da Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 16 Feb 2014 01:11:59 -0500 Subject: new rule shape features --- decoder/ff_ruleshape.cc | 138 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) (limited to 'decoder/ff_ruleshape.cc') diff --git a/decoder/ff_ruleshape.cc b/decoder/ff_ruleshape.cc index 7bb548c4..35b41c46 100644 --- a/decoder/ff_ruleshape.cc +++ b/decoder/ff_ruleshape.cc @@ -1,5 +1,8 @@ #include "ff_ruleshape.h" +#include "filelib.h" +#include "stringlib.h" +#include "verbose.h" #include "trule.h" #include "hg.h" #include "fdict.h" @@ -104,3 +107,138 @@ void RuleShapeFeatures::TraversalFeaturesImpl(const SentenceMetadata& /* smeta * features->set_value(cur->fid_, 1.0); } +namespace { +void ParseRSArgs(string const& in, string* emapfile, string* fmapfile, unsigned *pfxsize) { + vector const& argv=SplitOnWhitespace(in); + *emapfile = ""; + *fmapfile = ""; + *pfxsize = 0; +#define RSSPEC_NEXTARG if (i==argv.end()) { \ + cerr << "Missing argument for "<<*last<<". "; goto usage; \ + } else { ++i; } + + for (vector::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { + string const& s=*i; + if (s[0]=='-') { + if (s.size()>2) goto fail; + switch (s[1]) { + case 'e': + if (emapfile->size() > 0) { cerr << "Multiple -e specifications!\n"; abort(); } + RSSPEC_NEXTARG; *emapfile=*i; + break; + case 'f': + if (fmapfile->size() > 0) { cerr << "Multiple -f specifications!\n"; abort(); } + RSSPEC_NEXTARG; *fmapfile=*i; + break; + case 'p': + RSSPEC_NEXTARG; *pfxsize=atoi(i->c_str()); + break; +#undef RSSPEC_NEXTARG + default: + fail: + cerr<<"Unknown RuleShape2 option "<* pv, unsigned f, unsigned t, unsigned pfx_size) { + if (pfx_size) { + const string& ts = TD::Convert(t); + if (pfx_size < ts.size()) + t = TD::Convert(ts.substr(0, pfx_size)); + } + if (f >= pv->size()) + pv->resize((f + 1) * 1.2); + (*pv)[f] = t; +} +} + +RuleShapeFeatures2::~RuleShapeFeatures2() {} + +RuleShapeFeatures2::RuleShapeFeatures2(const string& param) : kNT(TD::Convert("NT")), kUNK(TD::Convert("")) { + string emap; + string fmap; + unsigned pfxsize = 0; + ParseRSArgs(param, &emap, &fmap, &pfxsize); + has_src_ = fmap.size(); + has_trg_ = emap.size(); + if (has_trg_) LoadWordClasses(emap, pfxsize, &e2class_); + if (has_src_) LoadWordClasses(fmap, pfxsize, &f2class_); + if (!has_trg_ && !has_src_) { + cerr << "RuleShapeFeatures2 requires [-e trg_map.gz] or [-f src_map.gz] or both, and optional [-p pfxsize]\n"; + abort(); + } +} + +void RuleShapeFeatures2::LoadWordClasses(const string& file, const unsigned pfx_size, vector* pv) { + ReadFile rf(file); + istream& in = *rf.stream(); + string line; + vector dummy; + int lc = 0; + if (!SILENT) + cerr << " Loading word classes from " << file << " ...\n"; + AddWordToClassMapping_(pv, TD::Convert(""), TD::Convert(""), 0); + AddWordToClassMapping_(pv, TD::Convert(""), TD::Convert(""), 0); + while(getline(in, line)) { + dummy.clear(); + TD::ConvertSentence(line, &dummy); + ++lc; + if (dummy.size() != 2 && dummy.size() != 3) { + cerr << " Class map file expects: CLASS WORD [freq]\n"; + cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; + abort(); + } + AddWordToClassMapping_(pv, dummy[1], dummy[0], pfx_size); + } + if (!SILENT) + cerr << " Loaded word " << lc << " mapping rules.\n"; +} + +void RuleShapeFeatures2::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, + const Hypergraph::Edge& edge, + const vector& /* ant_contexts */, + SparseVector* features, + SparseVector* /* estimated_features */, + void* /* context */) const { + const vector& f = edge.rule_->f(); + const vector& e = edge.rule_->e(); + Node* fid = &fidtree_; + if (has_src_) { + for (unsigned i = 0; i < f.size(); ++i) + fid = &fid->next_[MapF(f[i])]; + } + if (has_trg_) { + for (unsigned i = 0; i < e.size(); ++i) + fid = &fid->next_[MapE(e[i])]; + } + if (!fid->fid_) { + ostringstream os; + os << "RS:"; + if (has_src_) { + for (unsigned i = 0; i < f.size(); ++i) { + if (i) os << '_'; + os << TD::Convert(MapF(f[i])); + } + if (has_trg_) os << "__"; + } + if (has_trg_) { + for (unsigned i = 0; i < e.size(); ++i) { + if (i) os << '_'; + os << TD::Convert(MapE(e[i])); + } + } + fid->fid_ = FD::Convert(os.str()); + } + features->set_value(fid->fid_, 1); +} + -- cgit v1.2.3