summaryrefslogtreecommitdiff
path: root/decoder/ff_csplit.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-03-10 01:58:30 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-03-10 01:58:30 -0500
commit4f9933d668d247ea5831c3f2af0b996a94da28f7 (patch)
tree965f1ade8619dbb7387a3b33d29f157e04a6fff2 /decoder/ff_csplit.cc
parent0acaff8e91fbec0699da8c4a84fdba8c4be9c229 (diff)
remove dependency on SRILM
Diffstat (limited to 'decoder/ff_csplit.cc')
-rw-r--r--decoder/ff_csplit.cc93
1 files changed, 41 insertions, 52 deletions
diff --git a/decoder/ff_csplit.cc b/decoder/ff_csplit.cc
index 204b7ce6..dee6f4f9 100644
--- a/decoder/ff_csplit.cc
+++ b/decoder/ff_csplit.cc
@@ -3,8 +3,7 @@
#include <set>
#include <cstring>
-#include "Vocab.h"
-#include "Ngram.h"
+#include "klm/lm/model.hh"
#include "sentence_metadata.h"
#include "lattice.h"
@@ -155,51 +154,62 @@ void BasicCSplitFeatures::TraversalFeaturesImpl(
pimpl_->TraversalFeaturesImpl(edge, smeta.GetSourceLattice().size(), features);
}
+namespace {
+struct CSVMapper : public lm::ngram::EnumerateVocab {
+ CSVMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); }
+ void Add(lm::WordIndex index, const StringPiece &str) {
+ const WordID cdec_id = TD::Convert(str.as_string());
+ if (cdec_id >= out_->size())
+ out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN);
+ (*out_)[cdec_id] = index;
+ }
+ vector<lm::WordIndex>* out_;
+ const lm::WordIndex kLM_UNKNOWN_TOKEN;
+};
+}
+
+template<class Model>
struct ReverseCharLMCSplitFeatureImpl {
- ReverseCharLMCSplitFeatureImpl(const string& param) :
- order_(5),
- vocab_(TD::dict_),
- ngram_(vocab_, order_) {
- kBOS = vocab_.getIndex("<s>");
- kEOS = vocab_.getIndex("</s>");
- File file(param.c_str(), "r", 0);
- assert(file);
- cerr << "Reading " << order_ << "-gram LM from " << param << endl;
- ngram_.read(file);
+ ReverseCharLMCSplitFeatureImpl(const string& param) {
+ CSVMapper vm(&cdec2klm_map_);
+ lm::ngram::Config conf;
+ conf.enumerate_vocab = &vm;
+ cerr << "Reading character LM from " << param << endl;
+ ngram_ = new Model(param.c_str(), conf);
+ order_ = ngram_->Order();
+ kEOS = MapWord(TD::Convert("</s>"));
+ assert(kEOS > 0);
+ }
+ lm::WordIndex MapWord(const WordID w) const {
+ if (w < cdec2klm_map_.size()) return cdec2klm_map_[w];
+ return 0;
}
double LeftPhonotacticProb(const Lattice& inword, const int start) {
const int end = inword.size();
- for (int i = 0; i < order_; ++i)
- sc[i] = kBOS;
+ lm::ngram::State state = ngram_->BeginSentenceState();
int sp = min(end - start, order_ - 1);
// cerr << "[" << start << "," << sp << "]\n";
- int ci = (order_ - sp - 1);
- int wi = start;
+ int wi = start + sp - 1;
while (sp > 0) {
- sc[ci] = inword[wi][0].label;
- // cerr << " CHAR: " << TD::Convert(sc[ci]) << " ci=" << ci << endl;
- ++wi;
- ++ci;
+ const lm::ngram::State scopy(state);
+ ngram_->Score(scopy, MapWord(inword[wi][0].label), state);
+ --wi;
--sp;
}
- // cerr << " END ci=" << ci << endl;
- sc[ci] = Vocab_None;
- const double startprob = ngram_.wordProb(kEOS, sc);
- // cerr << " PROB=" << startprob << endl;
+ const lm::ngram::State scopy(state);
+ const double startprob = ngram_->Score(scopy, kEOS, state);
return startprob;
}
private:
- const int order_;
- Vocab& vocab_;
- VocabIndex kBOS;
- VocabIndex kEOS;
- Ngram ngram_;
- VocabIndex sc[80];
+ Model* ngram_;
+ int order_;
+ vector<lm::WordIndex> cdec2klm_map_;
+ lm::WordIndex kEOS;
};
ReverseCharLMCSplitFeature::ReverseCharLMCSplitFeature(const string& param) :
- pimpl_(new ReverseCharLMCSplitFeatureImpl(param)),
+ pimpl_(new ReverseCharLMCSplitFeatureImpl<lm::ngram::ProbingModel>(param)),
fid_(FD::Convert("RevCharLM")) {}
void ReverseCharLMCSplitFeature::TraversalFeaturesImpl(
@@ -217,26 +227,5 @@ void ReverseCharLMCSplitFeature::TraversalFeaturesImpl(
if (edge.rule_->EWords() != 1) return;
const double lpp = pimpl_->LeftPhonotacticProb(smeta.GetSourceLattice(), edge.i_);
features->set_value(fid_, lpp);
-#if 0
- WordID neighbor_word = 0;
- const WordID word = edge.rule_->e_[1];
- const char* sword = TD::Convert(word);
- const int len = strlen(sword);
- int cur = 0;
- int chars = 0;
- while(cur < len) {
- cur += UTF8Len(sword[cur]);
- ++chars;
- }
- if (chars > 4 && (sword[0] == 's' || sword[0] == 'n')) {
- neighbor_word = TD::Convert(string(&sword[1]));
- }
- if (neighbor_word) {
- float nfreq = freq_dict_.LookUp(neighbor_word);
- cerr << "COMPARE: " << TD::Convert(word) << " & " << TD::Convert(neighbor_word) << endl;
- if (!nfreq) nfreq = 99.0f;
- features->set_value(fdoes_deletion_help_, (freq - nfreq));
- }
-#endif
}