summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-03-10 01:58:30 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-03-10 01:58:30 -0500
commit19e0a382269042605c347b48e5ac92c5012f1ccc (patch)
tree966cac5e26788c1225e1e20257547902a3ba6be7 /decoder
parentb749a9ce861a1f800a0837a90e1376e4e5fc6739 (diff)
remove dependency on SRILM
Diffstat (limited to 'decoder')
-rw-r--r--decoder/decoder.cc9
-rw-r--r--decoder/ff_bleu.cc5
-rw-r--r--decoder/ff_csplit.cc93
-rw-r--r--decoder/ff_csplit.h5
-rwxr-xr-xdecoder/ff_from_fsa.h15
-rw-r--r--decoder/ff_lm.cc111
-rw-r--r--decoder/ff_lm.h5
-rwxr-xr-xdecoder/ff_lm_fsa.h13
-rwxr-xr-xdecoder/ff_sample_fsa.h2
9 files changed, 96 insertions, 162 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 239c8620..95ff6270 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -763,10 +763,6 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest);
cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl;
}
- if (conf.count("show_partition")) {
- const prob_t z = Inside<prob_t, EdgeProb>(forest);
- cerr << " Init. partition log(Z): " << log(z) << endl;
- }
for (int pass = 0; pass < rescoring_passes.size(); ++pass) {
const RescoringPass& rp = rescoring_passes[pass];
@@ -793,6 +789,11 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,show_features,cur_weights,oracle.show_derivation);
}
+ if (conf.count("show_partition")) {
+ const prob_t z = Inside<prob_t, EdgeProb>(forest);
+ cerr << " " << passtr << " partition log(Z): " << log(z) << endl;
+ }
+
string fullbp = "beam_prune" + StringSuffixForRescoringPass(pass);
string fulldp = "density_prune" + StringSuffixForRescoringPass(pass);
maybe_prune(forest,conf,fullbp.c_str(),fulldp.c_str(),passtr,srclen);
diff --git a/decoder/ff_bleu.cc b/decoder/ff_bleu.cc
index aa4e6d85..a842bba8 100644
--- a/decoder/ff_bleu.cc
+++ b/decoder/ff_bleu.cc
@@ -13,8 +13,6 @@ char const* bleu_usage_verbose="Uses feature id 0! Make sure there are no other
#include "ff_bleu.h"
#include "tdict.h"
-#include "Vocab.h"
-#include "Ngram.h"
#include "hg.h"
#include "stringlib.h"
#include "sentence_metadata.h"
@@ -25,7 +23,7 @@ using namespace std;
class BLEUModelImpl {
public:
explicit BLEUModelImpl(int order) :
- ngram_(TD::dict_, order), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1),
+ buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1),
floor_(-100.0),
kSTART(TD::Convert("<s>")),
kSTOP(TD::Convert("</s>")),
@@ -219,7 +217,6 @@ class BLEUModelImpl {
}
protected:
- Ngram ngram_;
vector<WordID> buffer_;
const int order_;
const int state_size_;
diff --git a/decoder/ff_csplit.cc b/decoder/ff_csplit.cc
index 204b7ce6..dee6f4f9 100644
--- a/decoder/ff_csplit.cc
+++ b/decoder/ff_csplit.cc
@@ -3,8 +3,7 @@
#include <set>
#include <cstring>
-#include "Vocab.h"
-#include "Ngram.h"
+#include "klm/lm/model.hh"
#include "sentence_metadata.h"
#include "lattice.h"
@@ -155,51 +154,62 @@ void BasicCSplitFeatures::TraversalFeaturesImpl(
pimpl_->TraversalFeaturesImpl(edge, smeta.GetSourceLattice().size(), features);
}
+namespace {
+struct CSVMapper : public lm::ngram::EnumerateVocab {
+ CSVMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); }
+ void Add(lm::WordIndex index, const StringPiece &str) {
+ const WordID cdec_id = TD::Convert(str.as_string());
+ if (cdec_id >= out_->size())
+ out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN);
+ (*out_)[cdec_id] = index;
+ }
+ vector<lm::WordIndex>* out_;
+ const lm::WordIndex kLM_UNKNOWN_TOKEN;
+};
+}
+
+template<class Model>
struct ReverseCharLMCSplitFeatureImpl {
- ReverseCharLMCSplitFeatureImpl(const string& param) :
- order_(5),
- vocab_(TD::dict_),
- ngram_(vocab_, order_) {
- kBOS = vocab_.getIndex("<s>");
- kEOS = vocab_.getIndex("</s>");
- File file(param.c_str(), "r", 0);
- assert(file);
- cerr << "Reading " << order_ << "-gram LM from " << param << endl;
- ngram_.read(file);
+ ReverseCharLMCSplitFeatureImpl(const string& param) {
+ CSVMapper vm(&cdec2klm_map_);
+ lm::ngram::Config conf;
+ conf.enumerate_vocab = &vm;
+ cerr << "Reading character LM from " << param << endl;
+ ngram_ = new Model(param.c_str(), conf);
+ order_ = ngram_->Order();
+ kEOS = MapWord(TD::Convert("</s>"));
+ assert(kEOS > 0);
+ }
+ lm::WordIndex MapWord(const WordID w) const {
+ if (w < cdec2klm_map_.size()) return cdec2klm_map_[w];
+ return 0;
}
double LeftPhonotacticProb(const Lattice& inword, const int start) {
const int end = inword.size();
- for (int i = 0; i < order_; ++i)
- sc[i] = kBOS;
+ lm::ngram::State state = ngram_->BeginSentenceState();
int sp = min(end - start, order_ - 1);
// cerr << "[" << start << "," << sp << "]\n";
- int ci = (order_ - sp - 1);
- int wi = start;
+ int wi = start + sp - 1;
while (sp > 0) {
- sc[ci] = inword[wi][0].label;
- // cerr << " CHAR: " << TD::Convert(sc[ci]) << " ci=" << ci << endl;
- ++wi;
- ++ci;
+ const lm::ngram::State scopy(state);
+ ngram_->Score(scopy, MapWord(inword[wi][0].label), state);
+ --wi;
--sp;
}
- // cerr << " END ci=" << ci << endl;
- sc[ci] = Vocab_None;
- const double startprob = ngram_.wordProb(kEOS, sc);
- // cerr << " PROB=" << startprob << endl;
+ const lm::ngram::State scopy(state);
+ const double startprob = ngram_->Score(scopy, kEOS, state);
return startprob;
}
private:
- const int order_;
- Vocab& vocab_;
- VocabIndex kBOS;
- VocabIndex kEOS;
- Ngram ngram_;
- VocabIndex sc[80];
+ Model* ngram_;
+ int order_;
+ vector<lm::WordIndex> cdec2klm_map_;
+ lm::WordIndex kEOS;
};
ReverseCharLMCSplitFeature::ReverseCharLMCSplitFeature(const string& param) :
- pimpl_(new ReverseCharLMCSplitFeatureImpl(param)),
+ pimpl_(new ReverseCharLMCSplitFeatureImpl<lm::ngram::ProbingModel>(param)),
fid_(FD::Convert("RevCharLM")) {}
void ReverseCharLMCSplitFeature::TraversalFeaturesImpl(
@@ -217,26 +227,5 @@ void ReverseCharLMCSplitFeature::TraversalFeaturesImpl(
if (edge.rule_->EWords() != 1) return;
const double lpp = pimpl_->LeftPhonotacticProb(smeta.GetSourceLattice(), edge.i_);
features->set_value(fid_, lpp);
-#if 0
- WordID neighbor_word = 0;
- const WordID word = edge.rule_->e_[1];
- const char* sword = TD::Convert(word);
- const int len = strlen(sword);
- int cur = 0;
- int chars = 0;
- while(cur < len) {
- cur += UTF8Len(sword[cur]);
- ++chars;
- }
- if (chars > 4 && (sword[0] == 's' || sword[0] == 'n')) {
- neighbor_word = TD::Convert(string(&sword[1]));
- }
- if (neighbor_word) {
- float nfreq = freq_dict_.LookUp(neighbor_word);
- cerr << "COMPARE: " << TD::Convert(word) << " & " << TD::Convert(neighbor_word) << endl;
- if (!nfreq) nfreq = 99.0f;
- features->set_value(fdoes_deletion_help_, (freq - nfreq));
- }
-#endif
}
diff --git a/decoder/ff_csplit.h b/decoder/ff_csplit.h
index c1cfb64b..38c0c5b8 100644
--- a/decoder/ff_csplit.h
+++ b/decoder/ff_csplit.h
@@ -4,6 +4,7 @@
#include <boost/shared_ptr.hpp>
#include "ff.h"
+#include "klm/lm/model.hh"
class BasicCSplitFeaturesImpl;
class BasicCSplitFeatures : public FeatureFunction {
@@ -20,7 +21,7 @@ class BasicCSplitFeatures : public FeatureFunction {
boost::shared_ptr<BasicCSplitFeaturesImpl> pimpl_;
};
-class ReverseCharLMCSplitFeatureImpl;
+template <class M> class ReverseCharLMCSplitFeatureImpl;
class ReverseCharLMCSplitFeature : public FeatureFunction {
public:
ReverseCharLMCSplitFeature(const std::string& param);
@@ -32,7 +33,7 @@ class ReverseCharLMCSplitFeature : public FeatureFunction {
SparseVector<double>* estimated_features,
void* out_context) const;
private:
- boost::shared_ptr<ReverseCharLMCSplitFeatureImpl> pimpl_;
+ boost::shared_ptr<ReverseCharLMCSplitFeatureImpl<lm::ngram::ProbingModel> > pimpl_;
const int fid_;
};
diff --git a/decoder/ff_from_fsa.h b/decoder/ff_from_fsa.h
index 26aca048..f2db8a4b 100755
--- a/decoder/ff_from_fsa.h
+++ b/decoder/ff_from_fsa.h
@@ -3,6 +3,11 @@
#include "ff_fsa.h"
+#ifndef TD__none
+// replacing dependency on SRILM
+#define TD__none -1
+#endif
+
#ifndef FSA_FF_DEBUG
# define FSA_FF_DEBUG 0
#endif
@@ -94,7 +99,7 @@ public:
return;
}
- // bear with me, because this is hard to understand. reminder: ant_contexts and out_state are left-words first (up to M, TD::none padded). if all M words are present, then FSA state follows. otherwise 0 bytes to keep memcmp/hash happy.
+ // bear with me, because this is hard to understand. reminder: ant_contexts and out_state are left-words first (up to M, TD__none padded). if all M words are present, then FSA state follows. otherwise 0 bytes to keep memcmp/hash happy.
//why do we compute heuristic in so many places? well, because that's how we know what state we should score words in once we're full on our left context (because of markov order bound, we know the score will be the same no matter what came before that left context)
// these left_* refer to our output (out_state):
@@ -163,7 +168,7 @@ public:
if (left_out<left_full) { // finally: partial heuristic for unfilled items
// fsa.reset(ff.heuristic_start_state()); fsa.scan(left_begin,left_out,&h_accum);
ff.ScanPhraseAccumOnly(smeta,edge,left_begin,left_out,ff.heuristic_start_state(),&h_accum);
- do { *left_out++=TD::none; } while(left_out<left_full); // none-terminate so left_end(out_state) will know how many words
+ do { *left_out++=TD__none; } while(left_out<left_full); // none-terminate so left_end(out_state) will know how many words
ff.state_zero(out_fsa_state); // so we compare / hash correctly. don't know state yet because left context isn't full
} else // or else store final right-state. heuristic was already assigned
ff.state_copy(out_fsa_state,fsa.cs);
@@ -233,7 +238,7 @@ public:
static void test() {
WordID w1[1],w1b[1],w2[2];
w1[0]=w2[0]=TD::Convert("hi");
- w2[1]=w1b[0]=TD::none;
+ w2[1]=w1b[0]=TD__none;
assert(left_end(w1,w1+1)==w1+1);
assert(left_end(w1b,w1b+1)==w1b);
assert(left_end(w2,w2+2)==w2+1);
@@ -262,12 +267,12 @@ private:
/*
state layout: left WordIds, followed by fsa state
left words have never been scored. last ones remaining will be scored on FinalTraversalFeatures only.
- right state is unknown until we have all M left words (less than M means TD::none will pad out right end). unk right state will be zeroed out for proper hash/equal recombination.
+ right state is unknown until we have all M left words (less than M means TD__none will pad out right end). unk right state will be zeroed out for proper hash/equal recombination.
*/
static inline WordID const* left_end(WordID const* left, WordID const* e) {
for (;e>left;--e)
- if (e[-1]!=TD::none) break;
+ if (e[-1]!=TD__none) break;
//post: [left,e] are the seen left words
return e;
}
diff --git a/decoder/ff_lm.cc b/decoder/ff_lm.cc
index a9929253..afa36b96 100644
--- a/decoder/ff_lm.cc
+++ b/decoder/ff_lm.cc
@@ -59,8 +59,6 @@ char const* usage_verbose="-n determines the name of the feature (and its weight
#include "fast_lexical_cast.hpp"
#include "tdict.h"
-#include "Vocab.h"
-#include "Ngram.h"
#include "hg.h"
#include "stringlib.h"
@@ -80,41 +78,9 @@ string LanguageModel::usage(bool param,bool verbose) {
}
-// NgramShare will keep track of all loaded lms and reuse them.
-//TODO: ref counting by shared_ptr? for now, first one to load LM needs to stick around as long as all subsequent users.
-
#include <boost/shared_ptr.hpp>
using namespace boost;
-//WARNING: first person to add a pointer to ngram must keep it around until others are done using it.
-struct NgramShare
-{
-// typedef shared_ptr<Ngram> NP;
- typedef Ngram *NP;
- map<string,NP> ns;
- bool have(string const& file) const
- {
- return ns.find(file)!=ns.end();
- }
- NP get(string const& file) const
- {
- assert(have(file));
- return ns.find(file)->second;
- }
- void set(string const& file,NP n)
- {
- ns[file]=n;
- }
- void add(string const& file,NP n)
- {
- assert(!have(file));
- set(file,n);
- }
-};
-
-//TODO: namespace or static?
-NgramShare ngs;
-
namespace NgramCache {
struct Cache {
map<WordID, Cache> tree;
@@ -215,37 +181,28 @@ class LanguageModelImpl : public LanguageModelInterface {
state_size_ = OrderToStateSize(order)-1;
unigram=(order<=1);
floor_ = -100;
- kSTART = TD::ss;
- kSTOP = TD::se;
- kUNKNOWN = TD::unk;
- kNONE = TD::none;
+ kSTART = TD::Convert("<s>");
+ kSTOP = TD::Convert("</s>");
+ kUNKNOWN = TD::Convert("<unk>");
+ kNONE = 0;
kSTAR = TD::Convert("<{STAR}>");
}
public:
- explicit LanguageModelImpl(int order) : ngram_(TD::dict_, order)
+ explicit LanguageModelImpl(int order)
{
init(order);
}
-//TODO: show that unigram special case (0 state) computes what it should.
- LanguageModelImpl(int order, const string& f, int load_order=0) :
- ngram_(TD::dict_, load_order ? load_order : order)
- {
- init(order);
- File file(f.c_str(), "r", 0);
- assert(file);
- cerr << "Reading " << order_ << "-gram LM from " << f << endl;
- ngram_.read(file, false);
- }
-
virtual ~LanguageModelImpl() {
}
- Ngram *get_lm() // for make_lm_impl ngs sharing only.
+ //Ngram *get_lm() // for make_lm_impl ngs sharing only.
+ void *get_lm() // for make_lm_impl ngs sharing only.
{
- return &ngram_;
+ //return &ngram_;
+ return 0;
}
@@ -258,17 +215,19 @@ class LanguageModelImpl : public LanguageModelInterface {
}
virtual double WordProb(WordID word, WordID const* context) {
- return ngram_.wordProb(word, (VocabIndex*)context);
+ return -100;
+ //return ngram_.wordProb(word, (VocabIndex*)context);
}
// may be shorter than actual null-terminated length. context must be null terminated. len is just to save effort for subclasses that don't support contextID
virtual int ContextSize(WordID const* context,int len) {
unsigned ret;
- ngram_.contextID((VocabIndex*)context,ret);
+ //ngram_.contextID((VocabIndex*)context,ret);
return ret;
}
virtual double ContextBOW(WordID const* context,int shortened_len) {
- return ngram_.contextBOW((VocabIndex*)context,shortened_len);
+ //return ngram_.contextBOW((VocabIndex*)context,shortened_len);
+ return -100;
}
inline double LookupProbForBufferContents(int i) {
@@ -457,7 +416,6 @@ public:
}
protected:
- Ngram ngram_;
vector<WordID> buffer_;
int order_;
int state_size_;
@@ -470,8 +428,7 @@ public:
bool unigram;
};
-struct ClientLMI : public LanguageModelImpl
-{
+struct ClientLMI : public LanguageModelImpl {
ClientLMI(int order,string const& server) : LanguageModelImpl(order), client_(server)
{}
@@ -489,37 +446,13 @@ protected:
LMClient client_;
};
-struct ReuseLMI : public LanguageModelImpl
-{
- ReuseLMI(int order, Ngram *ng) : LanguageModelImpl(order), ng(ng)
- {}
- double WordProb(int word, WordID const* context) {
- return ng->wordProb(word, (VocabIndex*)context);
- }
- virtual int ContextSize(WordID const* context, int len) {
- unsigned ret;
- ng->contextID((VocabIndex*)context,ret);
- return ret;
- }
- virtual double ContextBOW(WordID const* context,int shortened_len) {
- return ng->contextBOW((VocabIndex*)context,shortened_len);
- }
-protected:
- Ngram *ng;
-};
-
LanguageModelImpl *make_lm_impl(int order, string const& f, int load_order)
{
if (f.find("lm://") == 0) {
return new ClientLMI(order,f.substr(5));
- } else if (load_order==0 && ngs.have(f)) {
- cerr<<"Reusing already loaded Ngram LM: "<<f<<endl;
- return new ReuseLMI(order,ngs.get(f));
} else {
- LanguageModelImpl *r=new LanguageModelImpl(order,f,load_order);
- if (!load_order || !ngs.have(f))
- ngs.add(f,r->get_lm());
- return r;
+ cerr << "LanguageModel no longer supports non-remote LMs. Please use KLanguageModel!\nPlease see http://cdec-decoder.org/index.php?title=Language_model_notes\n";
+ abort();
}
}
@@ -600,12 +533,12 @@ void LanguageModelFsa::set_ngram_order(int i) {
WordID *ss=(WordID*)start.begin();
WordID *hs=(WordID*)h_start.begin();
if (ctxlen_) { // avoid segfault in case of unigram lm (0 state)
- set_end_phrase(TD::se);
+ set_end_phrase(TD::Convert("</s>"));
// se is pretty boring in unigram case, just adds constant prob. check that this is what we want
- ss[0]=TD::ss; // start-sentence context (length 1)
- hs[0]=TD::none; // empty context
+ ss[0]=TD::Convert("<s>"); // start-sentence context (length 1)
+ hs[0]=0; // empty context
for (int i=1;i<ctxlen_;++i) {
- ss[i]=hs[i]=TD::none; // need this so storage is initialized for hashing.
+ ss[i]=hs[i]=0; // need this so storage is initialized for hashing.
//TODO: reevaluate whether state space comes cleared by allocator or not.
}
}
@@ -627,7 +560,7 @@ void LanguageModelFsa::print_state(ostream &o,void const* st) const {
for (int i=ctxlen_;i>0;sp=true) {
--i;
WordID w=wst[i];
- if (w==TD::none) continue;
+ if (w==0) continue;
if (sp) o<<' ';
o << TD::Convert(w);
}
diff --git a/decoder/ff_lm.h b/decoder/ff_lm.h
index e682481d..8885efce 100644
--- a/decoder/ff_lm.h
+++ b/decoder/ff_lm.h
@@ -8,6 +8,9 @@
#include "ff.h"
#include "config.h"
+// everything in this file is deprecated and may be broken.
+// Chris Dyer, Mar 2011
+
class LanguageModelInterface {
public:
double floor_;
@@ -29,7 +32,7 @@ class LanguageModelInterface {
double p=ContextBOW(context,slen);
while (len>slen) {
--len;
- context[len]=TD::none;
+ context[len]=0;
}
return p;
}
diff --git a/decoder/ff_lm_fsa.h b/decoder/ff_lm_fsa.h
index d2df943e..85b7ef44 100755
--- a/decoder/ff_lm_fsa.h
+++ b/decoder/ff_lm_fsa.h
@@ -21,8 +21,13 @@
#include "ff_fsa.h"
#include "ff_lm.h"
+#ifndef TD__none
+// replacing dependency on SRILM
+#define TD__none -1
+#endif
+
namespace {
-WordID empty_context=TD::none;
+WordID empty_context=TD__none;
}
struct LanguageModelFsa : public FsaFeatureFunctionBase<LanguageModelFsa> {
@@ -40,7 +45,7 @@ struct LanguageModelFsa : public FsaFeatureFunctionBase<LanguageModelFsa> {
}
static inline WordID const* left_end(WordID const* left, WordID const* e) {
for (;e>left;--e)
- if (e[-1]!=TD::none) break;
+ if (e[-1]!=TD__none) break;
//post: [left,e] are the seen left words
return e;
}
@@ -55,7 +60,7 @@ struct LanguageModelFsa : public FsaFeatureFunctionBase<LanguageModelFsa> {
} else {
WordID ctx[ngram_order_]; //alloca if you don't have C99
state_copy(ctx,old_st);
- ctx[ctxlen_]=TD::none;
+ ctx[ctxlen_]=TD__none;
Featval p=floored(pimpl_->WordProb(w,ctx));
FSALMDBG(de,"p("<<TD::Convert(w)<<"|"<<TD::Convert(ctx,ctx+ctxlen_)<<")="<<p);FSALMDBGnl(de);
// states are srilm contexts so are in reverse order (most recent word is first, then 1-back comes next, etc.).
@@ -88,7 +93,7 @@ struct LanguageModelFsa : public FsaFeatureFunctionBase<LanguageModelFsa> {
WP st_end=st+ctxlen_; // may include some null already (or none if full)
int nboth=nw+ctxlen_;
WordID ctx[nboth+1];
- ctx[nboth]=TD::none;
+ ctx[nboth]=TD__none;
// reverse order - state at very end of context, then [i,end) in rev order ending at ctx[0]
W ctx_score_end=wordcpy_reverse(ctx,begin,end);
wordcpy(ctx_score_end,st,st_end); // st already reversed.
diff --git a/decoder/ff_sample_fsa.h b/decoder/ff_sample_fsa.h
index 20d64b16..74d71b6a 100755
--- a/decoder/ff_sample_fsa.h
+++ b/decoder/ff_sample_fsa.h
@@ -114,7 +114,7 @@ struct LongerThanPrev : public FsaFeatureFunctionBase<LongerThanPrev> {
// similar example feature; base type exposes stateful type, defines markov_order 1, state size = sizeof(State)
struct ShorterThanPrev : FsaTypedBase<int,ShorterThanPrev> {
ShorterThanPrev(std::string const& param)
- : FsaTypedBase<int,ShorterThanPrev>(-1,4,singleton_sentence(TD::se)) // start, h_start, end_phrase
+ : FsaTypedBase<int,ShorterThanPrev>(-1,4,singleton_sentence(TD::Convert("</s>"))) // start, h_start, end_phrase
// h_start estimate state: anything <4 chars is usually shorter than previous
{ Init(); }
static std::string usage(bool param,bool verbose) {