summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/cdec_ff.cc1
-rw-r--r--decoder/decoder.cc12
-rw-r--r--decoder/ff_ngrams.cc20
-rw-r--r--decoder/ff_ruleshape.cc138
-rw-r--r--decoder/ff_ruleshape.h46
5 files changed, 205 insertions, 12 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 8689a615..7f7e075b 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -59,6 +59,7 @@ void register_feature_functions() {
ff_registry.Register("KLanguageModel", new KLanguageModelFactory());
ff_registry.Register("NonLatinCount", new FFFactory<NonLatinCount>);
ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>);
+ ff_registry.Register("RuleShape2", new FFFactory<RuleShapeFeatures2>);
ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>);
ff_registry.Register("LexNullJump", new FFFactory<LexNullJump>);
ff_registry.Register("NewJump", new FFFactory<NewJump>);
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index e02c7730..31049216 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -408,7 +408,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
("pb_max_distortion,D", po::value<int>()->default_value(4), "Phrase-based decoder: maximum distortion")
("cll_gradient,G","Compute conditional log-likelihood gradient and write to STDOUT (src & ref required)")
- ("get_oracle_forest,o", "Calculate rescored hypregraph using approximate BLEU scoring of rules")
+ ("get_oracle_forest,o", "Calculate rescored hypergraph using approximate BLEU scoring of rules")
("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
@@ -662,11 +662,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
oracle.show_derivation=conf.count("show_derivations");
remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");
- if (conf.count("extract_rules")) {
- stringstream ss;
- ss << sent_id;
- extract_file.reset(new WriteFile(str("extract_rules",conf)+"/"+ss.str()));
- }
combine_size = conf["combine_size"].as<int>();
if (combine_size < 1) combine_size = 1;
sent_id = -1;
@@ -720,6 +715,11 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
}
cerr << " id = " << sent_id << endl;
}
+ if (conf.count("extract_rules")) {
+ stringstream ss;
+ ss << sent_id << ".gz";
+ extract_file.reset(new WriteFile(str("extract_rules",conf)+"/"+ss.str()));
+ }
string to_translate;
Lattice ref;
ParseTranslatorInputLattice(buf, &to_translate, &ref);
diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc
index d337b28b..0a97cba5 100644
--- a/decoder/ff_ngrams.cc
+++ b/decoder/ff_ngrams.cc
@@ -36,7 +36,7 @@ struct State {
}
explicit State(const State<MAX_ORDER>& other, unsigned order, WordID extend) {
char om1 = order - 1;
- assert(om1 > 0);
+ if (!om1) { memset(state, 0, sizeof(state)); return; }
for (char i = 1; i < om1; ++i) state[i - 1]= other.state[i];
state[om1 - 1] = extend;
}
@@ -60,8 +60,9 @@ namespace {
}
}
-static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector<string>& prefixes, string& target_separator, string* cluster_file) {
+static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector<string>& prefixes, string& target_separator, string* cluster_file, string* featname) {
vector<string> const& argv=SplitOnWhitespace(in);
+ *featname = "";
*explicit_markers = false;
*order = 3;
prefixes.push_back("NOT-USED");
@@ -83,6 +84,9 @@ static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order,
case 'x':
*explicit_markers = true;
break;
+ case 'n':
+ LMSPEC_NEXTARG; *featname=*i;
+ break;
case 'U':
LMSPEC_NEXTARG;
prefixes[1] = *i;
@@ -148,7 +152,7 @@ usage:
<< "Example feature instantiation: \n"
<< " tri:a|b|c \n\n";
- return false;
+ abort();
}
class NgramDetectorImpl {
@@ -226,6 +230,7 @@ class NgramDetectorImpl {
++n;
if (!fid) {
ostringstream os;
+ os << featname_;
os << prefixes_[n];
for (int i = n-1; i >= 0; --i) {
os << (i != n-1 ? target_separator_ : "");
@@ -404,7 +409,8 @@ class NgramDetectorImpl {
public:
explicit NgramDetectorImpl(bool explicit_markers, unsigned order,
- vector<string>& prefixes, string& target_separator, const string& clusters) :
+ vector<string>& prefixes, string& target_separator, const string& clusters,
+ const string& featname) :
kCDEC_UNK(TD::Convert("<unk>")) ,
add_sos_eos_(!explicit_markers) {
order_ = order;
@@ -414,6 +420,7 @@ class NgramDetectorImpl {
unscored_words_offset_ = is_complete_offset_ + 1;
prefixes_ = prefixes;
target_separator_ = target_separator;
+ featname_ = featname;
// special handling of beginning / ending sentence markers
dummy_state_ = new char[state_size_];
@@ -454,6 +461,7 @@ class NgramDetectorImpl {
TRulePtr dummy_rule_;
vector<string> prefixes_;
string target_separator_;
+ string featname_;
struct FidTree {
map<WordID, int> fids;
map<WordID, FidTree> levels;
@@ -467,9 +475,9 @@ NgramDetector::NgramDetector(const string& param) {
bool explicit_markers = false;
unsigned order = 3;
string clusters;
- ParseArgs(param, &explicit_markers, &order, prefixes, target_separator, &clusters);
+ ParseArgs(param, &explicit_markers, &order, prefixes, target_separator, &clusters, &featname);
pimpl_ = new NgramDetectorImpl(explicit_markers, order, prefixes,
- target_separator, clusters);
+ target_separator, clusters, featname);
SetStateSize(pimpl_->ReserveStateSize());
}
diff --git a/decoder/ff_ruleshape.cc b/decoder/ff_ruleshape.cc
index 7bb548c4..35b41c46 100644
--- a/decoder/ff_ruleshape.cc
+++ b/decoder/ff_ruleshape.cc
@@ -1,5 +1,8 @@
#include "ff_ruleshape.h"
+#include "filelib.h"
+#include "stringlib.h"
+#include "verbose.h"
#include "trule.h"
#include "hg.h"
#include "fdict.h"
@@ -104,3 +107,138 @@ void RuleShapeFeatures::TraversalFeaturesImpl(const SentenceMetadata& /* smeta *
features->set_value(cur->fid_, 1.0);
}
+namespace {
+void ParseRSArgs(string const& in, string* emapfile, string* fmapfile, unsigned *pfxsize) {
+ vector<string> const& argv=SplitOnWhitespace(in);
+ *emapfile = "";
+ *fmapfile = "";
+ *pfxsize = 0;
+#define RSSPEC_NEXTARG if (i==argv.end()) { \
+ cerr << "Missing argument for "<<*last<<". "; goto usage; \
+ } else { ++i; }
+
+ for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) {
+ string const& s=*i;
+ if (s[0]=='-') {
+ if (s.size()>2) goto fail;
+ switch (s[1]) {
+ case 'e':
+ if (emapfile->size() > 0) { cerr << "Multiple -e specifications!\n"; abort(); }
+ RSSPEC_NEXTARG; *emapfile=*i;
+ break;
+ case 'f':
+ if (fmapfile->size() > 0) { cerr << "Multiple -f specifications!\n"; abort(); }
+ RSSPEC_NEXTARG; *fmapfile=*i;
+ break;
+ case 'p':
+ RSSPEC_NEXTARG; *pfxsize=atoi(i->c_str());
+ break;
+#undef RSSPEC_NEXTARG
+ default:
+ fail:
+ cerr<<"Unknown RuleShape2 option "<<s<<" ; ";
+ goto usage;
+ }
+ } else {
+ cerr << "RuleShape2 bad argument!\n";
+ abort();
+ }
+ }
+ return;
+usage:
+ cerr << "Bad parameters for RuleShape2\n";
+ abort();
+}
+
+inline void AddWordToClassMapping_(vector<WordID>* pv, unsigned f, unsigned t, unsigned pfx_size) {
+ if (pfx_size) {
+ const string& ts = TD::Convert(t);
+ if (pfx_size < ts.size())
+ t = TD::Convert(ts.substr(0, pfx_size));
+ }
+ if (f >= pv->size())
+ pv->resize((f + 1) * 1.2);
+ (*pv)[f] = t;
+}
+}
+
+RuleShapeFeatures2::~RuleShapeFeatures2() {}
+
+RuleShapeFeatures2::RuleShapeFeatures2(const string& param) : kNT(TD::Convert("NT")), kUNK(TD::Convert("<unk>")) {
+ string emap;
+ string fmap;
+ unsigned pfxsize = 0;
+ ParseRSArgs(param, &emap, &fmap, &pfxsize);
+ has_src_ = fmap.size();
+ has_trg_ = emap.size();
+ if (has_trg_) LoadWordClasses(emap, pfxsize, &e2class_);
+ if (has_src_) LoadWordClasses(fmap, pfxsize, &f2class_);
+ if (!has_trg_ && !has_src_) {
+ cerr << "RuleShapeFeatures2 requires [-e trg_map.gz] or [-f src_map.gz] or both, and optional [-p pfxsize]\n";
+ abort();
+ }
+}
+
+void RuleShapeFeatures2::LoadWordClasses(const string& file, const unsigned pfx_size, vector<WordID>* pv) {
+ ReadFile rf(file);
+ istream& in = *rf.stream();
+ string line;
+ vector<WordID> dummy;
+ int lc = 0;
+ if (!SILENT)
+ cerr << " Loading word classes from " << file << " ...\n";
+ AddWordToClassMapping_(pv, TD::Convert("<s>"), TD::Convert("<s>"), 0);
+ AddWordToClassMapping_(pv, TD::Convert("</s>"), TD::Convert("</s>"), 0);
+ while(getline(in, line)) {
+ dummy.clear();
+ TD::ConvertSentence(line, &dummy);
+ ++lc;
+ if (dummy.size() != 2 && dummy.size() != 3) {
+ cerr << " Class map file expects: CLASS WORD [freq]\n";
+ cerr << " Format error in " << file << ", line " << lc << ": " << line << endl;
+ abort();
+ }
+ AddWordToClassMapping_(pv, dummy[1], dummy[0], pfx_size);
+ }
+ if (!SILENT)
+ cerr << " Loaded word " << lc << " mapping rules.\n";
+}
+
+void RuleShapeFeatures2::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& /* ant_contexts */,
+ SparseVector<double>* features,
+ SparseVector<double>* /* estimated_features */,
+ void* /* context */) const {
+ const vector<int>& f = edge.rule_->f();
+ const vector<int>& e = edge.rule_->e();
+ Node* fid = &fidtree_;
+ if (has_src_) {
+ for (unsigned i = 0; i < f.size(); ++i)
+ fid = &fid->next_[MapF(f[i])];
+ }
+ if (has_trg_) {
+ for (unsigned i = 0; i < e.size(); ++i)
+ fid = &fid->next_[MapE(e[i])];
+ }
+ if (!fid->fid_) {
+ ostringstream os;
+ os << "RS:";
+ if (has_src_) {
+ for (unsigned i = 0; i < f.size(); ++i) {
+ if (i) os << '_';
+ os << TD::Convert(MapF(f[i]));
+ }
+ if (has_trg_) os << "__";
+ }
+ if (has_trg_) {
+ for (unsigned i = 0; i < e.size(); ++i) {
+ if (i) os << '_';
+ os << TD::Convert(MapE(e[i]));
+ }
+ }
+ fid->fid_ = FD::Convert(os.str());
+ }
+ features->set_value(fid->fid_, 1);
+}
+
diff --git a/decoder/ff_ruleshape.h b/decoder/ff_ruleshape.h
index 9f20faf3..488cfd84 100644
--- a/decoder/ff_ruleshape.h
+++ b/decoder/ff_ruleshape.h
@@ -2,6 +2,7 @@
#define _FF_RULESHAPE_H_
#include <vector>
+#include <map>
#include "ff.h"
class RuleShapeFeatures : public FeatureFunction {
@@ -28,4 +29,49 @@ class RuleShapeFeatures : public FeatureFunction {
}
};
+class RuleShapeFeatures2 : public FeatureFunction {
+ public:
+ ~RuleShapeFeatures2();
+ RuleShapeFeatures2(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const HG::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ private:
+ struct Node {
+ int fid_;
+ Node() : fid_() {}
+ std::map<WordID, Node> next_;
+ };
+ mutable Node fidtree_;
+
+ inline WordID MapE(WordID w) const {
+ if (w <= 0) return kNT;
+ unsigned res = 0;
+ if (w < e2class_.size()) res = e2class_[w];
+ if (!res) res = kUNK;
+ return res;
+ }
+
+ inline WordID MapF(WordID w) const {
+ if (w <= 0) return kNT;
+ unsigned res = 0;
+ if (w < f2class_.size()) res = f2class_[w];
+ if (!res) res = kUNK;
+ return res;
+ }
+
+ // prfx_size=0 => use full word classes otherwise truncate to specified length
+ void LoadWordClasses(const std::string& fname, unsigned pfxsize, std::vector<WordID>* pv);
+ const WordID kNT;
+ const WordID kUNK;
+ std::vector<WordID> e2class_;
+ std::vector<WordID> f2class_;
+ bool has_src_;
+ bool has_trg_;
+};
+
#endif