From 36778292e0e98d9b5307a38104f1e99aad7c3aa4 Mon Sep 17 00:00:00 2001 From: graehl Date: Mon, 5 Jul 2010 18:52:56 +0000 Subject: reuse of same-filename Ngram objects (caveat: lifetime is up to you) git-svn-id: https://ws10smt.googlecode.com/svn/trunk@132 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/ff_lm.cc | 172 +++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 136 insertions(+), 36 deletions(-) (limited to 'decoder/ff_lm.cc') diff --git a/decoder/ff_lm.cc b/decoder/ff_lm.cc index 21c05cf2..1f89e24f 100644 --- a/decoder/ff_lm.cc +++ b/decoder/ff_lm.cc @@ -1,5 +1,7 @@ //TODO: allow features to reorder by heuristic*weight the rules' terminal phrases (or of hyperedges'). if first pass has pruning, then compute over whole ruleset as part of heuristic +//TODO: verify that this is true: if ngram order is bigger than lm state's, then the longest possible ngram scores are still used. if you really want a lower order, a truncated copy of the LM should be small enough. otherwise, an option to null out words outside of the order's window would need to be implemented. + #include "ff_lm.h" #include @@ -10,6 +12,7 @@ #include #include +#include #include "tdict.h" #include "Vocab.h" @@ -24,6 +27,41 @@ using namespace std; +// intend to have a 0-state prelm-pass heuristic LM that is better than 1gram (like how estimated_features are lower order estimates). NgramShare will keep track of all loaded lms and reuse them. +//TODO: ref counting by shared_ptr? for now, first one to load LM needs to stick around as long as all subsequent users. + +#include +using namespace boost; + +//WARNING: first person to add a pointer to ngram must keep it around until others are done using it. +struct NgramShare +{ +// typedef shared_ptr NP; + typedef Ngram *NP; + map ns; + bool have(string const& file) const + { + return ns.find(file)!=ns.end(); + } + NP get(string const& file) const + { + assert(have(file)); + return ns.find(file)->second; + } + void set(string const& file,NP n) + { + ns[file]=n; + } + void add(string const& file,NP n) + { + assert(!have(file)); + set(file,n); + } +}; + +//TODO: namespace or static? +NgramShare ngs; + namespace NgramCache { struct Cache { map tree; @@ -36,7 +74,8 @@ namespace NgramCache { struct LMClient { - LMClient(const char* host) : port(6666) { + LMClient(string hostname) : port(6666) { + char const* host=hostname.c_str(); strcpy(request_buffer, "prob "); s = const_cast(strchr(host, ':')); // TODO fix const_cast if (s != NULL) { @@ -121,7 +160,6 @@ class LanguageModelImpl { explicit LanguageModelImpl(int order) : ngram_(*TD::dict_, order), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1), floor_(-100.0), - client_(), kSTART(TD::Convert("")), kSTOP(TD::Convert("")), kUNKNOWN(TD::Convert("")), @@ -131,26 +169,26 @@ class LanguageModelImpl { LanguageModelImpl(int order, const string& f) : ngram_(*TD::dict_, order), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1), floor_(-100.0), - client_(NULL), kSTART(TD::Convert("")), kSTOP(TD::Convert("")), kUNKNOWN(TD::Convert("")), kNONE(-1), kSTAR(TD::Convert("<{STAR}>")) { - if (f.find("lm://") == 0) { - client_ = new LMClient(f.substr(5).c_str()); - } else { - File file(f.c_str(), "r", 0); - assert(file); - cerr << "Reading " << order_ << "-gram LM from " << f << endl; - ngram_.read(file, false); - } + File file(f.c_str(), "r", 0); + assert(file); + cerr << "Reading " << order_ << "-gram LM from " << f << endl; + ngram_.read(file, false); } virtual ~LanguageModelImpl() { - delete client_; } + Ngram *get_lm() // for make_lm_impl ngs sharing only. + { + return &ngram_; + } + + inline int StateSize(const void* state) const { return *(static_cast(state) + state_size_); } @@ -160,9 +198,7 @@ class LanguageModelImpl { } virtual double WordProb(int word, int* context) { - return client_ ? - client_->wordProb(word, context) - : ngram_.wordProb(word, (VocabIndex*)context); + return ngram_.wordProb(word, (VocabIndex*)context); } inline double LookupProbForBufferContents(int i) { @@ -243,6 +279,7 @@ class LanguageModelImpl { return ProbNoRemnant(len - 1, len); } + //NOTE: this is where the scoring of words happens (heuristic happens in EstimateProb) double LookupWords(const TRule& rule, const vector& ant_states, void* vstate) { int len = rule.ELength() - rule.Arity(); for (int i = 0; i < ant_states.size(); ++i) @@ -301,9 +338,6 @@ class LanguageModelImpl { const int order_; const int state_size_; const double floor_; - private: - LMClient* client_; - public: const WordID kSTART; const WordID kSTOP; @@ -312,27 +346,93 @@ class LanguageModelImpl { const WordID kSTAR; }; -LanguageModel::LanguageModel(const string& param) : - fid_(FD::Convert("LanguageModel")) { - vector argv; - int argc = SplitOnWhitespace(param, &argv); - int order = 3; - // TODO add support for -n FeatureName - string filename; - if (argc < 1) { cerr << "LanguageModel requires a filename, minimally!\n"; abort(); } - else if (argc == 1) { filename = argv[0]; } - else if (argc == 2 || argc > 3) { cerr << "Don't understand 'LanguageModel " << param << "'\n"; } - else if (argc == 3) { - if (argv[0] == "-o") { - order = atoi(argv[1].c_str()); - filename = argv[2]; - } else if (argv[1] == "-o") { - order = atoi(argv[2].c_str()); - filename = argv[0]; +struct ClientLMI : public LanguageModelImpl +{ + ClientLMI(int order,string const& server) : LanguageModelImpl(order), client_(server) + {} + + virtual double WordProb(int word, int* context) { + return client_.wordProb(word, context); + } + +protected: + LMClient client_; +}; + +struct ReuseLMI : public LanguageModelImpl +{ + ReuseLMI(int order, Ngram *ng) : LanguageModelImpl(order), ng(ng) + {} + virtual double WordProb(int word, int* context) { + return ng->wordProb(word, (VocabIndex*)context); + } +protected: + Ngram *ng; +}; + +LanguageModelImpl *make_lm_impl(int order, string const& f) +{ + if (f.find("lm://") == 0) { + return new ClientLMI(order,f.substr(5)); + } else if (ngs.have(f)) { + return new ReuseLMI(order,ngs.get(f)); + } else { + LanguageModelImpl *r=new LanguageModelImpl(order,f); + ngs.add(f,r->get_lm()); + return r; + } +} + +bool parse_lmspec(std::string const& in, int &order, string &featurename, string &filename) +{ + vector const& argv=SplitOnWhitespace(in); + featurename="LanguageModel"; + order=3; +#define LMSPEC_NEXTARG if (i==argv.end()) { \ + cerr << "Missing argument for "<<*last<<". "; goto usage; \ + } else { ++i; } + + for (vector::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { + string const& s=*i; + if (s[0]=='-') { + if (s.size()>2) goto fail; + switch (s[1]) { + case 'o': + LMSPEC_NEXTARG; order=lexical_cast(*i); + break; + case 'n': + LMSPEC_NEXTARG; featurename=*i; + break; +#undef LMSPEC_NEXTARG + default: + fail: + cerr<<"Unknown LanguageModel option "<