From 8882e9ebe158aef382bb5544559ef7f2a553db62 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 11 Sep 2012 14:23:39 +0100 Subject: Update kenlm and build system --- klm/lm/vocab.hh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'klm/lm/vocab.hh') diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index a25432f9..074cd446 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -62,7 +62,7 @@ class SortedVocabulary : public base::Vocabulary { } // Size for purposes of file writing - static size_t Size(std::size_t entries, const Config &config); + static uint64_t Size(uint64_t entries, const Config &config); // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. WordIndex Bound() const { return bound_; } @@ -129,7 +129,7 @@ class ProbingVocabulary : public base::Vocabulary { return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; } - static size_t Size(std::size_t entries, const Config &config); + static uint64_t Size(uint64_t entries, const Config &config); // Vocab words are [0, Bound()). WordIndex Bound() const { return bound_; } -- cgit v1.2.3 From 0ff82d648446645df245decc1e9eafad304eb327 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 22 Oct 2012 14:04:27 +0100 Subject: Update search, make it compile --- Makefile.am | 1 + configure.ac | 6 +- decoder/Makefile.am | 3 +- decoder/decoder.cc | 8 +- decoder/incremental.cc | 184 +++++++++++++++++++++++++++++++++++++++ decoder/incremental.h | 11 +++ decoder/lazy.cc | 178 -------------------------------------- decoder/lazy.h | 11 --- dtrain/Makefile.am | 2 +- klm/alone/Jamfile | 4 - klm/alone/assemble.cc | 76 ---------------- klm/alone/assemble.hh | 21 ----- klm/alone/graph.hh | 87 ------------------- klm/alone/just_vocab.cc | 14 --- klm/alone/labeled_edge.hh | 30 ------- klm/alone/main.cc | 85 ------------------ klm/alone/read.cc | 118 ------------------------- klm/alone/read.hh | 29 ------- klm/alone/threading.cc | 80 ----------------- klm/alone/threading.hh | 129 --------------------------- klm/alone/vocab.cc | 19 ---- klm/alone/vocab.hh | 34 -------- klm/lm/model.cc | 2 +- klm/lm/vocab.cc | 4 +- klm/lm/vocab.hh | 2 +- klm/search/Jamfile | 2 +- klm/search/Makefile.am | 11 +++ klm/search/arity.hh | 8 -- klm/search/context.hh | 10 +-- klm/search/edge.hh | 53 ++++++++---- klm/search/edge_generator.cc | 144 ++++++++++++++----------------- klm/search/edge_generator.hh | 49 ++++++----- klm/search/edge_queue.cc | 25 ------ klm/search/edge_queue.hh | 73 ---------------- klm/search/final.hh | 41 ++++----- klm/search/header.hh | 57 ++++++++++++ klm/search/source.hh | 48 ----------- klm/search/types.hh | 8 +- klm/search/vertex.cc | 10 +-- klm/search/vertex.hh | 55 ++++++------ klm/search/vertex_generator.cc | 97 ++++++++++++--------- klm/search/vertex_generator.hh | 33 +++---- klm/util/Makefile.am | 2 + klm/util/ersatz_progress.hh | 2 +- klm/util/exception.hh | 2 +- klm/util/pool.cc | 35 ++++++++ klm/util/pool.hh | 45 ++++++++++ klm/util/probing_hash_table.hh | 2 +- klm/util/string_piece.cc | 192 +++++++++++++++++++++++++++++++++++++++++ klm/util/tokenize_piece.hh | 12 +++ mira/Makefile.am | 2 +- training/Makefile.am | 38 ++++---- 52 files changed, 838 insertions(+), 1356 deletions(-) create mode 100644 decoder/incremental.cc create mode 100644 decoder/incremental.h delete mode 100644 decoder/lazy.cc delete mode 100644 decoder/lazy.h delete mode 100644 klm/alone/Jamfile delete mode 100644 klm/alone/assemble.cc delete mode 100644 klm/alone/assemble.hh delete mode 100644 klm/alone/graph.hh delete mode 100644 klm/alone/just_vocab.cc delete mode 100644 klm/alone/labeled_edge.hh delete mode 100644 klm/alone/main.cc delete mode 100644 klm/alone/read.cc delete mode 100644 klm/alone/read.hh delete mode 100644 klm/alone/threading.cc delete mode 100644 klm/alone/threading.hh delete mode 100644 klm/alone/vocab.cc delete mode 100644 klm/alone/vocab.hh create mode 100644 klm/search/Makefile.am delete mode 100644 klm/search/arity.hh delete mode 100644 klm/search/edge_queue.cc delete mode 100644 klm/search/edge_queue.hh create mode 100644 klm/search/header.hh delete mode 100644 klm/search/source.hh create mode 100644 klm/util/pool.cc create mode 100644 klm/util/pool.hh create mode 100644 klm/util/string_piece.cc (limited to 'klm/lm/vocab.hh') diff --git a/Makefile.am b/Makefile.am index 3e0103a8..fefc470d 100644 --- a/Makefile.am +++ b/Makefile.am @@ -6,6 +6,7 @@ SUBDIRS = \ mteval \ klm/util \ klm/lm \ + klm/search \ decoder \ training \ training/liblbfgs \ diff --git a/configure.ac b/configure.ac index 03a0ee87..cb132d66 100644 --- a/configure.ac +++ b/configure.ac @@ -12,6 +12,7 @@ AC_PROG_CXX AC_LANG_CPLUSPLUS BOOST_REQUIRE([1.44]) BOOST_PROGRAM_OPTIONS +BOOST_SYSTEM BOOST_TEST AM_PATH_PYTHON AC_CHECK_HEADER(dlfcn.h,AC_DEFINE(HAVE_DLFCN_H)) @@ -73,9 +74,9 @@ fi #BOOST_THREADS CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" -LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS" +LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_SYSTEM_LDFLAGS" # $BOOST_THREAD_LDFLAGS" -LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS" +LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_SYSTEM_LIBS" # $BOOST_THREAD_LIBS" AC_CHECK_HEADER(google/dense_hash_map, @@ -123,6 +124,7 @@ AC_CONFIG_FILES([rampion/Makefile]) AC_CONFIG_FILES([minrisk/Makefile]) AC_CONFIG_FILES([klm/util/Makefile]) AC_CONFIG_FILES([klm/lm/Makefile]) +AC_CONFIG_FILES([klm/search/Makefile]) AC_CONFIG_FILES([mira/Makefile]) AC_CONFIG_FILES([dtrain/Makefile]) AC_CONFIG_FILES([example_extff/Makefile]) diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 5c0a1964..f8f427d3 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -17,7 +17,7 @@ trule_test_SOURCES = trule_test.cc trule_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz cdec_SOURCES = cdec.cc -cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/search/libksearch.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils -I../klm @@ -73,6 +73,7 @@ libcdec_a_SOURCES = \ ff_source_syntax.cc \ ff_bleu.cc \ ff_factory.cc \ + incremental.cc \ lexalign.cc \ lextrans.cc \ tagger.cc \ diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 052823ca..fe812011 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -39,7 +39,7 @@ #include "sampler.h" #include "forest_writer.h" // TODO this section should probably be handled by an Observer -#include "lazy.h" +#include "incremental.h" #include "hg_io.h" #include "aligner.h" @@ -412,7 +412,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") ("show_target_graph", po::value(), "Directory to write the target hypergraphs to") - ("lazy_search", po::value(), "Run lazy search with this language model file") + ("incremental_search", po::value(), "Run lazy search with this language model file") ("coarse_to_fine_beam_prune", po::value(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") ("ctf_beam_widen", po::value()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") ("ctf_num_widenings", po::value()->default_value(2), "Widen coarse beam this many times before backing off to full parse") @@ -828,8 +828,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (conf.count("show_target_graph")) HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); - if (conf.count("lazy_search")) { - PassToLazy(conf["lazy_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); + if (conf.count("incremental_search")) { + PassToIncremental(conf["incremental_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); o->NotifyDecodingComplete(smeta); return true; } diff --git a/decoder/incremental.cc b/decoder/incremental.cc new file mode 100644 index 00000000..768bbd65 --- /dev/null +++ b/decoder/incremental.cc @@ -0,0 +1,184 @@ +#include "incremental.h" + +#include "hg.h" +#include "fdict.h" +#include "tdict.h" + +#include "lm/enumerate_vocab.hh" +#include "lm/model.hh" +#include "search/config.hh" +#include "search/context.hh" +#include "search/edge.hh" +#include "search/edge_generator.hh" +#include "search/rule.hh" +#include "search/vertex.hh" +#include "search/vertex_generator.hh" +#include "util/exception.hh" + +#include +#include + +#include +#include + +namespace { + +struct MapVocab : public lm::EnumerateVocab { + public: + MapVocab() {} + + // Do not call after Lookup. + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_.size()) out_.resize(cdec_id + 1); + out_[cdec_id] = index; + } + + // Assumes Add has been called and will never be called again. + lm::WordIndex FromCDec(WordID id) const { + return out_[out_.size() > id ? id : 0]; + } + + private: + std::vector out_; +}; + +class IncrementalBase { + public: + IncrementalBase(const std::vector &weights) : + cdec_weights_(weights), + weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { + std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; + } + + virtual ~IncrementalBase() {} + + virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; + + static IncrementalBase *Load(const char *model_file, const std::vector &weights); + + protected: + lm::ngram::Config GetConfig() { + lm::ngram::Config ret; + ret.enumerate_vocab = &vocab_; + return ret; + } + + MapVocab vocab_; + + const std::vector &cdec_weights_; + + const search::Weights weights_; +}; + +template class Incremental : public IncrementalBase { + public: + Incremental(const char *model_file, const std::vector &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {} + + void Search(unsigned int pop_limit, const Hypergraph &hg) const; + + private: + void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; + + const Model m_; +}; + +IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector &weights) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + return new Incremental(model_file, weights); + case lm::ngram::REST_PROBING: + return new Incremental(model_file, weights); + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); + } +} + +void PrintFinal(const Hypergraph &hg, const search::Final final) { + const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); + const search::Final *child(final.Children()); + for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { + if (*i > 0) { + std::cout << TD::Convert(*i) << ' '; + } else { + PrintFinal(hg, *child++); + } + } +} + +template void Incremental::Search(unsigned int pop_limit, const Hypergraph &hg) const { + boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); + search::Config config(weights_, pop_limit); + search::Context context(config, m_); + + for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { + search::EdgeGenerator gen; + const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; + for (unsigned int j = 0; j < down_edges.size(); ++j) { + unsigned int edge_index = down_edges[j]; + ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], gen); + } + search::VertexGenerator vertex_gen(context, out_vertices[i]); + gen.Search(context, vertex_gen); + } + const search::Final top = out_vertices[hg.nodes_.size() - 2].BestChild(); + if (top.Valid()) { + std::cout << "NO PATH FOUND" << std::endl; + } else { + PrintFinal(hg, top); + std::cout << "||| " << top.GetScore() << std::endl; + } +} + +template void Incremental::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const { + const std::vector &e = in.rule_->e(); + std::vector words; + words.reserve(e.size()); + std::vector nts; + unsigned int terminals = 0; + float score = 0.0; + for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { + if (*word <= 0) { + nts.push_back(vertices[in.tail_nodes_[-*word]].RootPartial()); + if (nts.back().Empty()) return; + score += nts.back().Bound(); + words.push_back(lm::kMaxWordIndex); + } else { + ++terminals; + words.push_back(vocab_.FromCDec(*word)); + } + } + + if (final) { + words.push_back(m_.GetVocabulary().EndSentence()); + } + + search::PartialEdge out(gen.AllocateEdge(nts.size())); + + memcpy(out.NT(), &nts[0], sizeof(search::PartialVertex) * nts.size()); + + search::Note note; + note.vp = ∈ + out.SetNote(note); + + score += in.rule_->GetFeatureValues().dot(cdec_weights_); + score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; + score += search::ScoreRule(context, words, final, out.Between()); + out.SetScore(score); + + gen.AddEdge(out); +} + +boost::scoped_ptr AwfulGlobalIncremental; + +} // namespace + +void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { + if (!AwfulGlobalIncremental.get()) { + std::cerr << "Pop limit " << pop_limit << std::endl; + AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights)); + } + AwfulGlobalIncremental->Search(pop_limit, hg); +} diff --git a/decoder/incremental.h b/decoder/incremental.h new file mode 100644 index 00000000..180383ce --- /dev/null +++ b/decoder/incremental.h @@ -0,0 +1,11 @@ +#ifndef _INCREMENTAL_H_ +#define _INCREMENTAL_H_ + +#include "weights.h" +#include + +class Hypergraph; + +void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); + +#endif // _INCREMENTAL_H_ diff --git a/decoder/lazy.cc b/decoder/lazy.cc deleted file mode 100644 index 1e6a94fe..00000000 --- a/decoder/lazy.cc +++ /dev/null @@ -1,178 +0,0 @@ -#include "hg.h" -#include "lazy.h" -#include "fdict.h" -#include "tdict.h" - -#include "lm/enumerate_vocab.hh" -#include "lm/model.hh" -#include "search/config.hh" -#include "search/context.hh" -#include "search/edge.hh" -#include "search/edge_queue.hh" -#include "search/vertex.hh" -#include "search/vertex_generator.hh" -#include "util/exception.hh" - -#include -#include - -#include -#include - -namespace { - -struct MapVocab : public lm::EnumerateVocab { - public: - MapVocab() {} - - // Do not call after Lookup. - void Add(lm::WordIndex index, const StringPiece &str) { - const WordID cdec_id = TD::Convert(str.as_string()); - if (cdec_id >= out_.size()) out_.resize(cdec_id + 1); - out_[cdec_id] = index; - } - - // Assumes Add has been called and will never be called again. - lm::WordIndex FromCDec(WordID id) const { - return out_[out_.size() > id ? id : 0]; - } - - private: - std::vector out_; -}; - -class LazyBase { - public: - LazyBase(const std::vector &weights) : - cdec_weights_(weights), - weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { - std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; - } - - virtual ~LazyBase() {} - - virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; - - static LazyBase *Load(const char *model_file, const std::vector &weights); - - protected: - lm::ngram::Config GetConfig() { - lm::ngram::Config ret; - ret.enumerate_vocab = &vocab_; - return ret; - } - - MapVocab vocab_; - - const std::vector &cdec_weights_; - - const search::Weights weights_; -}; - -template class Lazy : public LazyBase { - public: - Lazy(const char *model_file, const std::vector &weights) : LazyBase(weights), m_(model_file, GetConfig()) {} - - void Search(unsigned int pop_limit, const Hypergraph &hg) const; - - private: - unsigned char ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const; - - const Model m_; -}; - -LazyBase *LazyBase::Load(const char *model_file, const std::vector &weights) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - return new Lazy(model_file, weights); - case lm::ngram::REST_PROBING: - return new Lazy(model_file, weights); - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - -void PrintFinal(const Hypergraph &hg, const search::Final &final) { - const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); - boost::array::const_iterator child(final.Children().begin()); - for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { - if (*i > 0) { - std::cout << TD::Convert(*i) << ' '; - } else { - PrintFinal(hg, **child++); - } - } -} - -template void Lazy::Search(unsigned int pop_limit, const Hypergraph &hg) const { - boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); - search::Config config(weights_, pop_limit); - search::Context context(config, m_); - - for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { - search::EdgeQueue queue(context.PopLimit()); - const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; - for (unsigned int j = 0; j < down_edges.size(); ++j) { - unsigned int edge_index = down_edges[j]; - unsigned char arity = ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], queue.InitializeEdge()); - search::Note note; - note.vp = &hg.edges_[edge_index]; - if (arity != 255) queue.AddEdge(arity, note); - } - search::VertexGenerator vertex_gen(context, out_vertices[i]); - queue.Search(context, vertex_gen); - } - const search::Final *top = out_vertices[hg.nodes_.size() - 2].BestChild(); - if (!top) { - std::cout << "NO PATH FOUND" << std::endl; - } else { - PrintFinal(hg, *top); - std::cout << "||| " << top->Bound() << std::endl; - } -} - -template unsigned char Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const { - const std::vector &e = in.rule_->e(); - std::vector words; - unsigned int terminals = 0; - unsigned char nt = 0; - out.score = 0.0; - for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { - if (*word <= 0) { - out.nt[nt] = vertices[in.tail_nodes_[-*word]].RootPartial(); - if (out.nt[nt].Empty()) return 255; - out.score += out.nt[nt].Bound(); - ++nt; - words.push_back(lm::kMaxWordIndex); - } else { - ++terminals; - words.push_back(vocab_.FromCDec(*word)); - } - } - for (unsigned char fill = nt; fill < search::kMaxArity; ++fill) { - out.nt[fill] = search::kBlankPartialVertex; - } - - if (final) { - words.push_back(m_.GetVocabulary().EndSentence()); - } - - out.score += in.rule_->GetFeatureValues().dot(cdec_weights_); - out.score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; - out.score += search::ScoreRule(context, words, final, out.between); - return nt; -} - -boost::scoped_ptr AwfulGlobalLazy; - -} // namespace - -void PassToLazy(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { - if (!AwfulGlobalLazy.get()) { - std::cerr << "Pop limit " << pop_limit << std::endl; - AwfulGlobalLazy.reset(LazyBase::Load(model_file, weights)); - } - AwfulGlobalLazy->Search(pop_limit, hg); -} diff --git a/decoder/lazy.h b/decoder/lazy.h deleted file mode 100644 index 94895b19..00000000 --- a/decoder/lazy.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef _LAZY_H_ -#define _LAZY_H_ - -#include "weights.h" -#include - -class Hypergraph; - -void PassToLazy(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); - -#endif // _LAZY_H_ diff --git a/dtrain/Makefile.am b/dtrain/Makefile.am index 64fef489..ca9581f5 100644 --- a/dtrain/Makefile.am +++ b/dtrain/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = dtrain dtrain_SOURCES = dtrain.cc score.cc -dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/klm/alone/Jamfile b/klm/alone/Jamfile deleted file mode 100644 index 2cc90c05..00000000 --- a/klm/alone/Jamfile +++ /dev/null @@ -1,4 +0,0 @@ -lib standalone : assemble.cc read.cc threading.cc vocab.cc ../lm//kenlm ../util//kenutil ../search//search : .. : : .. ../search//search ../lm//kenlm ; - -exe decode : main.cc standalone main.cc : multi:..//boost_thread ; -exe just_vocab : just_vocab.cc standalone : multi:..//boost_thread ; diff --git a/klm/alone/assemble.cc b/klm/alone/assemble.cc deleted file mode 100644 index 2ae72ce9..00000000 --- a/klm/alone/assemble.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "alone/assemble.hh" - -#include "alone/labeled_edge.hh" -#include "search/final.hh" - -#include - -namespace alone { - -std::ostream &operator<<(std::ostream &o, const search::Final &final) { - const std::vector &words = static_cast(final.From()).Words(); - if (words.empty()) return o; - const search::Final *const *child = final.Children().data(); - std::vector::const_iterator i(words.begin()); - for (; i != words.end() - 1; ++i) { - if (*i) { - o << **i << ' '; - } else { - o << **child << ' '; - ++child; - } - } - - if (*i) { - if (**i != "") { - o << **i; - } - } else { - o << **child; - } - - return o; -} - -namespace { - -void MakeIndent(std::ostream &o, const char *indent_str, unsigned int level) { - for (unsigned int i = 0; i < level; ++i) - o << indent_str; -} - -void DetailedFinalInternal(std::ostream &o, const search::Final &final, const char *indent_str, unsigned int indent) { - o << "(\n"; - MakeIndent(o, indent_str, indent); - const std::vector &words = static_cast(final.From()).Words(); - const search::Final *const *child = final.Children().data(); - for (std::vector::const_iterator i(words.begin()); i != words.end(); ++i) { - if (*i) { - o << **i; - if (i == words.end() - 1) { - o << '\n'; - MakeIndent(o, indent_str, indent); - } else { - o << ' '; - } - } else { - // One extra indent from the line we're currently on. - o << indent_str; - DetailedFinalInternal(o, **child, indent_str, indent + 1); - for (unsigned int i = 0; i < indent; ++i) o << indent_str; - ++child; - } - } - o << ")=" << final.Bound() << '\n'; -} -} // namespace - -void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str) { - DetailedFinalInternal(o, final, indent_str, 0); -} - -void PrintFinal(const search::Final &final) { - std::cout << final << std::endl; -} - -} // namespace alone diff --git a/klm/alone/assemble.hh b/klm/alone/assemble.hh deleted file mode 100644 index e6b0ad5c..00000000 --- a/klm/alone/assemble.hh +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef ALONE_ASSEMBLE__ -#define ALONE_ASSEMBLE__ - -#include - -namespace search { -class Final; -} // namespace search - -namespace alone { - -std::ostream &operator<<(std::ostream &o, const search::Final &final); - -void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str = " "); - -// This isn't called anywhere but makes it easy to print from gdb. -void PrintFinal(const search::Final &final); - -} // namespace alone - -#endif // ALONE_ASSEMBLE__ diff --git a/klm/alone/graph.hh b/klm/alone/graph.hh deleted file mode 100644 index 788352c9..00000000 --- a/klm/alone/graph.hh +++ /dev/null @@ -1,87 +0,0 @@ -#ifndef ALONE_GRAPH__ -#define ALONE_GRAPH__ - -#include "alone/labeled_edge.hh" -#include "search/rule.hh" -#include "search/types.hh" -#include "search/vertex.hh" -#include "util/exception.hh" - -#include -#include -#include - -namespace alone { - -template class FixedAllocator : boost::noncopyable { - public: - FixedAllocator() : current_(NULL), end_(NULL) {} - - void Init(std::size_t count) { - assert(!current_); - array_.reset(new T[count]); - current_ = array_.get(); - end_ = current_ + count; - } - - T &operator[](std::size_t idx) { - return array_.get()[idx]; - } - - T *New() { - T *ret = current_++; - UTIL_THROW_IF(ret >= end_, util::Exception, "Allocating past end"); - return ret; - } - - std::size_t Size() const { - return end_ - array_.get(); - } - - private: - boost::scoped_array array_; - T *current_, *end_; -}; - -class Graph : boost::noncopyable { - public: - typedef LabeledEdge Edge; - typedef search::Vertex Vertex; - - Graph() {} - - void SetCounts(std::size_t vertices, std::size_t edges) { - vertices_.Init(vertices); - edges_.Init(edges); - } - - Vertex *NewVertex() { - return vertices_.New(); - } - - std::size_t VertexSize() const { return vertices_.Size(); } - - Vertex &MutableVertex(std::size_t index) { - return vertices_[index]; - } - - Edge *NewEdge() { - return edges_.New(); - } - - std::size_t EdgeSize() const { return edges_.Size(); } - - void SetRoot(Vertex *root) { root_ = root; } - - Vertex &Root() { return *root_; } - - private: - FixedAllocator vertices_; - FixedAllocator edges_; - - Vertex *root_; -}; - -} // namespace alone - -#endif // ALONE_GRAPH__ diff --git a/klm/alone/just_vocab.cc b/klm/alone/just_vocab.cc deleted file mode 100644 index 35aea5ed..00000000 --- a/klm/alone/just_vocab.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "alone/read.hh" -#include "util/file_piece.hh" - -#include - -int main() { - util::FilePiece f(0, "stdin", &std::cerr); - while (true) { - try { - alone::JustVocab(f, std::cout); - } catch (const util::EndOfFileException &e) { break; } - std::cout << '\n'; - } -} diff --git a/klm/alone/labeled_edge.hh b/klm/alone/labeled_edge.hh deleted file mode 100644 index 94d8cbdf..00000000 --- a/klm/alone/labeled_edge.hh +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef ALONE_LABELED_EDGE__ -#define ALONE_LABELED_EDGE__ - -#include "search/edge.hh" - -#include -#include - -namespace alone { - -class LabeledEdge : public search::Edge { - public: - LabeledEdge() {} - - void AppendWord(const std::string *word) { - words_.push_back(word); - } - - const std::vector &Words() const { - return words_; - } - - private: - // NULL for non-terminals. - std::vector words_; -}; - -} // namespace alone - -#endif // ALONE_LABELED_EDGE__ diff --git a/klm/alone/main.cc b/klm/alone/main.cc deleted file mode 100644 index e09ab01d..00000000 --- a/klm/alone/main.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "alone/threading.hh" -#include "search/config.hh" -#include "search/context.hh" -#include "util/exception.hh" -#include "util/file_piece.hh" -#include "util/usage.hh" - -#include - -#include -#include - -namespace alone { - -template void ReadLoop(const std::string &graph_prefix, Control &control) { - for (unsigned int sentence = 0; ; ++sentence) { - std::stringstream name; - name << graph_prefix << '/' << sentence; - std::auto_ptr file; - try { - file.reset(new util::FilePiece(name.str().c_str())); - } catch (const util::ErrnoException &e) { - if (e.Error() == ENOENT) return; - throw; - } - control.Add(file.release()); - } -} - -template void RunWithModelType(const char *graph_prefix, const char *model_file, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { - Model model(model_file); - search::Weights weights(weight_str); - search::Config config(weights, pop_limit); - - if (threads > 1) { -#ifdef WITH_THREADS - Controller controller(config, model, threads, std::cout); - ReadLoop(graph_prefix, controller); -#else - UTIL_THROW(util::Exception, "Threading support not compiled in."); -#endif - } else { - InThread controller(config, model, std::cout); - ReadLoop(graph_prefix, controller); - } -} - -void Run(const char *graph_prefix, const char *lm_name, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - RunWithModelType(graph_prefix, lm_name, weight_str, pop_limit, threads); - break; - case lm::ngram::REST_PROBING: - RunWithModelType(graph_prefix, lm_name, weight_str, pop_limit, threads); - break; - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - -} // namespace alone - -int main(int argc, char *argv[]) { - if (argc < 5 || argc > 6) { - std::cerr << argv[0] << " graph_prefix lm \"weights\" pop [threads]" << std::endl; - return 1; - } - -#ifdef WITH_THREADS - unsigned thread_count = boost::thread::hardware_concurrency(); -#else - unsigned thread_count = 1; -#endif - if (argc == 6) { - thread_count = boost::lexical_cast(argv[5]); - UTIL_THROW_IF(!thread_count, util::Exception, "Thread count 0"); - } - UTIL_THROW_IF(!thread_count, util::Exception, "Boost doesn't know how many threads there are. Pass it on the command line."); - alone::Run(argv[1], argv[2], argv[3], boost::lexical_cast(argv[4]), thread_count); - - util::PrintUsage(std::cerr); - return 0; -} diff --git a/klm/alone/read.cc b/klm/alone/read.cc deleted file mode 100644 index 0b20be35..00000000 --- a/klm/alone/read.cc +++ /dev/null @@ -1,118 +0,0 @@ -#include "alone/read.hh" - -#include "alone/graph.hh" -#include "alone/vocab.hh" -#include "search/arity.hh" -#include "search/context.hh" -#include "search/weights.hh" -#include "util/file_piece.hh" - -#include -#include - -#include - -namespace alone { - -namespace { - -template Graph::Edge &ReadEdge(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab, bool final) { - Graph::Edge *ret = to.NewEdge(); - - StringPiece got; - - std::vector words; - unsigned long int terminals = 0; - while ("|||" != (got = from.ReadDelimited())) { - if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) { - // non-terminal - char *end_ptr; - unsigned long int child = std::strtoul(got.data() + 1, &end_ptr, 10); - UTIL_THROW_IF(end_ptr != got.data() + got.size() - 1, FormatException, "Bad non-terminal" << got); - UTIL_THROW_IF(child >= to.VertexSize(), FormatException, "Reference to vertex " << child << " but we only have " << to.VertexSize() << " vertices. Is the file in bottom-up format?"); - ret->Add(to.MutableVertex(child)); - words.push_back(lm::kMaxWordIndex); - ret->AppendWord(NULL); - } else { - const std::pair &found = vocab.FindOrAdd(got); - words.push_back(found.second); - ret->AppendWord(&found.first); - ++terminals; - } - } - if (final) { - // This is not counted for the word penalty. - words.push_back(vocab.EndSentence().second); - ret->AppendWord(&vocab.EndSentence().first); - } - // Hard-coded word penalty. - float additive = context.GetWeights().DotNoLM(from.ReadLine()) - context.GetWeights().WordPenalty() * static_cast(terminals) / M_LN10; - ret->InitRule().Init(context, additive, words, final); - unsigned int arity = ret->GetRule().Arity(); - UTIL_THROW_IF(arity > search::kMaxArity, util::Exception, "Edit search/arity.hh and increase " << search::kMaxArity << " to at least " << arity); - return *ret; -} - -} // namespace - -// TODO: refactor -void JustVocab(util::FilePiece &from, std::ostream &out) { - boost::unordered_set seen; - unsigned long int vertices = from.ReadULong(); - from.ReadULong(); // edges - UTIL_THROW_IF(vertices == 0, FormatException, "Vertex count is zero"); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); - std::string temp; - for (unsigned long int i = 0; i < vertices; ++i) { - unsigned long int edge_count = from.ReadULong(); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); - for (unsigned long int e = 0; e < edge_count; ++e) { - StringPiece got; - while ("|||" != (got = from.ReadDelimited())) { - if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) continue; - temp.assign(got.data(), got.size()); - if (seen.insert(temp).second) out << temp << ' '; - } - from.ReadLine(); // weights - } - } - // Eat sentence - from.ReadLine(); -} - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab) { - unsigned long int vertices; - try { - vertices = from.ReadULong(); - } catch (const util::EndOfFileException &e) { return false; } - unsigned long int edges = from.ReadULong(); - UTIL_THROW_IF(vertices < 2, FormatException, "Vertex count is " << vertices); - UTIL_THROW_IF(edges == 0, FormatException, "Edge count is " << edges); - --vertices; - --edges; - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); - to.SetCounts(vertices, edges); - Graph::Vertex *vertex; - for (unsigned long int i = 0; ; ++i) { - vertex = to.NewVertex(); - unsigned long int edge_count = from.ReadULong(); - bool root = (i == vertices - 1); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); - for (unsigned long int e = 0; e < edge_count; ++e) { - vertex->Add(ReadEdge(context, from, to, vocab, root)); - } - vertex->FinishedAdding(); - if (root) break; - } - to.SetRoot(vertex); - StringPiece str = from.ReadLine(); - UTIL_THROW_IF("1" != str, FormatException, "Expected one edge to root"); - // The edge - from.ReadLine(); - return true; -} - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); - -} // namespace alone diff --git a/klm/alone/read.hh b/klm/alone/read.hh deleted file mode 100644 index 10769a86..00000000 --- a/klm/alone/read.hh +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef ALONE_READ__ -#define ALONE_READ__ - -#include "util/exception.hh" - -#include - -namespace util { class FilePiece; } - -namespace search { template class Context; } - -namespace alone { - -class Graph; -class Vocab; - -class FormatException : public util::Exception { - public: - FormatException() {} - ~FormatException() throw() {} -}; - -void JustVocab(util::FilePiece &from, std::ostream &to); - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); - -} // namespace alone - -#endif // ALONE_READ__ diff --git a/klm/alone/threading.cc b/klm/alone/threading.cc deleted file mode 100644 index 475386b6..00000000 --- a/klm/alone/threading.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include "alone/threading.hh" - -#include "alone/assemble.hh" -#include "alone/graph.hh" -#include "alone/read.hh" -#include "alone/vocab.hh" -#include "lm/model.hh" -#include "search/context.hh" -#include "search/vertex_generator.hh" - -#include -#include -#include - -#include - -namespace alone { -template void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out) { - search::Context context(config, model); - Graph graph; - Vocab vocab(model.GetVocabulary()); - { - boost::scoped_ptr in(in_ptr); - ReadCDec(context, *in, graph, vocab); - } - - for (std::size_t i = 0; i < graph.VertexSize(); ++i) { - search::VertexGenerator(context, graph.MutableVertex(i)); - } - search::PartialVertex top = graph.Root().RootPartial(); - if (top.Empty()) { - out << "NO PATH FOUND"; - } else { - search::PartialVertex continuation; - while (!top.Complete()) { - top.Split(continuation); - top = continuation; - } - out << top.End() << " ||| " << top.End().Bound() << std::endl; - } -} - -template void Decode(const search::Config &config, const lm::ngram::ProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); -template void Decode(const search::Config &config, const lm::ngram::RestProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); - -#ifdef WITH_THREADS -template void DecodeHandler::operator()(Input message) { - std::stringstream assemble; - Decode(config_, model_, message.file, assemble); - Produce(message.sentence_id, assemble.str()); -} - -template void DecodeHandler::Produce(unsigned int sentence_id, const std::string &str) { - Output out; - out.sentence_id = sentence_id; - out.str = new std::string(str); - out_.Produce(out); -} - -void PrintHandler::operator()(Output message) { - unsigned int relative = message.sentence_id - done_; - if (waiting_.size() <= relative) waiting_.resize(relative + 1); - waiting_[relative] = message.str; - for (std::string *lead; !waiting_.empty() && (lead = waiting_[0]); waiting_.pop_front(), ++done_) { - out_ << *lead; - delete lead; - } -} - -template Controller::Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to) : - sentence_id_(0), - printer_(decode_workers, 1, boost::ref(to), Output::Poison()), - decoder_(3, decode_workers, boost::in_place(boost::ref(config), boost::ref(model), boost::ref(printer_.In())), Input::Poison()) {} - -template class Controller; -template class Controller; - -#endif - -} // namespace alone diff --git a/klm/alone/threading.hh b/klm/alone/threading.hh deleted file mode 100644 index 0ab0f739..00000000 --- a/klm/alone/threading.hh +++ /dev/null @@ -1,129 +0,0 @@ -#ifndef ALONE_THREADING__ -#define ALONE_THREADING__ - -#ifdef WITH_THREADS -#include "util/pcqueue.hh" -#include "util/pool.hh" -#endif - -#include -#include -#include - -namespace util { -class FilePiece; -} // namespace util - -namespace search { -class Config; -template class Context; -} // namespace search - -namespace alone { - -template void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out); - -class Graph; - -#ifdef WITH_THREADS -struct SentenceID { - unsigned int sentence_id; - bool operator==(const SentenceID &other) const { - return sentence_id == other.sentence_id; - } -}; - -struct Input : public SentenceID { - util::FilePiece *file; - static Input Poison() { - Input ret; - ret.sentence_id = static_cast(-1); - ret.file = NULL; - return ret; - } -}; - -struct Output : public SentenceID { - std::string *str; - static Output Poison() { - Output ret; - ret.sentence_id = static_cast(-1); - ret.str = NULL; - return ret; - } -}; - -template class DecodeHandler { - public: - typedef Input Request; - - DecodeHandler(const search::Config &config, const Model &model, util::PCQueue &out) : config_(config), model_(model), out_(out) {} - - void operator()(Input message); - - private: - void Produce(unsigned int sentence_id, const std::string &str); - - const search::Config &config_; - - const Model &model_; - - util::PCQueue &out_; -}; - -class PrintHandler { - public: - typedef Output Request; - - explicit PrintHandler(std::ostream &o) : out_(o), done_(0) {} - - void operator()(Output message); - - private: - std::ostream &out_; - std::deque waiting_; - unsigned int done_; -}; - -template class Controller { - public: - // This config must remain valid. - explicit Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to); - - // Takes ownership of in. - void Add(util::FilePiece *in) { - Input input; - input.sentence_id = sentence_id_++; - input.file = in; - decoder_.Produce(input); - } - - private: - unsigned int sentence_id_; - - util::Pool printer_; - - util::Pool > decoder_; -}; -#endif - -// Same API as controller. -template class InThread { - public: - InThread(const search::Config &config, const Model &model, std::ostream &to) : config_(config), model_(model), to_(to) {} - - // Takes ownership of in. - void Add(util::FilePiece *in) { - Decode(config_, model_, in, to_); - } - - private: - const search::Config &config_; - - const Model &model_; - - std::ostream &to_; -}; - -} // namespace alone -#endif // ALONE_THREADING__ diff --git a/klm/alone/vocab.cc b/klm/alone/vocab.cc deleted file mode 100644 index ffe55301..00000000 --- a/klm/alone/vocab.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "alone/vocab.hh" - -#include "lm/virtual_interface.hh" -#include "util/string_piece.hh" - -namespace alone { - -Vocab::Vocab(const lm::base::Vocabulary &backing) : backing_(backing), end_sentence_(FindOrAdd("")) {} - -const std::pair &Vocab::FindOrAdd(const StringPiece &str) { - Map::const_iterator i(FindStringPiece(map_, str)); - if (i != map_.end()) return *i; - std::pair to_ins; - to_ins.first.assign(str.data(), str.size()); - to_ins.second = backing_.Index(str); - return *map_.insert(to_ins).first; -} - -} // namespace alone diff --git a/klm/alone/vocab.hh b/klm/alone/vocab.hh deleted file mode 100644 index 3ac0f542..00000000 --- a/klm/alone/vocab.hh +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef ALONE_VOCAB__ -#define ALONE_VOCAB__ - -#include "lm/word_index.hh" -#include "util/string_piece.hh" - -#include -#include - -#include - -namespace lm { namespace base { class Vocabulary; } } - -namespace alone { - -class Vocab { - public: - explicit Vocab(const lm::base::Vocabulary &backing); - - const std::pair &FindOrAdd(const StringPiece &str); - - const std::pair &EndSentence() const { return end_sentence_; } - - private: - typedef boost::unordered_map Map; - Map map_; - - const lm::base::Vocabulary &backing_; - - const std::pair &end_sentence_; -}; - -} // namespace alone -#endif // ALONE_VCOAB__ diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 40af8a63..2fd20481 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -87,7 +87,7 @@ template void GenericModel.. ; +lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : .. ; import testing ; diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am new file mode 100644 index 00000000..ccc5b7f6 --- /dev/null +++ b/klm/search/Makefile.am @@ -0,0 +1,11 @@ +noinst_LIBRARIES = libksearch.a + +libksearch_a_SOURCES = \ + edge_generator.cc \ + rule.cc \ + vertex.cc \ + vertex_generator.cc \ + weights.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. + diff --git a/klm/search/arity.hh b/klm/search/arity.hh deleted file mode 100644 index 09c2c671..00000000 --- a/klm/search/arity.hh +++ /dev/null @@ -1,8 +0,0 @@ -#ifndef SEARCH_ARITY__ -#define SEARCH_ARITY__ -namespace search { - -const unsigned int kMaxArity = 2; - -} // namespace search -#endif // SEARCH_ARITY__ diff --git a/klm/search/context.hh b/klm/search/context.hh index 27940053..62163144 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -7,6 +7,7 @@ #include "search/types.hh" #include "search/vertex.hh" #include "util/exception.hh" +#include "util/pool.hh" #include #include @@ -21,10 +22,8 @@ class ContextBase { public: explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - Final *NewFinal() { - Final *ret = final_pool_.construct(); - assert(ret); - return ret; + util::Pool &FinalPool() { + return final_pool_; } VertexNode *NewVertexNode() { @@ -42,7 +41,8 @@ class ContextBase { const Weights &GetWeights() const { return weights_; } private: - boost::object_pool final_pool_; + util::Pool final_pool_; + boost::object_pool vertex_node_pool_; unsigned int pop_limit_; diff --git a/klm/search/edge.hh b/klm/search/edge.hh index 77ab0ade..187904bf 100644 --- a/klm/search/edge.hh +++ b/klm/search/edge.hh @@ -2,30 +2,53 @@ #define SEARCH_EDGE__ #include "lm/state.hh" -#include "search/arity.hh" -#include "search/rule.hh" +#include "search/header.hh" #include "search/types.hh" #include "search/vertex.hh" +#include "util/pool.hh" -#include +#include + +#include namespace search { -struct PartialEdge { - Score score; - // Terminals - lm::ngram::ChartState between[kMaxArity + 1]; - // Non-terminals - PartialVertex nt[kMaxArity]; +// Copyable, but the copy will be shallow. +class PartialEdge : public Header { + public: + // Allow default construction for STL. + PartialEdge() {} + + PartialEdge(util::Pool &pool, Arity arity) + : Header(pool.Allocate(Size(arity, arity + 1)), arity) {} + + PartialEdge(util::Pool &pool, Arity arity, Arity chart_states) + : Header(pool.Allocate(Size(arity, chart_states)), arity) {} - const lm::ngram::ChartState &CompletedState() const { - return between[0]; - } + // Non-terminals + const PartialVertex *NT() const { + return reinterpret_cast(After()); + } + PartialVertex *NT() { + return reinterpret_cast(After()); + } - bool operator<(const PartialEdge &other) const { - return score < other.score; - } + const lm::ngram::ChartState &CompletedState() const { + return *Between(); + } + const lm::ngram::ChartState *Between() const { + return reinterpret_cast(After() + GetArity() * sizeof(PartialVertex)); + } + lm::ngram::ChartState *Between() { + return reinterpret_cast(After() + GetArity() * sizeof(PartialVertex)); + } + + private: + static std::size_t Size(Arity arity, Arity chart_states) { + return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState); + } }; + } // namespace search #endif // SEARCH_EDGE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index 56239dfb..260159b1 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -4,117 +4,107 @@ #include "lm/partial.hh" #include "search/context.hh" #include "search/vertex.hh" -#include "search/vertex_generator.hh" #include namespace search { -EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { -/* for (unsigned char i = 0; i < edge.Arity(); ++i) { - root.nt[i] = edge.GetVertex(i).RootPartial(); - } - for (unsigned char i = edge.Arity(); i < 2; ++i) { - root.nt[i] = kBlankPartialVertex; - }*/ - generate_.push(&root); - top_score_ = root.score; -} - namespace { -template float FastScore(const Context &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) { - memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1)); - - float ret = 0.0; - lm::ngram::ChartState *before, *after; - if (victim == 0) { - before = &update.between[0]; - after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1]; - } else { - assert(victim == 1); - assert(arity == 2); - before = &update.between[previous.nt[0].Complete() ? 0 : 1]; - after = &update.between[2]; - } - const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State(); - const PartialVertex &update_nt = update.nt[victim]; +template void FastScore(const Context &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) { + lm::ngram::ChartState *between = update.Between(); + lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1]; + + float adjustment = 0.0; + const lm::ngram::ChartState &previous_reveal = previous_vertex.State(); + const PartialVertex &update_nt = update.NT()[victim]; const lm::ngram::ChartState &update_reveal = update_nt.State(); - float just_after = 0.0; if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { - just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); + adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); } - if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) { - ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); + if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) { + adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); } if (update_nt.Complete()) { if (update_reveal.left.full) { before->left.full = true; } else { assert(update_reveal.left.length == update_reveal.right.length); - ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); + adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); } - if (victim == 0) { - update.between[0].right = after->right; - } else { - update.between[2].left = before->left; + before->right = after->right; + // Shift the others shifted one down, covering after. + for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) { + *cover = *(cover + 1); } } - return previous.score + (ret + just_after) * context.GetWeights().LM(); + update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); } } // namespace -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool) { +template PartialEdge EdgeGenerator::Pop(Context &context) { assert(!generate_.empty()); - PartialEdge &top = *generate_.top(); + PartialEdge top = generate_.top(); generate_.pop(); - unsigned int victim = 0; - unsigned char lowest_length = 255; - for (unsigned char i = 0; i != arity_; ++i) { - if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { - lowest_length = top.nt[i].Length(); - victim = i; + PartialVertex *const top_nt = top.NT(); + const Arity arity = top.GetArity(); + + Arity victim = 0; + Arity victim_completed; + Arity incomplete; + // Select victim or return if complete. + { + Arity completed = 0; + unsigned char lowest_length = 255; + for (Arity i = 0; i != arity; ++i) { + if (top_nt[i].Complete()) { + ++completed; + } else if (top_nt[i].Length() < lowest_length) { + lowest_length = top_nt[i].Length(); + victim = i; + victim_completed = completed; + } } - } - if (lowest_length == 255) { - // All states report complete. - top.between[0].right = top.between[arity_].right; - // Now top.between[0] is the full edge state. - top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; - return ⊤ + if (lowest_length == 255) { + return top; + } + incomplete = arity - completed; } - unsigned int stay = !victim; - PartialEdge &continuation = *static_cast(partial_edge_pool.malloc()); - float old_bound = top.nt[victim].Bound(); - // The alternate's score will change because alternate.nt[victim] changes. - bool split = top.nt[victim].Split(continuation.nt[victim]); - // top is now the alternate. + PartialVertex old_value(top_nt[victim]); + PartialVertex alternate_changed; + if (top_nt[victim].Split(alternate_changed)) { + PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1); + alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound()); - continuation.nt[stay] = top.nt[stay]; - continuation.score = FastScore(context, victim, arity_, top, continuation); - // TODO: dedupe? - generate_.push(&continuation); + alternate.SetNote(top.GetNote()); + + PartialVertex *alternate_nt = alternate.NT(); + for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i]; + alternate_nt[victim] = alternate_changed; + for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i]; + + memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1)); - if (split) { - // We have an alternate. - top.score += top.nt[victim].Bound() - old_bound; // TODO: dedupe? - generate_.push(&top); - } else { - partial_edge_pool.free(&top); + generate_.push(alternate); } - top_score_ = generate_.top()->score; - return NULL; + // top is now the continuation. + FastScore(context, victim, victim - victim_completed, incomplete, old_value, top); + // TODO: dedupe? + generate_.push(top); + + // Invalid indicates no new hypothesis generated. + return PartialEdge(); } -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); } // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index 875ccc5e..582c78b7 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -3,11 +3,8 @@ #include "search/edge.hh" #include "search/note.hh" +#include "search/types.hh" -#include -#include - -#include #include namespace lm { @@ -20,38 +17,40 @@ namespace search { template class Context; -class VertexGenerator; - -struct PartialEdgePointerLess : std::binary_function { - bool operator()(const PartialEdge *first, const PartialEdge *second) const { - return *first < *second; - } -}; - class EdgeGenerator { public: - EdgeGenerator(PartialEdge &root, unsigned char arity, Note note); + EdgeGenerator() {} - Score TopScore() const { - return top_score_; + PartialEdge AllocateEdge(Arity arity) { + return PartialEdge(partial_edge_pool_, arity); } - Note GetNote() const { - return note_; + void AddEdge(PartialEdge edge) { + generate_.push(edge); } - // Pop. If there's a complete hypothesis, return it. Otherwise return NULL. - template PartialEdge *Pop(Context &context, boost::pool<> &partial_edge_pool); + bool Empty() const { return generate_.empty(); } + + // Pop. If there's a complete hypothesis, return it. Otherwise return an invalid PartialEdge. + template PartialEdge Pop(Context &context); + + template void Search(Context &context, Output &output) { + unsigned to_pop = context.PopLimit(); + while (to_pop > 0 && !generate_.empty()) { + PartialEdge got(Pop(context)); + if (got.Valid()) { + output.NewHypothesis(got); + --to_pop; + } + } + output.FinishedSearch(); + } private: - Score top_score_; - - unsigned char arity_; + util::Pool partial_edge_pool_; - typedef std::priority_queue, PartialEdgePointerLess> Generate; + typedef std::priority_queue Generate; Generate generate_; - - Note note_; }; } // namespace search diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc deleted file mode 100644 index e3ae6ebf..00000000 --- a/klm/search/edge_queue.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "search/edge_queue.hh" - -#include "lm/left.hh" -#include "search/context.hh" - -#include - -namespace search { - -EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) { - take_ = static_cast(partial_edge_pool_.malloc()); -} - -/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) { - // Ignore empty edges. - for (unsigned char i = 0; i < edge.Arity(); ++i) { - PartialVertex root(edge.GetVertex(i).RootPartial()); - if (root.Empty()) return; - total_score += root.Bound(); - } - PartialEdge &allocated = *static_cast(partial_edge_pool_.malloc()); - allocated.score = total_score; -}*/ - -} // namespace search diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh deleted file mode 100644 index 187eaed7..00000000 --- a/klm/search/edge_queue.hh +++ /dev/null @@ -1,73 +0,0 @@ -#ifndef SEARCH_EDGE_QUEUE__ -#define SEARCH_EDGE_QUEUE__ - -#include "search/edge.hh" -#include "search/edge_generator.hh" -#include "search/note.hh" - -#include -#include - -#include - -namespace search { - -template class Context; - -class EdgeQueue { - public: - explicit EdgeQueue(unsigned int pop_limit_hint); - - PartialEdge &InitializeEdge() { - return *take_; - } - - void AddEdge(unsigned char arity, Note note) { - generate_.push(edge_pool_.construct(*take_, arity, note)); - take_ = static_cast(partial_edge_pool_.malloc()); - } - - bool Empty() const { return generate_.empty(); } - - /* Generate hypotheses and send them to output. Normally, output is a - * VertexGenerator, but the decoder may want to route edges to different - * vertices i.e. if they have different LHS non-terminal labels. - */ - template void Search(Context &context, Output &output) { - int to_pop = context.PopLimit(); - while (to_pop > 0 && !generate_.empty()) { - EdgeGenerator *top = generate_.top(); - generate_.pop(); - PartialEdge *ret = top->Pop(context, partial_edge_pool_); - if (ret) { - output.NewHypothesis(*ret, top->GetNote()); - --to_pop; - if (top->TopScore() != -kScoreInf) { - generate_.push(top); - } - } else { - generate_.push(top); - } - } - output.FinishedSearch(); - } - - private: - boost::object_pool edge_pool_; - - struct LessByTopScore : public std::binary_function { - bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { - return first->TopScore() < second->TopScore(); - } - }; - - typedef std::priority_queue, LessByTopScore> Generate; - Generate generate_; - - boost::pool<> partial_edge_pool_; - - PartialEdge *take_; -}; - -} // namespace search -#endif // SEARCH_EDGE_QUEUE__ diff --git a/klm/search/final.hh b/klm/search/final.hh index 1b3092ac..50e62cf2 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -1,37 +1,34 @@ #ifndef SEARCH_FINAL__ #define SEARCH_FINAL__ -#include "search/arity.hh" -#include "search/note.hh" -#include "search/types.hh" - -#include +#include "search/header.hh" +#include "util/pool.hh" namespace search { -class Final { +// A full hypothesis with pointers to children. +class Final : public Header { public: - typedef boost::array ChildArray; + Final() {} - void Reset(Score bound, Note note, const Final &left, const Final &right) { - bound_ = bound; - note_ = note; - children_[0] = &left; - children_[1] = &right; + Final(util::Pool &pool, Score score, Arity arity, Note note) + : Header(pool.Allocate(Size(arity)), arity) { + SetScore(score); + SetNote(note); } - const ChildArray &Children() const { return children_; } - - Note GetNote() const { return note_; } - - Score Bound() const { return bound_; } + // These are arrays of length GetArity(). + Final *Children() { + return reinterpret_cast(After()); + } + const Final *Children() const { + return reinterpret_cast(After()); + } private: - Score bound_; - - Note note_; - - ChildArray children_; + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Final); + } }; } // namespace search diff --git a/klm/search/header.hh b/klm/search/header.hh new file mode 100644 index 00000000..25550dbe --- /dev/null +++ b/klm/search/header.hh @@ -0,0 +1,57 @@ +#ifndef SEARCH_HEADER__ +#define SEARCH_HEADER__ + +// Header consisting of Score, Arity, and Note + +#include "search/note.hh" +#include "search/types.hh" + +#include + +namespace search { + +// Copying is shallow. +class Header { + public: + bool Valid() const { return base_; } + + Score GetScore() const { + return *reinterpret_cast(base_); + } + void SetScore(Score to) { + *reinterpret_cast(base_) = to; + } + bool operator<(const Header &other) const { + return GetScore() < other.GetScore(); + } + + Arity GetArity() const { + return *reinterpret_cast(base_ + sizeof(Score)); + } + + Note GetNote() const { + return *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)); + } + void SetNote(Note to) { + *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)) = to; + } + + protected: + Header() : base_(NULL) {} + + Header(void *base, Arity arity) : base_(static_cast(base)) { + *reinterpret_cast(base_ + sizeof(Score)) = arity; + } + + static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note); + + uint8_t *After() { return base_ + kHeaderSize; } + const uint8_t *After() const { return base_ + kHeaderSize; } + + private: + uint8_t *base_; +}; + +} // namespace search + +#endif // SEARCH_HEADER__ diff --git a/klm/search/source.hh b/klm/search/source.hh deleted file mode 100644 index 11839f7b..00000000 --- a/klm/search/source.hh +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef SEARCH_SOURCE__ -#define SEARCH_SOURCE__ - -#include "search/types.hh" - -#include -#include - -namespace search { - -template class Source { - public: - Source() : bound_(kScoreInf) {} - - Index Size() const { - return final_.size(); - } - - Score Bound() const { - return bound_; - } - - const Final &operator[](Index index) const { - return *final_[index]; - } - - Score ScoreOrBound(Index index) const { - return Size() > index ? final_[index]->Total() : Bound(); - } - - protected: - void AddFinal(const Final &store) { - final_.push_back(&store); - } - - void SetBound(Score to) { - assert(to <= bound_ + 0.001); - bound_ = to; - } - - private: - std::vector final_; - - Score bound_; -}; - -} // namespace search -#endif // SEARCH_SOURCE__ diff --git a/klm/search/types.hh b/klm/search/types.hh index 9726379f..06eb5bfa 100644 --- a/klm/search/types.hh +++ b/klm/search/types.hh @@ -1,17 +1,13 @@ #ifndef SEARCH_TYPES__ #define SEARCH_TYPES__ -#include +#include namespace search { typedef float Score; -const Score kScoreInf = INFINITY; -// This could have been an enum but gcc wants 4 bytes. -typedef bool ExtendDirection; -const ExtendDirection kExtendLeft = 0; -const ExtendDirection kExtendRight = 1; +typedef uint32_t Arity; } // namespace search diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index cc53c0dd..11f4631f 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -21,9 +21,9 @@ struct GreaterByBound : public std::binary_functionBound(); + bound_ = end_.GetScore(); return; } if (extend_.size() == 1 && parent_ptr) { @@ -39,10 +39,4 @@ void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { bound_ = extend_.front()->Bound(); } -namespace { -VertexNode kBlankVertexNode; -} // namespace - -PartialVertex kBlankPartialVertex(kBlankVertexNode); - } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index e1a9ad11..52bc1dfe 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -18,7 +18,7 @@ class ContextBase; class VertexNode { public: - VertexNode() : end_(NULL) {} + VertexNode() {} void InitRoot() { extend_.clear(); @@ -26,8 +26,7 @@ class VertexNode { state_.left.length = 0; state_.right.length = 0; right_full_ = false; - bound_ = -kScoreInf; - end_ = NULL; + end_ = Final(); } lm::ngram::ChartState &MutableState() { return state_; } @@ -37,19 +36,20 @@ class VertexNode { extend_.push_back(next); } - void SetEnd(Final *end) { end_ = end; } + void SetEnd(Final end) { + assert(!end_.Valid()); + end_ = end; + } - Final &MutableEnd() { return *end_; } - void SortAndSet(ContextBase &context, VertexNode **parent_pointer); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_ && extend_.empty(); + return !end_.Valid() && extend_.empty(); } bool Complete() const { - return end_; + return end_.Valid(); } const lm::ngram::ChartState &State() const { return state_; } @@ -63,8 +63,8 @@ class VertexNode { return state_.left.length + state_.right.length; } - // May be NULL. - const Final *End() const { return end_; } + // Will be invalid unless this is a leaf. + const Final End() const { return end_; } const VertexNode &operator[](size_t index) const { return *extend_[index]; @@ -81,7 +81,7 @@ class VertexNode { bool right_full_; Score bound_; - Final *end_; + Final end_; }; class PartialVertex { @@ -97,7 +97,7 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); } + Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } unsigned char Length() const { return back_->Length(); } @@ -105,20 +105,24 @@ class PartialVertex { return index_ + 1 < back_->Size(); } - // Split into continuation and alternative, rendering this the alternative. - bool Split(PartialVertex &continuation) { + // Split into continuation and alternative, rendering this the continuation. + bool Split(PartialVertex &alternative) { assert(!Complete()); - continuation.back_ = &((*back_)[index_]); - continuation.index_ = 0; + bool ret; if (index_ + 1 < back_->Size()) { - ++index_; - return true; + alternative.index_ = index_ + 1; + alternative.back_ = back_; + ret = true; + } else { + ret = false; } - return false; + back_ = &((*back_)[index_]); + index_ = 0; + return ret; } - const Final &End() const { - return *back_->End(); + const Final End() const { + return back_->End(); } private: @@ -126,25 +130,22 @@ class PartialVertex { unsigned int index_; }; -extern PartialVertex kBlankPartialVertex; - class Vertex { public: Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } - const Final *BestChild() const { + const Final BestChild() const { PartialVertex top(RootPartial()); if (top.Empty()) { - return NULL; + return Final(); } else { PartialVertex continuation; while (!top.Complete()) { top.Split(continuation); - top = continuation; } - return &top.End(); + return top.End(); } } diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index d94e6e06..0945fe55 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -10,74 +10,85 @@ namespace search { VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { gen.root_.InitRoot(); - root_.under = &gen.root_; } namespace { + const uint64_t kCompleteAdd = static_cast(-1); -} // namespace -void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) { - const lm::ngram::ChartState &state = partial.CompletedState(); - std::pair got(existing_.insert(std::pair(hash_value(state), NULL))); - if (!got.second) { - // Found it already. - Final &exists = *got.first->second; - if (exists.Bound() < partial.score) { - exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); - } - return; +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map extend; +}; + +Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { + Trie &next = node.extend[added]; + if (!next.under) { + next.under = context.NewVertexNode(); + lm::ngram::ChartState &writing = next.under->MutableState(); + writing = state; + writing.left.full &= left_full && state.left.full; + next.under->MutableRightFull() = right_full && state.left.full; + writing.left.length = left; + writing.right.length = right; + node.under->AddExtend(next.under); } + return next; +} + +void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { + Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); + Final *child_out = final.Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = part->End(); + + starter.under->SetEnd(final); +} + +void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + unsigned char left = 0, right = 0; - Trie *node = &root_; + Trie *node = &root; while (true) { if (left == state.left.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false); for (; right < state.right.length; ++right) { - node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false); } break; } - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false); left++; if (right == state.right.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true); for (; left < state.left.length; ++left) { - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true); } break; } - node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false); right++; } - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - got.first->second = CompleteTransition(*node, state, note, partial); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); + CompleteTransition(context, *node, partial); } -VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { - VertexGenerator::Trie &next = node.extend[added]; - if (!next.under) { - next.under = context_.NewVertexNode(); - lm::ngram::ChartState &writing = next.under->MutableState(); - writing = state; - writing.left.full &= left_full && state.left.full; - next.under->MutableRightFull() = right_full && state.left.full; - writing.left.length = left; - writing.right.length = right; - node.under->AddExtend(next.under); - } - return next; -} +} // namespace -Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) { - VertexNode &node = *starter.under; - assert(node.State().left.full == state.left.full); - assert(!node.End()); - Final *final = context_.NewFinal(); - final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); - node.SetEnd(final); - return final; +void VertexGenerator::FinishedSearch() { + Trie root; + root.under = &gen_.root_; + for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, i->second); + } + root.under->SortAndSet(context_, NULL); } } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 6b98da3e..60e86112 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -1,13 +1,11 @@ #ifndef SEARCH_VERTEX_GENERATOR__ #define SEARCH_VERTEX_GENERATOR__ -#include "search/note.hh" +#include "search/edge.hh" #include "search/vertex.hh" #include -#include - namespace lm { namespace ngram { class ChartState; @@ -18,40 +16,29 @@ namespace search { class ContextBase; class Final; -struct PartialEdge; class VertexGenerator { public: VertexGenerator(ContextBase &context, Vertex &gen); - void NewHypothesis(const PartialEdge &partial, Note note); - - void FinishedSearch() { - root_.under->SortAndSet(context_, NULL); + void NewHypothesis(PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + std::pair ret(existing_.insert(std::make_pair(hash_value(state), partial))); + if (!ret.second && ret.first->second < partial) { + ret.first->second = partial; + } } + void FinishedSearch(); + const Vertex &Generating() const { return gen_; } private: - // Parallel structure to VertexNode. - struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map extend; - }; - - Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); - - Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial); - ContextBase &context_; Vertex &gen_; - Trie root_; - - typedef boost::unordered_map Existing; + typedef boost::unordered_map Existing; Existing existing_; }; diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 5ceccf2c..5306850f 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -26,6 +26,8 @@ libklm_util_a_SOURCES = \ file_piece.cc \ mmap.cc \ murmur_hash.cc \ + pool.cc \ + string_piece.cc \ usage.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index ff4d590f..9909736d 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -4,7 +4,7 @@ #include #include -#include +#include // Ersatz version of boost::progress so core language model doesn't depend on // boost. Also adds option to print nothing. diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 83f99cd6..053a850b 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -6,7 +6,7 @@ #include #include -#include +#include namespace util { diff --git a/klm/util/pool.cc b/klm/util/pool.cc new file mode 100644 index 00000000..2dffd06f --- /dev/null +++ b/klm/util/pool.cc @@ -0,0 +1,35 @@ +#include "util/pool.hh" + +#include + +namespace util { + +Pool::Pool() { + current_ = NULL; + current_end_ = NULL; +} + +Pool::~Pool() { + FreeAll(); +} + +void Pool::FreeAll() { + for (std::vector::const_iterator i(free_list_.begin()); i != free_list_.end(); ++i) { + free(*i); + } + free_list_.clear(); + current_ = NULL; + current_end_ = NULL; +} + +void *Pool::More(std::size_t size) { + std::size_t amount = std::max(static_cast(32) << free_list_.size(), size); + uint8_t *ret = static_cast(malloc(amount)); + if (!ret) throw std::bad_alloc(); + free_list_.push_back(ret); + current_ = ret + size; + current_end_ = ret + amount; + return ret; +} + +} // namespace util diff --git a/klm/util/pool.hh b/klm/util/pool.hh new file mode 100644 index 00000000..72f8a0c8 --- /dev/null +++ b/klm/util/pool.hh @@ -0,0 +1,45 @@ +// Very simple pool. It can only allocate memory. And all of the memory it +// allocates must be freed at the same time. + +#ifndef UTIL_POOL__ +#define UTIL_POOL__ + +#include + +#include + +namespace util { + +class Pool { + public: + Pool(); + + ~Pool(); + + void *Allocate(std::size_t size) { + void *ret = current_; + current_ += size; + if (current_ < current_end_) { + return ret; + } else { + return More(size); + } + } + + void FreeAll(); + + private: + void *More(std::size_t size); + + std::vector free_list_; + + uint8_t *current_, *current_end_; + + // no copying + Pool(const Pool &); + Pool &operator=(const Pool &); +}; + +} // namespace util + +#endif // UTIL_POOL__ diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 770faa7e..4a8aff35 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -8,7 +8,7 @@ #include #include -#include +#include namespace util { diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc new file mode 100644 index 00000000..b422cefc --- /dev/null +++ b/klm/util/string_piece.cc @@ -0,0 +1,192 @@ +// Copyright 2004 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in string_piece.hh. + +#include "util/string_piece.hh" + +#include + +#include + +#ifndef HAVE_ICU + +typedef StringPiece::size_type size_type; + +void StringPiece::CopyToString(std::string* target) const { + target->assign(ptr_, length_); +} + +size_type StringPiece::find(const StringPiece& s, size_type pos) const { + if (length_ < 0 || pos > static_cast(length_)) + return npos; + + const char* result = std::search(ptr_ + pos, ptr_ + length_, + s.ptr_, s.ptr_ + s.length_); + const size_type xpos = result - ptr_; + return xpos + s.length_ <= length_ ? xpos : npos; +} + +size_type StringPiece::find(char c, size_type pos) const { + if (length_ <= 0 || pos >= static_cast(length_)) { + return npos; + } + const char* result = std::find(ptr_ + pos, ptr_ + length_, c); + return result != ptr_ + length_ ? result - ptr_ : npos; +} + +size_type StringPiece::rfind(const StringPiece& s, size_type pos) const { + if (length_ < s.length_) return npos; + const size_t ulen = length_; + if (s.length_ == 0) return std::min(ulen, pos); + + const char* last = ptr_ + std::min(ulen - s.length_, pos) + s.length_; + const char* result = std::find_end(ptr_, last, s.ptr_, s.ptr_ + s.length_); + return result != last ? result - ptr_ : npos; +} + +size_type StringPiece::rfind(char c, size_type pos) const { + if (length_ <= 0) return npos; + for (int i = std::min(pos, static_cast(length_ - 1)); + i >= 0; --i) { + if (ptr_[i] == c) { + return i; + } + } + return npos; +} + +// For each character in characters_wanted, sets the index corresponding +// to the ASCII code of that character to 1 in table. This is used by +// the find_.*_of methods below to tell whether or not a character is in +// the lookup table in constant time. +// The argument `table' must be an array that is large enough to hold all +// the possible values of an unsigned char. Thus it should be be declared +// as follows: +// bool table[UCHAR_MAX + 1] +static inline void BuildLookupTable(const StringPiece& characters_wanted, + bool* table) { + const size_type length = characters_wanted.length(); + const char* const data = characters_wanted.data(); + for (size_type i = 0; i < length; ++i) { + table[static_cast(data[i])] = true; + } +} + +size_type StringPiece::find_first_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0 || s.length_ == 0) + return npos; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_first_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = pos; i < length_; ++i) { + if (lookup[static_cast(ptr_[i])]) { + return i; + } + } + return npos; +} + +size_type StringPiece::find_first_not_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0) + return npos; + + if (s.length_ == 0) + return 0; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_first_not_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = pos; i < length_; ++i) { + if (!lookup[static_cast(ptr_[i])]) { + return i; + } + } + return npos; +} + +size_type StringPiece::find_first_not_of(char c, size_type pos) const { + if (length_ == 0) + return npos; + + for (; pos < length_; ++pos) { + if (ptr_[pos] != c) { + return pos; + } + } + return npos; +} + +size_type StringPiece::find_last_of(const StringPiece& s, size_type pos) const { + if (length_ == 0 || s.length_ == 0) + return npos; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_last_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = std::min(pos, length_ - 1); ; --i) { + if (lookup[static_cast(ptr_[i])]) + return i; + if (i == 0) + break; + } + return npos; +} + +size_type StringPiece::find_last_not_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0) + return npos; + + size_type i = std::min(pos, length_ - 1); + if (s.length_ == 0) + return i; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_last_not_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (; ; --i) { + if (!lookup[static_cast(ptr_[i])]) + return i; + if (i == 0) + break; + } + return npos; +} + +size_type StringPiece::find_last_not_of(char c, size_type pos) const { + if (length_ == 0) + return npos; + + for (size_type i = std::min(pos, length_ - 1); ; --i) { + if (ptr_[i] != c) + return i; + if (i == 0) + break; + } + return npos; +} + +StringPiece StringPiece::substr(size_type pos, size_type n) const { + if (pos > length_) pos = length_; + if (n > length_ - pos) n = length_ - pos; + return StringPiece(ptr_ + pos, n); +} + +const size_type StringPiece::npos = size_type(-1); + +#endif // !HAVE_ICU diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index c7e1c863..4a7f5460 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -54,6 +54,18 @@ class AnyCharacter { StringPiece chars_; }; +class AnyCharacterLast { + public: + explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {} + + StringPiece Find(const StringPiece &in) const { + return StringPiece(std::find_end(in.data(), in.data() + in.size(), chars_.data(), chars_.data() + chars_.size()), 1); + } + + private: + StringPiece chars_; +}; + template class TokenIter : public boost::iterator_facade, const StringPiece, boost::forward_traversal_tag> { public: TokenIter() {} diff --git a/mira/Makefile.am b/mira/Makefile.am index 7b4a4e12..3f8f17cd 100644 --- a/mira/Makefile.am +++ b/mira/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = kbest_mira kbest_mira_SOURCES = kbest_mira.cc -kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/Makefile.am b/training/Makefile.am index 5254333a..f9c25391 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -32,60 +32,60 @@ libtraining_a_SOURCES = \ risk.cc mpi_online_optimize_SOURCES = mpi_online_optimize.cc -mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_flex_optimize_SOURCES = mpi_flex_optimize.cc -mpi_flex_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_flex_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc -mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_extract_features_SOURCES = mpi_extract_features.cc -mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc cllh_observer.cc -mpi_batch_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_batch_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc cllh_observer.cc -mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz augment_grammar_SOURCES = augment_grammar.cc -augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz test_ngram_SOURCES = test_ngram.cc -test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz fast_align_SOURCES = fast_align.cc ttables.cc -fast_align_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +fast_align_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz lbl_model_SOURCES = lbl_model.cc -lbl_model_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +lbl_model_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz grammar_convert_SOURCES = grammar_convert.cc -grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz optimize_test_SOURCES = optimize_test.cc -optimize_test_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +optimize_test_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz collapse_weights_SOURCES = collapse_weights.cc -collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz lbfgs_test_SOURCES = lbfgs_test.cc -lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_optimize_reduce_SOURCES = mr_optimize_reduce.cc -mr_optimize_reduce_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_optimize_reduce_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_em_map_adapter_SOURCES = mr_em_map_adapter.cc -mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_reduce_to_weights_SOURCES = mr_reduce_to_weights.cc -mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_em_adapted_reduce_SOURCES = mr_em_adapted_reduce.cc -mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz plftools_SOURCES = plftools.cc -plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm -- cgit v1.2.3 From de53e2e98acd0e2d07efb39bef430bd598908aa8 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 14 Dec 2012 12:39:04 -0800 Subject: Updated incremental, updated kenlm. Incremental assumes --- decoder/incremental.cc | 44 +++++++------- klm/lm/left.hh | 66 +++++++++++---------- klm/lm/max_order.hh | 7 +-- klm/lm/model.cc | 3 +- klm/lm/search_hashed.cc | 8 +-- klm/lm/search_hashed.hh | 2 +- klm/lm/vocab.cc | 7 ++- klm/lm/vocab.hh | 5 +- klm/search/Makefile.am | 4 +- klm/search/applied.hh | 86 +++++++++++++++++++++++++++ klm/search/config.hh | 25 ++++++-- klm/search/context.hh | 28 ++------- klm/search/dedupe.hh | 131 +++++++++++++++++++++++++++++++++++++++++ klm/search/edge_generator.cc | 3 +- klm/search/edge_generator.hh | 1 - klm/search/final.hh | 36 ----------- klm/search/header.hh | 9 ++- klm/search/nbest.cc | 106 +++++++++++++++++++++++++++++++++ klm/search/nbest.hh | 81 +++++++++++++++++++++++++ klm/search/note.hh | 12 ---- klm/search/rule.cc | 52 ++++++++-------- klm/search/rule.hh | 11 +++- klm/search/types.hh | 17 ++++++ klm/search/vertex.cc | 27 ++++++--- klm/search/vertex.hh | 37 +++++++----- klm/search/vertex_generator.cc | 44 +++----------- klm/search/vertex_generator.hh | 72 ++++++++++++++++++---- klm/search/weights.cc | 71 ---------------------- klm/search/weights.hh | 52 ---------------- klm/search/weights_test.cc | 38 ------------ 30 files changed, 680 insertions(+), 405 deletions(-) create mode 100644 klm/search/applied.hh create mode 100644 klm/search/dedupe.hh delete mode 100644 klm/search/final.hh create mode 100644 klm/search/nbest.cc create mode 100644 klm/search/nbest.hh delete mode 100644 klm/search/note.hh delete mode 100644 klm/search/weights.cc delete mode 100644 klm/search/weights.hh delete mode 100644 klm/search/weights_test.cc (limited to 'klm/lm/vocab.hh') diff --git a/decoder/incremental.cc b/decoder/incremental.cc index 46615b0b..85647a44 100644 --- a/decoder/incremental.cc +++ b/decoder/incremental.cc @@ -6,6 +6,7 @@ #include "lm/enumerate_vocab.hh" #include "lm/model.hh" +#include "search/applied.hh" #include "search/config.hh" #include "search/context.hh" #include "search/edge.hh" @@ -48,16 +49,16 @@ template class Incremental : public IncrementalBase { Incremental(const char *model_file, const std::vector &weights) : IncrementalBase(weights), m_(model_file, GetConfig()), - weights_( - weights[FD::Convert("KLanguageModel")], - weights[FD::Convert("KLanguageModel_OOV")], - weights[FD::Convert("WordPenalty")]) { - std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; + lm_(weights[FD::Convert("KLanguageModel")]), + oov_(weights[FD::Convert("KLanguageModel_OOV")]), + word_penalty_(weights[FD::Convert("WordPenalty")]) { + std::cerr << "Weights KLanguageModel " << lm_ << " KLanguageModel_OOV " << oov_ << " WordPenalty " << word_penalty_ << std::endl; } + void Search(unsigned int pop_limit, const Hypergraph &hg) const; private: - void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; + void ConvertEdge(const search::Context &context, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; lm::ngram::Config GetConfig() { lm::ngram::Config ret; @@ -69,46 +70,47 @@ template class Incremental : public IncrementalBase { const Model m_; - const search::Weights weights_; + const float lm_, oov_, word_penalty_; }; -void PrintFinal(const Hypergraph &hg, const search::Final final) { +void PrintApplied(const Hypergraph &hg, const search::Applied final) { const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); - const search::Final *child(final.Children()); + const search::Applied *child(final.Children()); for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { if (*i > 0) { std::cout << TD::Convert(*i) << ' '; } else { - PrintFinal(hg, *child++); + PrintApplied(hg, *child++); } } } template void Incremental::Search(unsigned int pop_limit, const Hypergraph &hg) const { boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); - search::Config config(weights_, pop_limit); + search::Config config(lm_, pop_limit, search::NBestConfig(1)); search::Context context(config, m_); + search::SingleBest best; for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { search::EdgeGenerator gen; const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; for (unsigned int j = 0; j < down_edges.size(); ++j) { unsigned int edge_index = down_edges[j]; - ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], gen); + ConvertEdge(context, out_vertices.get(), hg.edges_[edge_index], gen); } - search::VertexGenerator vertex_gen(context, out_vertices[i]); + search::VertexGenerator vertex_gen(context, out_vertices[i], best); gen.Search(context, vertex_gen); } - const search::Final top = out_vertices[hg.nodes_.size() - 2].BestChild(); + const search::Applied top = out_vertices[hg.nodes_.size() - 2].BestChild(); if (!top.Valid()) { std::cout << "NO PATH FOUND" << std::endl; } else { - PrintFinal(hg, top); + PrintApplied(hg, top); std::cout << "||| " << top.GetScore() << std::endl; } } -template void Incremental::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const { +template void Incremental::ConvertEdge(const search::Context &context, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const { const std::vector &e = in.rule_->e(); std::vector words; words.reserve(e.size()); @@ -127,10 +129,6 @@ template void Incremental::ConvertEdge(const search::Contex } } - if (final) { - words.push_back(m_.GetVocabulary().EndSentence()); - } - search::PartialEdge out(gen.AllocateEdge(nts.size())); memcpy(out.NT(), &nts[0], sizeof(search::PartialVertex) * nts.size()); @@ -140,8 +138,10 @@ template void Incremental::ConvertEdge(const search::Contex out.SetNote(note); score += in.rule_->GetFeatureValues().dot(cdec_weights_); - score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; - score += search::ScoreRule(context, words, final, out.Between()); + score -= static_cast(terminals) * word_penalty_ / M_LN10; + search::ScoreRuleRet res(search::ScoreRule(context.LanguageModel(), words, out.Between())); + score += res.prob * lm_ + static_cast(res.oov) * oov_; + out.SetScore(score); gen.AddEdge(out); diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 8c27232e..85c1ea37 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -51,36 +51,36 @@ namespace ngram { template class RuleScore { public: - explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) { out.left.length = 0; out.right.length = 0; } void BeginSentence() { - out_.right = model_.BeginSentenceState(); - // out_.left is empty. + out_->right = model_.BeginSentenceState(); + // out_->left is empty. left_done_ = true; } void Terminal(WordIndex word) { - State copy(out_.right); - FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); + State copy(out_->right); + FullScoreReturn ret(model_.FullScore(copy, word, out_->right)); if (left_done_) { prob_ += ret.prob; return; } if (ret.independent_left) { prob_ += ret.prob; left_done_ = true; return; } - out_.left.pointers[out_.left.length++] = ret.extend_left; + out_->left.pointers[out_->left.length++] = ret.extend_left; prob_ += ret.rest; - if (out_.right.length != copy.length + 1) + if (out_->right.length != copy.length + 1) left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. void BeginNonTerminal(const ChartState &in, float prob = 0.0) { prob_ = prob; - out_ = in; + *out_ = in; left_done_ = in.left.full; } @@ -89,23 +89,23 @@ template class RuleScore { if (!in.left.length) { if (in.left.full) { - for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; + for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i; left_done_ = true; - out_.right = in.right; + out_->right = in.right; } return; } - if (!out_.right.length) { - out_.right = in.right; + if (!out_->right.length) { + out_->right = in.right; if (left_done_) { prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1); return; } - if (out_.left.length) { + if (out_->left.length) { left_done_ = true; } else { - out_.left = in.left; + out_->left = in.left; left_done_ = in.left.full; } return; @@ -113,10 +113,10 @@ template class RuleScore { float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1]; float *back = backoffs, *back2 = backoffs2; - unsigned char next_use = out_.right.length; + unsigned char next_use = out_->right.length; // First word - if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return; + if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return; // Words after the first, so extending a bigram to begin with for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { @@ -127,54 +127,58 @@ template class RuleScore { if (in.left.full) { for (const float *i = back; i != back + next_use; ++i) prob_ += *i; left_done_ = true; - out_.right = in.right; + out_->right = in.right; return; } // Right state was minimized, so it's already independent of the new words to the left. if (in.right.length < in.left.length) { - out_.right = in.right; + out_->right = in.right; return; } // Shift exisiting words down. - for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { + for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) { *(i + in.right.length) = *i; } // Add words from in.right. - std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); + std::copy(in.right.words, in.right.words + in.right.length, out_->right.words); // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. - std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); - std::copy(back, back + next_use, out_.right.backoff + in.right.length); - out_.right.length = in.right.length + next_use; + std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff); + std::copy(back, back + next_use, out_->right.backoff + in.right.length); + out_->right.length = in.right.length + next_use; } float Finish() { // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. - out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1); + out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1); return prob_; } void Reset() { prob_ = 0.0; left_done_ = false; - out_.left.length = 0; - out_.right.length = 0; + out_->left.length = 0; + out_->right.length = 0; + } + void Reset(ChartState &replacement) { + out_ = &replacement; + Reset(); } private: bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { ProcessRet(model_.ExtendLeft( - out_.right.words, out_.right.words + next_use, // Words to extend into + out_->right.words, out_->right.words + next_use, // Words to extend into back_in, // Backoffs to use in.left.pointers[extend_length - 1], extend_length, // Words to be extended back_out, // Backoffs for the next score next_use)); // Length of n-gram to use in next scoring. - if (next_use != out_.right.length) { + if (next_use != out_->right.length) { left_done_ = true; if (!next_use) { // Early exit. - out_.right = in.right; + out_->right = in.right; prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1); return true; } @@ -193,13 +197,13 @@ template class RuleScore { left_done_ = true; return; } - out_.left.pointers[out_.left.length++] = ret.extend_left; + out_->left.pointers[out_->left.length++] = ret.extend_left; prob_ += ret.rest; } const M &model_; - ChartState &out_; + ChartState *out_; bool left_done_; diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index 989f8324..ea0dea46 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -4,9 +4,8 @@ * (kMaxOrder - 1) * sizeof(float) bytes instead of * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead */ -#ifndef KENLM_MAX_ORDER -#define KENLM_MAX_ORDER 6 -#endif #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'. Otherwise, edit lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif + +#define KENLM_MAX_ORDER 5 diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 2fd20481..fc61efee 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -87,7 +87,7 @@ template void GenericModel FullScoreReturn GenericModel void HashedSearch::InitializeFromARPA(const char * template <> void HashedSearch::DispatchBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { NoRestBuild build; - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } template <> void HashedSearch::DispatchBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { @@ -239,19 +239,19 @@ template <> void HashedSearch::DispatchBuild(util::FilePiece &f, cons case Config::REST_MAX: { MaxRestBuild build; - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } break; case Config::REST_LOWER: { LowerRestBuild build(config, counts.size(), vocab); - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } break; } } -template template void HashedSearch::ApplyBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { +template template void HashedSearch::ApplyBuild(util::FilePiece &f, const std::vector &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { for (WordIndex i = 0; i < counts[0]; ++i) { build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]); } diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index a52f107b..00595796 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -147,7 +147,7 @@ template class HashedSearch { // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. void DispatchBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); - template void ApplyBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); + template void ApplyBuild(util::FilePiece &f, const std::vector &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); class Unigram { public: diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 11c27518..fd7f96dc 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -116,7 +116,9 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { } *end_ = hashed; if (enumerate_) { - strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); + void *copied = string_backing_.Allocate(str.size()); + memcpy(copied, str.data(), str.size()); + strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast(copied), str.size()); } ++end_; // This is 1 + the offset where it was inserted to make room for unk. @@ -126,7 +128,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { if (enumerate_) { if (!strings_to_enumerate_.empty()) { - util::PairedIterator values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); + util::PairedIterator values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); util::JointSort(begin_, end_, values); } for (WordIndex i = 0; i < static_cast(end_ - begin_); ++i) { @@ -134,6 +136,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { enumerate_->Add(i + 1, strings_to_enumerate_[i]); } strings_to_enumerate_.clear(); + string_backing_.FreeAll(); } else { util::JointSort(begin_, end_, reorder_vocab + 1); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index de54eb06..3902f117 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -4,6 +4,7 @@ #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" +#include "util/pool.hh" #include "util/probing_hash_table.hh" #include "util/sorted_uniform.hh" #include "util/string_piece.hh" @@ -96,7 +97,9 @@ class SortedVocabulary : public base::Vocabulary { EnumerateVocab *enumerate_; // Actual strings. Used only when loading from ARPA and enumerate_ != NULL - std::vector strings_to_enumerate_; + util::Pool string_backing_; + + std::vector strings_to_enumerate_; }; #pragma pack(push) diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am index ccc5b7f6..5aea33c2 100644 --- a/klm/search/Makefile.am +++ b/klm/search/Makefile.am @@ -2,10 +2,10 @@ noinst_LIBRARIES = libksearch.a libksearch_a_SOURCES = \ edge_generator.cc \ + nbest.cc \ rule.cc \ vertex.cc \ - vertex_generator.cc \ - weights.cc + vertex_generator.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/search/applied.hh b/klm/search/applied.hh new file mode 100644 index 00000000..bd659e5c --- /dev/null +++ b/klm/search/applied.hh @@ -0,0 +1,86 @@ +#ifndef SEARCH_APPLIED__ +#define SEARCH_APPLIED__ + +#include "search/edge.hh" +#include "search/header.hh" +#include "util/pool.hh" + +#include + +namespace search { + +// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted. +template class GenericApplied : public Header { + public: + GenericApplied() {} + + GenericApplied(void *location, PartialEdge partial) + : Header(location) { + memcpy(Base(), partial.Base(), kHeaderSize); + Below *child_out = Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = Below(part->End()); + } + + GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) { + SetScore(score); + SetNote(note); + } + + explicit GenericApplied(History from) : Header(from) {} + + + // These are arrays of length GetArity(). + Below *Children() { + return reinterpret_cast(After()); + } + const Below *Children() const { + return reinterpret_cast(After()); + } + + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Below); + } +}; + +// Applied rule that references itself. +class Applied : public GenericApplied { + private: + typedef GenericApplied P; + + public: + Applied() {} + Applied(void *location, PartialEdge partial) : P(location, partial) {} + Applied(History from) : P(from) {} +}; + +// How to build single-best hypotheses. +class SingleBest { + public: + typedef PartialEdge Combine; + + void Add(PartialEdge &existing, PartialEdge add) const { + if (!existing.Valid() || existing.GetScore() < add.GetScore()) + existing = add; + } + + NBestComplete Complete(PartialEdge partial) { + if (!partial.Valid()) + return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY); + void *place_final = pool_.Allocate(Applied::Size(partial.GetArity())); + Applied(place_final, partial); + return NBestComplete( + place_final, + partial.CompletedState(), + partial.GetScore()); + } + + private: + util::Pool pool_; +}; + +} // namespace search + +#endif // SEARCH_APPLIED__ diff --git a/klm/search/config.hh b/klm/search/config.hh index ef8e2354..ba18c09e 100644 --- a/klm/search/config.hh +++ b/klm/search/config.hh @@ -1,23 +1,36 @@ #ifndef SEARCH_CONFIG__ #define SEARCH_CONFIG__ -#include "search/weights.hh" -#include "util/string_piece.hh" +#include "search/types.hh" namespace search { +struct NBestConfig { + explicit NBestConfig(unsigned int in_size) { + keep = in_size; + size = in_size; + } + + unsigned int keep, size; +}; + class Config { public: - Config(const Weights &weights, unsigned int pop_limit) : - weights_(weights), pop_limit_(pop_limit) {} + Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) : + lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {} - const Weights &GetWeights() const { return weights_; } + Score LMWeight() const { return lm_weight_; } unsigned int PopLimit() const { return pop_limit_; } + const NBestConfig &GetNBest() const { return nbest_; } + private: - Weights weights_; + Score lm_weight_; + unsigned int pop_limit_; + + NBestConfig nbest_; }; } // namespace search diff --git a/klm/search/context.hh b/klm/search/context.hh index 62163144..08f21bbf 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -1,30 +1,16 @@ #ifndef SEARCH_CONTEXT__ #define SEARCH_CONTEXT__ -#include "lm/model.hh" #include "search/config.hh" -#include "search/final.hh" -#include "search/types.hh" #include "search/vertex.hh" -#include "util/exception.hh" -#include "util/pool.hh" #include -#include - -#include namespace search { -class Weights; - class ContextBase { public: - explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - - util::Pool &FinalPool() { - return final_pool_; - } + explicit ContextBase(const Config &config) : config_(config) {} VertexNode *NewVertexNode() { VertexNode *ret = vertex_node_pool_.construct(); @@ -36,18 +22,16 @@ class ContextBase { vertex_node_pool_.destroy(node); } - unsigned int PopLimit() const { return pop_limit_; } + unsigned int PopLimit() const { return config_.PopLimit(); } - const Weights &GetWeights() const { return weights_; } + Score LMWeight() const { return config_.LMWeight(); } - private: - util::Pool final_pool_; + const Config &GetConfig() const { return config_; } + private: boost::object_pool vertex_node_pool_; - unsigned int pop_limit_; - - const Weights &weights_; + Config config_; }; template class Context : public ContextBase { diff --git a/klm/search/dedupe.hh b/klm/search/dedupe.hh new file mode 100644 index 00000000..7eaa3b95 --- /dev/null +++ b/klm/search/dedupe.hh @@ -0,0 +1,131 @@ +#ifndef SEARCH_DEDUPE__ +#define SEARCH_DEDUPE__ + +#include "lm/state.hh" +#include "search/edge_generator.hh" + +#include +#include + +namespace search { + +class Dedupe { + public: + Dedupe() {} + + PartialEdge AllocateEdge(Arity arity) { + return behind_.AllocateEdge(arity); + } + + void AddEdge(PartialEdge edge) { + edge.MutableFlags() = 0; + + uint64_t hash = 0; + const PartialVertex *v = edge.NT(); + const PartialVertex *v_end = v + edge.GetArity(); + for (; v != v_end; ++v) { + const void *ptr = v->Identify(); + hash = util::MurmurHashNative(&ptr, sizeof(const void*), hash); + } + + const lm::ngram::ChartState *c = edge.Between(); + const lm::ngram::ChartState *const c_end = c + edge.GetArity() + 1; + for (; c != c_end; ++c) hash = hash_value(*c, hash); + + std::pair ret(table_.insert(std::make_pair(hash, edge))); + if (!ret.second) FoundDupe(ret.first->second, edge); + } + + bool Empty() const { return behind_.Empty(); } + + template void Search(Context &context, Output &output) { + for (Table::const_iterator i(table_.begin()); i != table_.end(); ++i) { + behind_.AddEdge(i->second); + } + Unpack unpack(output, *this); + behind_.Search(context, unpack); + } + + private: + void FoundDupe(PartialEdge &table, PartialEdge adding) { + if (table.GetFlags() & kPackedFlag) { + Packed &packed = *static_cast(table.GetNote().mut); + if (table.GetScore() >= adding.GetScore()) { + packed.others.push_back(adding); + return; + } + Note original(packed.original); + packed.original = adding.GetNote(); + adding.SetNote(table.GetNote()); + table.SetNote(original); + packed.others.push_back(table); + packed.starting = adding.GetScore(); + table = adding; + table.MutableFlags() |= kPackedFlag; + return; + } + PartialEdge loser; + if (adding.GetScore() > table.GetScore()) { + loser = table; + table = adding; + } else { + loser = adding; + } + // table is winner, loser is loser... + packed_.construct(table, loser); + } + + struct Packed { + Packed(PartialEdge winner, PartialEdge loser) + : original(winner.GetNote()), starting(winner.GetScore()), others(1, loser) { + winner.MutableNote().vp = this; + winner.MutableFlags() |= kPackedFlag; + loser.MutableFlags() &= ~kPackedFlag; + } + Note original; + Score starting; + std::vector others; + }; + + template class Unpack { + public: + explicit Unpack(Output &output, Dedupe &owner) : output_(output), owner_(owner) {} + + void NewHypothesis(PartialEdge edge) { + if (edge.GetFlags() & kPackedFlag) { + Packed &packed = *reinterpret_cast(edge.GetNote().mut); + edge.SetNote(packed.original); + edge.MutableFlags() = 0; + std::size_t copy_size = sizeof(PartialVertex) * edge.GetArity() + sizeof(lm::ngram::ChartState); + for (std::vector::iterator i = packed.others.begin(); i != packed.others.end(); ++i) { + PartialEdge copy(owner_.AllocateEdge(edge.GetArity())); + copy.SetScore(edge.GetScore() - packed.starting + i->GetScore()); + copy.MutableFlags() = 0; + copy.SetNote(i->GetNote()); + memcpy(copy.NT(), edge.NT(), copy_size); + output_.NewHypothesis(copy); + } + } + output_.NewHypothesis(edge); + } + + void FinishedSearch() { + output_.FinishedSearch(); + } + + private: + Output &output_; + Dedupe &owner_; + }; + + EdgeGenerator behind_; + + typedef boost::unordered_map Table; + Table table_; + + boost::object_pool packed_; + + static const uint16_t kPackedFlag = 1; +}; +} // namespace search +#endif // SEARCH_DEDUPE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index 260159b1..eacf5de5 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -1,6 +1,7 @@ #include "search/edge_generator.hh" #include "lm/left.hh" +#include "lm/model.hh" #include "lm/partial.hh" #include "search/context.hh" #include "search/vertex.hh" @@ -38,7 +39,7 @@ template void FastScore(const Context &context, Arity victi *cover = *(cover + 1); } } - update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); + update.SetScore(update.GetScore() + adjustment * context.LMWeight()); } } // namespace diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index 582c78b7..203942c6 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -2,7 +2,6 @@ #define SEARCH_EDGE_GENERATOR__ #include "search/edge.hh" -#include "search/note.hh" #include "search/types.hh" #include diff --git a/klm/search/final.hh b/klm/search/final.hh deleted file mode 100644 index 50e62cf2..00000000 --- a/klm/search/final.hh +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef SEARCH_FINAL__ -#define SEARCH_FINAL__ - -#include "search/header.hh" -#include "util/pool.hh" - -namespace search { - -// A full hypothesis with pointers to children. -class Final : public Header { - public: - Final() {} - - Final(util::Pool &pool, Score score, Arity arity, Note note) - : Header(pool.Allocate(Size(arity)), arity) { - SetScore(score); - SetNote(note); - } - - // These are arrays of length GetArity(). - Final *Children() { - return reinterpret_cast(After()); - } - const Final *Children() const { - return reinterpret_cast(After()); - } - - private: - static std::size_t Size(Arity arity) { - return kHeaderSize + arity * sizeof(const Final); - } -}; - -} // namespace search - -#endif // SEARCH_FINAL__ diff --git a/klm/search/header.hh b/klm/search/header.hh index 25550dbe..69f0eed0 100644 --- a/klm/search/header.hh +++ b/klm/search/header.hh @@ -3,7 +3,6 @@ // Header consisting of Score, Arity, and Note -#include "search/note.hh" #include "search/types.hh" #include @@ -24,6 +23,9 @@ class Header { bool operator<(const Header &other) const { return GetScore() < other.GetScore(); } + bool operator>(const Header &other) const { + return GetScore() > other.GetScore(); + } Arity GetArity() const { return *reinterpret_cast(base_ + sizeof(Score)); @@ -36,9 +38,14 @@ class Header { *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)) = to; } + uint8_t *Base() { return base_; } + const uint8_t *Base() const { return base_; } + protected: Header() : base_(NULL) {} + explicit Header(void *base) : base_(static_cast(base)) {} + Header(void *base, Arity arity) : base_(static_cast(base)) { *reinterpret_cast(base_ + sizeof(Score)) = arity; } diff --git a/klm/search/nbest.cc b/klm/search/nbest.cc new file mode 100644 index 00000000..ec3322c9 --- /dev/null +++ b/klm/search/nbest.cc @@ -0,0 +1,106 @@ +#include "search/nbest.hh" + +#include "util/pool.hh" + +#include +#include +#include + +#include +#include + +namespace search { + +NBestList::NBestList(std::vector &partials, util::Pool &entry_pool, std::size_t keep) { + assert(!partials.empty()); + std::vector::iterator end; + if (partials.size() > keep) { + end = partials.begin() + keep; + std::nth_element(partials.begin(), end, partials.end(), std::greater()); + } else { + end = partials.end(); + } + for (std::vector::const_iterator i(partials.begin()); i != end; ++i) { + queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i)); + } +} + +Score NBestList::TopAfterConstructor() const { + assert(revealed_.empty()); + return queue_.top().GetScore(); +} + +const std::vector &NBestList::Extract(util::Pool &pool, std::size_t n) { + while (revealed_.size() < n && !queue_.empty()) { + MoveTop(pool); + } + return revealed_; +} + +Score NBestList::Visit(util::Pool &pool, std::size_t index) { + if (index + 1 < revealed_.size()) + return revealed_[index + 1].GetScore() - revealed_[index].GetScore(); + if (queue_.empty()) + return -INFINITY; + if (index + 1 == revealed_.size()) + return queue_.top().GetScore() - revealed_[index].GetScore(); + assert(index == revealed_.size()); + + MoveTop(pool); + + if (queue_.empty()) return -INFINITY; + return queue_.top().GetScore() - revealed_[index].GetScore(); +} + +Applied NBestList::Get(util::Pool &pool, std::size_t index) { + assert(index <= revealed_.size()); + if (index == revealed_.size()) MoveTop(pool); + return revealed_[index]; +} + +void NBestList::MoveTop(util::Pool &pool) { + assert(!queue_.empty()); + QueueEntry entry(queue_.top()); + queue_.pop(); + RevealedRef *const children_begin = entry.Children(); + RevealedRef *const children_end = children_begin + entry.GetArity(); + Score basis = entry.GetScore(); + for (RevealedRef *child = children_begin; child != children_end; ++child) { + Score change = child->in_->Visit(pool, child->index_); + if (change != -INFINITY) { + assert(change < 0.001); + QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote()); + std::copy(children_begin, child, new_entry.Children()); + RevealedRef *update = new_entry.Children() + (child - children_begin); + update->in_ = child->in_; + update->index_ = child->index_ + 1; + std::copy(child + 1, children_end, update + 1); + queue_.push(new_entry); + } + // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010. + if (child->index_) break; + } + + // Convert QueueEntry to Applied. This leaves some unused memory. + void *overwrite = entry.Children(); + for (unsigned int i = 0; i < entry.GetArity(); ++i) { + RevealedRef from(*(static_cast(overwrite) + i)); + *(static_cast(overwrite) + i) = from.in_->Get(pool, from.index_); + } + revealed_.push_back(Applied(entry.Base())); +} + +NBestComplete NBest::Complete(std::vector &partials) { + assert(!partials.empty()); + NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep); + return NBestComplete( + list, + partials.front().CompletedState(), // All partials have the same state + list->TopAfterConstructor()); +} + +const std::vector &NBest::Extract(History history) { + return static_cast(history)->Extract(entry_pool_, config_.size); +} + +} // namespace search diff --git a/klm/search/nbest.hh b/klm/search/nbest.hh new file mode 100644 index 00000000..cb7651bc --- /dev/null +++ b/klm/search/nbest.hh @@ -0,0 +1,81 @@ +#ifndef SEARCH_NBEST__ +#define SEARCH_NBEST__ + +#include "search/applied.hh" +#include "search/config.hh" +#include "search/edge.hh" + +#include + +#include +#include +#include + +#include + +namespace search { + +class NBestList; + +class NBestList { + private: + class RevealedRef { + public: + explicit RevealedRef(History history) + : in_(static_cast(history)), index_(0) {} + + private: + friend class NBestList; + + NBestList *in_; + std::size_t index_; + }; + + typedef GenericApplied QueueEntry; + + public: + NBestList(std::vector &existing, util::Pool &entry_pool, std::size_t keep); + + Score TopAfterConstructor() const; + + const std::vector &Extract(util::Pool &pool, std::size_t n); + + private: + Score Visit(util::Pool &pool, std::size_t index); + + Applied Get(util::Pool &pool, std::size_t index); + + void MoveTop(util::Pool &pool); + + typedef std::vector Revealed; + Revealed revealed_; + + typedef std::priority_queue Queue; + Queue queue_; +}; + +class NBest { + public: + typedef std::vector Combine; + + explicit NBest(const NBestConfig &config) : config_(config) {} + + void Add(std::vector &existing, PartialEdge addition) const { + existing.push_back(addition); + } + + NBestComplete Complete(std::vector &partials); + + const std::vector &Extract(History root); + + private: + const NBestConfig config_; + + boost::object_pool list_pool_; + + util::Pool entry_pool_; +}; + +} // namespace search + +#endif // SEARCH_NBEST__ diff --git a/klm/search/note.hh b/klm/search/note.hh deleted file mode 100644 index 50bed06e..00000000 --- a/klm/search/note.hh +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef SEARCH_NOTE__ -#define SEARCH_NOTE__ - -namespace search { - -union Note { - const void *vp; -}; - -} // namespace search - -#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 5b00207e..0244a09f 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -1,7 +1,7 @@ #include "search/rule.hh" +#include "lm/model.hh" #include "search/context.hh" -#include "search/final.hh" #include @@ -9,35 +9,35 @@ namespace search { -template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing) { - unsigned int oov_count = 0; - float prob = 0.0; - const Model &model = context.LanguageModel(); - const lm::WordIndex oov = model.GetVocabulary().NotFound(); - for (std::vector::const_iterator word = words.begin(); ; ++word) { - lm::ngram::RuleScore scorer(model, *(writing++)); - // TODO: optimize - if (prepend_bos && (word == words.begin())) { - scorer.BeginSentence(); - } - for (; ; ++word) { - if (word == words.end()) { - prob += scorer.Finish(); - return static_cast(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); - } - if (*word == kNonTerminal) break; - if (*word == oov) ++oov_count; +template ScoreRuleRet ScoreRule(const Model &model, const std::vector &words, lm::ngram::ChartState *writing) { + ScoreRuleRet ret; + ret.prob = 0.0; + ret.oov = 0; + const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence(); + lm::ngram::RuleScore scorer(model, *(writing++)); + std::vector::const_iterator word = words.begin(); + if (word != words.end() && *word == bos) { + scorer.BeginSentence(); + ++word; + } + for (; word != words.end(); ++word) { + if (*word == kNonTerminal) { + ret.prob += scorer.Finish(); + scorer.Reset(*(writing++)); + } else { + if (*word == oov) ++ret.oov; scorer.Terminal(*word); } - prob += scorer.Finish(); } + ret.prob += scorer.Finish(); + return ret; } -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 0ce2794d..43ca6162 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -9,11 +9,16 @@ namespace search { -template class Context; - const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; -template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *state_out); +struct ScoreRuleRet { + Score prob; + unsigned int oov; +}; + +// Pass and normally. +// Indicate non-terminals with kNonTerminal. +template ScoreRuleRet ScoreRule(const Model &model, const std::vector &words, lm::ngram::ChartState *state_out); } // namespace search diff --git a/klm/search/types.hh b/klm/search/types.hh index 06eb5bfa..f9c849b3 100644 --- a/klm/search/types.hh +++ b/klm/search/types.hh @@ -3,12 +3,29 @@ #include +namespace lm { namespace ngram { class ChartState; } } + namespace search { typedef float Score; typedef uint32_t Arity; +union Note { + const void *vp; +}; + +typedef void *History; + +struct NBestComplete { + NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score) + : history(in_history), state(&in_state), score(in_score) {} + + History history; + const lm::ngram::ChartState *state; + Score score; +}; + } // namespace search #endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index 11f4631f..45842982 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -19,21 +19,34 @@ struct GreaterByBound : public std::binary_functionSortAndSet(context, parent_ptr); + if (extend_.size() == 1) { + parent_ptr = extend_[0]; + extend_[0]->RecursiveSortAndSet(context, parent_ptr); context.DeleteVertexNode(this); return; } for (std::vector::iterator i = extend_.begin(); i != extend_.end(); ++i) { - (*i)->SortAndSet(context, &*i); + (*i)->RecursiveSortAndSet(context, *i); + } + std::sort(extend_.begin(), extend_.end(), GreaterByBound()); + bound_ = extend_.front()->Bound(); +} + +void VertexNode::SortAndSet(ContextBase &context) { + // This is the root. The root might be empty. + if (extend_.empty()) { + bound_ = -INFINITY; + return; + } + // The root cannot be replaced. There's always one transition. + for (std::vector::iterator i = extend_.begin(); i != extend_.end(); ++i) { + (*i)->RecursiveSortAndSet(context, *i); } std::sort(extend_.begin(), extend_.end(), GreaterByBound()); bound_ = extend_.front()->Bound(); diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index 52bc1dfe..10b3339b 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -2,7 +2,6 @@ #define SEARCH_VERTEX__ #include "lm/left.hh" -#include "search/final.hh" #include "search/types.hh" #include @@ -10,6 +9,7 @@ #include #include +#include #include namespace search { @@ -18,7 +18,7 @@ class ContextBase; class VertexNode { public: - VertexNode() {} + VertexNode() : end_() {} void InitRoot() { extend_.clear(); @@ -26,7 +26,7 @@ class VertexNode { state_.left.length = 0; state_.right.length = 0; right_full_ = false; - end_ = Final(); + end_ = History(); } lm::ngram::ChartState &MutableState() { return state_; } @@ -36,20 +36,21 @@ class VertexNode { extend_.push_back(next); } - void SetEnd(Final end) { - assert(!end_.Valid()); + void SetEnd(History end, Score score) { + assert(!end_); end_ = end; + bound_ = score; } - void SortAndSet(ContextBase &context, VertexNode **parent_pointer); + void SortAndSet(ContextBase &context); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_.Valid() && extend_.empty(); + return !end_ && extend_.empty(); } bool Complete() const { - return end_.Valid(); + return end_; } const lm::ngram::ChartState &State() const { return state_; } @@ -64,7 +65,7 @@ class VertexNode { } // Will be invalid unless this is a leaf. - const Final End() const { return end_; } + const History End() const { return end_; } const VertexNode &operator[](size_t index) const { return *extend_[index]; @@ -75,13 +76,15 @@ class VertexNode { } private: + void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); + std::vector extend_; lm::ngram::ChartState state_; bool right_full_; Score bound_; - Final end_; + History end_; }; class PartialVertex { @@ -97,7 +100,7 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } + Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); } unsigned char Length() const { return back_->Length(); } @@ -121,7 +124,7 @@ class PartialVertex { return ret; } - const Final End() const { + const History End() const { return back_->End(); } @@ -130,16 +133,18 @@ class PartialVertex { unsigned int index_; }; +template class VertexGenerator; + class Vertex { public: Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } - const Final BestChild() const { + const History BestChild() const { PartialVertex top(RootPartial()); if (top.Empty()) { - return Final(); + return History(); } else { PartialVertex continuation; while (!top.Complete()) { @@ -150,8 +155,8 @@ class Vertex { } private: - friend class VertexGenerator; - + template friend class VertexGenerator; + template friend class RootVertexGenerator; VertexNode root_; }; diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 0945fe55..73139ffc 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -4,26 +4,18 @@ #include "search/context.hh" #include "search/edge.hh" +#include +#include + #include namespace search { -VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { - gen.root_.InitRoot(); -} - +#if BOOST_VERSION > 104200 namespace { const uint64_t kCompleteAdd = static_cast(-1); -// Parallel structure to VertexNode. -struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map extend; -}; - Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { Trie &next = node.extend[added]; if (!next.under) { @@ -39,19 +31,10 @@ Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::n return next; } -void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { - Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); - Final *child_out = final.Children(); - const PartialVertex *part = partial.NT(); - const PartialVertex *const part_end_loop = part + partial.GetArity(); - for (; part != part_end_loop; ++part, ++child_out) - *child_out = part->End(); - - starter.under->SetEnd(final); -} +} // namespace -void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { - const lm::ngram::ChartState &state = partial.CompletedState(); +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) { + const lm::ngram::ChartState &state = *end.state; unsigned char left = 0, right = 0; Trie *node = &root; @@ -77,18 +60,9 @@ void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { } node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - CompleteTransition(context, *node, partial); + node->under->SetEnd(end.history, end.score); } -} // namespace - -void VertexGenerator::FinishedSearch() { - Trie root; - root.under = &gen_.root_; - for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { - AddHypothesis(context_, root, i->second); - } - root.under->SortAndSet(context_, NULL); -} +#endif // BOOST_VERSION } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 60e86112..da563c2d 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -2,9 +2,11 @@ #define SEARCH_VERTEX_GENERATOR__ #include "search/edge.hh" +#include "search/types.hh" #include "search/vertex.hh" #include +#include namespace lm { namespace ngram { @@ -15,21 +17,44 @@ class ChartState; namespace search { class ContextBase; -class Final; -class VertexGenerator { +#if BOOST_VERSION > 104200 +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map extend; +}; + +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); + +#endif // BOOST_VERSION + +// Output makes the single-best or n-best list. +template class VertexGenerator { public: - VertexGenerator(ContextBase &context, Vertex &gen); + VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { + gen.root_.InitRoot(); + } void NewHypothesis(PartialEdge partial) { - const lm::ngram::ChartState &state = partial.CompletedState(); - std::pair ret(existing_.insert(std::make_pair(hash_value(state), partial))); - if (!ret.second && ret.first->second < partial) { - ret.first->second = partial; - } + nbest_.Add(existing_[hash_value(partial.CompletedState())], partial); } - void FinishedSearch(); + void FinishedSearch() { +#if BOOST_VERSION > 104200 + Trie root; + root.under = &gen_.root_; + for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, nbest_.Complete(i->second)); + } + existing_.clear(); + root.under->SortAndSet(context_); +#else + UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); +#endif + } const Vertex &Generating() const { return gen_; } @@ -38,8 +63,35 @@ class VertexGenerator { Vertex &gen_; - typedef boost::unordered_map Existing; + typedef boost::unordered_map Existing; Existing existing_; + + Output &nbest_; +}; + +// Special case for root vertex: everything should come together into the root +// node. In theory, this should happen naturally due to state collapsing with +// and . If that's the case, VertexGenerator is fine, though it will +// make one connection. +template class RootVertexGenerator { + public: + RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {} + + void NewHypothesis(PartialEdge partial) { + out_.Add(combine_, partial); + } + + void FinishedSearch() { + gen_.root_.InitRoot(); + NBestComplete completed(out_.Complete(combine_)); + gen_.root_.SetEnd(completed.history, completed.score); + } + + private: + Vertex &gen_; + + typename Output::Combine combine_; + Output &out_; }; } // namespace search diff --git a/klm/search/weights.cc b/klm/search/weights.cc deleted file mode 100644 index d65471ad..00000000 --- a/klm/search/weights.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include "search/weights.hh" -#include "util/tokenize_piece.hh" - -#include - -namespace search { - -namespace { -struct Insert { - void operator()(boost::unordered_map &map, StringPiece name, search::Score score) const { - std::string copy(name.data(), name.size()); - map[copy] = score; - } -}; - -struct DotProduct { - search::Score total; - DotProduct() : total(0.0) {} - - void operator()(const boost::unordered_map &map, StringPiece name, search::Score score) { - boost::unordered_map::const_iterator i(FindStringPiece(map, name)); - if (i != map.end()) - total += score * i->second; - } -}; - -template void Parse(StringPiece text, Map &map, Op &op) { - for (util::TokenIter spaces(text, ' '); spaces; ++spaces) { - util::TokenIter equals(*spaces, '='); - UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); - StringPiece name(*equals); - UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); - char *end; - // Assumes proper termination. - double value = std::strtod(equals->data(), &end); - UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); - UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); - op(map, name, value); - } -} - -} // namespace - -Weights::Weights(StringPiece text) { - Insert op; - Parse(text, map_, op); - lm_ = Steal("LanguageModel"); - oov_ = Steal("OOV"); - word_penalty_ = Steal("WordPenalty"); -} - -Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} - -search::Score Weights::DotNoLM(StringPiece text) const { - DotProduct dot; - Parse(text, map_, dot); - return dot.total; -} - -float Weights::Steal(const std::string &str) { - Map::iterator i(map_.find(str)); - if (i == map_.end()) { - return 0.0; - } else { - float ret = i->second; - map_.erase(i); - return ret; - } -} - -} // namespace search diff --git a/klm/search/weights.hh b/klm/search/weights.hh deleted file mode 100644 index df1c419f..00000000 --- a/klm/search/weights.hh +++ /dev/null @@ -1,52 +0,0 @@ -// For now, the individual features are not kept. -#ifndef SEARCH_WEIGHTS__ -#define SEARCH_WEIGHTS__ - -#include "search/types.hh" -#include "util/exception.hh" -#include "util/string_piece.hh" - -#include - -#include - -namespace search { - -class WeightParseException : public util::Exception { - public: - WeightParseException() {} - ~WeightParseException() throw() {} -}; - -class Weights { - public: - // Parses weights, sets lm_weight_, removes it from map_. - explicit Weights(StringPiece text); - - // Just the three scores we care about adding. - Weights(Score lm, Score oov, Score word_penalty); - - Score DotNoLM(StringPiece text) const; - - Score LM() const { return lm_; } - - Score OOV() const { return oov_; } - - Score WordPenalty() const { return word_penalty_; } - - // Mostly for testing. - const boost::unordered_map &GetMap() const { return map_; } - - private: - float Steal(const std::string &str); - - typedef boost::unordered_map Map; - - Map map_; - - Score lm_, oov_, word_penalty_; -}; - -} // namespace search - -#endif // SEARCH_WEIGHTS__ diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc deleted file mode 100644 index 4811ff06..00000000 --- a/klm/search/weights_test.cc +++ /dev/null @@ -1,38 +0,0 @@ -#include "search/weights.hh" - -#define BOOST_TEST_MODULE WeightTest -#include -#include - -namespace search { -namespace { - -#define CHECK_WEIGHT(value, string) \ - i = parsed.find(string); \ - BOOST_REQUIRE(i != parsed.end()); \ - BOOST_CHECK_CLOSE((value), i->second, 0.001); - -BOOST_AUTO_TEST_CASE(parse) { - // These are not real feature weights. - Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); - const boost::unordered_map &parsed = w.GetMap(); - boost::unordered_map::const_iterator i; - CHECK_WEIGHT(0.0, "rarity"); - CHECK_WEIGHT(0.0, "phrase-SGT"); - CHECK_WEIGHT(9.45117, "phrase-TGS"); - CHECK_WEIGHT(2.33833, "lexical-SGT"); - BOOST_CHECK(parsed.end() == parsed.find("lm")); - BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001); - CHECK_WEIGHT(-28.3317, "lexical-TGS"); - CHECK_WEIGHT(5.0, "glue?"); -} - -BOOST_AUTO_TEST_CASE(dot) { - Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); - BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001); - BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001); - BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001); -} - -} // namespace -} // namespace search -- cgit v1.2.3