From f794c881ccaf4aa0646300dea1e1f6e7a307e019 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 12 Sep 2012 18:36:03 +0100 Subject: Partially written bridge to lazy --- decoder/lazy.cc | 122 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ decoder/lazy.h | 8 ++++ 2 files changed, 130 insertions(+) create mode 100644 decoder/lazy.cc create mode 100644 decoder/lazy.h (limited to 'decoder') diff --git a/decoder/lazy.cc b/decoder/lazy.cc new file mode 100644 index 00000000..f5b61c75 --- /dev/null +++ b/decoder/lazy.cc @@ -0,0 +1,122 @@ +#include "hg.h" +#include "lazy.h" +#include "tdict.h" + +#include "lm/enumerate_vocab.hh" +#include "lm/model.hh" +#include "search/edge.hh" +#include "search/vertex.hh" +#include "util/exception.hh" + +#include + +namespace { + +struct MapVocab : public lm::EnumerateVocab { + public: + MapVocab() {} + + // Do not call after Lookup. + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_->size()) out_.resize(cdec_id + 1); + out_[cdec_id] = index; + } + + // Assumes Add has been called and will never be called again. + lm::WordIndex FromCDec(WordID id) const { + return out_[out.size() > id ? id : 0]; + } + + private: + std::vector out_; +}; + +class LazyBase { + public: + LazyBase() {} + + virtual ~LazyBase() {} + + virtual void Search(const Hypergraph &hg) const = 0; + + static LazyBase *Load(const char *model_file); + + protected: + lm::ngram::Config GetConfig() const { + lm::ngram::Config ret; + ret.enumerate_vocab = &vocab_; + return ret; + } + + MapVocab vocab_; +}; + +template class Lazy : public LazyBase { + public: + explicit Lazy(const char *model_file) : m_(model_file, GetConfig()) {} + + void Search(const Hypergraph &hg) const; + + private: + void ConvertEdge(const Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const; + + const Model m_; +}; + +static LazyBase *LazyBase::Load(const char *model_file) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + return new Lazy(model_file); + case lm::ngram::REST_PROBING: + return new Lazy(model_file); + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); + } +} + +template void Lazy::Search(const Hypergraph &hg) const { + boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); + boost::scoped_array out_edges(new search::Edge[hg.edges_.size()]); + for (unsigned int i = 0; i < hg.nodes_.size(); ++i) { + search::Vertex *out_vertex = out_vertices[i]; + const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; + for (unsigned int j = 0; j < edges.size(); ++j) { + unsigned int edge_index = down_edges[j]; + const Hypergraph::Edge &in_edge = hg.edges_[edge_index]; + search::Edge &out_edge = out_edges[edge_index]; + } + } +} + +// TODO: get weights into here somehow. +template void Lazy::ConvertEdge(const Context &context, bool final, search::Vertices *vertices, const Hypergraph::Edge &in, search::Edge &out) const { + const std::vector &e = in_edge.rule_->e(); + std::vector words; + unsigned int terminals = 0; + for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { + if (*word <= 0) { + out.Add(vertices[edge.tail_nodes_[-*word]]); + words.push_back(lm::kMaxWordIndex); + } else { + ++terminals; + words.push_back(vocab_.FromCDec(*word)); + } + } + + if (final) { + words.push_back(m_.GetVocabulary().EndSentence()); + } + + float additive = edge.rule_->GetFeatureValues().dot(weight_vector); + + out.InitRule().Init(context, additive, words, final); +} + +} // namespace + +void PassToLazy(const Hypergraph &hg) { + +} diff --git a/decoder/lazy.h b/decoder/lazy.h new file mode 100644 index 00000000..aecd030d --- /dev/null +++ b/decoder/lazy.h @@ -0,0 +1,8 @@ +#ifndef _LAZY_H_ +#define _LAZY_H_ + +class Hypergraph; + +void PassToLazy(const Hypergraph &hg); + +#endif // _LAZY_H_ -- cgit v1.2.3