diff options
Diffstat (limited to 'decoder/ff_lm.cc')
-rw-r--r-- | decoder/ff_lm.cc | 454 |
1 files changed, 454 insertions, 0 deletions
diff --git a/decoder/ff_lm.cc b/decoder/ff_lm.cc new file mode 100644 index 00000000..a1dc8b81 --- /dev/null +++ b/decoder/ff_lm.cc @@ -0,0 +1,454 @@ +#include "ff_lm.h" + +#include <sstream> +#include <unistd.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <netinet/in.h> +#include <netdb.h> + +#include <boost/shared_ptr.hpp> + +#include "tdict.h" +#include "Vocab.h" +#include "Ngram.h" +#include "hg.h" +#include "stringlib.h" + +#ifdef HAVE_RANDLM +#include "RandLM.h" +#endif + +using namespace std; + +namespace NgramCache { + struct Cache { + map<WordID, Cache> tree; + float prob; + Cache() : prob() {} + }; + static Cache cache_; + void Clear() { cache_.tree.clear(); } +} + +struct LMClient { + + LMClient(const char* host) : port(6666) { + strcpy(request_buffer, "prob "); + s = const_cast<char*>(strchr(host, ':')); // TODO fix const_cast + if (s != NULL) { + *s = '\0'; + ++s; + port = atoi(s); + } + sock = socket(AF_INET, SOCK_STREAM, 0); + hp = gethostbyname(host); + if (hp == NULL) { + cerr << "unknown host " << host << endl; + abort(); + } + bzero((char *)&server, sizeof(server)); + bcopy(hp->h_addr, (char *)&server.sin_addr, hp->h_length); + server.sin_family = hp->h_addrtype; + server.sin_port = htons(port); + + int errors = 0; + while (connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) { + cerr << "Error: connect()\n"; + sleep(1); + errors++; + if (errors > 3) exit(1); + } + cerr << "Connected to LM on " << host << " on port " << port << endl; + } + + float wordProb(int word, int* context) { + NgramCache::Cache* cur = &NgramCache::cache_; + int i = 0; + while (context[i] > 0) { + cur = &cur->tree[context[i++]]; + } + cur = &cur->tree[word]; + if (cur->prob) { return cur->prob; } + + i = 0; + int pos = TD::AppendString(word, 5, 16000, request_buffer); + while (context[i] > 0) { + assert(pos < 15995); + request_buffer[pos] = ' '; + ++pos; + pos = TD::AppendString(context[i], pos, 16000, request_buffer); + ++i; + } + assert(pos < 15999); + request_buffer[pos] = '\n'; + ++pos; + request_buffer[pos] = 0; + write(sock, request_buffer, pos); + int r = read(sock, res, 6); + int errors = 0; + int cnt = 0; + while (1) { + if (r < 0) { + errors++; sleep(1); + cerr << "Error: read()\n"; + if (errors > 5) exit(1); + } else if (r==0 || res[cnt] == '\n') { break; } + else { + cnt += r; + if (cnt==6) break; + read(sock, &res[cnt], 6-cnt); + } + } + cur->prob = *reinterpret_cast<float*>(res); + return cur->prob; + } + + private: + int sock, port; + char *s; + struct hostent *hp; + struct sockaddr_in server; + char res[8]; + char request_buffer[16000]; +}; + +class LanguageModelImpl { + public: + explicit LanguageModelImpl(int order) : + ngram_(*TD::dict_, order), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1), + floor_(-100.0), + client_(), + kSTART(TD::Convert("<s>")), + kSTOP(TD::Convert("</s>")), + kUNKNOWN(TD::Convert("<unk>")), + kNONE(-1), + kSTAR(TD::Convert("<{STAR}>")) {} + + LanguageModelImpl(int order, const string& f) : + ngram_(*TD::dict_, order), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1), + floor_(-100.0), + client_(NULL), + kSTART(TD::Convert("<s>")), + kSTOP(TD::Convert("</s>")), + kUNKNOWN(TD::Convert("<unk>")), + kNONE(-1), + kSTAR(TD::Convert("<{STAR}>")) { + if (f.find("lm://") == 0) { + client_ = new LMClient(f.substr(5).c_str()); + } else { + File file(f.c_str(), "r", 0); + assert(file); + cerr << "Reading " << order_ << "-gram LM from " << f << endl; + ngram_.read(file, false); + } + } + + virtual ~LanguageModelImpl() { + delete client_; + } + + inline int StateSize(const void* state) const { + return *(static_cast<const char*>(state) + state_size_); + } + + inline void SetStateSize(int size, void* state) const { + *(static_cast<char*>(state) + state_size_) = size; + } + + virtual double WordProb(int word, int* context) { + return client_ ? + client_->wordProb(word, context) + : ngram_.wordProb(word, (VocabIndex*)context); + } + + inline double LookupProbForBufferContents(int i) { +// int k = i; cerr << "P("; while(buffer_[k] > 0) { std::cerr << TD::Convert(buffer_[k++]) << " "; } + double p = WordProb(buffer_[i], &buffer_[i+1]); + if (p < floor_) p = floor_; +// cerr << ")=" << p << endl; + return p; + } + + string DebugStateToString(const void* state) const { + int len = StateSize(state); + const int* astate = reinterpret_cast<const int*>(state); + string res = "["; + for (int i = 0; i < len; ++i) { + res += " "; + res += TD::Convert(astate[i]); + } + res += " ]"; + return res; + } + + inline double ProbNoRemnant(int i, int len) { + int edge = len; + bool flag = true; + double sum = 0.0; + while (i >= 0) { + if (buffer_[i] == kSTAR) { + edge = i; + flag = false; + } else if (buffer_[i] <= 0) { + edge = i; + flag = true; + } else { + if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART))) + sum += LookupProbForBufferContents(i); + } + --i; + } + return sum; + } + + double EstimateProb(const vector<WordID>& phrase) { + int len = phrase.size(); + buffer_.resize(len + 1); + buffer_[len] = kNONE; + int i = len - 1; + for (int j = 0; j < len; ++j,--i) + buffer_[i] = phrase[j]; + return ProbNoRemnant(len - 1, len); + } + + double EstimateProb(const void* state) { + int len = StateSize(state); + // cerr << "residual len: " << len << endl; + buffer_.resize(len + 1); + buffer_[len] = kNONE; + const int* astate = reinterpret_cast<const int*>(state); + int i = len - 1; + for (int j = 0; j < len; ++j,--i) + buffer_[i] = astate[j]; + return ProbNoRemnant(len - 1, len); + } + + double FinalTraversalCost(const void* state) { + int slen = StateSize(state); + int len = slen + 2; + // cerr << "residual len: " << len << endl; + buffer_.resize(len + 1); + buffer_[len] = kNONE; + buffer_[len-1] = kSTART; + const int* astate = reinterpret_cast<const int*>(state); + int i = len - 2; + for (int j = 0; j < slen; ++j,--i) + buffer_[i] = astate[j]; + buffer_[i] = kSTOP; + assert(i == 0); + return ProbNoRemnant(len - 1, len); + } + + double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate) { + int len = rule.ELength() - rule.Arity(); + for (int i = 0; i < ant_states.size(); ++i) + len += StateSize(ant_states[i]); + buffer_.resize(len + 1); + buffer_[len] = kNONE; + int i = len - 1; + const vector<WordID>& e = rule.e(); + for (int j = 0; j < e.size(); ++j) { + if (e[j] < 1) { + const int* astate = reinterpret_cast<const int*>(ant_states[-e[j]]); + int slen = StateSize(astate); + for (int k = 0; k < slen; ++k) + buffer_[i--] = astate[k]; + } else { + buffer_[i--] = e[j]; + } + } + + double sum = 0.0; + int* remnant = reinterpret_cast<int*>(vstate); + int j = 0; + i = len - 1; + int edge = len; + + while (i >= 0) { + if (buffer_[i] == kSTAR) { + edge = i; + } else if (edge-i >= order_) { + sum += LookupProbForBufferContents(i); + } else if (edge == len && remnant) { + remnant[j++] = buffer_[i]; + } + --i; + } + if (!remnant) return sum; + + if (edge != len || len >= order_) { + remnant[j++] = kSTAR; + if (order_-1 < edge) edge = order_-1; + for (int i = edge-1; i >= 0; --i) + remnant[j++] = buffer_[i]; + } + + SetStateSize(j, vstate); + return sum; + } + + static int OrderToStateSize(int order) { + return ((order-1) * 2 + 1) * sizeof(WordID) + 1; + } + + protected: + Ngram ngram_; + vector<WordID> buffer_; + const int order_; + const int state_size_; + const double floor_; + private: + LMClient* client_; + + public: + const WordID kSTART; + const WordID kSTOP; + const WordID kUNKNOWN; + const WordID kNONE; + const WordID kSTAR; +}; + +LanguageModel::LanguageModel(const string& param) : + fid_(FD::Convert("LanguageModel")) { + vector<string> argv; + int argc = SplitOnWhitespace(param, &argv); + int order = 3; + // TODO add support for -n FeatureName + string filename; + if (argc < 1) { cerr << "LanguageModel requires a filename, minimally!\n"; abort(); } + else if (argc == 1) { filename = argv[0]; } + else if (argc == 2 || argc > 3) { cerr << "Don't understand 'LanguageModel " << param << "'\n"; } + else if (argc == 3) { + if (argv[0] == "-o") { + order = atoi(argv[1].c_str()); + filename = argv[2]; + } else if (argv[1] == "-o") { + order = atoi(argv[2].c_str()); + filename = argv[0]; + } + } + SetStateSize(LanguageModelImpl::OrderToStateSize(order)); + pimpl_ = new LanguageModelImpl(order, filename); +} + +LanguageModel::~LanguageModel() { + delete pimpl_; +} + +string LanguageModel::DebugStateToString(const void* state) const{ + return pimpl_->DebugStateToString(state); +} + +void LanguageModel::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_states, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* state) const { + (void) smeta; + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state)); + estimated_features->set_value(fid_, pimpl_->EstimateProb(state)); +} + +void LanguageModel::FinalTraversalFeatures(const void* ant_state, + SparseVector<double>* features) const { + features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); +} + +#ifdef HAVE_RANDLM +struct RandLMImpl : public LanguageModelImpl { + RandLMImpl(int order, randlm::RandLM* rlm) : + LanguageModelImpl(order), + rlm_(rlm), + oov_(rlm->getWordID(rlm->getOOV())), + rb_(1000, oov_) { + map<int, randlm::WordID> map_cdec2randlm; + int max_wordid = 0; + for(map<randlm::Word, randlm::WordID>::const_iterator it = rlm->vocabStart(); + it != rlm->vocabEnd(); ++it) { + const int cur = TD::Convert(it->first); + map_cdec2randlm[TD::Convert(it->first)] = it->second; + if (cur > max_wordid) max_wordid = cur; + } + cdec2randlm_.resize(max_wordid + 1, oov_); + for (map<int, randlm::WordID>::iterator it = map_cdec2randlm.begin(); + it != map_cdec2randlm.end(); ++it) + cdec2randlm_[it->first] = it->second; + map_cdec2randlm.clear(); + } + + inline randlm::WordID Convert2RandLM(int w) { + return (w < cdec2randlm_.size() ? cdec2randlm_[w] : oov_); + } + + virtual double WordProb(int word, int* context) { + int i = order_; + int c = 1; + rb_[i] = Convert2RandLM(word); + while (i > 1 && *context > 0) { + --i; + rb_[i] = Convert2RandLM(*context); + ++context; + ++c; + } + const void* finalState = 0; + int found; + //cerr << "I = " << i << endl; + return rlm_->getProb(&rb_[i], c, &found, &finalState); + } + private: + boost::shared_ptr<randlm::RandLM> rlm_; + randlm::WordID oov_; + vector<randlm::WordID> cdec2randlm_; + vector<randlm::WordID> rb_; +}; + +LanguageModelRandLM::LanguageModelRandLM(const string& param) : + fid_(FD::Convert("RandLM")) { + vector<string> argv; + int argc = SplitOnWhitespace(param, &argv); + int order = 3; + // TODO add support for -n FeatureName + string filename; + if (argc < 1) { cerr << "RandLM requires a filename, minimally!\n"; abort(); } + else if (argc == 1) { filename = argv[0]; } + else if (argc == 2 || argc > 3) { cerr << "Don't understand 'RandLM " << param << "'\n"; } + else if (argc == 3) { + if (argv[0] == "-o") { + order = atoi(argv[1].c_str()); + filename = argv[2]; + } else if (argv[1] == "-o") { + order = atoi(argv[2].c_str()); + filename = argv[0]; + } + } + SetStateSize(LanguageModelImpl::OrderToStateSize(order)); + int cache_MB = 200; // increase cache size + randlm::RandLM* rlm = randlm::RandLM::initRandLM(filename, order, cache_MB); + assert(rlm != NULL); + pimpl_ = new RandLMImpl(order, rlm); +} + +LanguageModelRandLM::~LanguageModelRandLM() { + delete pimpl_; +} + +void LanguageModelRandLM::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_states, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* state) const { + (void) smeta; + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state)); + estimated_features->set_value(fid_, pimpl_->EstimateProb(state)); +} + +void LanguageModelRandLM::FinalTraversalFeatures(const void* ant_state, + SparseVector<double>* features) const { + features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); +} + +#endif + |