summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/ff_klm.cc49
-rw-r--r--decoder/ff_klm.h5
-rw-r--r--decoder/ff_ngrams.cc68
-rw-r--r--decoder/ff_rules.cc20
-rw-r--r--decoder/ff_rules.h1
-rw-r--r--decoder/kbest.h18
6 files changed, 127 insertions, 34 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index fefa90bd..c8ca917a 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -1,6 +1,7 @@
#include "ff_klm.h"
#include <cstring>
+#include <cstdlib>
#include <iostream>
#include <boost/scoped_ptr.hpp>
@@ -151,8 +152,9 @@ template <class Model> class BoundaryRuleScore {
template <class Model>
class KLanguageModelImpl {
public:
- double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) {
+ double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, double* emit, void* remnant) {
*oovs = 0;
+ *emit = 0;
const vector<WordID>& e = rule.e();
BoundaryRuleScore<Model> ruleScore(*ngram_, *static_cast<BoundaryAnnotatedState*>(remnant));
unsigned i = 0;
@@ -169,8 +171,9 @@ class KLanguageModelImpl {
if (e[i] <= 0) {
ruleScore.NonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]]));
} else {
- const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future,
- // maybe handle emission
+ float ep = 0.f;
+ const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i], &ep);
+ if (ep) { *emit += ep; }
const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id
if (cur_word == 0) (*oovs) += 1.0;
ruleScore.Terminal(cur_word);
@@ -205,12 +208,14 @@ class KLanguageModelImpl {
// if this is not a class-based LM, returns w untransformed,
// otherwise returns a word class mapping of w,
// returns TD::Convert("<unk>") if there is no mapping for w
- WordID ClassifyWordIfNecessary(WordID w) const {
+ WordID ClassifyWordIfNecessary(WordID w, float* emitp) const {
if (word2class_map_.empty()) return w;
if (w >= word2class_map_.size())
return kCDEC_UNK;
- else
- return word2class_map_[w];
+ else {
+ *emitp = word2class_map_[w].second;
+ return word2class_map_[w].first;
+ }
}
// converts to cdec word id's to KenLM's id space, OOVs and <unk> end up at 0
@@ -256,32 +261,32 @@ class KLanguageModelImpl {
int lc = 0;
if (!SILENT)
cerr << " Loading word classes from " << file << " ...\n";
- AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("<s>"));
- AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>"));
- while(in) {
- getline(in, line);
- if (!in) continue;
+ AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("<s>"), 0.0);
+ AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>"), 0.0);
+ while(getline(in, line)) {
dummy.clear();
TD::ConvertSentence(line, &dummy);
++lc;
- if (dummy.size() != 2) {
+ if (dummy.size() != 3) {
+ cerr << " Class map file expects: CLASS WORD logp(WORD|CLASS)\n";
cerr << " Format error in " << file << ", line " << lc << ": " << line << endl;
abort();
}
- AddWordToClassMapping_(dummy[0], dummy[1]);
+ AddWordToClassMapping_(dummy[1], dummy[0], strtof(TD::Convert(dummy[2]).c_str(), NULL));
}
}
- void AddWordToClassMapping_(WordID word, WordID cls) {
+ void AddWordToClassMapping_(WordID word, WordID cls, float emit) {
if (word2class_map_.size() <= word) {
- word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK);
+ word2class_map_.resize((word + 10) * 1.1, pair<WordID,float>(kCDEC_UNK,0.f));
assert(word2class_map_.size() > word);
}
- if(word2class_map_[word] != kCDEC_UNK) {
+ if(word2class_map_[word].first != kCDEC_UNK) {
cerr << "Multiple classes for symbol " << TD::Convert(word) << endl;
abort();
}
- word2class_map_[word] = cls;
+ word2class_map_[word].first = cls;
+ word2class_map_[word].second = emit;
}
~KLanguageModelImpl() {
@@ -304,7 +309,9 @@ class KLanguageModelImpl {
int order_;
vector<lm::WordIndex> cdec2klm_map_;
- vector<WordID> word2class_map_; // if this is a class-based LM, this is the word->class mapping
+ vector<pair<WordID,float> > word2class_map_; // if this is a class-based LM,
+ // .first is the word->class mapping
+ // .second is the emission log probability
};
template <class Model>
@@ -322,6 +329,7 @@ KLanguageModel<Model>::KLanguageModel(const string& param) {
}
fid_ = FD::Convert(featname);
oov_fid_ = FD::Convert(featname+"_OOV");
+ emit_fid_ = FD::Convert(featname+"_Emit");
// cerr << "FID: " << oov_fid_ << endl;
SetStateSize(pimpl_->ReserveStateSize());
}
@@ -340,9 +348,12 @@ void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* sme
void* state) const {
double est = 0;
double oovs = 0;
- features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state));
+ double emit = 0;
+ features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, &emit, state));
if (oovs && oov_fid_)
features->set_value(oov_fid_, oovs);
+ if (emit && emit_fid_)
+ features->set_value(emit_fid_, emit);
}
template <class Model>
diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h
index b5ceffd0..db4032f7 100644
--- a/decoder/ff_klm.h
+++ b/decoder/ff_klm.h
@@ -28,8 +28,9 @@ class KLanguageModel : public FeatureFunction {
SparseVector<double>* estimated_features,
void* out_context) const;
private:
- int fid_; // conceptually const; mutable only to simplify constructor
- int oov_fid_; // will be zero if extra OOV feature is not configured by decoder
+ int fid_; // LanguageModel
+ int oov_fid_; // LanguageModel_OOV
+ int emit_fid_; // LanguageModel_Emit [only used for class-based LMs]
KLanguageModelImpl<Model>* pimpl_;
};
diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc
index 9c13fdbb..d337b28b 100644
--- a/decoder/ff_ngrams.cc
+++ b/decoder/ff_ngrams.cc
@@ -60,7 +60,7 @@ namespace {
}
}
-static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector<string>& prefixes, string& target_separator) {
+static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector<string>& prefixes, string& target_separator, string* cluster_file) {
vector<string> const& argv=SplitOnWhitespace(in);
*explicit_markers = false;
*order = 3;
@@ -103,6 +103,10 @@ static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order,
LMSPEC_NEXTARG;
prefixes[5] = *i;
break;
+ case 'c':
+ LMSPEC_NEXTARG;
+ *cluster_file = *i;
+ break;
case 'S':
LMSPEC_NEXTARG;
target_separator = *i;
@@ -124,6 +128,7 @@ usage:
<< "NgramFeatures Usage: \n"
<< " feature_function=NgramFeatures filename.lm [-x] [-o <order>] \n"
+ << " [-c <cluster-file>]\n"
<< " [-U <unigram-prefix>] [-B <bigram-prefix>][-T <trigram-prefix>]\n"
<< " [-4 <4-gram-prefix>] [-5 <5-gram-prefix>] [-S <separator>]\n\n"
@@ -203,6 +208,12 @@ class NgramDetectorImpl {
SetFlag(flag, HAS_FULL_CONTEXT, state);
}
+ WordID MapToClusterIfNecessary(WordID w) const {
+ if (cluster_map.size() == 0) return w;
+ if (w >= cluster_map.size()) return kCDEC_UNK;
+ return cluster_map[w];
+ }
+
void FireFeatures(const State<5>& state, WordID cur, SparseVector<double>* feats) {
FidTree* ft = &fidroot_;
int n = 0;
@@ -285,7 +296,7 @@ class NgramDetectorImpl {
context_complete = true;
}
} else { // handle terminal
- const WordID cur_word = e[j];
+ const WordID cur_word = MapToClusterIfNecessary(e[j]);
SparseVector<double> p;
if (cur_word == kSOS_) {
state = BeginSentenceState();
@@ -348,9 +359,52 @@ class NgramDetectorImpl {
}
}
+ void ReadClusterFile(const string& clusters) {
+ ReadFile rf(clusters);
+ istream& in = *rf.stream();
+ string line;
+ int lc = 0;
+ string cluster;
+ string word;
+ while(getline(in, line)) {
+ ++lc;
+ if (line.size() == 0) continue;
+ if (line[0] == '#') continue;
+ unsigned cend = 1;
+ while((line[cend] != ' ' && line[cend] != '\t') && cend < line.size()) {
+ ++cend;
+ }
+ if (cend == line.size()) {
+ cerr << "Line " << lc << " in " << clusters << " malformed: " << line << endl;
+ abort();
+ }
+ unsigned wbeg = cend + 1;
+ while((line[wbeg] == ' ' || line[wbeg] == '\t') && wbeg < line.size()) {
+ ++wbeg;
+ }
+ if (wbeg == line.size()) {
+ cerr << "Line " << lc << " in " << clusters << " malformed: " << line << endl;
+ abort();
+ }
+ unsigned wend = wbeg + 1;
+ while((line[wend] != ' ' && line[wend] != '\t') && wend < line.size()) {
+ ++wend;
+ }
+ const WordID clusterid = TD::Convert(line.substr(0, cend));
+ const WordID wordid = TD::Convert(line.substr(wbeg, wend - wbeg));
+ if (wordid >= cluster_map.size())
+ cluster_map.resize(wordid + 10, kCDEC_UNK);
+ cluster_map[wordid] = clusterid;
+ }
+ cluster_map[kSOS_] = kSOS_;
+ cluster_map[kEOS_] = kEOS_;
+ }
+
+ vector<WordID> cluster_map;
+
public:
explicit NgramDetectorImpl(bool explicit_markers, unsigned order,
- vector<string>& prefixes, string& target_separator) :
+ vector<string>& prefixes, string& target_separator, const string& clusters) :
kCDEC_UNK(TD::Convert("<unk>")) ,
add_sos_eos_(!explicit_markers) {
order_ = order;
@@ -369,6 +423,9 @@ class NgramDetectorImpl {
dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0"));
kSOS_ = TD::Convert("<s>");
kEOS_ = TD::Convert("</s>");
+
+ if (clusters.size())
+ ReadClusterFile(clusters);
}
~NgramDetectorImpl() {
@@ -409,9 +466,10 @@ NgramDetector::NgramDetector(const string& param) {
vector<string> prefixes;
bool explicit_markers = false;
unsigned order = 3;
- ParseArgs(param, &explicit_markers, &order, prefixes, target_separator);
+ string clusters;
+ ParseArgs(param, &explicit_markers, &order, prefixes, target_separator, &clusters);
pimpl_ = new NgramDetectorImpl(explicit_markers, order, prefixes,
- target_separator);
+ target_separator, clusters);
SetStateSize(pimpl_->ReserveStateSize());
}
diff --git a/decoder/ff_rules.cc b/decoder/ff_rules.cc
index 6716d3da..410e083c 100644
--- a/decoder/ff_rules.cc
+++ b/decoder/ff_rules.cc
@@ -107,7 +107,12 @@ void RuleSourceBigramFeatures::TraversalFeaturesImpl(const SentenceMetadata& sme
(*features) += it->second;
}
-RuleTargetBigramFeatures::RuleTargetBigramFeatures(const std::string& param) {
+RuleTargetBigramFeatures::RuleTargetBigramFeatures(const std::string& param) : inds(1000) {
+ for (unsigned i = 0; i < inds.size(); ++i) {
+ ostringstream os;
+ os << (i + 1);
+ inds[i] = os.str();
+ }
}
void RuleTargetBigramFeatures::PrepareForInput(const SentenceMetadata& smeta) {
@@ -126,11 +131,18 @@ void RuleTargetBigramFeatures::TraversalFeaturesImpl(const SentenceMetadata& sme
it = rule2_feats_.insert(make_pair(&rule, SparseVector<double>())).first;
SparseVector<double>& f = it->second;
string prev = "<r>";
+ vector<WordID> nt_types(rule.Arity());
+ unsigned ntc = 0;
+ for (int i = 0; i < rule.f_.size(); ++i)
+ if (rule.f_[i] < 0) nt_types[ntc++] = -rule.f_[i];
for (int i = 0; i < rule.e_.size(); ++i) {
WordID w = rule.e_[i];
- if (w < 0) w = -w;
- if (w == 0) return;
- const string& cur = TD::Convert(w);
+ string cur;
+ if (w > 0) {
+ cur = TD::Convert(w);
+ } else {
+ cur = TD::Convert(nt_types[-w]) + inds[-w];
+ }
ostringstream os;
os << "RBT:" << prev << '_' << cur;
const int fid = FD::Convert(Escape(os.str()));
diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h
index b100ec34..f210dc65 100644
--- a/decoder/ff_rules.h
+++ b/decoder/ff_rules.h
@@ -51,6 +51,7 @@ class RuleTargetBigramFeatures : public FeatureFunction {
void* context) const;
virtual void PrepareForInput(const SentenceMetadata& smeta);
private:
+ std::vector<std::string> inds;
mutable std::map<const TRule*, SparseVector<double> > rule2_feats_;
};
diff --git a/decoder/kbest.h b/decoder/kbest.h
index 9a55f653..44c23151 100644
--- a/decoder/kbest.h
+++ b/decoder/kbest.h
@@ -6,6 +6,7 @@
#include <tr1/unordered_set>
#include <boost/shared_ptr.hpp>
+#include <boost/type_traits.hpp>
#include "wordid.h"
#include "hg.h"
@@ -134,7 +135,7 @@ namespace KBest {
}
add_next = false;
- if (cand.size() > 0) {
+ while (!add_next && cand.size() > 0) {
std::pop_heap(cand.begin(), cand.end(), HeapCompare());
Derivation* d = cand.back();
cand.pop_back();
@@ -145,10 +146,15 @@ namespace KBest {
if (!filter(d->yield)) {
D.push_back(d);
add_next = true;
+ } else {
+ // just because a node already derived a string (or whatever
+ // equivalent derivation class), you need to add its successors
+ // to the node's candidate pool
+ LazyNext(d, &cand, &s.ds);
}
- } else {
- break;
}
+ if (!add_next)
+ break;
}
if (k < D.size()) return D[k]; else return NULL;
}
@@ -184,7 +190,11 @@ namespace KBest {
s.cand.push_back(d);
}
- const unsigned effective_k = std::min(k_prime, s.cand.size());
+ unsigned effective_k = s.cand.size();
+ if (boost::is_same<DerivationFilter,NoFilter<T> >::value) {
+ // if there's no filter you can use this optimization
+ effective_k = std::min(k_prime, s.cand.size());
+ }
const typename CandidateHeap::iterator kth = s.cand.begin() + effective_k;
std::nth_element(s.cand.begin(), kth, s.cand.end(), DerivationCompare());
s.cand.resize(effective_k);