summaryrefslogtreecommitdiff
path: root/decoder/ff_ruleshape.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/ff_ruleshape.cc')
-rw-r--r--decoder/ff_ruleshape.cc138
1 files changed, 138 insertions, 0 deletions
diff --git a/decoder/ff_ruleshape.cc b/decoder/ff_ruleshape.cc
index 7bb548c4..35b41c46 100644
--- a/decoder/ff_ruleshape.cc
+++ b/decoder/ff_ruleshape.cc
@@ -1,5 +1,8 @@
#include "ff_ruleshape.h"
+#include "filelib.h"
+#include "stringlib.h"
+#include "verbose.h"
#include "trule.h"
#include "hg.h"
#include "fdict.h"
@@ -104,3 +107,138 @@ void RuleShapeFeatures::TraversalFeaturesImpl(const SentenceMetadata& /* smeta *
features->set_value(cur->fid_, 1.0);
}
+namespace {
+void ParseRSArgs(string const& in, string* emapfile, string* fmapfile, unsigned *pfxsize) {
+ vector<string> const& argv=SplitOnWhitespace(in);
+ *emapfile = "";
+ *fmapfile = "";
+ *pfxsize = 0;
+#define RSSPEC_NEXTARG if (i==argv.end()) { \
+ cerr << "Missing argument for "<<*last<<". "; goto usage; \
+ } else { ++i; }
+
+ for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) {
+ string const& s=*i;
+ if (s[0]=='-') {
+ if (s.size()>2) goto fail;
+ switch (s[1]) {
+ case 'e':
+ if (emapfile->size() > 0) { cerr << "Multiple -e specifications!\n"; abort(); }
+ RSSPEC_NEXTARG; *emapfile=*i;
+ break;
+ case 'f':
+ if (fmapfile->size() > 0) { cerr << "Multiple -f specifications!\n"; abort(); }
+ RSSPEC_NEXTARG; *fmapfile=*i;
+ break;
+ case 'p':
+ RSSPEC_NEXTARG; *pfxsize=atoi(i->c_str());
+ break;
+#undef RSSPEC_NEXTARG
+ default:
+ fail:
+ cerr<<"Unknown RuleShape2 option "<<s<<" ; ";
+ goto usage;
+ }
+ } else {
+ cerr << "RuleShape2 bad argument!\n";
+ abort();
+ }
+ }
+ return;
+usage:
+ cerr << "Bad parameters for RuleShape2\n";
+ abort();
+}
+
+inline void AddWordToClassMapping_(vector<WordID>* pv, unsigned f, unsigned t, unsigned pfx_size) {
+ if (pfx_size) {
+ const string& ts = TD::Convert(t);
+ if (pfx_size < ts.size())
+ t = TD::Convert(ts.substr(0, pfx_size));
+ }
+ if (f >= pv->size())
+ pv->resize((f + 1) * 1.2);
+ (*pv)[f] = t;
+}
+}
+
+RuleShapeFeatures2::~RuleShapeFeatures2() {}
+
+RuleShapeFeatures2::RuleShapeFeatures2(const string& param) : kNT(TD::Convert("NT")), kUNK(TD::Convert("<unk>")) {
+ string emap;
+ string fmap;
+ unsigned pfxsize = 0;
+ ParseRSArgs(param, &emap, &fmap, &pfxsize);
+ has_src_ = fmap.size();
+ has_trg_ = emap.size();
+ if (has_trg_) LoadWordClasses(emap, pfxsize, &e2class_);
+ if (has_src_) LoadWordClasses(fmap, pfxsize, &f2class_);
+ if (!has_trg_ && !has_src_) {
+ cerr << "RuleShapeFeatures2 requires [-e trg_map.gz] or [-f src_map.gz] or both, and optional [-p pfxsize]\n";
+ abort();
+ }
+}
+
+void RuleShapeFeatures2::LoadWordClasses(const string& file, const unsigned pfx_size, vector<WordID>* pv) {
+ ReadFile rf(file);
+ istream& in = *rf.stream();
+ string line;
+ vector<WordID> dummy;
+ int lc = 0;
+ if (!SILENT)
+ cerr << " Loading word classes from " << file << " ...\n";
+ AddWordToClassMapping_(pv, TD::Convert("<s>"), TD::Convert("<s>"), 0);
+ AddWordToClassMapping_(pv, TD::Convert("</s>"), TD::Convert("</s>"), 0);
+ while(getline(in, line)) {
+ dummy.clear();
+ TD::ConvertSentence(line, &dummy);
+ ++lc;
+ if (dummy.size() != 2 && dummy.size() != 3) {
+ cerr << " Class map file expects: CLASS WORD [freq]\n";
+ cerr << " Format error in " << file << ", line " << lc << ": " << line << endl;
+ abort();
+ }
+ AddWordToClassMapping_(pv, dummy[1], dummy[0], pfx_size);
+ }
+ if (!SILENT)
+ cerr << " Loaded word " << lc << " mapping rules.\n";
+}
+
+void RuleShapeFeatures2::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& /* ant_contexts */,
+ SparseVector<double>* features,
+ SparseVector<double>* /* estimated_features */,
+ void* /* context */) const {
+ const vector<int>& f = edge.rule_->f();
+ const vector<int>& e = edge.rule_->e();
+ Node* fid = &fidtree_;
+ if (has_src_) {
+ for (unsigned i = 0; i < f.size(); ++i)
+ fid = &fid->next_[MapF(f[i])];
+ }
+ if (has_trg_) {
+ for (unsigned i = 0; i < e.size(); ++i)
+ fid = &fid->next_[MapE(e[i])];
+ }
+ if (!fid->fid_) {
+ ostringstream os;
+ os << "RS:";
+ if (has_src_) {
+ for (unsigned i = 0; i < f.size(); ++i) {
+ if (i) os << '_';
+ os << TD::Convert(MapF(f[i]));
+ }
+ if (has_trg_) os << "__";
+ }
+ if (has_trg_) {
+ for (unsigned i = 0; i < e.size(); ++i) {
+ if (i) os << '_';
+ os << TD::Convert(MapE(e[i]));
+ }
+ }
+ fid->fid_ = FD::Convert(os.str());
+ }
+ features->set_value(fid->fid_, 1);
+}
+