diff options
author | Patrick Simianer <p@simianer.de> | 2014-02-28 14:09:55 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2014-02-28 14:09:55 +0100 |
commit | 739a8cd9a92ee10411e352e1677235a8c39ba8b3 (patch) | |
tree | 6ea4e6994d0a8c9fb9750f25f2694f5388a4dfe8 /decoder | |
parent | ab71c44e61d00c788e84b44156d0be16191e267d (diff) | |
parent | 5675965782e2c9201a7a2fe54b542f5b06d660ef (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/cdec_ff.cc | 1 | ||||
-rw-r--r-- | decoder/decoder.cc | 12 | ||||
-rw-r--r-- | decoder/ff_ngrams.cc | 20 | ||||
-rw-r--r-- | decoder/ff_ruleshape.cc | 138 | ||||
-rw-r--r-- | decoder/ff_ruleshape.h | 46 |
5 files changed, 205 insertions, 12 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 8689a615..7f7e075b 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -59,6 +59,7 @@ void register_feature_functions() { ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); ff_registry.Register("NonLatinCount", new FFFactory<NonLatinCount>); ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>); + ff_registry.Register("RuleShape2", new FFFactory<RuleShapeFeatures2>); ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>); ff_registry.Register("LexNullJump", new FFFactory<LexNullJump>); ff_registry.Register("NewJump", new FFFactory<NewJump>); diff --git a/decoder/decoder.cc b/decoder/decoder.cc index e02c7730..31049216 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -408,7 +408,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart") ("pb_max_distortion,D", po::value<int>()->default_value(4), "Phrase-based decoder: maximum distortion") ("cll_gradient,G","Compute conditional log-likelihood gradient and write to STDOUT (src & ref required)") - ("get_oracle_forest,o", "Calculate rescored hypregraph using approximate BLEU scoring of rules") + ("get_oracle_forest,o", "Calculate rescored hypergraph using approximate BLEU scoring of rules") ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") @@ -662,11 +662,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream oracle.show_derivation=conf.count("show_derivations"); remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations"); - if (conf.count("extract_rules")) { - stringstream ss; - ss << sent_id; - extract_file.reset(new WriteFile(str("extract_rules",conf)+"/"+ss.str())); - } combine_size = conf["combine_size"].as<int>(); if (combine_size < 1) combine_size = 1; sent_id = -1; @@ -720,6 +715,11 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } cerr << " id = " << sent_id << endl; } + if (conf.count("extract_rules")) { + stringstream ss; + ss << sent_id << ".gz"; + extract_file.reset(new WriteFile(str("extract_rules",conf)+"/"+ss.str())); + } string to_translate; Lattice ref; ParseTranslatorInputLattice(buf, &to_translate, &ref); diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc index d337b28b..0a97cba5 100644 --- a/decoder/ff_ngrams.cc +++ b/decoder/ff_ngrams.cc @@ -36,7 +36,7 @@ struct State { } explicit State(const State<MAX_ORDER>& other, unsigned order, WordID extend) { char om1 = order - 1; - assert(om1 > 0); + if (!om1) { memset(state, 0, sizeof(state)); return; } for (char i = 1; i < om1; ++i) state[i - 1]= other.state[i]; state[om1 - 1] = extend; } @@ -60,8 +60,9 @@ namespace { } } -static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector<string>& prefixes, string& target_separator, string* cluster_file) { +static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector<string>& prefixes, string& target_separator, string* cluster_file, string* featname) { vector<string> const& argv=SplitOnWhitespace(in); + *featname = ""; *explicit_markers = false; *order = 3; prefixes.push_back("NOT-USED"); @@ -83,6 +84,9 @@ static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, case 'x': *explicit_markers = true; break; + case 'n': + LMSPEC_NEXTARG; *featname=*i; + break; case 'U': LMSPEC_NEXTARG; prefixes[1] = *i; @@ -148,7 +152,7 @@ usage: << "Example feature instantiation: \n" << " tri:a|b|c \n\n"; - return false; + abort(); } class NgramDetectorImpl { @@ -226,6 +230,7 @@ class NgramDetectorImpl { ++n; if (!fid) { ostringstream os; + os << featname_; os << prefixes_[n]; for (int i = n-1; i >= 0; --i) { os << (i != n-1 ? target_separator_ : ""); @@ -404,7 +409,8 @@ class NgramDetectorImpl { public: explicit NgramDetectorImpl(bool explicit_markers, unsigned order, - vector<string>& prefixes, string& target_separator, const string& clusters) : + vector<string>& prefixes, string& target_separator, const string& clusters, + const string& featname) : kCDEC_UNK(TD::Convert("<unk>")) , add_sos_eos_(!explicit_markers) { order_ = order; @@ -414,6 +420,7 @@ class NgramDetectorImpl { unscored_words_offset_ = is_complete_offset_ + 1; prefixes_ = prefixes; target_separator_ = target_separator; + featname_ = featname; // special handling of beginning / ending sentence markers dummy_state_ = new char[state_size_]; @@ -454,6 +461,7 @@ class NgramDetectorImpl { TRulePtr dummy_rule_; vector<string> prefixes_; string target_separator_; + string featname_; struct FidTree { map<WordID, int> fids; map<WordID, FidTree> levels; @@ -467,9 +475,9 @@ NgramDetector::NgramDetector(const string& param) { bool explicit_markers = false; unsigned order = 3; string clusters; - ParseArgs(param, &explicit_markers, &order, prefixes, target_separator, &clusters); + ParseArgs(param, &explicit_markers, &order, prefixes, target_separator, &clusters, &featname); pimpl_ = new NgramDetectorImpl(explicit_markers, order, prefixes, - target_separator, clusters); + target_separator, clusters, featname); SetStateSize(pimpl_->ReserveStateSize()); } diff --git a/decoder/ff_ruleshape.cc b/decoder/ff_ruleshape.cc index 7bb548c4..35b41c46 100644 --- a/decoder/ff_ruleshape.cc +++ b/decoder/ff_ruleshape.cc @@ -1,5 +1,8 @@ #include "ff_ruleshape.h" +#include "filelib.h" +#include "stringlib.h" +#include "verbose.h" #include "trule.h" #include "hg.h" #include "fdict.h" @@ -104,3 +107,138 @@ void RuleShapeFeatures::TraversalFeaturesImpl(const SentenceMetadata& /* smeta * features->set_value(cur->fid_, 1.0); } +namespace { +void ParseRSArgs(string const& in, string* emapfile, string* fmapfile, unsigned *pfxsize) { + vector<string> const& argv=SplitOnWhitespace(in); + *emapfile = ""; + *fmapfile = ""; + *pfxsize = 0; +#define RSSPEC_NEXTARG if (i==argv.end()) { \ + cerr << "Missing argument for "<<*last<<". "; goto usage; \ + } else { ++i; } + + for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { + string const& s=*i; + if (s[0]=='-') { + if (s.size()>2) goto fail; + switch (s[1]) { + case 'e': + if (emapfile->size() > 0) { cerr << "Multiple -e specifications!\n"; abort(); } + RSSPEC_NEXTARG; *emapfile=*i; + break; + case 'f': + if (fmapfile->size() > 0) { cerr << "Multiple -f specifications!\n"; abort(); } + RSSPEC_NEXTARG; *fmapfile=*i; + break; + case 'p': + RSSPEC_NEXTARG; *pfxsize=atoi(i->c_str()); + break; +#undef RSSPEC_NEXTARG + default: + fail: + cerr<<"Unknown RuleShape2 option "<<s<<" ; "; + goto usage; + } + } else { + cerr << "RuleShape2 bad argument!\n"; + abort(); + } + } + return; +usage: + cerr << "Bad parameters for RuleShape2\n"; + abort(); +} + +inline void AddWordToClassMapping_(vector<WordID>* pv, unsigned f, unsigned t, unsigned pfx_size) { + if (pfx_size) { + const string& ts = TD::Convert(t); + if (pfx_size < ts.size()) + t = TD::Convert(ts.substr(0, pfx_size)); + } + if (f >= pv->size()) + pv->resize((f + 1) * 1.2); + (*pv)[f] = t; +} +} + +RuleShapeFeatures2::~RuleShapeFeatures2() {} + +RuleShapeFeatures2::RuleShapeFeatures2(const string& param) : kNT(TD::Convert("NT")), kUNK(TD::Convert("<unk>")) { + string emap; + string fmap; + unsigned pfxsize = 0; + ParseRSArgs(param, &emap, &fmap, &pfxsize); + has_src_ = fmap.size(); + has_trg_ = emap.size(); + if (has_trg_) LoadWordClasses(emap, pfxsize, &e2class_); + if (has_src_) LoadWordClasses(fmap, pfxsize, &f2class_); + if (!has_trg_ && !has_src_) { + cerr << "RuleShapeFeatures2 requires [-e trg_map.gz] or [-f src_map.gz] or both, and optional [-p pfxsize]\n"; + abort(); + } +} + +void RuleShapeFeatures2::LoadWordClasses(const string& file, const unsigned pfx_size, vector<WordID>* pv) { + ReadFile rf(file); + istream& in = *rf.stream(); + string line; + vector<WordID> dummy; + int lc = 0; + if (!SILENT) + cerr << " Loading word classes from " << file << " ...\n"; + AddWordToClassMapping_(pv, TD::Convert("<s>"), TD::Convert("<s>"), 0); + AddWordToClassMapping_(pv, TD::Convert("</s>"), TD::Convert("</s>"), 0); + while(getline(in, line)) { + dummy.clear(); + TD::ConvertSentence(line, &dummy); + ++lc; + if (dummy.size() != 2 && dummy.size() != 3) { + cerr << " Class map file expects: CLASS WORD [freq]\n"; + cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; + abort(); + } + AddWordToClassMapping_(pv, dummy[1], dummy[0], pfx_size); + } + if (!SILENT) + cerr << " Loaded word " << lc << " mapping rules.\n"; +} + +void RuleShapeFeatures2::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, + const Hypergraph::Edge& edge, + const vector<const void*>& /* ant_contexts */, + SparseVector<double>* features, + SparseVector<double>* /* estimated_features */, + void* /* context */) const { + const vector<int>& f = edge.rule_->f(); + const vector<int>& e = edge.rule_->e(); + Node* fid = &fidtree_; + if (has_src_) { + for (unsigned i = 0; i < f.size(); ++i) + fid = &fid->next_[MapF(f[i])]; + } + if (has_trg_) { + for (unsigned i = 0; i < e.size(); ++i) + fid = &fid->next_[MapE(e[i])]; + } + if (!fid->fid_) { + ostringstream os; + os << "RS:"; + if (has_src_) { + for (unsigned i = 0; i < f.size(); ++i) { + if (i) os << '_'; + os << TD::Convert(MapF(f[i])); + } + if (has_trg_) os << "__"; + } + if (has_trg_) { + for (unsigned i = 0; i < e.size(); ++i) { + if (i) os << '_'; + os << TD::Convert(MapE(e[i])); + } + } + fid->fid_ = FD::Convert(os.str()); + } + features->set_value(fid->fid_, 1); +} + diff --git a/decoder/ff_ruleshape.h b/decoder/ff_ruleshape.h index 9f20faf3..488cfd84 100644 --- a/decoder/ff_ruleshape.h +++ b/decoder/ff_ruleshape.h @@ -2,6 +2,7 @@ #define _FF_RULESHAPE_H_ #include <vector> +#include <map> #include "ff.h" class RuleShapeFeatures : public FeatureFunction { @@ -28,4 +29,49 @@ class RuleShapeFeatures : public FeatureFunction { } }; +class RuleShapeFeatures2 : public FeatureFunction { + public: + ~RuleShapeFeatures2(); + RuleShapeFeatures2(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const HG::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + private: + struct Node { + int fid_; + Node() : fid_() {} + std::map<WordID, Node> next_; + }; + mutable Node fidtree_; + + inline WordID MapE(WordID w) const { + if (w <= 0) return kNT; + unsigned res = 0; + if (w < e2class_.size()) res = e2class_[w]; + if (!res) res = kUNK; + return res; + } + + inline WordID MapF(WordID w) const { + if (w <= 0) return kNT; + unsigned res = 0; + if (w < f2class_.size()) res = f2class_[w]; + if (!res) res = kUNK; + return res; + } + + // prfx_size=0 => use full word classes otherwise truncate to specified length + void LoadWordClasses(const std::string& fname, unsigned pfxsize, std::vector<WordID>* pv); + const WordID kNT; + const WordID kUNK; + std::vector<WordID> e2class_; + std::vector<WordID> f2class_; + bool has_src_; + bool has_trg_; +}; + #endif |