diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/ff_lm.cc | 172 | 
1 files changed, 136 insertions, 36 deletions
diff --git a/decoder/ff_lm.cc b/decoder/ff_lm.cc index 21c05cf2..1f89e24f 100644 --- a/decoder/ff_lm.cc +++ b/decoder/ff_lm.cc @@ -1,5 +1,7 @@  //TODO: allow features to reorder by heuristic*weight the rules' terminal phrases (or of hyperedges').  if first pass has pruning, then compute over whole ruleset as part of heuristic +//TODO: verify that this is true: if ngram order is bigger than lm state's, then the longest possible ngram scores are still used.  if you really want a lower order, a truncated copy of the LM should be small enough.  otherwise, an option to null out words outside of the order's window would need to be implemented. +  #include "ff_lm.h"  #include <sstream> @@ -10,6 +12,7 @@  #include <netdb.h>  #include <boost/shared_ptr.hpp> +#include <boost/lexical_cast.hpp>  #include "tdict.h"  #include "Vocab.h" @@ -24,6 +27,41 @@  using namespace std; +// intend to have a 0-state prelm-pass heuristic LM that is better than 1gram (like how estimated_features are lower order estimates).  NgramShare will keep track of all loaded lms and reuse them. +//TODO: ref counting by shared_ptr?  for now, first one to load LM needs to stick around as long as all subsequent users. + +#include <boost/shared_ptr.hpp> +using namespace boost; + +//WARNING: first person to add a pointer to ngram must keep it around until others are done using it. +struct NgramShare +{ +//  typedef shared_ptr<Ngram> NP; +  typedef Ngram *NP; +  map<string,NP> ns; +  bool have(string const& file) const +  { +    return ns.find(file)!=ns.end(); +  } +  NP get(string const& file) const +  { +    assert(have(file)); +    return ns.find(file)->second; +  } +  void set(string const& file,NP n) +  { +    ns[file]=n; +  } +  void add(string const& file,NP n) +  { +    assert(!have(file)); +    set(file,n); +  } +}; + +//TODO: namespace or static? +NgramShare ngs; +  namespace NgramCache {    struct Cache {      map<WordID, Cache> tree; @@ -36,7 +74,8 @@ namespace NgramCache {  struct LMClient { -  LMClient(const char* host) : port(6666) { +  LMClient(string hostname) : port(6666) { +    char const* host=hostname.c_str();      strcpy(request_buffer, "prob ");      s = const_cast<char*>(strchr(host, ':'));  // TODO fix const_cast      if (s != NULL) { @@ -121,7 +160,6 @@ class LanguageModelImpl {    explicit LanguageModelImpl(int order) :        ngram_(*TD::dict_, order), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1),        floor_(-100.0), -      client_(),        kSTART(TD::Convert("<s>")),        kSTOP(TD::Convert("</s>")),        kUNKNOWN(TD::Convert("<unk>")), @@ -131,26 +169,26 @@ class LanguageModelImpl {    LanguageModelImpl(int order, const string& f) :        ngram_(*TD::dict_, order), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1),        floor_(-100.0), -      client_(NULL),        kSTART(TD::Convert("<s>")),        kSTOP(TD::Convert("</s>")),        kUNKNOWN(TD::Convert("<unk>")),        kNONE(-1),        kSTAR(TD::Convert("<{STAR}>")) { -    if (f.find("lm://") == 0) { -      client_ = new LMClient(f.substr(5).c_str()); -    } else { -      File file(f.c_str(), "r", 0); -      assert(file); -      cerr << "Reading " << order_ << "-gram LM from " << f << endl; -      ngram_.read(file, false); -    } +    File file(f.c_str(), "r", 0); +    assert(file); +    cerr << "Reading " << order_ << "-gram LM from " << f << endl; +    ngram_.read(file, false);    }    virtual ~LanguageModelImpl() { -    delete client_;    } +  Ngram *get_lm() // for make_lm_impl ngs sharing only. +  { +    return &ngram_; +  } + +    inline int StateSize(const void* state) const {      return *(static_cast<const char*>(state) + state_size_);    } @@ -160,9 +198,7 @@ class LanguageModelImpl {    }    virtual double WordProb(int word, int* context) { -    return client_ ? -          client_->wordProb(word, context) -        : ngram_.wordProb(word, (VocabIndex*)context); +    return ngram_.wordProb(word, (VocabIndex*)context);    }    inline double LookupProbForBufferContents(int i) { @@ -243,6 +279,7 @@ class LanguageModelImpl {      return ProbNoRemnant(len - 1, len);    } +  //NOTE: this is where the scoring of words happens (heuristic happens in EstimateProb)    double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate) {      int len = rule.ELength() - rule.Arity();      for (int i = 0; i < ant_states.size(); ++i) @@ -301,9 +338,6 @@ class LanguageModelImpl {    const int order_;    const int state_size_;    const double floor_; - private: -  LMClient* client_; -   public:    const WordID kSTART;    const WordID kSTOP; @@ -312,27 +346,93 @@ class LanguageModelImpl {    const WordID kSTAR;  }; -LanguageModel::LanguageModel(const string& param) : -    fid_(FD::Convert("LanguageModel")) { -  vector<string> argv; -  int argc = SplitOnWhitespace(param, &argv); -  int order = 3; -  // TODO add support for -n FeatureName -  string filename; -  if (argc < 1) { cerr << "LanguageModel requires a filename, minimally!\n"; abort(); } -  else if (argc == 1) { filename = argv[0]; } -  else if (argc == 2 || argc > 3) { cerr << "Don't understand 'LanguageModel " << param << "'\n"; } -  else if (argc == 3) { -    if (argv[0] == "-o") { -      order = atoi(argv[1].c_str()); -      filename = argv[2]; -    } else if (argv[1] == "-o") { -      order = atoi(argv[2].c_str()); -      filename = argv[0]; +struct ClientLMI : public LanguageModelImpl +{ +  ClientLMI(int order,string const& server) : LanguageModelImpl(order), client_(server) +  {} + +  virtual double WordProb(int word, int* context) { +    return client_.wordProb(word, context); +  } + +protected: +  LMClient client_; +}; + +struct ReuseLMI : public LanguageModelImpl +{ +  ReuseLMI(int order, Ngram *ng) : LanguageModelImpl(order), ng(ng) +  {} +  virtual double WordProb(int word, int* context) { +    return ng->wordProb(word, (VocabIndex*)context); +  } +protected: +  Ngram *ng; +}; + +LanguageModelImpl *make_lm_impl(int order, string const& f) +{ +  if (f.find("lm://") == 0) { +    return new ClientLMI(order,f.substr(5)); +  } else if (ngs.have(f)) { +    return new ReuseLMI(order,ngs.get(f)); +  } else { +    LanguageModelImpl *r=new LanguageModelImpl(order,f); +    ngs.add(f,r->get_lm()); +    return r; +  } +} + +bool parse_lmspec(std::string const& in, int &order, string &featurename, string &filename) +{ +  vector<string> const& argv=SplitOnWhitespace(in); +  featurename="LanguageModel"; +  order=3; +#define LMSPEC_NEXTARG if (i==argv.end()) {            \ +    cerr << "Missing argument for "<<*last<<". "; goto usage; \ +    } else { ++i; } + +  for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { +    string const& s=*i; +    if (s[0]=='-') { +      if (s.size()>2) goto fail; +      switch (s[1]) { +      case 'o': +        LMSPEC_NEXTARG; order=lexical_cast<int>(*i); +        break; +      case 'n': +        LMSPEC_NEXTARG; featurename=*i; +        break; +#undef LMSPEC_NEXTARG +      default: +      fail: +        cerr<<"Unknown LanguageModel option "<<s<<" ; "; +        goto usage; +      } +    } else { +      if (filename.empty()) +        filename=s; +      else { +        cerr<<"More than one filename provided. "; +        goto usage; +      }      }    } +  return true; +usage: +  cerr<<"LanguageModel specification should be: [-o order] [-n featurename] filename"<<endl<<" you provided: "<<in<<endl; +  return false; +} + + +LanguageModel::LanguageModel(const string& param) { +  int order; +  string featurename,filename; +  if (!parse_lmspec(param,order,featurename,filename)) +    abort(); +  fid_=FD::Convert("LanguageModel");    SetStateSize(LanguageModelImpl::OrderToStateSize(order)); -  pimpl_ = new LanguageModelImpl(order, filename); +  pimpl_ = make_lm_impl(order,filename);  }  LanguageModel::~LanguageModel() {  | 
