From f568b392f82fd94b788a1b38094855234d318205 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sun, 14 Oct 2012 10:46:34 +0100 Subject: Update to faster but less cute search --- decoder/lazy.cc | 57 +++++++++++++++++++++++++++++---------------------------- 1 file changed, 29 insertions(+), 28 deletions(-) (limited to 'decoder') diff --git a/decoder/lazy.cc b/decoder/lazy.cc index c4138d7b..9dc657d6 100644 --- a/decoder/lazy.cc +++ b/decoder/lazy.cc @@ -8,6 +8,7 @@ #include "search/config.hh" #include "search/context.hh" #include "search/edge.hh" +#include "search/edge_queue.hh" #include "search/vertex.hh" #include "search/vertex_generator.hh" #include "util/exception.hh" @@ -75,7 +76,7 @@ template class Lazy : public LazyBase { void Search(unsigned int pop_limit, const Hypergraph &hg) const; private: - void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const; + unsigned char ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const; const Model m_; }; @@ -93,73 +94,73 @@ LazyBase *LazyBase::Load(const char *model_file, const std::vector &we } } -void PrintFinal(const Hypergraph &hg, const search::Edge *edge_base, const search::Final &final) { - const std::vector &words = hg.edges_[&final.From() - edge_base].rule_->e(); +void PrintFinal(const Hypergraph &hg, const search::Final &final) { + const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); boost::array::const_iterator child(final.Children().begin()); for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { if (*i > 0) { std::cout << TD::Convert(*i) << ' '; } else { - PrintFinal(hg, edge_base, **child++); + PrintFinal(hg, **child++); } } } template void Lazy::Search(unsigned int pop_limit, const Hypergraph &hg) const { boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); - boost::scoped_array out_edges(new search::Edge[hg.edges_.size()]); search::Config config(weights_, pop_limit); search::Context context(config, m_); for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { - search::Vertex &out_vertex = out_vertices[i]; + search::EdgeQueue queue(context.PopLimit()); const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; for (unsigned int j = 0; j < down_edges.size(); ++j) { unsigned int edge_index = down_edges[j]; - ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], out_edges[edge_index]); - out_vertex.Add(out_edges[edge_index]); + unsigned char arity = ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], queue.InitializeEdge()); + search::Note note; + note.vp = &hg.edges_[edge_index]; + if (arity != 255) queue.AddEdge(arity, note); } - out_vertex.FinishedAdding(); - search::VertexGenerator(context, out_vertex); + search::VertexGenerator vertex_gen(context, out_vertices[i]); + queue.Search(context, vertex_gen); } - search::PartialVertex top = out_vertices[hg.nodes_.size() - 2].RootPartial(); - if (top.Empty()) { - std::cout << "NO PATH FOUND"; + const search::Final *top = out_vertices[hg.nodes_.size() - 2].BestChild(); + if (!top) { + std::cout << "NO PATH FOUND" << std::endl; } else { - search::PartialVertex continuation; - while (!top.Complete()) { - top.Split(continuation); - top = continuation; - } - PrintFinal(hg, out_edges.get(), top.End()); - std::cout << "||| " << top.End().Bound() << std::endl; + PrintFinal(hg, *top); + std::cout << "||| " << top->Bound() << std::endl; } } -// TODO: get weights into here somehow. -template void Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const { +template unsigned char Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const { const std::vector &e = in.rule_->e(); std::vector words; unsigned int terminals = 0; + unsigned char nt = 0; for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { if (*word <= 0) { - out.Add(vertices[in.tail_nodes_[-*word]]); + out.nt[nt] = vertices[in.tail_nodes_[-*word]].RootPartial(); + if (out.nt[nt].Empty()) return 255; + ++nt; words.push_back(lm::kMaxWordIndex); } else { ++terminals; words.push_back(vocab_.FromCDec(*word)); } } + for (unsigned char fill = nt; fill < search::kMaxArity; ++fill) { + out.nt[nt] = search::kBlankPartialVertex; + } if (final) { words.push_back(m_.GetVocabulary().EndSentence()); } - float additive = in.rule_->GetFeatureValues().dot(cdec_weights_); - UTIL_THROW_IF(isnan(additive), util::Exception, "Bad dot product"); - additive -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; - - out.InitRule().Init(context, additive, words, final); + out.score = in.rule_->GetFeatureValues().dot(cdec_weights_); + out.score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; + out.score += search::ScoreRule(context, words, final, out.between); + return nt; } boost::scoped_ptr AwfulGlobalLazy; -- cgit v1.2.3