diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-10 01:58:30 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-10 01:58:30 -0500 |
commit | 19e0a382269042605c347b48e5ac92c5012f1ccc (patch) | |
tree | 966cac5e26788c1225e1e20257547902a3ba6be7 /decoder/ff_csplit.cc | |
parent | b749a9ce861a1f800a0837a90e1376e4e5fc6739 (diff) |
remove dependency on SRILM
Diffstat (limited to 'decoder/ff_csplit.cc')
-rw-r--r-- | decoder/ff_csplit.cc | 93 |
1 files changed, 41 insertions, 52 deletions
diff --git a/decoder/ff_csplit.cc b/decoder/ff_csplit.cc index 204b7ce6..dee6f4f9 100644 --- a/decoder/ff_csplit.cc +++ b/decoder/ff_csplit.cc @@ -3,8 +3,7 @@ #include <set> #include <cstring> -#include "Vocab.h" -#include "Ngram.h" +#include "klm/lm/model.hh" #include "sentence_metadata.h" #include "lattice.h" @@ -155,51 +154,62 @@ void BasicCSplitFeatures::TraversalFeaturesImpl( pimpl_->TraversalFeaturesImpl(edge, smeta.GetSourceLattice().size(), features); } +namespace { +struct CSVMapper : public lm::ngram::EnumerateVocab { + CSVMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_->size()) + out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); + (*out_)[cdec_id] = index; + } + vector<lm::WordIndex>* out_; + const lm::WordIndex kLM_UNKNOWN_TOKEN; +}; +} + +template<class Model> struct ReverseCharLMCSplitFeatureImpl { - ReverseCharLMCSplitFeatureImpl(const string& param) : - order_(5), - vocab_(TD::dict_), - ngram_(vocab_, order_) { - kBOS = vocab_.getIndex("<s>"); - kEOS = vocab_.getIndex("</s>"); - File file(param.c_str(), "r", 0); - assert(file); - cerr << "Reading " << order_ << "-gram LM from " << param << endl; - ngram_.read(file); + ReverseCharLMCSplitFeatureImpl(const string& param) { + CSVMapper vm(&cdec2klm_map_); + lm::ngram::Config conf; + conf.enumerate_vocab = &vm; + cerr << "Reading character LM from " << param << endl; + ngram_ = new Model(param.c_str(), conf); + order_ = ngram_->Order(); + kEOS = MapWord(TD::Convert("</s>")); + assert(kEOS > 0); + } + lm::WordIndex MapWord(const WordID w) const { + if (w < cdec2klm_map_.size()) return cdec2klm_map_[w]; + return 0; } double LeftPhonotacticProb(const Lattice& inword, const int start) { const int end = inword.size(); - for (int i = 0; i < order_; ++i) - sc[i] = kBOS; + lm::ngram::State state = ngram_->BeginSentenceState(); int sp = min(end - start, order_ - 1); // cerr << "[" << start << "," << sp << "]\n"; - int ci = (order_ - sp - 1); - int wi = start; + int wi = start + sp - 1; while (sp > 0) { - sc[ci] = inword[wi][0].label; - // cerr << " CHAR: " << TD::Convert(sc[ci]) << " ci=" << ci << endl; - ++wi; - ++ci; + const lm::ngram::State scopy(state); + ngram_->Score(scopy, MapWord(inword[wi][0].label), state); + --wi; --sp; } - // cerr << " END ci=" << ci << endl; - sc[ci] = Vocab_None; - const double startprob = ngram_.wordProb(kEOS, sc); - // cerr << " PROB=" << startprob << endl; + const lm::ngram::State scopy(state); + const double startprob = ngram_->Score(scopy, kEOS, state); return startprob; } private: - const int order_; - Vocab& vocab_; - VocabIndex kBOS; - VocabIndex kEOS; - Ngram ngram_; - VocabIndex sc[80]; + Model* ngram_; + int order_; + vector<lm::WordIndex> cdec2klm_map_; + lm::WordIndex kEOS; }; ReverseCharLMCSplitFeature::ReverseCharLMCSplitFeature(const string& param) : - pimpl_(new ReverseCharLMCSplitFeatureImpl(param)), + pimpl_(new ReverseCharLMCSplitFeatureImpl<lm::ngram::ProbingModel>(param)), fid_(FD::Convert("RevCharLM")) {} void ReverseCharLMCSplitFeature::TraversalFeaturesImpl( @@ -217,26 +227,5 @@ void ReverseCharLMCSplitFeature::TraversalFeaturesImpl( if (edge.rule_->EWords() != 1) return; const double lpp = pimpl_->LeftPhonotacticProb(smeta.GetSourceLattice(), edge.i_); features->set_value(fid_, lpp); -#if 0 - WordID neighbor_word = 0; - const WordID word = edge.rule_->e_[1]; - const char* sword = TD::Convert(word); - const int len = strlen(sword); - int cur = 0; - int chars = 0; - while(cur < len) { - cur += UTF8Len(sword[cur]); - ++chars; - } - if (chars > 4 && (sword[0] == 's' || sword[0] == 'n')) { - neighbor_word = TD::Convert(string(&sword[1])); - } - if (neighbor_word) { - float nfreq = freq_dict_.LookUp(neighbor_word); - cerr << "COMPARE: " << TD::Convert(word) << " & " << TD::Convert(neighbor_word) << endl; - if (!nfreq) nfreq = 99.0f; - features->set_value(fdoes_deletion_help_, (freq - nfreq)); - } -#endif } |