diff options
Diffstat (limited to 'decoder/ff_lm.cc')
-rw-r--r-- | decoder/ff_lm.cc | 103 |
1 files changed, 87 insertions, 16 deletions
diff --git a/decoder/ff_lm.cc b/decoder/ff_lm.cc index 6579fbee..a5f43867 100644 --- a/decoder/ff_lm.cc +++ b/decoder/ff_lm.cc @@ -20,6 +20,7 @@ char const* usage_verbose="-n determines the name of the feature (and its weight #endif #include "ff_lm.h" +#include "ff_lm_fsa.h" #include <sstream> #include <unistd.h> @@ -44,8 +45,12 @@ char const* usage_verbose="-n determines the name of the feature (and its weight using namespace std; +string LanguageModelFsa::usage(bool param,bool verbose) { + return FeatureFunction::usage_helper("LanguageModelFsa",usage_short,usage_verbose,param,verbose); +} + string LanguageModel::usage(bool param,bool verbose) { - return usage_helper(usage_name,usage_short,usage_verbose,param,verbose); + return FeatureFunction::usage_helper(usage_name,usage_short,usage_verbose,param,verbose); } @@ -126,7 +131,7 @@ struct LMClient { cerr << "Connected to LM on " << host << " on port " << port << endl; } - float wordProb(int word, int* context) { + float wordProb(int word, WordID const* context) { NgramCache::Cache* cur = &NgramCache::cache_; int i = 0; while (context[i] > 0) { @@ -183,10 +188,10 @@ class LanguageModelImpl { order_=order; state_size_ = OrderToStateSize(order)-1; unigram=(order<=1); - floor_=-100; - kSTART = TD::Convert("<s>"); - kSTOP = TD::Convert("</s>"); - kUNKNOWN = TD::Convert("<unk>"); + floor_ = -100; + kSTART = TD::ss; + kSTOP = TD::se; + kUNKNOWN = TD::unk; kNONE = TD::none; kSTAR = TD::Convert("<{STAR}>"); } @@ -226,7 +231,7 @@ class LanguageModelImpl { *(static_cast<char*>(state) + state_size_) = size; } - virtual double WordProb(int word, int* context) { + virtual double WordProb(WordID word, WordID const* context) { return ngram_.wordProb(word, (VocabIndex*)context); } @@ -425,8 +430,8 @@ public: vector<WordID> buffer_; int order_; int state_size_; - double floor_; public: + double floor_; WordID kSTART; WordID kSTOP; WordID kUNKNOWN; @@ -440,7 +445,7 @@ struct ClientLMI : public LanguageModelImpl ClientLMI(int order,string const& server) : LanguageModelImpl(order), client_(server) {} - virtual double WordProb(int word, int* context) { + virtual double WordProb(int word, WordID const* context) { return client_.wordProb(word, context); } @@ -452,7 +457,7 @@ struct ReuseLMI : public LanguageModelImpl { ReuseLMI(int order, Ngram *ng) : LanguageModelImpl(order), ng(ng) {} - double WordProb(int word, int* context) { + double WordProb(int word, WordID const* context) { return ng->wordProb(word, (VocabIndex*)context); } protected: @@ -520,8 +525,7 @@ usage: return false; } - -LanguageModel::LanguageModel(const string& param) { +LanguageModelImpl *make_lm_impl(string const& param, int *order_out, int *fid_out) { int order,load_order; string featurename,filename; if (!parse_lmspec(param,order,featurename,filename,load_order)) @@ -530,12 +534,80 @@ LanguageModel::LanguageModel(const string& param) { if (load_order) cerr<<" loading LM as order "<<load_order; cerr<<endl; - fid_=FD::Convert(featurename); - pimpl_ = make_lm_impl(order,filename,load_order); + *order_out=order; + *fid_out=FD::Convert(featurename); + return make_lm_impl(order,filename,load_order); +} + + +LanguageModel::LanguageModel(const string& param) { + int order; + pimpl_ = make_lm_impl(param,&order,&fid_); //TODO: see if it's actually possible to set order_ later to mutate an already used FF for e.g. multipass. comment in ff.h says only to change state size in constructor. clone instead? differently -n named ones from same lm filename are already possible, so no urgency. SetStateSize(LanguageModelImpl::OrderToStateSize(order)); } +//TODO: decide whether to waste a word of space so states are always none-terminated for SRILM. otherwise we have to copy +void LanguageModelFsa::set_ngram_order(int i) { + assert(i>0); + ngram_order_=i; + ctxlen_=i-1; + set_state_bytes(ctxlen_*sizeof(WordID)); + set_end_phrase(TD::se); //TODO: pretty boring in unigram case, just adds constant prob - bu WordID *ss=(WordID*)start.begin(); + WordID *hs=(WordID*)h_start.begin(); +t for compat. with non-fsa version, leave it + if (ctxlen_) { // avoid segfault in case of unigram lm (0 state) + ss[0]=TD::ss; // start-sentence context (length 1) + hs[0]=TD::none; // empty context + for (int i=1;i<ctxlen_;++i) { + ss[i]=hs[i]=TD::none; // need this so storage is initialized for hashing. + //TODO: reevaluate whether state space comes cleared by allocator or not. + } + } +} +namespace { +WordID empty_context=TD::none; +} + +LanguageModelFsa::LanguageModelFsa(string const& param) { + int lmorder; + pimpl_ = make_lm_impl(param,&lmorder,&fid_); + floor_=pimpl_->floor_; + set_ngram_order(lmorder); +} + +//TODO: use sri equivalent states (expose in lm impl?) +void LanguageModelFsa::Scan(SentenceMetadata const& /* smeta */,const Hypergraph::Edge& /* edge */,WordID w,void const* old_st,void *new_st,FeatureVector *features) const { + //variable length array is in C99, msvc++, if it doesn't support it, #ifdef it or use a stackalloc call (forget the name) + Featval p; + if (ctxlen_) { + WordID ctx[ngram_order_]; + state_cpy(ctx,old_st); + ctx[ctxlen_]=TD::none; // make this part of state? wastes space but saves copies. + p=pimpl_->WordProb(w,ctx); +// states are sri contexts so are in reverse order (most recent word is first, then 1-back comes next, etc.). + WordID *nst=(WordID *)new_st; + nst[0]=w; // new most recent word + to_state(nst+1,ctx,ctxlen_-1); // rotate old words right + } else { + p=pimpl_->WordProb(w,&empty_context); + } + add_feat(features,(p<floor_)?floor_:p); +} + +void LanguageModelFsa::print_state(ostream &o,void *st) const { + WordID *wst=(WordID *)st; + o<<'['; + for (int i=ctxlen_;i>0;) { + --i; + WordID w=wst[i]; + if (w==TD::none) continue; + if (i) o<<' '; + o << TD::Convert(w); + } + o<<']'; +} + Features LanguageModel::features() const { return single_feature(fid_); } @@ -548,13 +620,12 @@ string LanguageModel::DebugStateToString(const void* state) const{ return pimpl_->DebugStateToString(state); } -void LanguageModel::TraversalFeaturesImpl(const SentenceMetadata& smeta, +void LanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, const Hypergraph::Edge& edge, const vector<const void*>& ant_states, SparseVector<double>* features, SparseVector<double>* estimated_features, void* state) const { - (void) smeta; features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state)); estimated_features->set_value(fid_, pimpl_->EstimateProb(state)); } |