summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorPaul Baltescu <pauldb89@gmail.com>2013-04-24 17:18:10 +0100
committerPaul Baltescu <pauldb89@gmail.com>2013-04-24 17:18:10 +0100
commite8b412577b9d3fe2090b9f48443f919cd268c809 (patch)
treeb46a7b51d365519dfb5170d71bac33be6d3e29b9 /decoder
parentd189426a7ea56b71eb6e25ed02a7b0993cfb56a8 (diff)
parent5aee54869aa19cfe9be965e67a472e94449d16da (diff)
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'decoder')
-rw-r--r--decoder/Makefile.am2
-rw-r--r--decoder/cdec_ff.cc2
-rw-r--r--decoder/ff_klm.cc49
-rw-r--r--decoder/ff_klm.h5
-rw-r--r--decoder/ff_ngrams.cc68
-rw-r--r--decoder/ff_rules.cc20
-rw-r--r--decoder/ff_rules.h1
-rw-r--r--decoder/ff_source_path.cc42
-rw-r--r--decoder/ff_source_path.h26
9 files changed, 185 insertions, 30 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 6499b38b..82b50f19 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -60,6 +60,7 @@ libcdec_a_SOURCES = \
ff_rules.h \
ff_ruleshape.h \
ff_sample_fsa.h \
+ ff_source_path.h \
ff_source_syntax.h \
ff_spans.h \
ff_tagger.h \
@@ -140,6 +141,7 @@ libcdec_a_SOURCES = \
ff_wordalign.cc \
ff_csplit.cc \
ff_tagger.cc \
+ ff_source_path.cc \
ff_source_syntax.cc \
ff_bleu.cc \
ff_factory.cc \
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 3ab0f9f6..0bf441d4 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -14,6 +14,7 @@
#include "ff_rules.h"
#include "ff_ruleshape.h"
#include "ff_bleu.h"
+#include "ff_source_path.h"
#include "ff_source_syntax.h"
#include "ff_register.h"
#include "ff_charset.h"
@@ -70,6 +71,7 @@ void register_feature_functions() {
ff_registry.Register("InputIndicator", new FFFactory<InputIndicator>);
ff_registry.Register("LexicalTranslationTrigger", new FFFactory<LexicalTranslationTrigger>);
ff_registry.Register("WordPairFeatures", new FFFactory<WordPairFeatures>);
+ ff_registry.Register("SourcePathFeatures", new FFFactory<SourcePathFeatures>);
ff_registry.Register("WordSet", new FFFactory<WordSet>);
ff_registry.Register("Dwarf", new FFFactory<Dwarf>);
ff_registry.Register("External", new FFFactory<ExternalFeature>);
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index fefa90bd..c8ca917a 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -1,6 +1,7 @@
#include "ff_klm.h"
#include <cstring>
+#include <cstdlib>
#include <iostream>
#include <boost/scoped_ptr.hpp>
@@ -151,8 +152,9 @@ template <class Model> class BoundaryRuleScore {
template <class Model>
class KLanguageModelImpl {
public:
- double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) {
+ double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, double* emit, void* remnant) {
*oovs = 0;
+ *emit = 0;
const vector<WordID>& e = rule.e();
BoundaryRuleScore<Model> ruleScore(*ngram_, *static_cast<BoundaryAnnotatedState*>(remnant));
unsigned i = 0;
@@ -169,8 +171,9 @@ class KLanguageModelImpl {
if (e[i] <= 0) {
ruleScore.NonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]]));
} else {
- const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future,
- // maybe handle emission
+ float ep = 0.f;
+ const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i], &ep);
+ if (ep) { *emit += ep; }
const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id
if (cur_word == 0) (*oovs) += 1.0;
ruleScore.Terminal(cur_word);
@@ -205,12 +208,14 @@ class KLanguageModelImpl {
// if this is not a class-based LM, returns w untransformed,
// otherwise returns a word class mapping of w,
// returns TD::Convert("<unk>") if there is no mapping for w
- WordID ClassifyWordIfNecessary(WordID w) const {
+ WordID ClassifyWordIfNecessary(WordID w, float* emitp) const {
if (word2class_map_.empty()) return w;
if (w >= word2class_map_.size())
return kCDEC_UNK;
- else
- return word2class_map_[w];
+ else {
+ *emitp = word2class_map_[w].second;
+ return word2class_map_[w].first;
+ }
}
// converts to cdec word id's to KenLM's id space, OOVs and <unk> end up at 0
@@ -256,32 +261,32 @@ class KLanguageModelImpl {
int lc = 0;
if (!SILENT)
cerr << " Loading word classes from " << file << " ...\n";
- AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("<s>"));
- AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>"));
- while(in) {
- getline(in, line);
- if (!in) continue;
+ AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("<s>"), 0.0);
+ AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>"), 0.0);
+ while(getline(in, line)) {
dummy.clear();
TD::ConvertSentence(line, &dummy);
++lc;
- if (dummy.size() != 2) {
+ if (dummy.size() != 3) {
+ cerr << " Class map file expects: CLASS WORD logp(WORD|CLASS)\n";
cerr << " Format error in " << file << ", line " << lc << ": " << line << endl;
abort();
}
- AddWordToClassMapping_(dummy[0], dummy[1]);
+ AddWordToClassMapping_(dummy[1], dummy[0], strtof(TD::Convert(dummy[2]).c_str(), NULL));
}
}
- void AddWordToClassMapping_(WordID word, WordID cls) {
+ void AddWordToClassMapping_(WordID word, WordID cls, float emit) {
if (word2class_map_.size() <= word) {
- word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK);
+ word2class_map_.resize((word + 10) * 1.1, pair<WordID,float>(kCDEC_UNK,0.f));
assert(word2class_map_.size() > word);
}
- if(word2class_map_[word] != kCDEC_UNK) {
+ if(word2class_map_[word].first != kCDEC_UNK) {
cerr << "Multiple classes for symbol " << TD::Convert(word) << endl;
abort();
}
- word2class_map_[word] = cls;
+ word2class_map_[word].first = cls;
+ word2class_map_[word].second = emit;
}
~KLanguageModelImpl() {
@@ -304,7 +309,9 @@ class KLanguageModelImpl {
int order_;
vector<lm::WordIndex> cdec2klm_map_;
- vector<WordID> word2class_map_; // if this is a class-based LM, this is the word->class mapping
+ vector<pair<WordID,float> > word2class_map_; // if this is a class-based LM,
+ // .first is the word->class mapping
+ // .second is the emission log probability
};
template <class Model>
@@ -322,6 +329,7 @@ KLanguageModel<Model>::KLanguageModel(const string& param) {
}
fid_ = FD::Convert(featname);
oov_fid_ = FD::Convert(featname+"_OOV");
+ emit_fid_ = FD::Convert(featname+"_Emit");
// cerr << "FID: " << oov_fid_ << endl;
SetStateSize(pimpl_->ReserveStateSize());
}
@@ -340,9 +348,12 @@ void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* sme
void* state) const {
double est = 0;
double oovs = 0;
- features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state));
+ double emit = 0;
+ features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, &emit, state));
if (oovs && oov_fid_)
features->set_value(oov_fid_, oovs);
+ if (emit && emit_fid_)
+ features->set_value(emit_fid_, emit);
}
template <class Model>
diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h
index b5ceffd0..db4032f7 100644
--- a/decoder/ff_klm.h
+++ b/decoder/ff_klm.h
@@ -28,8 +28,9 @@ class KLanguageModel : public FeatureFunction {
SparseVector<double>* estimated_features,
void* out_context) const;
private:
- int fid_; // conceptually const; mutable only to simplify constructor
- int oov_fid_; // will be zero if extra OOV feature is not configured by decoder
+ int fid_; // LanguageModel
+ int oov_fid_; // LanguageModel_OOV
+ int emit_fid_; // LanguageModel_Emit [only used for class-based LMs]
KLanguageModelImpl<Model>* pimpl_;
};
diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc
index 9c13fdbb..d337b28b 100644
--- a/decoder/ff_ngrams.cc
+++ b/decoder/ff_ngrams.cc
@@ -60,7 +60,7 @@ namespace {
}
}
-static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector<string>& prefixes, string& target_separator) {
+static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector<string>& prefixes, string& target_separator, string* cluster_file) {
vector<string> const& argv=SplitOnWhitespace(in);
*explicit_markers = false;
*order = 3;
@@ -103,6 +103,10 @@ static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order,
LMSPEC_NEXTARG;
prefixes[5] = *i;
break;
+ case 'c':
+ LMSPEC_NEXTARG;
+ *cluster_file = *i;
+ break;
case 'S':
LMSPEC_NEXTARG;
target_separator = *i;
@@ -124,6 +128,7 @@ usage:
<< "NgramFeatures Usage: \n"
<< " feature_function=NgramFeatures filename.lm [-x] [-o <order>] \n"
+ << " [-c <cluster-file>]\n"
<< " [-U <unigram-prefix>] [-B <bigram-prefix>][-T <trigram-prefix>]\n"
<< " [-4 <4-gram-prefix>] [-5 <5-gram-prefix>] [-S <separator>]\n\n"
@@ -203,6 +208,12 @@ class NgramDetectorImpl {
SetFlag(flag, HAS_FULL_CONTEXT, state);
}
+ WordID MapToClusterIfNecessary(WordID w) const {
+ if (cluster_map.size() == 0) return w;
+ if (w >= cluster_map.size()) return kCDEC_UNK;
+ return cluster_map[w];
+ }
+
void FireFeatures(const State<5>& state, WordID cur, SparseVector<double>* feats) {
FidTree* ft = &fidroot_;
int n = 0;
@@ -285,7 +296,7 @@ class NgramDetectorImpl {
context_complete = true;
}
} else { // handle terminal
- const WordID cur_word = e[j];
+ const WordID cur_word = MapToClusterIfNecessary(e[j]);
SparseVector<double> p;
if (cur_word == kSOS_) {
state = BeginSentenceState();
@@ -348,9 +359,52 @@ class NgramDetectorImpl {
}
}
+ void ReadClusterFile(const string& clusters) {
+ ReadFile rf(clusters);
+ istream& in = *rf.stream();
+ string line;
+ int lc = 0;
+ string cluster;
+ string word;
+ while(getline(in, line)) {
+ ++lc;
+ if (line.size() == 0) continue;
+ if (line[0] == '#') continue;
+ unsigned cend = 1;
+ while((line[cend] != ' ' && line[cend] != '\t') && cend < line.size()) {
+ ++cend;
+ }
+ if (cend == line.size()) {
+ cerr << "Line " << lc << " in " << clusters << " malformed: " << line << endl;
+ abort();
+ }
+ unsigned wbeg = cend + 1;
+ while((line[wbeg] == ' ' || line[wbeg] == '\t') && wbeg < line.size()) {
+ ++wbeg;
+ }
+ if (wbeg == line.size()) {
+ cerr << "Line " << lc << " in " << clusters << " malformed: " << line << endl;
+ abort();
+ }
+ unsigned wend = wbeg + 1;
+ while((line[wend] != ' ' && line[wend] != '\t') && wend < line.size()) {
+ ++wend;
+ }
+ const WordID clusterid = TD::Convert(line.substr(0, cend));
+ const WordID wordid = TD::Convert(line.substr(wbeg, wend - wbeg));
+ if (wordid >= cluster_map.size())
+ cluster_map.resize(wordid + 10, kCDEC_UNK);
+ cluster_map[wordid] = clusterid;
+ }
+ cluster_map[kSOS_] = kSOS_;
+ cluster_map[kEOS_] = kEOS_;
+ }
+
+ vector<WordID> cluster_map;
+
public:
explicit NgramDetectorImpl(bool explicit_markers, unsigned order,
- vector<string>& prefixes, string& target_separator) :
+ vector<string>& prefixes, string& target_separator, const string& clusters) :
kCDEC_UNK(TD::Convert("<unk>")) ,
add_sos_eos_(!explicit_markers) {
order_ = order;
@@ -369,6 +423,9 @@ class NgramDetectorImpl {
dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0"));
kSOS_ = TD::Convert("<s>");
kEOS_ = TD::Convert("</s>");
+
+ if (clusters.size())
+ ReadClusterFile(clusters);
}
~NgramDetectorImpl() {
@@ -409,9 +466,10 @@ NgramDetector::NgramDetector(const string& param) {
vector<string> prefixes;
bool explicit_markers = false;
unsigned order = 3;
- ParseArgs(param, &explicit_markers, &order, prefixes, target_separator);
+ string clusters;
+ ParseArgs(param, &explicit_markers, &order, prefixes, target_separator, &clusters);
pimpl_ = new NgramDetectorImpl(explicit_markers, order, prefixes,
- target_separator);
+ target_separator, clusters);
SetStateSize(pimpl_->ReserveStateSize());
}
diff --git a/decoder/ff_rules.cc b/decoder/ff_rules.cc
index 6716d3da..410e083c 100644
--- a/decoder/ff_rules.cc
+++ b/decoder/ff_rules.cc
@@ -107,7 +107,12 @@ void RuleSourceBigramFeatures::TraversalFeaturesImpl(const SentenceMetadata& sme
(*features) += it->second;
}
-RuleTargetBigramFeatures::RuleTargetBigramFeatures(const std::string& param) {
+RuleTargetBigramFeatures::RuleTargetBigramFeatures(const std::string& param) : inds(1000) {
+ for (unsigned i = 0; i < inds.size(); ++i) {
+ ostringstream os;
+ os << (i + 1);
+ inds[i] = os.str();
+ }
}
void RuleTargetBigramFeatures::PrepareForInput(const SentenceMetadata& smeta) {
@@ -126,11 +131,18 @@ void RuleTargetBigramFeatures::TraversalFeaturesImpl(const SentenceMetadata& sme
it = rule2_feats_.insert(make_pair(&rule, SparseVector<double>())).first;
SparseVector<double>& f = it->second;
string prev = "<r>";
+ vector<WordID> nt_types(rule.Arity());
+ unsigned ntc = 0;
+ for (int i = 0; i < rule.f_.size(); ++i)
+ if (rule.f_[i] < 0) nt_types[ntc++] = -rule.f_[i];
for (int i = 0; i < rule.e_.size(); ++i) {
WordID w = rule.e_[i];
- if (w < 0) w = -w;
- if (w == 0) return;
- const string& cur = TD::Convert(w);
+ string cur;
+ if (w > 0) {
+ cur = TD::Convert(w);
+ } else {
+ cur = TD::Convert(nt_types[-w]) + inds[-w];
+ }
ostringstream os;
os << "RBT:" << prev << '_' << cur;
const int fid = FD::Convert(Escape(os.str()));
diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h
index b100ec34..f210dc65 100644
--- a/decoder/ff_rules.h
+++ b/decoder/ff_rules.h
@@ -51,6 +51,7 @@ class RuleTargetBigramFeatures : public FeatureFunction {
void* context) const;
virtual void PrepareForInput(const SentenceMetadata& smeta);
private:
+ std::vector<std::string> inds;
mutable std::map<const TRule*, SparseVector<double> > rule2_feats_;
};
diff --git a/decoder/ff_source_path.cc b/decoder/ff_source_path.cc
new file mode 100644
index 00000000..2a3bee2e
--- /dev/null
+++ b/decoder/ff_source_path.cc
@@ -0,0 +1,42 @@
+#include "ff_source_path.h"
+
+#include "hg.h"
+
+using namespace std;
+
+SourcePathFeatures::SourcePathFeatures(const string& param) : FeatureFunction(sizeof(int)) {}
+
+void SourcePathFeatures::FireBigramFeature(WordID prev, WordID cur, SparseVector<double>* features) const {
+ int& fid = bigram_fids[prev][cur];
+ if (!fid) fid = FD::Convert("SB:"+TD::Convert(prev) + "_" + TD::Convert(cur));
+ if (fid) features->add_value(fid, 1.0);
+}
+
+void SourcePathFeatures::FireUnigramFeature(WordID cur, SparseVector<double>* features) const {
+ int& fid = unigram_fids[cur];
+ if (!fid) fid = FD::Convert("SU:" + TD::Convert(cur));
+ if (fid) features->add_value(fid, 1.0);
+}
+
+void SourcePathFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const HG::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ WordID* res = reinterpret_cast<WordID*>(context);
+ const vector<int>& f = edge.rule_->f();
+ int prev = 0;
+ unsigned ntc = 0;
+ for (unsigned i = 0; i < f.size(); ++i) {
+ int cur = f[i];
+ if (cur < 0)
+ cur = *reinterpret_cast<const WordID*>(ant_contexts[ntc++]);
+ else
+ FireUnigramFeature(cur, features);
+ if (prev) FireBigramFeature(prev, cur, features);
+ prev = cur;
+ }
+ *res = prev;
+}
+
diff --git a/decoder/ff_source_path.h b/decoder/ff_source_path.h
new file mode 100644
index 00000000..03126412
--- /dev/null
+++ b/decoder/ff_source_path.h
@@ -0,0 +1,26 @@
+#ifndef _FF_SOURCE_PATH_H_
+#define _FF_SOURCE_PATH_H_
+
+#include <vector>
+#include <map>
+#include "ff.h"
+
+class SourcePathFeatures : public FeatureFunction {
+ public:
+ SourcePathFeatures(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const HG::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+
+ private:
+ void FireBigramFeature(WordID prev, WordID cur, SparseVector<double>* features) const;
+ void FireUnigramFeature(WordID cur, SparseVector<double>* features) const;
+ mutable std::map<WordID, std::map<WordID, int> > bigram_fids;
+ mutable std::map<WordID, int> unigram_fids;
+};
+
+#endif