From b6e70b420ed993ee73f71058d04b382147896068 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 12 Aug 2012 23:33:21 -0400 Subject: use new union api --- decoder/decoder.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'decoder/decoder.cc') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index a6f7b1ce..a69a6d05 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -24,6 +24,7 @@ #include "hg.h" #include "sentence_metadata.h" #include "hg_intersect.h" +#include "hg_union.h" #include "oracle_bleu.h" #include "apply_models.h" @@ -980,7 +981,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); if (!succeeded) abort(); } - new_hg.Union(forest); + HG::Union(forest, &new_hg); bool succeeded = writer.Write(new_hg, false); if (!succeeded) abort(); } else { @@ -1067,7 +1068,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); if (!succeeded) abort(); } - new_hg.Union(forest); + HG::Union(forest, &new_hg); bool succeeded = writer.Write(new_hg, false); if (!succeeded) abort(); } else { -- cgit v1.2.3 From 8505fdfdf0bc4ce9acec42e1980a2fdd4f254109 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 13 Sep 2012 11:15:32 +0100 Subject: It compiles. --- decoder/Jamfile | 2 ++ decoder/decoder.cc | 4 +++ decoder/lazy.cc | 78 +++++++++++++++++++++++++++++++++++++-------------- decoder/lazy.h | 5 +++- klm/search/config.hh | 6 ++-- klm/search/weights.cc | 2 ++ klm/search/weights.hh | 17 ++++++----- 7 files changed, 82 insertions(+), 32 deletions(-) (limited to 'decoder/decoder.cc') diff --git a/decoder/Jamfile b/decoder/Jamfile index da02d063..d778dc7f 100644 --- a/decoder/Jamfile +++ b/decoder/Jamfile @@ -58,10 +58,12 @@ lib decoder : rescore_translator.cc hg_remove_eps.cc hg_union.cc + lazy.cc $(glc) ..//utils ..//mteval ../klm/lm//kenlm + ../klm/search//search ..//boost_program_options : . : : diff --git a/decoder/decoder.cc b/decoder/decoder.cc index a69a6d05..3a410cf2 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -38,6 +38,7 @@ #include "sampler.h" #include "forest_writer.h" // TODO this section should probably be handled by an Observer +#include "lazy.h" #include "hg_io.h" #include "aligner.h" @@ -832,6 +833,9 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (conf.count("show_target_graph")) HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); + if (conf.count("lazy_search")) + PassToLazy(forest, CurrentWeightVector()); + for (int pass = 0; pass < rescoring_passes.size(); ++pass) { const RescoringPass& rp = rescoring_passes[pass]; const vector& cur_weights = *rp.weight_vector; diff --git a/decoder/lazy.cc b/decoder/lazy.cc index f5b61c75..4776c1b8 100644 --- a/decoder/lazy.cc +++ b/decoder/lazy.cc @@ -1,15 +1,23 @@ #include "hg.h" #include "lazy.h" +#include "fdict.h" #include "tdict.h" #include "lm/enumerate_vocab.hh" #include "lm/model.hh" +#include "search/config.hh" +#include "search/context.hh" #include "search/edge.hh" #include "search/vertex.hh" +#include "search/vertex_generator.hh" #include "util/exception.hh" +#include #include +#include +#include + namespace { struct MapVocab : public lm::EnumerateVocab { @@ -19,13 +27,13 @@ struct MapVocab : public lm::EnumerateVocab { // Do not call after Lookup. void Add(lm::WordIndex index, const StringPiece &str) { const WordID cdec_id = TD::Convert(str.as_string()); - if (cdec_id >= out_->size()) out_.resize(cdec_id + 1); + if (cdec_id >= out_.size()) out_.resize(cdec_id + 1); out_[cdec_id] = index; } // Assumes Add has been called and will never be called again. lm::WordIndex FromCDec(WordID id) const { - return out_[out.size() > id ? id : 0]; + return out_[out_.size() > id ? id : 0]; } private: @@ -34,44 +42,50 @@ struct MapVocab : public lm::EnumerateVocab { class LazyBase { public: - LazyBase() {} + LazyBase(const std::vector &weights) : + cdec_weights_(weights), + config_(search::Weights(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]), 1000) {} virtual ~LazyBase() {} virtual void Search(const Hypergraph &hg) const = 0; - static LazyBase *Load(const char *model_file); + static LazyBase *Load(const char *model_file, const std::vector &weights); protected: - lm::ngram::Config GetConfig() const { + lm::ngram::Config GetConfig() { lm::ngram::Config ret; ret.enumerate_vocab = &vocab_; return ret; } MapVocab vocab_; + + const std::vector &cdec_weights_; + + const search::Config config_; }; template class Lazy : public LazyBase { public: - explicit Lazy(const char *model_file) : m_(model_file, GetConfig()) {} + Lazy(const char *model_file, const std::vector &weights) : LazyBase(weights), m_(model_file, GetConfig()) {} void Search(const Hypergraph &hg) const; private: - void ConvertEdge(const Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const; + void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const; const Model m_; }; -static LazyBase *LazyBase::Load(const char *model_file) { +LazyBase *LazyBase::Load(const char *model_file, const std::vector &weights) { lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING; + if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; switch (model_type) { case lm::ngram::PROBING: - return new Lazy(model_file); + return new Lazy(model_file, weights); case lm::ngram::REST_PROBING: - return new Lazy(model_file); + return new Lazy(model_file, weights); default: UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); } @@ -80,25 +94,41 @@ static LazyBase *LazyBase::Load(const char *model_file) { template void Lazy::Search(const Hypergraph &hg) const { boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); boost::scoped_array out_edges(new search::Edge[hg.edges_.size()]); + + search::Context context(config_, m_); + for (unsigned int i = 0; i < hg.nodes_.size(); ++i) { - search::Vertex *out_vertex = out_vertices[i]; + search::Vertex &out_vertex = out_vertices[i]; const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; - for (unsigned int j = 0; j < edges.size(); ++j) { + for (unsigned int j = 0; j < down_edges.size(); ++j) { unsigned int edge_index = down_edges[j]; - const Hypergraph::Edge &in_edge = hg.edges_[edge_index]; - search::Edge &out_edge = out_edges[edge_index]; + ConvertEdge(context, i == hg.nodes_.size() - 1, out_vertices.get(), hg.edges_[edge_index], out_edges[edge_index]); + out_vertex.Add(out_edges[edge_index]); } + out_vertex.FinishedAdding(); + search::VertexGenerator(context, out_vertex); + } + search::PartialVertex top = out_vertices[hg.nodes_.size() - 1].RootPartial(); + if (top.Empty()) { + std::cout << "NO PATH FOUND"; + } else { + search::PartialVertex continuation; + while (!top.Complete()) { + top.Split(continuation); + top = continuation; + } + std::cout << top.End().Bound() << std::endl; } } // TODO: get weights into here somehow. -template void Lazy::ConvertEdge(const Context &context, bool final, search::Vertices *vertices, const Hypergraph::Edge &in, search::Edge &out) const { - const std::vector &e = in_edge.rule_->e(); +template void Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const { + const std::vector &e = in.rule_->e(); std::vector words; unsigned int terminals = 0; for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { if (*word <= 0) { - out.Add(vertices[edge.tail_nodes_[-*word]]); + out.Add(vertices[in.tail_nodes_[-*word]]); words.push_back(lm::kMaxWordIndex); } else { ++terminals; @@ -110,13 +140,19 @@ template void Lazy::ConvertEdge(const Context &conte words.push_back(m_.GetVocabulary().EndSentence()); } - float additive = edge.rule_->GetFeatureValues().dot(weight_vector); + float additive = in.rule_->GetFeatureValues().dot(cdec_weights_); + additive -= terminals * context.GetWeights().WordPenalty() * static_cast(terminals) / M_LN10; out.InitRule().Init(context, additive, words, final); } -} // namespace +boost::scoped_ptr AwfulGlobalLazy; -void PassToLazy(const Hypergraph &hg) { +} // namespace +void PassToLazy(const Hypergraph &hg, const std::vector &weights) { + if (!AwfulGlobalLazy.get()) { + AwfulGlobalLazy.reset(LazyBase::Load("lm", weights)); + } + AwfulGlobalLazy->Search(hg); } diff --git a/decoder/lazy.h b/decoder/lazy.h index aecd030d..3e71a3b0 100644 --- a/decoder/lazy.h +++ b/decoder/lazy.h @@ -1,8 +1,11 @@ #ifndef _LAZY_H_ #define _LAZY_H_ +#include "weights.h" +#include + class Hypergraph; -void PassToLazy(const Hypergraph &hg); +void PassToLazy(const Hypergraph &hg, const std::vector &weights); #endif // _LAZY_H_ diff --git a/klm/search/config.hh b/klm/search/config.hh index e21e4b7c..ef8e2354 100644 --- a/klm/search/config.hh +++ b/klm/search/config.hh @@ -8,15 +8,15 @@ namespace search { class Config { public: - Config(StringPiece weight_str, unsigned int pop_limit) : - weights_(weight_str), pop_limit_(pop_limit) {} + Config(const Weights &weights, unsigned int pop_limit) : + weights_(weights), pop_limit_(pop_limit) {} const Weights &GetWeights() const { return weights_; } unsigned int PopLimit() const { return pop_limit_; } private: - search::Weights weights_; + Weights weights_; unsigned int pop_limit_; }; diff --git a/klm/search/weights.cc b/klm/search/weights.cc index 82ff3f12..d65471ad 100644 --- a/klm/search/weights.cc +++ b/klm/search/weights.cc @@ -49,6 +49,8 @@ Weights::Weights(StringPiece text) { word_penalty_ = Steal("WordPenalty"); } +Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} + search::Score Weights::DotNoLM(StringPiece text) const { DotProduct dot; Parse(text, map_, dot); diff --git a/klm/search/weights.hh b/klm/search/weights.hh index 4a4388c7..df1c419f 100644 --- a/klm/search/weights.hh +++ b/klm/search/weights.hh @@ -23,25 +23,28 @@ class Weights { // Parses weights, sets lm_weight_, removes it from map_. explicit Weights(StringPiece text); - search::Score DotNoLM(StringPiece text) const; + // Just the three scores we care about adding. + Weights(Score lm, Score oov, Score word_penalty); - search::Score LM() const { return lm_; } + Score DotNoLM(StringPiece text) const; - search::Score OOV() const { return oov_; } + Score LM() const { return lm_; } - search::Score WordPenalty() const { return word_penalty_; } + Score OOV() const { return oov_; } + + Score WordPenalty() const { return word_penalty_; } // Mostly for testing. - const boost::unordered_map &GetMap() const { return map_; } + const boost::unordered_map &GetMap() const { return map_; } private: float Steal(const std::string &str); - typedef boost::unordered_map Map; + typedef boost::unordered_map Map; Map map_; - search::Score lm_, oov_, word_penalty_; + Score lm_, oov_, word_penalty_; }; } // namespace search -- cgit v1.2.3 From a950a83a807518e465706c3712d6f80afff460b9 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 13 Sep 2012 04:28:30 -0700 Subject: Allow lm file name, print weights --- decoder/decoder.cc | 3 ++- decoder/lazy.cc | 10 +++++++--- decoder/lazy.h | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) (limited to 'decoder/decoder.cc') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 3a410cf2..525c6ba6 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -416,6 +416,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") ("show_target_graph", po::value(), "Directory to write the target hypergraphs to") + ("lazy_search", po::value(), "Run lazy search with this language model file") ("coarse_to_fine_beam_prune", po::value(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") ("ctf_beam_widen", po::value()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") ("ctf_num_widenings", po::value()->default_value(2), "Widen coarse beam this many times before backing off to full parse") @@ -834,7 +835,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); if (conf.count("lazy_search")) - PassToLazy(forest, CurrentWeightVector()); + PassToLazy(conf["lazy_search"].as().c_str(), CurrentWeightVector(), forest); for (int pass = 0; pass < rescoring_passes.size(); ++pass) { const RescoringPass& rp = rescoring_passes[pass]; diff --git a/decoder/lazy.cc b/decoder/lazy.cc index 4776c1b8..58a9e08a 100644 --- a/decoder/lazy.cc +++ b/decoder/lazy.cc @@ -44,7 +44,9 @@ class LazyBase { public: LazyBase(const std::vector &weights) : cdec_weights_(weights), - config_(search::Weights(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]), 1000) {} + config_(search::Weights(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]), 1000) { + std::cerr << "Weights KLanguageModel " << config_.GetWeights().LM() << " KLanguageModel_OOV " << config_.GetWeights().OOV() << " WordPenalty " << config_.GetWeights().WordPenalty() << std::endl; + } virtual ~LazyBase() {} @@ -95,6 +97,7 @@ template void Lazy::Search(const Hypergraph &hg) const { boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); boost::scoped_array out_edges(new search::Edge[hg.edges_.size()]); + search::Context context(config_, m_); for (unsigned int i = 0; i < hg.nodes_.size(); ++i) { @@ -141,6 +144,7 @@ template void Lazy::ConvertEdge(const search::ContextGetFeatureValues().dot(cdec_weights_); + UTIL_THROW_IF(isnan(additive), util::Exception, "Bad dot product"); additive -= terminals * context.GetWeights().WordPenalty() * static_cast(terminals) / M_LN10; out.InitRule().Init(context, additive, words, final); @@ -150,9 +154,9 @@ boost::scoped_ptr AwfulGlobalLazy; } // namespace -void PassToLazy(const Hypergraph &hg, const std::vector &weights) { +void PassToLazy(const char *model_file, const std::vector &weights, const Hypergraph &hg) { if (!AwfulGlobalLazy.get()) { - AwfulGlobalLazy.reset(LazyBase::Load("lm", weights)); + AwfulGlobalLazy.reset(LazyBase::Load(model_file, weights)); } AwfulGlobalLazy->Search(hg); } diff --git a/decoder/lazy.h b/decoder/lazy.h index 3e71a3b0..d1f030d1 100644 --- a/decoder/lazy.h +++ b/decoder/lazy.h @@ -6,6 +6,6 @@ class Hypergraph; -void PassToLazy(const Hypergraph &hg, const std::vector &weights); +void PassToLazy(const char *model_file, const std::vector &weights, const Hypergraph &hg); #endif // _LAZY_H_ -- cgit v1.2.3 From c32c03c1e0bb1e0407c90032cd3bf41f8bd61251 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 13 Sep 2012 07:53:16 -0700 Subject: Steal cubepruning_pop_limit command line argument --- decoder/decoder.cc | 2 +- decoder/lazy.cc | 19 ++++++++++--------- decoder/lazy.h | 2 +- 3 files changed, 12 insertions(+), 11 deletions(-) (limited to 'decoder/decoder.cc') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 525c6ba6..83077a68 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -835,7 +835,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); if (conf.count("lazy_search")) - PassToLazy(conf["lazy_search"].as().c_str(), CurrentWeightVector(), forest); + PassToLazy(conf["lazy_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); for (int pass = 0; pass < rescoring_passes.size(); ++pass) { const RescoringPass& rp = rescoring_passes[pass]; diff --git a/decoder/lazy.cc b/decoder/lazy.cc index 9d69dac6..0f12a1ff 100644 --- a/decoder/lazy.cc +++ b/decoder/lazy.cc @@ -44,13 +44,13 @@ class LazyBase { public: LazyBase(const std::vector &weights) : cdec_weights_(weights), - config_(search::Weights(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]), 1000) { - std::cerr << "Weights KLanguageModel " << config_.GetWeights().LM() << " KLanguageModel_OOV " << config_.GetWeights().OOV() << " WordPenalty " << config_.GetWeights().WordPenalty() << std::endl; + weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { + std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; } virtual ~LazyBase() {} - virtual void Search(const Hypergraph &hg) const = 0; + virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; static LazyBase *Load(const char *model_file, const std::vector &weights); @@ -65,14 +65,14 @@ class LazyBase { const std::vector &cdec_weights_; - const search::Config config_; + const search::Weights weights_; }; template class Lazy : public LazyBase { public: Lazy(const char *model_file, const std::vector &weights) : LazyBase(weights), m_(model_file, GetConfig()) {} - void Search(const Hypergraph &hg) const; + void Search(unsigned int pop_limit, const Hypergraph &hg) const; private: void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const; @@ -105,10 +105,11 @@ void PrintFinal(const Hypergraph &hg, const search::Edge *edge_base, const searc } } -template void Lazy::Search(const Hypergraph &hg) const { +template void Lazy::Search(unsigned int pop_limit, const Hypergraph &hg) const { boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); boost::scoped_array out_edges(new search::Edge[hg.edges_.size()]); - search::Context context(config_, m_); + search::Config config(weights_, pop_limit); + search::Context context(config, m_); for (unsigned int i = 0; i < hg.nodes_.size(); ++i) { search::Vertex &out_vertex = out_vertices[i]; @@ -165,9 +166,9 @@ boost::scoped_ptr AwfulGlobalLazy; } // namespace -void PassToLazy(const char *model_file, const std::vector &weights, const Hypergraph &hg) { +void PassToLazy(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { if (!AwfulGlobalLazy.get()) { AwfulGlobalLazy.reset(LazyBase::Load(model_file, weights)); } - AwfulGlobalLazy->Search(hg); + AwfulGlobalLazy->Search(pop_limit, hg); } diff --git a/decoder/lazy.h b/decoder/lazy.h index d1f030d1..94895b19 100644 --- a/decoder/lazy.h +++ b/decoder/lazy.h @@ -6,6 +6,6 @@ class Hypergraph; -void PassToLazy(const char *model_file, const std::vector &weights, const Hypergraph &hg); +void PassToLazy(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); #endif // _LAZY_H_ -- cgit v1.2.3 From 28403a7d3cbca2de743a7d654ffb9e1600ce7c5c Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 11 Oct 2012 06:02:26 -0400 Subject: Skip rest of decoder when using incremental --- decoder/decoder.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'decoder/decoder.cc') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 83077a68..29eaa4f6 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -834,8 +834,11 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (conf.count("show_target_graph")) HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); - if (conf.count("lazy_search")) + if (conf.count("lazy_search")) { PassToLazy(conf["lazy_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); + o->NotifyDecodingComplete(smeta); + return true; + } for (int pass = 0; pass < rescoring_passes.size(); ++pass) { const RescoringPass& rp = rescoring_passes[pass]; -- cgit v1.2.3 From 21a8287a2e1451db41c35494647c7b8c3e7e5adc Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 15 Oct 2012 23:20:41 -0400 Subject: get rid of nested class that was causing header polution --- decoder/decoder.cc | 14 ++--- decoder/ff.cc | 4 +- decoder/hg.h | 180 ++++++++++++++++++++--------------------------------- decoder/hg_io.cc | 4 +- utils/weights.cc | 8 +-- 5 files changed, 83 insertions(+), 127 deletions(-) (limited to 'decoder/decoder.cc') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index a69a6d05..47b298b9 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -871,13 +871,13 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (rp.fid_summary) { if (summary_feature_type == kEDGE_PROB) { const prob_t z = forest.PushWeightsToGoal(1.0); - if (!isfinite(log(z)) || isnan(log(z))) { + if (!std::isfinite(log(z)) || std::isnan(log(z))) { cerr << " " << passtr << " !!! Invalid partition detected, abandoning.\n"; } else { for (int i = 0; i < forest.edges_.size(); ++i) { const double log_prob_transition = log(forest.edges_[i].edge_prob_); // locally normalized by the edge // head node by forest.PushWeightsToGoal - if (!isfinite(log_prob_transition) || isnan(log_prob_transition)) { + if (!std::isfinite(log_prob_transition) || std::isnan(log_prob_transition)) { cerr << "Edge: i=" << i << " got bad inside prob: " << *forest.edges_[i].rule_ << endl; abort(); } @@ -889,7 +889,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } else if (summary_feature_type == kNODE_RISK) { Hypergraph::EdgeProbs posts; const prob_t z = forest.ComputeEdgePosteriors(1.0, &posts); - if (!isfinite(log(z)) || isnan(log(z))) { + if (!std::isfinite(log(z)) || std::isnan(log(z))) { cerr << " " << passtr << " !!! Invalid partition detected, abandoning.\n"; } else { for (int i = 0; i < forest.nodes_.size(); ++i) { @@ -898,7 +898,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { for (int j = 0; j < in_edges.size(); ++j) node_post += (posts[in_edges[j]] / z); const double log_np = log(node_post); - if (!isfinite(log_np) || isnan(log_np)) { + if (!std::isfinite(log_np) || std::isnan(log_np)) { cerr << "got bad posterior prob for node " << i << endl; abort(); } @@ -913,13 +913,13 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } else if (summary_feature_type == kEDGE_RISK) { Hypergraph::EdgeProbs posts; const prob_t z = forest.ComputeEdgePosteriors(1.0, &posts); - if (!isfinite(log(z)) || isnan(log(z))) { + if (!std::isfinite(log(z)) || std::isnan(log(z))) { cerr << " " << passtr << " !!! Invalid partition detected, abandoning.\n"; } else { assert(posts.size() == forest.edges_.size()); for (int i = 0; i < posts.size(); ++i) { const double log_np = log(posts[i] / z); - if (!isfinite(log_np) || isnan(log_np)) { + if (!std::isfinite(log_np) || std::isnan(log_np)) { cerr << "got bad posterior prob for node " << i << endl; abort(); } @@ -1090,7 +1090,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { cerr << "DIFF. ERR! log_z < log_ref_z: " << log_z << " " << log_ref_z << endl; exit(1); } - assert(!isnan(log_ref_z)); + assert(!std::isnan(log_ref_z)); ref_exp -= full_exp; acc_vec += ref_exp; acc_obj += (log_z - log_ref_z); diff --git a/decoder/ff.cc b/decoder/ff.cc index 557e0b5f..008fcad4 100644 --- a/decoder/ff.cc +++ b/decoder/ff.cc @@ -175,7 +175,7 @@ void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, Hypergraph::Edge* edge, FFState* context, prob_t* combination_cost_estimate) const { - edge->reset_info(); + //edge->reset_info(); context->resize(state_size_); if (state_size_ > 0) { memset(&(*context)[0], 0, state_size_); @@ -203,7 +203,7 @@ void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, void ModelSet::AddFinalFeatures(const FFState& state, Hypergraph::Edge* edge,SentenceMetadata const& smeta) const { assert(1 == edge->rule_->Arity()); - edge->reset_info(); + //edge->reset_info(); for (int i = 0; i < models_.size(); ++i) { const FeatureFunction& ff = *models_[i]; const void* ant_state = NULL; diff --git a/decoder/hg.h b/decoder/hg.h index 6d67f2fa..f53d2fd2 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -33,47 +33,20 @@ // slow #undef HG_EDGES_TOPO_SORTED -class Hypergraph; -typedef boost::shared_ptr HypergraphP; - -// class representing an acyclic hypergraph -// - edges have 1 head, 0..n tails -class Hypergraph { -public: - Hypergraph() : is_linear_chain_(false) {} +// SmallVector is a fast, small vector implementation for sizes <= 2 +typedef SmallVectorUnsigned TailNodeVector; // indices in nodes_ +typedef std::vector EdgesVector; // indices in edges_ - // SmallVector is a fast, small vector implementation for sizes <= 2 - typedef SmallVectorUnsigned TailNodeVector; // indices in nodes_ - typedef std::vector EdgesVector; // indices in edges_ - - // TODO get rid of cat_? - // TODO keep cat_ and add span and/or state? :) - struct Node { - Node() : id_(), cat_() {} - int id_; // equal to this object's position in the nodes_ vector - WordID cat_; // non-terminal category if <0, 0 if not set - WordID NT() const { return -cat_; } - EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_ - EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_ - void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting - cat_=o.cat_; - } - void copy_reindex(Node const& o,indices_after const& n2,indices_after const& e2) { - copy_fixed(o); - id_=n2[id_]; - e2.reindex_push_back(o.in_edges_,in_edges_); - e2.reindex_push_back(o.out_edges_,out_edges_); - } - }; +enum { + NONE=0,CATEGORY=1,SPAN=2,PROB=4,FEATURES=8,RULE=16,RULE_LHS=32,PREV_SPAN=64,ALL=0xFFFFFFFF +}; +namespace HG { - // TODO get rid of edge_prob_? (can be computed on the fly as the dot - // product of the weight vector and the feature values) struct Edge { -// int poplimit; //TODO: cube pruning per edge limit? per node didn't work well at all. also, inside cost + outside(node) is the same information i'd use to set a per-edge limit anyway - and nonmonotonicity in cube pruning may mean it's good to favor edge (in same node) w/ relatively worse score Edge() : i_(-1), j_(-1), prev_i_(-1), prev_j_(-1) {} Edge(int id,Edge const& copy_pod_from) : id_(id) { copy_pod(copy_pod_from); } // call copy_features yourself later. - Edge(int id,Edge const& copy_from,TailNodeVector const& tail) // fully inits - probably more expensive when push_back(Edge(...)) than setting after + Edge(int id,Edge const& copy_from,TailNodeVector const& tail) // fully inits - probably more expensive when push_back(Edge(...)) than sett : tail_nodes_(tail),id_(id) { copy_pod(copy_from);copy_features(copy_from); } inline int Arity() const { return tail_nodes_.size(); } int head_node_; // refers to a position in nodes_ @@ -83,8 +56,6 @@ public: prob_t edge_prob_; // dot product of weights and feat_values int id_; // equal to this object's position in the edges_ vector - //FIXME: these span ids belong in Node, not Edge, right? every node should have the same spans. - // span info. typically, i_ and j_ refer to indices in the source sentence. // In synchronous parsing, i_ and j_ will refer to target sentence/lattice indices // while prev_i_ prev_j_ will refer to positions in the source. @@ -97,54 +68,6 @@ public: short int j_; short int prev_i_; short int prev_j_; - - void copy_info(Edge const& o) { -#if USE_INFO_EDGE - set_info(o.info_.str()); // by convention, each person putting info here starts with a separator (e.g. space). it's empty if nobody put any info there. -#else - (void) o; -#endif - } - void copy_pod(Edge const& o) { - rule_=o.rule_; - i_ = o.i_; j_ = o.j_; prev_i_ = o.prev_i_; prev_j_ = o.prev_j_; - } - void copy_features(Edge const& o) { - feature_values_=o.feature_values_; - copy_info(o); - } - void copy_fixed(Edge const& o) { - copy_pod(o); - copy_features(o); - edge_prob_ = o.edge_prob_; - } - void copy_reindex(Edge const& o,indices_after const& n2,indices_after const& e2) { - copy_fixed(o); - head_node_=n2[o.head_node_]; - id_=e2[o.id_]; - n2.reindex_push_back(o.tail_nodes_,tail_nodes_); - } - -#if USE_INFO_EDGE - std::ostringstream info_; - void set_info(std::string const& s) { - info_.str(s); - info_.seekp(0,std::ios_base::end); - } - Edge(Edge const& o) : head_node_(o.head_node_),tail_nodes_(o.tail_nodes_),rule_(o.rule_),feature_values_(o.feature_values_),edge_prob_(o.edge_prob_),id_(o.id_),i_(o.i_),j_(o.j_),prev_i_(o.prev_i_),prev_j_(o.prev_j_), info_(o.info_.str(),std::ios_base::ate) { -// info_.seekp(0,std::ios_base::end); - } - void operator=(Edge const& o) { - head_node_ = o.head_node_; tail_nodes_ = o.tail_nodes_; rule_ = o.rule_; feature_values_ = o.feature_values_; edge_prob_ = o.edge_prob_; id_ = o.id_; i_ = o.i_; j_ = o.j_; prev_i_ = o.prev_i_; prev_j_ = o.prev_j_; - set_info(o.info_.str()); - } - std::string info() const { return info_.str(); } - void reset_info() { info_.str(""); info_.clear(); } -#else - std::string info() const { return std::string(); } - void reset_info() { } - void set_info(std::string const& ) { } -#endif void show(std::ostream &o,unsigned mask=SPAN|RULE) const { o<<'{'; if (mask&CATEGORY) @@ -159,10 +82,6 @@ public: o<<' '<AsString(mask&RULE_LHS); - if (USE_INFO_EDGE) { - std::string const& i=info(); - if (mask&&!i.empty()) o << " |||"< std::string derivation_tree(EdgeRecurse const& re,TEdgeHandle const& eh,bool indent=true,int show_mask=SPAN|RULE,int maxdepth=0x7FFFFFFF,int depth=0) const { std::ostringstream o; @@ -203,7 +138,43 @@ public: } }; - // all this info ought to live in Node, but for some reason it's on Edges. + // TODO get rid of cat_? + // TODO keep cat_ and add span and/or state? :) + struct Node { + Node() : id_(), cat_() {} + int id_; // equal to this object's position in the nodes_ vector + WordID cat_; // non-terminal category if <0, 0 if not set + WordID NT() const { return -cat_; } + EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_ + EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_ + void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting + cat_=o.cat_; + } + void copy_reindex(Node const& o,indices_after const& n2,indices_after const& e2) { + copy_fixed(o); + id_=n2[id_]; + e2.reindex_push_back(o.in_edges_,in_edges_); + e2.reindex_push_back(o.out_edges_,out_edges_); + } + }; + +} // namespace HG + +class Hypergraph; +typedef boost::shared_ptr HypergraphP; +// class representing an acyclic hypergraph +// - edges have 1 head, 0..n tails +class Hypergraph { +public: + Hypergraph() : is_linear_chain_(false) {} + typedef HG::Node Node; + typedef HG::Edge Edge; + typedef SmallVectorUnsigned TailNodeVector; // indices in nodes_ + typedef std::vector EdgesVector; // indices in edges_ + enum { + NONE=0,CATEGORY=1,SPAN=2,PROB=4,FEATURES=8,RULE=16,RULE_LHS=32,PREV_SPAN=64,ALL=0xFFFFFFFF + }; + // except for stateful models that have split nt,span, this should identify the node void SetNodeOrigin(int nodeid,NTSpan &r) const { Node const &n=nodes_[nodeid]; @@ -230,18 +201,9 @@ public: } return s; } - // 0 if none, -TD index otherwise (just like in rule) WordID NodeLHS(int nodeid) const { Node const &n=nodes_[nodeid]; return n.NT(); - /* - if (!n.in_edges_.empty()) { - Edge const& e=edges_[n.in_edges_.front()]; - if (e.rule_) - return -e.rule_->lhs_; - } - return 0; - */ } typedef std::vector EdgeProbs; @@ -250,14 +212,8 @@ public: typedef std::vector NodeMask; std::string show_viterbi_tree(bool indent=true,int show_mask=SPAN|RULE,int maxdepth=0x7FFFFFFF,int depth=0) const; -// builds viterbi hg and returns it formatted as a pretty string - - enum { - NONE=0,CATEGORY=1,SPAN=2,PROB=4,FEATURES=8,RULE=16,RULE_LHS=32,PREV_SPAN=64,ALL=0xFFFFFFFF - }; std::string show_first_tree(bool indent=true,int show_mask=SPAN|RULE,int maxdepth=0x7FFFFFFF,int depth=0) const; - // same as above, but takes in_edges_[0] all the way down - to make it viterbi cost (1-best), call ViterbiSortInEdges() first typedef Edge const* EdgeHandle; EdgeHandle operator()(int tailn,int /*taili*/,EdgeHandle /*parent*/) const { @@ -334,7 +290,7 @@ public: Edge* AddEdge(Edge const& in_edge, const TailNodeVector& tail) { edges_.push_back(Edge(edges_.size(),in_edge)); Edge* edge = &edges_.back(); - edge->copy_features(in_edge); + edge->feature_values_ = in_edge.feature_values_; edge->tail_nodes_ = tail; // possibly faster than copying to Edge() constructed above then copying via push_back. perhaps optimized it's the same. index_tails(*edge); return edge; diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index 3a68a429..8f604c89 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -392,8 +392,8 @@ string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses const Hypergraph::Edge& e = hg.edges_[hg.nodes_[i].out_edges_[j]]; const string output = e.rule_->e_.size() ==2 ? Escape(TD::Convert(e.rule_->e_[1])) : EPS; double prob = log(e.edge_prob_); - if (isinf(prob)) { prob = -9e20; } - if (isnan(prob)) { prob = 0; } + if (std::isinf(prob)) { prob = -9e20; } + if (std::isnan(prob)) { prob = 0; } os << "('" << output << "'," << prob << "," << e.head_node_ - i << "),"; } os << "),"; diff --git a/utils/weights.cc b/utils/weights.cc index f56e2a20..575877b6 100644 --- a/utils/weights.cc +++ b/utils/weights.cc @@ -34,7 +34,7 @@ void Weights::InitFromFile(const string& filename, int weight_count = 0; bool fl = false; string buf; - weight_t val = 0; + double val = 0; while (in) { getline(in, buf); if (buf.size() == 0) continue; @@ -53,7 +53,7 @@ void Weights::InitFromFile(const string& filename, if (feature_list) { feature_list->push_back(buf.substr(start, end - start)); } while(end < buf.size() && buf[end] == ' ') ++end; val = strtod(&buf.c_str()[end], NULL); - if (isnan(val)) { + if (std::isnan(val)) { cerr << FD::Convert(fid) << " has weight NaN!\n"; abort(); } @@ -127,8 +127,8 @@ void Weights::InitSparseVector(const vector& dv, void Weights::SanityCheck(const vector& w) { for (unsigned i = 0; i < w.size(); ++i) { - assert(!isnan(w[i])); - assert(!isinf(w[i])); + assert(!std::isnan(w[i])); + assert(!std::isinf(w[i])); } } -- cgit v1.2.3 From 09a522afef40e514e741ddf5f9bbc61cfb2170e8 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 16 Oct 2012 00:37:21 -0400 Subject: clean up of bad header includes --- decoder/Makefile.am | 2 + decoder/apply_models.cc | 1 + decoder/cdec_ff.cc | 1 + decoder/cfg.h | 2 +- decoder/cfg_format.h | 2 +- decoder/cfg_test.cc | 4 +- decoder/decoder.cc | 12 +-- decoder/exp_semiring.h | 2 +- decoder/ff.cc | 200 +++---------------------------------- decoder/ff.h | 238 +++----------------------------------------- decoder/ff_basic.cc | 80 +++++++++++++++ decoder/ff_basic.h | 68 +++++++++++++ decoder/ff_bleu.h | 2 +- decoder/ff_charset.cc | 6 +- decoder/ff_charset.h | 6 +- decoder/ff_context.cc | 2 + decoder/ff_context.h | 2 +- decoder/ff_csplit.cc | 1 + decoder/ff_csplit.h | 4 +- decoder/ff_dwarf.cc | 1 + decoder/ff_dwarf.h | 2 +- decoder/ff_external.cc | 8 +- decoder/ff_external.h | 6 +- decoder/ff_factory.h | 4 - decoder/ff_klm.cc | 6 -- decoder/ff_klm.h | 3 +- decoder/ff_lm.cc | 4 - decoder/ff_lm.h | 5 +- decoder/ff_ngrams.h | 2 +- decoder/ff_rules.cc | 2 + decoder/ff_rules.h | 5 +- decoder/ff_ruleshape.cc | 2 + decoder/ff_ruleshape.h | 2 +- decoder/ff_source_syntax.cc | 1 + decoder/ff_source_syntax.h | 4 +- decoder/ff_spans.cc | 2 + decoder/ff_spans.h | 4 +- decoder/ff_tagger.cc | 1 + decoder/ff_tagger.h | 6 +- decoder/ff_wordalign.h | 30 +++--- decoder/ff_wordset.cc | 1 + decoder/ff_wordset.h | 5 +- decoder/ffset.cc | 72 ++++++++++++++ decoder/ffset.h | 57 +++++++++++ decoder/grammar_test.cc | 2 + decoder/hg.h | 10 +- decoder/hg_io.cc | 2 +- decoder/inside_outside.h | 8 +- decoder/kbest.h | 14 +-- decoder/oracle_bleu.h | 11 +- decoder/program_options.h | 2 +- decoder/tromble_loss.h | 2 +- decoder/viterbi.cc | 4 +- decoder/viterbi.h | 32 +++--- example_extff/ff_example.cc | 2 + 55 files changed, 429 insertions(+), 530 deletions(-) create mode 100644 decoder/ff_basic.cc create mode 100644 decoder/ff_basic.h create mode 100644 decoder/ffset.cc create mode 100644 decoder/ffset.h (limited to 'decoder/decoder.cc') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 28863dbe..5c0a1964 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -56,6 +56,8 @@ libcdec_a_SOURCES = \ phrasetable_fst.cc \ trule.cc \ ff.cc \ + ffset.cc \ + ff_basic.cc \ ff_rules.cc \ ff_wordset.cc \ ff_context.cc \ diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9ba59d1b..330de9e2 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -16,6 +16,7 @@ #include "verbose.h" #include "hg.h" #include "ff.h" +#include "ffset.h" #define NORMAL_CP 1 #define FAST_CP 2 diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 54f6e12b..99ab7473 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -1,6 +1,7 @@ #include #include "ff.h" +#include "ff_basic.h" #include "ff_context.h" #include "ff_spans.h" #include "ff_lm.h" diff --git a/decoder/cfg.h b/decoder/cfg.h index 8cb29bb9..aeeacb83 100644 --- a/decoder/cfg.h +++ b/decoder/cfg.h @@ -130,7 +130,7 @@ struct CFG { int lhs; // index into nts RHS rhs; prob_t p; // h unused for now (there's nothing admissable, and p is already using 1st pass inside as pushed toward top) - FeatureVector f; // may be empty, unless copy_features on Init + SparseVector f; // may be empty, unless copy_features on Init IF_CFG_TRULE(TRulePtr rule;) int size() const { // for stats only return rhs.size(); diff --git a/decoder/cfg_format.h b/decoder/cfg_format.h index 2f40d483..d12da261 100644 --- a/decoder/cfg_format.h +++ b/decoder/cfg_format.h @@ -100,7 +100,7 @@ struct CFGFormat { } } - void print_features(std::ostream &o,prob_t p,FeatureVector const& fv=FeatureVector()) const { + void print_features(std::ostream &o,prob_t p,SparseVector const& fv=SparseVector()) const { bool logp=(logprob_feat && p!=prob_t::One()); if (features || logp) { o << partsep; diff --git a/decoder/cfg_test.cc b/decoder/cfg_test.cc index b8f4cf11..316c6d16 100644 --- a/decoder/cfg_test.cc +++ b/decoder/cfg_test.cc @@ -25,9 +25,9 @@ struct CFGTest : public TestWithParam { Hypergraph hg; CFG cfg; CFGFormat form; - FeatureVector weights; + SparseVector weights; - static void JsonFN(Hypergraph &hg,CFG &cfg,FeatureVector &featw,std::string file + static void JsonFN(Hypergraph &hg,CFG &cfg,SparseVector &featw,std::string file ,std::string const& wts="Model_0 1 EgivenF 1 f1 1") { istringstream ws(wts); diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 47b298b9..fef88d3f 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -29,6 +29,7 @@ #include "oracle_bleu.h" #include "apply_models.h" #include "ff.h" +#include "ffset.h" #include "ff_factory.h" #include "viterbi.h" #include "kbest.h" @@ -90,11 +91,6 @@ inline void ShowBanner() { cerr << "cdec v1.0 (c) 2009-2011 by Chris Dyer\n"; } -inline void show_models(po::variables_map const& conf,ModelSet &ms,char const* header) { - cerr<(); } @@ -132,7 +128,7 @@ inline boost::shared_ptr make_ff(string const& ffp,bool verbose } boost::shared_ptr pf = ff_registry.Create(ff, param); if (!pf) exit(1); - int nbyte=pf->NumBytesContext(); + int nbyte=pf->StateSize(); if (verbose_feature_functions && !SILENT) cerr<<"State is "< dummy; // = last_weights Oracle oc=oracle.ComputeOracle(smeta,&forest,dummy,10,conf["forest_output"].as()); if (!SILENT) cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; if (!SILENT) cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; diff --git a/decoder/exp_semiring.h b/decoder/exp_semiring.h index 111eaaf1..2a9034bb 100644 --- a/decoder/exp_semiring.h +++ b/decoder/exp_semiring.h @@ -59,7 +59,7 @@ struct PRWeightFunction { explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(), const RWeightFunction& rwf = RWeightFunction()) : pweight(pwf), rweight(rwf) {} - PRPair operator()(const Hypergraph::Edge& e) const { + PRPair operator()(const HG::Edge& e) const { const P p = pweight(e); const R r = rweight(e); return PRPair(p, r * p); diff --git a/decoder/ff.cc b/decoder/ff.cc index 008fcad4..6e276a5e 100644 --- a/decoder/ff.cc +++ b/decoder/ff.cc @@ -1,9 +1,3 @@ -//TODO: non-sparse vector for all feature functions? modelset applymodels keeps track of who has what features? it's nice having FF that could generate a handful out of 10000 possible feats, though. - -//TODO: actually score rule_feature()==true features once only, hash keyed on rule or modify TRule directly? need to keep clear in forest which features come from models vs. rules; then rescoring could drop all the old models features at once - -#include "fast_lexical_cast.hpp" -#include #include "ff.h" #include "tdict.h" @@ -16,8 +10,7 @@ FeatureFunction::~FeatureFunction() {} void FeatureFunction::PrepareForInput(const SentenceMetadata&) {} void FeatureFunction::FinalTraversalFeatures(const void* /* ant_state */, - SparseVector* /* features */) const { -} + SparseVector* /* features */) const {} string FeatureFunction::usage_helper(std::string const& name,std::string const& params,std::string const& details,bool sp,bool sd) { string r=name; @@ -32,188 +25,21 @@ string FeatureFunction::usage_helper(std::string const& name,std::string const& return r; } -Features FeatureFunction::single_feature(WordID feat) { - return Features(1,feat); -} - -Features ModelSet::all_features(std::ostream *warn,bool warn0) { - //return ::all_features(models_,weights_,warn,warn0); -} - -void show_features(Features const& ffs,DenseWeightVector const& weights_,std::ostream &out,std::ostream &warn,bool warn_zero_wt) { - out << "Weight Feature\n"; - for (unsigned i=0;i* final_features) const { + FinalTraversalFeatures(residual_state,final_features); } -void FeatureFunction::TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector& ant_states, - SparseVector* features, - SparseVector* estimated_features, - void* state) const { - throw std::runtime_error("TraversalFeaturesImpl not implemented - override it or TraversalFeaturesLog.\n"); +void FeatureFunction::TraversalFeaturesImpl(const SentenceMetadata&, + const Hypergraph::Edge&, + const std::vector&, + SparseVector*, + SparseVector*, + void*) const { + cerr << "TraversalFeaturesImpl not implemented - override it or TraversalFeaturesLog\n"; abort(); } -void WordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector& ant_states, - SparseVector* features, - SparseVector* estimated_features, - void* state) const { - (void) smeta; - (void) ant_states; - (void) state; - (void) estimated_features; - features->set_value(fid_, edge.rule_->EWords() * value_); -} - -SourceWordPenalty::SourceWordPenalty(const string& param) : - fid_(FD::Convert("SourceWordPenalty")), - value_(-1.0 / log(10)) { - if (!param.empty()) { - cerr << "Warning SourceWordPenalty ignoring parameter: " << param << endl; - } -} - -Features SourceWordPenalty::features() const { - return single_feature(fid_); -} - -Features WordPenalty::features() const { - return single_feature(fid_); -} - - -void SourceWordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector& ant_states, - SparseVector* features, - SparseVector* estimated_features, - void* state) const { - (void) smeta; - (void) ant_states; - (void) state; - (void) estimated_features; - features->set_value(fid_, edge.rule_->FWords() * value_); -} - -ArityPenalty::ArityPenalty(const std::string& param) : - value_(-1.0 / log(10)) { - string fname = "Arity_"; - unsigned MAX=DEFAULT_MAX_ARITY; - using namespace boost; - if (!param.empty()) - MAX=lexical_cast(param); - for (unsigned i = 0; i <= MAX; ++i) { - WordID fid=FD::Convert(fname+lexical_cast(i)); - fids_.push_back(fid); - } - while (!fids_.empty() && fids_.back()==0) fids_.pop_back(); // pretty up features vector in case FD was frozen. doesn't change anything -} - -Features ArityPenalty::features() const { - return Features(fids_.begin(),fids_.end()); -} - -void ArityPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector& ant_states, - SparseVector* features, - SparseVector* estimated_features, - void* state) const { - (void) smeta; - (void) ant_states; - (void) state; - (void) estimated_features; - unsigned a=edge.Arity(); - features->set_value(a& w, const vector& models) : - models_(models), - weights_(w), - state_size_(0), - model_state_pos_(models.size()) { - for (int i = 0; i < models_.size(); ++i) { - model_state_pos_[i] = state_size_; - state_size_ += models_[i]->NumBytesContext(); - } -} - -void ModelSet::PrepareForInput(const SentenceMetadata& smeta) { - for (int i = 0; i < models_.size(); ++i) - const_cast(models_[i])->PrepareForInput(smeta); -} - -void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, - const Hypergraph& /* hg */, - const FFStates& node_states, - Hypergraph::Edge* edge, - FFState* context, - prob_t* combination_cost_estimate) const { - //edge->reset_info(); - context->resize(state_size_); - if (state_size_ > 0) { - memset(&(*context)[0], 0, state_size_); - } - SparseVector est_vals; // only computed if combination_cost_estimate is non-NULL - if (combination_cost_estimate) *combination_cost_estimate = prob_t::One(); - for (int i = 0; i < models_.size(); ++i) { - const FeatureFunction& ff = *models_[i]; - void* cur_ff_context = NULL; - vector ants(edge->tail_nodes_.size()); - bool has_context = ff.NumBytesContext() > 0; - if (has_context) { - int spos = model_state_pos_[i]; - cur_ff_context = &(*context)[spos]; - for (int i = 0; i < ants.size(); ++i) { - ants[i] = &node_states[edge->tail_nodes_[i]][spos]; - } - } - ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context); - } - if (combination_cost_estimate) - combination_cost_estimate->logeq(est_vals.dot(weights_)); - edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); -} - -void ModelSet::AddFinalFeatures(const FFState& state, Hypergraph::Edge* edge,SentenceMetadata const& smeta) const { - assert(1 == edge->rule_->Arity()); - //edge->reset_info(); - for (int i = 0; i < models_.size(); ++i) { - const FeatureFunction& ff = *models_[i]; - const void* ant_state = NULL; - bool has_context = ff.NumBytesContext() > 0; - if (has_context) { - int spos = model_state_pos_[i]; - ant_state = &state[spos]; - } - ff.FinalTraversalFeatures(smeta, *edge, ant_state, &edge->feature_values_); - } - edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); -} - diff --git a/decoder/ff.h b/decoder/ff.h index 227787ca..4acbb7e3 100644 --- a/decoder/ff.h +++ b/decoder/ff.h @@ -1,26 +1,13 @@ #ifndef _FF_H_ #define _FF_H_ -#define DEBUG_INIT 0 -#if DEBUG_INIT -# include -# define DBGINIT(a) do { std::cerr< +#include #include -#include -#include "fdict.h" -#include "hg.h" -#include "feature_vector.h" -#include "value_array.h" +#include "sparse_vector.h" +namespace HG { struct Edge; struct Node; } +class Hypergraph; class SentenceMetadata; -class FeatureFunction; // see definition below - -typedef std::vector Features; // set of features ids // if you want to develop a new feature, inherit from this class and // override TraversalFeaturesImpl(...). If it's a feature that returns / @@ -30,51 +17,31 @@ class FeatureFunction { friend class ExternalFeature; public: std::string name_; // set by FF factory using usage() - bool debug_; // also set by FF factory checking param for immediate initial "debug" - //called after constructor, but before name_ and debug_ have been set - virtual void Init() { DBGINIT("default FF::Init name="< 0; } + int StateSize() const { return state_size_; } // override this. not virtual because we want to expose this to factory template for help before creating a FF static std::string usage(bool show_params,bool show_details) { return usage_helper("FIXME_feature_needs_name","[no parameters]","[no documentation yet]",show_params,show_details); } static std::string usage_helper(std::string const& name,std::string const& params,std::string const& details,bool show_params,bool show_details); - static Features single_feature(int feat); -public: - - // stateless feature that doesn't depend on source span: override and return true. then your feature can be precomputed over rules. - virtual bool rule_feature() const { return false; } // called once, per input, before any feature calls to TraversalFeatures, etc. // used to initialize sentence-specific data structures virtual void PrepareForInput(const SentenceMetadata& smeta); - //OVERRIDE THIS: - virtual Features features() const { return single_feature(FD::Convert(name_)); } - // returns the number of bytes of context that this feature function will - // (maximally) use. By default, 0 ("stateless" models in Hiero/Joshua). - // NOTE: this value is fixed for the instance of your class, you cannot - // use different amounts of memory for different nodes in the forest. this will be read as soon as you create a ModelSet, then fixed forever on - inline int NumBytesContext() const { return state_size_; } - // Compute the feature values and (if this applies) the estimates of the // feature values when this edge is used incorporated into a larger context inline void TraversalFeatures(const SentenceMetadata& smeta, - Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, - FeatureVector* features, - FeatureVector* estimated_features, + SparseVector* features, + SparseVector* estimated_features, void* out_state) const { - TraversalFeaturesLog(smeta, edge, ant_contexts, + TraversalFeaturesImpl(smeta, edge, ant_contexts, features, estimated_features, out_state); // TODO it's easy for careless feature function developers to overwrite // the end of their state and clobber someone else's memory. These bugs @@ -89,16 +56,13 @@ public: protected: virtual void FinalTraversalFeatures(const void* residual_state, - FeatureVector* final_features) const; + SparseVector* final_features) const; public: //override either this or one of above. virtual void FinalTraversalFeatures(const SentenceMetadata& /* smeta */, - Hypergraph::Edge& /* edge */, // so you can log() + const HG::Edge& /* edge */, const void* residual_state, - FeatureVector* final_features) const { - FinalTraversalFeatures(residual_state,final_features); - } - + SparseVector* final_features) const; protected: // context is a pointer to a buffer of size NumBytesContext() that the @@ -108,191 +72,19 @@ public: // of the particular FeatureFunction class. There is one exception: // equality of the contents (i.e., memcmp) is required to determine whether // two states can be combined. - - // by Log, I mean that the edge is non-const only so you can log to it with INFO_EDGE(edge,msg<<"etc."). most features don't use this so implement the below. it has a different name to allow a default implementation without name hiding when inheriting + overriding just 1. - virtual void TraversalFeaturesLog(const SentenceMetadata& smeta, - Hypergraph::Edge& edge, // this is writable only so you can use log() - const std::vector& ant_contexts, - FeatureVector* features, - FeatureVector* estimated_features, - void* context) const { - TraversalFeaturesImpl(smeta,edge,ant_contexts,features,estimated_features,context); - } - - // override above or below. virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - Hypergraph::Edge const& edge, + const HG::Edge& edge, const std::vector& ant_contexts, - FeatureVector* features, - FeatureVector* estimated_features, + SparseVector* features, + SparseVector* estimated_features, void* context) const; // !!! ONLY call this from subclass *CONSTRUCTORS* !!! void SetStateSize(size_t state_size) { state_size_ = state_size; } - int StateSize() const { return state_size_; } - private: - int state_size_; -}; - - -// word penalty feature, for each word on the E side of a rule, -// add value_ -class WordPenalty : public FeatureFunction { - public: - Features features() const; - WordPenalty(const std::string& param); - static std::string usage(bool p,bool d) { - return usage_helper("WordPenalty","","number of target words (local feature)",p,d); - } - bool rule_feature() const { return true; } - protected: - virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector& ant_contexts, - FeatureVector* features, - FeatureVector* estimated_features, - void* context) const; - private: - const int fid_; - const double value_; -}; - -class SourceWordPenalty : public FeatureFunction { - public: - bool rule_feature() const { return true; } - Features features() const; - SourceWordPenalty(const std::string& param); - static std::string usage(bool p,bool d) { - return usage_helper("SourceWordPenalty","","number of source words (local feature, and meaningless except when input has non-constant number of source words, e.g. segmentation/morphology/speech recognition lattice)",p,d); - } - protected: - virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector& ant_contexts, - FeatureVector* features, - FeatureVector* estimated_features, - void* context) const; - private: - const int fid_; - const double value_; -}; - -#define DEFAULT_MAX_ARITY 9 -#define DEFAULT_MAX_ARITY_STRINGIZE(x) #x -#define DEFAULT_MAX_ARITY_STRINGIZE_EVAL(x) DEFAULT_MAX_ARITY_STRINGIZE(x) -#define DEFAULT_MAX_ARITY_STR DEFAULT_MAX_ARITY_STRINGIZE_EVAL(DEFAULT_MAX_ARITY) - -class ArityPenalty : public FeatureFunction { - public: - bool rule_feature() const { return true; } - Features features() const; - ArityPenalty(const std::string& param); - static std::string usage(bool p,bool d) { - return usage_helper("ArityPenalty","[MaxArity(default " DEFAULT_MAX_ARITY_STR ")]","Indicator feature Arity_N=1 for rule of arity N (local feature). 0<=N<=MaxArity(default " DEFAULT_MAX_ARITY_STR ")",p,d); - } - - protected: - virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector& ant_contexts, - FeatureVector* features, - FeatureVector* estimated_features, - void* context) const; - private: - std::vector fids_; - const double value_; -}; - -void show_features(Features const& features,DenseWeightVector const& weights,std::ostream &out,std::ostream &warn,bool warn_zero_wt=true); //show features and weights - -template -Features all_features(std::vector const& models_,DenseWeightVector &weights_,std::ostream *warn=0,bool warn_fid_0=false) { - using namespace std; - Features ffs; -#define WARNFF(x) do { if (warn) { *warn << "WARNING: "<< x << endl; } } while(0) - typedef map FFM; - FFM ff_from; - for (unsigned i=0;iname_; - Features si=models_[i]->features(); - if (si.empty()) { - WARNFF(ffname<<" doesn't yet report any feature IDs - either supply feature weight, or use --no_freeze_feature_set, or implement features() method"); - } - unsigned n0=0; - for (unsigned j=0;j= weights_.size()) - weights_.resize(fid+1); - if (warn_fid_0 || fid) { - pair i_new=ff_from.insert(FFM::value_type(fid,ffname)); - if (i_new.second) { - if (fid) - ffs.push_back(fid); - else - WARNFF("Feature id 0 for "<second); - } - } - } - if (n0) - WARNFF(ffname<<" (models["< -void show_all_features(std::vector const& models_,DenseWeightVector &weights_,std::ostream &out,std::ostream &warn,bool warn_fid_0=true,bool warn_zero_wt=true) { - return show_features(all_features(models_,weights_,&warn,warn_fid_0),weights_,out,warn,warn_zero_wt); -} - -typedef ValueArray FFState; // this is about 10% faster than string. -//typedef std::string FFState; - -//FIXME: only context.data() is required to be contiguous, and it becomes invalid after next string operation. use ValueArray instead? (higher performance perhaps, save a word due to fixed size) -typedef std::vector FFStates; - -// this class is a set of FeatureFunctions that can be used to score, rescore, -// etc. a (translation?) forest -class ModelSet { - public: - ModelSet(const std::vector& weights, - const std::vector& models); - - // sets edge->feature_values_ and edge->edge_prob_ - // NOTE: edge must not necessarily be in hg.edges_ but its TAIL nodes - // must be. edge features are supposed to be overwritten, not added to (possibly because rule features aren't in ModelSet so need to be left alone - void AddFeaturesToEdge(const SentenceMetadata& smeta, - const Hypergraph& hg, - const FFStates& node_states, - Hypergraph::Edge* edge, - FFState* residual_context, - prob_t* combination_cost_estimate = NULL) const; - - //this is called INSTEAD of above when result of edge is goal (must be a unary rule - i.e. one variable, but typically it's assumed that there are no target terminals either (e.g. for LM)) - void AddFinalFeatures(const FFState& residual_context, - Hypergraph::Edge* edge, - SentenceMetadata const& smeta) const; - - // this is called once before any feature functions apply to a hypergraph - // it can be used to initialize sentence-specific data structures - void PrepareForInput(const SentenceMetadata& smeta); - - bool empty() const { return models_.empty(); } - - bool stateless() const { return !state_size_; } - Features all_features(std::ostream *warnings=0,bool warn_fid_zero=false); // this will warn about duplicate features as well (one function overwrites the feature of another). also resizes weights_ so it is large enough to hold the (0) weight for the largest reported feature id. since 0 is a NULL feature id, it's never included. if warn_fid_zero, then even the first 0 id is - void show_features(std::ostream &out,std::ostream &warn,bool warn_zero_wt=true); - private: - std::vector models_; - const std::vector& weights_; int state_size_; - std::vector model_state_pos_; }; #endif diff --git a/decoder/ff_basic.cc b/decoder/ff_basic.cc new file mode 100644 index 00000000..f9404d24 --- /dev/null +++ b/decoder/ff_basic.cc @@ -0,0 +1,80 @@ +#include "ff_basic.h" + +#include "fast_lexical_cast.hpp" +#include "hg.h" + +using namespace std; + +// Hiero and Joshua use log_10(e) as the value, so I do to +WordPenalty::WordPenalty(const string& param) : + fid_(FD::Convert("WordPenalty")), + value_(-1.0 / log(10)) { + if (!param.empty()) { + cerr << "Warning WordPenalty ignoring parameter: " << param << endl; + } +} + +void WordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + (void) smeta; + (void) ant_states; + (void) state; + (void) estimated_features; + features->set_value(fid_, edge.rule_->EWords() * value_); +} + + +SourceWordPenalty::SourceWordPenalty(const string& param) : + fid_(FD::Convert("SourceWordPenalty")), + value_(-1.0 / log(10)) { + if (!param.empty()) { + cerr << "Warning SourceWordPenalty ignoring parameter: " << param << endl; + } +} + +void SourceWordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + (void) smeta; + (void) ant_states; + (void) state; + (void) estimated_features; + features->set_value(fid_, edge.rule_->FWords() * value_); +} + + +ArityPenalty::ArityPenalty(const std::string& param) : + value_(-1.0 / log(10)) { + string fname = "Arity_"; + unsigned MAX=DEFAULT_MAX_ARITY; + using namespace boost; + if (!param.empty()) + MAX=lexical_cast(param); + for (unsigned i = 0; i <= MAX; ++i) { + WordID fid=FD::Convert(fname+lexical_cast(i)); + fids_.push_back(fid); + } + while (!fids_.empty() && fids_.back()==0) fids_.pop_back(); // pretty up features vector in case FD was frozen. doesn't change anything +} + +void ArityPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + (void) smeta; + (void) ant_states; + (void) state; + (void) estimated_features; + unsigned a=edge.Arity(); + features->set_value(a& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + const int fid_; + const double value_; +}; + +class SourceWordPenalty : public FeatureFunction { + public: + SourceWordPenalty(const std::string& param); + static std::string usage(bool p,bool d) { + return usage_helper("SourceWordPenalty","","number of source words (local feature, and meaningless except when input has non-constant number of source words, e.g. segmentation/morphology/speech recognition lattice)",p,d); + } + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const HG::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + const int fid_; + const double value_; +}; + +#define DEFAULT_MAX_ARITY 9 +#define DEFAULT_MAX_ARITY_STRINGIZE(x) #x +#define DEFAULT_MAX_ARITY_STRINGIZE_EVAL(x) DEFAULT_MAX_ARITY_STRINGIZE(x) +#define DEFAULT_MAX_ARITY_STR DEFAULT_MAX_ARITY_STRINGIZE_EVAL(DEFAULT_MAX_ARITY) + +class ArityPenalty : public FeatureFunction { + public: + ArityPenalty(const std::string& param); + static std::string usage(bool p,bool d) { + return usage_helper("ArityPenalty","[MaxArity(default " DEFAULT_MAX_ARITY_STR ")]","Indicator feature Arity_N=1 for rule of arity N (local feature). 0<=N<=MaxArity(default " DEFAULT_MAX_ARITY_STR ")",p,d); + } + + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const HG::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + std::vector fids_; + const double value_; +}; + +#endif diff --git a/decoder/ff_bleu.h b/decoder/ff_bleu.h index 5544920e..344dc788 100644 --- a/decoder/ff_bleu.h +++ b/decoder/ff_bleu.h @@ -20,7 +20,7 @@ class BLEUModel : public FeatureFunction { static std::string usage(bool param,bool verbose); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_charset.cc b/decoder/ff_charset.cc index 472de82b..6429088b 100644 --- a/decoder/ff_charset.cc +++ b/decoder/ff_charset.cc @@ -1,5 +1,7 @@ #include "ff_charset.h" +#include "tdict.h" +#include "hg.h" #include "fdict.h" #include "stringlib.h" @@ -20,8 +22,8 @@ bool ContainsNonLatin(const string& word) { void NonLatinCount::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, const std::vector& ant_contexts, - FeatureVector* features, - FeatureVector* estimated_features, + SparseVector* features, + SparseVector* estimated_features, void* context) const { const vector& e = edge.rule_->e(); int count = 0; diff --git a/decoder/ff_charset.h b/decoder/ff_charset.h index b1ad537e..267ef65d 100644 --- a/decoder/ff_charset.h +++ b/decoder/ff_charset.h @@ -13,10 +13,10 @@ class NonLatinCount : public FeatureFunction { NonLatinCount(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, - FeatureVector* features, - FeatureVector* estimated_features, + SparseVector* features, + SparseVector* estimated_features, void* context) const; private: mutable std::map is_non_latin_; diff --git a/decoder/ff_context.cc b/decoder/ff_context.cc index 9de4d737..f2b0e67c 100644 --- a/decoder/ff_context.cc +++ b/decoder/ff_context.cc @@ -5,12 +5,14 @@ #include #include +#include "hg.h" #include "filelib.h" #include "stringlib.h" #include "sentence_metadata.h" #include "lattice.h" #include "fdict.h" #include "verbose.h" +#include "tdict.h" RuleContextFeatures::RuleContextFeatures(const string& param) { // cerr << "initializing RuleContextFeatures with parameters: " << param; diff --git a/decoder/ff_context.h b/decoder/ff_context.h index 89bcb557..19198ec3 100644 --- a/decoder/ff_context.h +++ b/decoder/ff_context.h @@ -14,7 +14,7 @@ class RuleContextFeatures : public FeatureFunction { RuleContextFeatures(const string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_csplit.cc b/decoder/ff_csplit.cc index 252dbf8c..e6f78f84 100644 --- a/decoder/ff_csplit.cc +++ b/decoder/ff_csplit.cc @@ -5,6 +5,7 @@ #include "klm/lm/model.hh" +#include "hg.h" #include "sentence_metadata.h" #include "lattice.h" #include "tdict.h" diff --git a/decoder/ff_csplit.h b/decoder/ff_csplit.h index 38c0c5b8..64d42526 100644 --- a/decoder/ff_csplit.h +++ b/decoder/ff_csplit.h @@ -12,7 +12,7 @@ class BasicCSplitFeatures : public FeatureFunction { BasicCSplitFeatures(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -27,7 +27,7 @@ class ReverseCharLMCSplitFeature : public FeatureFunction { ReverseCharLMCSplitFeature(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_dwarf.cc b/decoder/ff_dwarf.cc index 43528405..fe7a472e 100644 --- a/decoder/ff_dwarf.cc +++ b/decoder/ff_dwarf.cc @@ -4,6 +4,7 @@ #include #include #include +#include "hg.h" #include "ff_dwarf.h" #include "dwarf.h" #include "wordid.h" diff --git a/decoder/ff_dwarf.h b/decoder/ff_dwarf.h index 083fcc7c..3d6a7da6 100644 --- a/decoder/ff_dwarf.h +++ b/decoder/ff_dwarf.h @@ -56,7 +56,7 @@ class Dwarf : public FeatureFunction { function word alignments set by 3. */ void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_external.cc b/decoder/ff_external.cc index dbb903d0..dea0e20f 100644 --- a/decoder/ff_external.cc +++ b/decoder/ff_external.cc @@ -1,8 +1,10 @@ #include "ff_external.h" -#include "stringlib.h" #include +#include "stringlib.h" +#include "hg.h" + using namespace std; ExternalFeature::ExternalFeature(const string& param) { @@ -50,8 +52,8 @@ void ExternalFeature::FinalTraversalFeatures(const void* context, void ExternalFeature::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, const std::vector& ant_contexts, - FeatureVector* features, - FeatureVector* estimated_features, + SparseVector* features, + SparseVector* estimated_features, void* context) const { ff_ext->TraversalFeaturesImpl(smeta, edge, ant_contexts, features, estimated_features, context); } diff --git a/decoder/ff_external.h b/decoder/ff_external.h index 283e58e8..3e2bee51 100644 --- a/decoder/ff_external.h +++ b/decoder/ff_external.h @@ -13,10 +13,10 @@ class ExternalFeature : public FeatureFunction { SparseVector* features) const; protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, - FeatureVector* features, - FeatureVector* estimated_features, + SparseVector* features, + SparseVector* estimated_features, void* context) const; private: void* lib_handle; diff --git a/decoder/ff_factory.h b/decoder/ff_factory.h index 5eb68c8b..bfdd3257 100644 --- a/decoder/ff_factory.h +++ b/decoder/ff_factory.h @@ -43,7 +43,6 @@ template struct FFFactory : public FactoryBase { FP Create(std::string param) const { FF *ret=new FF(param); - ret->Init(); return FP(ret); } virtual std::string usage(bool params,bool verbose) const { @@ -57,7 +56,6 @@ template struct FsaFactory : public FactoryBase { FP Create(std::string param) const { FF *ret=new FF(param); - ret->Init(); return FP(ret); } virtual std::string usage(bool params,bool verbose) const { @@ -98,8 +96,6 @@ struct FactoryRegistry : public UntypedFactoryRegistry { if (debug) cerr<<"debug enabled for "<(*it->second).Create(param); - res->init_name_debug(ffname,debug); - // could add a res->Init() here instead of in Create if we wanted feature id to potentially differ based on the registered name rather than static usage() - of course, specific feature ids can be computed on the basis of feature param as well; this only affects the default single feature id=name return res; } }; diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 09ef282c..fefa90bd 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -326,11 +326,6 @@ KLanguageModel::KLanguageModel(const string& param) { SetStateSize(pimpl_->ReserveStateSize()); } -template -Features KLanguageModel::features() const { - return single_feature(fid_); -} - template KLanguageModel::~KLanguageModel() { delete pimpl_; @@ -362,7 +357,6 @@ void KLanguageModel::FinalTraversalFeatures(const void* ant_state, template boost::shared_ptr CreateModel(const std::string ¶m) { KLanguageModel *ret = new KLanguageModel(param); - ret->Init(); return boost::shared_ptr(ret); } diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index 6efe50f6..b5ceffd0 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -20,10 +20,9 @@ class KLanguageModel : public FeatureFunction { virtual void FinalTraversalFeatures(const void* context, SparseVector* features) const; static std::string usage(bool param,bool verbose); - Features features() const; protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_lm.cc b/decoder/ff_lm.cc index 5e16d4e3..6ec7b4f3 100644 --- a/decoder/ff_lm.cc +++ b/decoder/ff_lm.cc @@ -519,10 +519,6 @@ LanguageModel::LanguageModel(const string& param) { SetStateSize(LanguageModelImpl::OrderToStateSize(order)); } -Features LanguageModel::features() const { - return single_feature(fid_); -} - LanguageModel::~LanguageModel() { delete pimpl_; } diff --git a/decoder/ff_lm.h b/decoder/ff_lm.h index ccee4268..94e18f00 100644 --- a/decoder/ff_lm.h +++ b/decoder/ff_lm.h @@ -55,10 +55,9 @@ class LanguageModel : public FeatureFunction { SparseVector* features) const; std::string DebugStateToString(const void* state) const; static std::string usage(bool param,bool verbose); - Features features() const; protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -81,7 +80,7 @@ class LanguageModelRandLM : public FeatureFunction { std::string DebugStateToString(const void* state) const; protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h index 064dbb49..4965d235 100644 --- a/decoder/ff_ngrams.h +++ b/decoder/ff_ngrams.h @@ -17,7 +17,7 @@ class NgramDetector : public FeatureFunction { SparseVector* features) const; protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_rules.cc b/decoder/ff_rules.cc index bd4c4cc0..0aafb0ba 100644 --- a/decoder/ff_rules.cc +++ b/decoder/ff_rules.cc @@ -10,6 +10,8 @@ #include "lattice.h" #include "fdict.h" #include "verbose.h" +#include "tdict.h" +#include "hg.h" using namespace std; diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h index 48d8bd05..7f5e1dfa 100644 --- a/decoder/ff_rules.h +++ b/decoder/ff_rules.h @@ -3,6 +3,7 @@ #include #include +#include "trule.h" #include "ff.h" #include "array2d.h" #include "wordid.h" @@ -12,7 +13,7 @@ class RuleIdentityFeatures : public FeatureFunction { RuleIdentityFeatures(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -27,7 +28,7 @@ class RuleNgramFeatures : public FeatureFunction { RuleNgramFeatures(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_ruleshape.cc b/decoder/ff_ruleshape.cc index f56ccfa9..7bb548c4 100644 --- a/decoder/ff_ruleshape.cc +++ b/decoder/ff_ruleshape.cc @@ -1,5 +1,7 @@ #include "ff_ruleshape.h" +#include "trule.h" +#include "hg.h" #include "fdict.h" #include diff --git a/decoder/ff_ruleshape.h b/decoder/ff_ruleshape.h index 23c9827e..9f20faf3 100644 --- a/decoder/ff_ruleshape.h +++ b/decoder/ff_ruleshape.h @@ -9,7 +9,7 @@ class RuleShapeFeatures : public FeatureFunction { RuleShapeFeatures(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index 035132b4..a1997695 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -3,6 +3,7 @@ #include #include +#include "hg.h" #include "sentence_metadata.h" #include "array2d.h" #include "filelib.h" diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h index 279563e1..a8c7150a 100644 --- a/decoder/ff_source_syntax.h +++ b/decoder/ff_source_syntax.h @@ -11,7 +11,7 @@ class SourceSyntaxFeatures : public FeatureFunction { ~SourceSyntaxFeatures(); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -28,7 +28,7 @@ class SourceSpanSizeFeatures : public FeatureFunction { ~SourceSpanSizeFeatures(); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc index 0483517b..0ccac69b 100644 --- a/decoder/ff_spans.cc +++ b/decoder/ff_spans.cc @@ -4,6 +4,8 @@ #include #include +#include "hg.h" +#include "tdict.h" #include "filelib.h" #include "stringlib.h" #include "sentence_metadata.h" diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h index 24e0dede..d2f5e84c 100644 --- a/decoder/ff_spans.h +++ b/decoder/ff_spans.h @@ -12,7 +12,7 @@ class SpanFeatures : public FeatureFunction { SpanFeatures(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -49,7 +49,7 @@ class CMR2008ReorderingFeatures : public FeatureFunction { CMR2008ReorderingFeatures(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc index fd9210fa..7f9af9cd 100644 --- a/decoder/ff_tagger.cc +++ b/decoder/ff_tagger.cc @@ -2,6 +2,7 @@ #include +#include "hg.h" #include "tdict.h" #include "sentence_metadata.h" #include "stringlib.h" diff --git a/decoder/ff_tagger.h b/decoder/ff_tagger.h index bd5b62c0..46418b0c 100644 --- a/decoder/ff_tagger.h +++ b/decoder/ff_tagger.h @@ -18,7 +18,7 @@ class Tagger_BigramIndicator : public FeatureFunction { Tagger_BigramIndicator(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -39,7 +39,7 @@ class LexicalPairIndicator : public FeatureFunction { virtual void PrepareForInput(const SentenceMetadata& smeta); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -59,7 +59,7 @@ class OutputIndicator : public FeatureFunction { OutputIndicator(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h index d7a2dda8..ba3d0b9b 100644 --- a/decoder/ff_wordalign.h +++ b/decoder/ff_wordalign.h @@ -13,7 +13,7 @@ class RelativeSentencePosition : public FeatureFunction { RelativeSentencePosition(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -36,7 +36,7 @@ class SourceBigram : public FeatureFunction { void PrepareForInput(const SentenceMetadata& smeta); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -55,7 +55,7 @@ class LexNullJump : public FeatureFunction { LexNullJump(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -72,7 +72,7 @@ class NewJump : public FeatureFunction { NewJump(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -109,7 +109,7 @@ class LexicalTranslationTrigger : public FeatureFunction { LexicalTranslationTrigger(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -132,14 +132,14 @@ class BlunsomSynchronousParseHack : public FeatureFunction { BlunsomSynchronousParseHack(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, void* out_context) const; private: inline bool DoesNotBelong(const void* state) const { - for (int i = 0; i < NumBytesContext(); ++i) { + for (int i = 0; i < StateSize(); ++i) { if (*(static_cast(state) + i)) return false; } return true; @@ -148,9 +148,9 @@ class BlunsomSynchronousParseHack : public FeatureFunction { inline void AppendAntecedentString(const void* state, std::vector* yield) const { int i = 0; int ind = 0; - while (i < NumBytesContext() && !(*(static_cast(state) + i))) { ++i; ind += 8; } - // std::cerr << i << " " << NumBytesContext() << std::endl; - assert(i != NumBytesContext()); + while (i < StateSize() && !(*(static_cast(state) + i))) { ++i; ind += 8; } + // std::cerr << i << " " << StateSize() << std::endl; + assert(i != StateSize()); assert(ind < cur_ref_->size()); int cur = *(static_cast(state) + i); int comp = 1; @@ -171,7 +171,7 @@ class BlunsomSynchronousParseHack : public FeatureFunction { } inline void SetStateMask(int start, int end, void* state) const { - assert((end / 8) < NumBytesContext()); + assert((end / 8) < StateSize()); int i = 0; int comp = 1; for (int j = 0; j < start; ++j) { @@ -209,7 +209,7 @@ class WordPairFeatures : public FeatureFunction { WordPairFeatures(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -226,7 +226,7 @@ class IdentityCycleDetector : public FeatureFunction { IdentityCycleDetector(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -242,7 +242,7 @@ class InputIndicator : public FeatureFunction { InputIndicator(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, @@ -258,7 +258,7 @@ class Fertility : public FeatureFunction { Fertility(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ff_wordset.cc b/decoder/ff_wordset.cc index 44468899..70cea7de 100644 --- a/decoder/ff_wordset.cc +++ b/decoder/ff_wordset.cc @@ -1,5 +1,6 @@ #include "ff_wordset.h" +#include "hg.h" #include "fdict.h" #include #include diff --git a/decoder/ff_wordset.h b/decoder/ff_wordset.h index 7c9a3fb7..639e1514 100644 --- a/decoder/ff_wordset.h +++ b/decoder/ff_wordset.h @@ -2,6 +2,7 @@ #define _FF_WORDSET_H_ #include "ff.h" +#include "tdict.h" #include #include @@ -32,11 +33,9 @@ class WordSet : public FeatureFunction { ~WordSet() { } - Features features() const { return single_feature(fid_); } - protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + const HG::Edge& edge, const std::vector& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/ffset.cc b/decoder/ffset.cc new file mode 100644 index 00000000..653a29f8 --- /dev/null +++ b/decoder/ffset.cc @@ -0,0 +1,72 @@ +#include "ffset.h" + +#include "ff.h" +#include "tdict.h" +#include "hg.h" + +using namespace std; + +ModelSet::ModelSet(const vector& w, const vector& models) : + models_(models), + weights_(w), + state_size_(0), + model_state_pos_(models.size()) { + for (int i = 0; i < models_.size(); ++i) { + model_state_pos_[i] = state_size_; + state_size_ += models_[i]->StateSize(); + } +} + +void ModelSet::PrepareForInput(const SentenceMetadata& smeta) { + for (int i = 0; i < models_.size(); ++i) + const_cast(models_[i])->PrepareForInput(smeta); +} + +void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, + const Hypergraph& /* hg */, + const FFStates& node_states, + HG::Edge* edge, + FFState* context, + prob_t* combination_cost_estimate) const { + //edge->reset_info(); + context->resize(state_size_); + if (state_size_ > 0) { + memset(&(*context)[0], 0, state_size_); + } + SparseVector est_vals; // only computed if combination_cost_estimate is non-NULL + if (combination_cost_estimate) *combination_cost_estimate = prob_t::One(); + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + void* cur_ff_context = NULL; + vector ants(edge->tail_nodes_.size()); + bool has_context = ff.StateSize() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + cur_ff_context = &(*context)[spos]; + for (int i = 0; i < ants.size(); ++i) { + ants[i] = &node_states[edge->tail_nodes_[i]][spos]; + } + } + ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context); + } + if (combination_cost_estimate) + combination_cost_estimate->logeq(est_vals.dot(weights_)); + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + +void ModelSet::AddFinalFeatures(const FFState& state, HG::Edge* edge,SentenceMetadata const& smeta) const { + assert(1 == edge->rule_->Arity()); + //edge->reset_info(); + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + const void* ant_state = NULL; + bool has_context = ff.StateSize() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + ant_state = &state[spos]; + } + ff.FinalTraversalFeatures(smeta, *edge, ant_state, &edge->feature_values_); + } + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + diff --git a/decoder/ffset.h b/decoder/ffset.h new file mode 100644 index 00000000..28aef667 --- /dev/null +++ b/decoder/ffset.h @@ -0,0 +1,57 @@ +#ifndef _FFSET_H_ +#define _FFSET_H_ + +#include +#include "value_array.h" +#include "prob.h" + +namespace HG { struct Edge; struct Node; } +class Hypergraph; +class FeatureFunction; +class SentenceMetadata; +class FeatureFunction; // see definition below + +// TODO let states be dynamically sized +typedef ValueArray FFState; // this is a fixed array, but about 10% faster than string + +//FIXME: only context.data() is required to be contiguous, and it becomes invalid after next string operation. use ValueArray instead? (higher performance perhaps, save a word due to fixed size) +typedef std::vector FFStates; + +// this class is a set of FeatureFunctions that can be used to score, rescore, +// etc. a (translation?) forest +class ModelSet { + public: + ModelSet(const std::vector& weights, + const std::vector& models); + + // sets edge->feature_values_ and edge->edge_prob_ + // NOTE: edge must not necessarily be in hg.edges_ but its TAIL nodes + // must be. edge features are supposed to be overwritten, not added to (possibly because rule features aren't in ModelSet so need to be left alone + void AddFeaturesToEdge(const SentenceMetadata& smeta, + const Hypergraph& hg, + const FFStates& node_states, + HG::Edge* edge, + FFState* residual_context, + prob_t* combination_cost_estimate = NULL) const; + + //this is called INSTEAD of above when result of edge is goal (must be a unary rule - i.e. one variable, but typically it's assumed that there are no target terminals either (e.g. for LM)) + void AddFinalFeatures(const FFState& residual_context, + HG::Edge* edge, + SentenceMetadata const& smeta) const; + + // this is called once before any feature functions apply to a hypergraph + // it can be used to initialize sentence-specific data structures + void PrepareForInput(const SentenceMetadata& smeta); + + bool empty() const { return models_.empty(); } + + bool stateless() const { return !state_size_; } + + private: + std::vector models_; + const std::vector& weights_; + int state_size_; + std::vector model_state_pos_; +}; + +#endif diff --git a/decoder/grammar_test.cc b/decoder/grammar_test.cc index 4500490a..912f4f12 100644 --- a/decoder/grammar_test.cc +++ b/decoder/grammar_test.cc @@ -10,7 +10,9 @@ #include "tdict.h" #include "grammar.h" #include "bottom_up_parser.h" +#include "hg.h" #include "ff.h" +#include "ffset.h" #include "weights.h" using namespace std; diff --git a/decoder/hg.h b/decoder/hg.h index f53d2fd2..3d8cd9bc 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -490,14 +490,14 @@ private: // for generic Viterbi/Inside algorithms struct EdgeProb { typedef prob_t Weight; - inline const prob_t& operator()(const Hypergraph::Edge& e) const { return e.edge_prob_; } + inline const prob_t& operator()(const HG::Edge& e) const { return e.edge_prob_; } }; struct EdgeSelectEdgeWeightFunction { typedef prob_t Weight; typedef std::vector EdgeMask; EdgeSelectEdgeWeightFunction(const EdgeMask& v) : v_(v) {} - inline prob_t operator()(const Hypergraph::Edge& e) const { + inline prob_t operator()(const HG::Edge& e) const { if (v_[e.id_]) return prob_t::One(); else return prob_t::Zero(); } @@ -507,7 +507,7 @@ private: struct ScaledEdgeProb { ScaledEdgeProb(const double& alpha) : alpha_(alpha) {} - inline prob_t operator()(const Hypergraph::Edge& e) const { return e.edge_prob_.pow(alpha_); } + inline prob_t operator()(const HG::Edge& e) const { return e.edge_prob_.pow(alpha_); } const double alpha_; typedef prob_t Weight; }; @@ -516,7 +516,7 @@ struct ScaledEdgeProb { struct EdgeFeaturesAndProbWeightFunction { typedef SparseVector Weight; typedef Weight Result; //TODO: change Result->Weight everywhere? - inline const Weight operator()(const Hypergraph::Edge& e) const { + inline const Weight operator()(const HG::Edge& e) const { SparseVector res; for (SparseVector::const_iterator it = e.feature_values_.begin(); it != e.feature_values_.end(); ++it) @@ -527,7 +527,7 @@ struct EdgeFeaturesAndProbWeightFunction { struct TransitionCountWeightFunction { typedef double Weight; - inline double operator()(const Hypergraph::Edge& e) const { (void)e; return 1.0; } + inline double operator()(const HG::Edge& e) const { (void)e; return 1.0; } }; #endif diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index 8f604c89..64c6663e 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -28,7 +28,7 @@ struct HGReader : public JSONParser { hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node); } } - void CreateEdge(const TRulePtr& rule, FeatureVector* feats, const SmallVectorUnsigned& tail) { + void CreateEdge(const TRulePtr& rule, SparseVector* feats, const SmallVectorUnsigned& tail) { Hypergraph::Edge* edge = hg.AddEdge(rule, tail); feats->swap(edge->feature_values_); edge->i_ = spans[0]; diff --git a/decoder/inside_outside.h b/decoder/inside_outside.h index f73a1d3f..c0377fe8 100644 --- a/decoder/inside_outside.h +++ b/decoder/inside_outside.h @@ -42,7 +42,7 @@ WeightType Inside(const Hypergraph& hg, Hypergraph::EdgesVector const& in=hg.nodes_[i].in_edges_; const unsigned num_in_edges = in.size(); for (unsigned j = 0; j < num_in_edges; ++j) { - const Hypergraph::Edge& edge = hg.edges_[in[j]]; + const HG::Edge& edge = hg.edges_[in[j]]; WeightType score = weight(edge); for (unsigned k = 0; k < edge.tail_nodes_.size(); ++k) { const int tail_node_index = edge.tail_nodes_[k]; @@ -74,7 +74,7 @@ void Outside(const Hypergraph& hg, Hypergraph::EdgesVector const& in=hg.nodes_[i].in_edges_; const int num_in_edges = in.size(); for (int j = 0; j < num_in_edges; ++j) { - const Hypergraph::Edge& edge = hg.edges_[in[j]]; + const HG::Edge& edge = hg.edges_[in[j]]; WeightType head_and_edge_weight = weight(edge); head_and_edge_weight *= head_node_outside_score; const int num_tail_nodes = edge.tail_nodes_.size(); @@ -138,7 +138,7 @@ struct InsideOutsides { Hypergraph::EdgesVector const& in=hg.nodes_[i].in_edges_; const int num_in_edges = in.size(); for (int j = 0; j < num_in_edges; ++j) { - const Hypergraph::Edge& edge = hg.edges_[in[j]]; + const HG::Edge& edge = hg.edges_[in[j]]; KType kbar_e = outside[i]; const int num_tail_nodes = edge.tail_nodes_.size(); for (int k = 0; k < num_tail_nodes; ++k) @@ -156,7 +156,7 @@ struct InsideOutsides { const int num_in_edges = in.size(); for (int j = 0; j < num_in_edges; ++j) { int edgei=in[j]; - const Hypergraph::Edge& edge = hg.edges_[edgei]; + const HG::Edge& edge = hg.edges_[edgei]; V x=weight(edge)*outside[i]; const int num_tail_nodes = edge.tail_nodes_.size(); for (int k = 0; k < num_tail_nodes; ++k) diff --git a/decoder/kbest.h b/decoder/kbest.h index 9af3a20e..9a55f653 100644 --- a/decoder/kbest.h +++ b/decoder/kbest.h @@ -48,7 +48,7 @@ namespace KBest { } struct Derivation { - Derivation(const Hypergraph::Edge& e, + Derivation(const HG::Edge& e, const SmallVectorInt& jv, const WeightType& w, const SparseVector& f) : @@ -58,11 +58,11 @@ namespace KBest { feature_values(f) {} // dummy constructor, just for query - Derivation(const Hypergraph::Edge& e, + Derivation(const HG::Edge& e, const SmallVectorInt& jv) : edge(&e), j(jv) {} T yield; - const Hypergraph::Edge* const edge; + const HG::Edge* const edge; const SmallVectorInt j; const WeightType score; const SparseVector feature_values; @@ -82,8 +82,8 @@ namespace KBest { Derivation const* d; explicit EdgeHandle(Derivation const* d) : d(d) { } // operator bool() const { return d->edge; } - operator Hypergraph::Edge const* () const { return d->edge; } -// Hypergraph::Edge const * operator ->() const { return d->edge; } + operator HG::Edge const* () const { return d->edge; } +// HG::Edge const * operator ->() const { return d->edge; } }; EdgeHandle operator()(unsigned t,unsigned taili,EdgeHandle const& parent) const { @@ -158,7 +158,7 @@ namespace KBest { // the yield is computed in LazyKthBest before the derivation is added to D // returns NULL if j refers to derivation numbers larger than the // antecedent structure define - Derivation* CreateDerivation(const Hypergraph::Edge& e, const SmallVectorInt& j) { + Derivation* CreateDerivation(const HG::Edge& e, const SmallVectorInt& j) { WeightType score = w(e); SparseVector feats = e.feature_values_; for (int i = 0; i < e.Arity(); ++i) { @@ -177,7 +177,7 @@ namespace KBest { const Hypergraph::Node& node = g.nodes_[v]; for (unsigned i = 0; i < node.in_edges_.size(); ++i) { - const Hypergraph::Edge& edge = g.edges_[node.in_edges_[i]]; + const HG::Edge& edge = g.edges_[node.in_edges_[i]]; SmallVectorInt jv(edge.Arity(), 0); Derivation* d = CreateDerivation(edge, jv); assert(d); diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index b603e27a..d2c4715c 100644 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -12,6 +12,7 @@ #include "scorer.h" #include "hg.h" #include "ff_factory.h" +#include "ffset.h" #include "ff_bleu.h" #include "sparse_vector.h" #include "viterbi.h" @@ -26,7 +27,7 @@ struct Translation { typedef std::vector Sentence; Sentence sentence; - FeatureVector features; + SparseVector features; Translation() { } Translation(Hypergraph const& hg,WeightVector *feature_weights=0) { @@ -57,14 +58,14 @@ struct Oracle { } // feature 0 will be the error rate in fear and hope // move toward hope - FeatureVector ModelHopeGradient() const { - FeatureVector r=hope.features-model.features; + SparseVector ModelHopeGradient() const { + SparseVector r=hope.features-model.features; r.set_value(0,0); return r; } // move toward hope from fear - FeatureVector FearHopeGradient() const { - FeatureVector r=hope.features-fear.features; + SparseVector FearHopeGradient() const { + SparseVector r=hope.features-fear.features; r.set_value(0,0); return r; } diff --git a/decoder/program_options.h b/decoder/program_options.h index 87afb320..3cd7649a 100644 --- a/decoder/program_options.h +++ b/decoder/program_options.h @@ -94,7 +94,7 @@ struct any_printer : public boost::function {} template - explicit any_printer(T const* tag) : F(typed_print()) { + explicit any_printer(T const*) : F(typed_print()) { } template diff --git a/decoder/tromble_loss.h b/decoder/tromble_loss.h index 599a2d54..fde33100 100644 --- a/decoder/tromble_loss.h +++ b/decoder/tromble_loss.h @@ -28,7 +28,7 @@ class TrombleLossComputer : private boost::base_from_member& ant_contexts, SparseVector* features, SparseVector* estimated_features, diff --git a/decoder/viterbi.cc b/decoder/viterbi.cc index 1b9c6665..9e381ac6 100644 --- a/decoder/viterbi.cc +++ b/decoder/viterbi.cc @@ -139,8 +139,8 @@ inline bool close_enough(double a,double b,double epsilon) return diff<=epsilon*fabs(a) || diff<=epsilon*fabs(b); } -FeatureVector ViterbiFeatures(Hypergraph const& hg,WeightVector const* weights,bool fatal_dotprod_disagreement) { - FeatureVector r; +SparseVector ViterbiFeatures(Hypergraph const& hg,WeightVector const* weights,bool fatal_dotprod_disagreement) { + SparseVector r; const prob_t p = Viterbi(hg, &r); if (weights) { double logp=log(p); diff --git a/decoder/viterbi.h b/decoder/viterbi.h index 03e961a2..a8a0ea7f 100644 --- a/decoder/viterbi.h +++ b/decoder/viterbi.h @@ -14,10 +14,10 @@ std::string viterbi_stats(Hypergraph const& hg, std::string const& name="forest" //TODO: make T a typename inside Traversal and WeightType a typename inside WeightFunction? // Traversal must implement: // typedef T Result; -// void operator()(Hypergraph::Edge const& e,const vector& ants, Result* result) const; +// void operator()(HG::Edge const& e,const vector& ants, Result* result) const; // WeightFunction must implement: // typedef prob_t Weight; -// Weight operator()(Hypergraph::Edge const& e) const; +// Weight operator()(HG::Edge const& e) const; template typename WeightFunction::Weight Viterbi(const Hypergraph& hg, typename Traversal::Result* result, @@ -39,9 +39,9 @@ typename WeightFunction::Weight Viterbi(const Hypergraph& hg, *cur_node_best_weight = WeightType(1); continue; } - Hypergraph::Edge const* edge_best=0; + HG::Edge const* edge_best=0; for (unsigned j = 0; j < num_in_edges; ++j) { - const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]]; + const HG::Edge& edge = hg.edges_[cur_node.in_edges_[j]]; WeightType score = weight(edge); for (unsigned k = 0; k < edge.tail_nodes_.size(); ++k) score *= vit_weight[edge.tail_nodes_[k]]; @@ -51,7 +51,7 @@ typename WeightFunction::Weight Viterbi(const Hypergraph& hg, } } assert(edge_best); - Hypergraph::Edge const& edgeb=*edge_best; + HG::Edge const& edgeb=*edge_best; std::vector antsb(edgeb.tail_nodes_.size()); for (unsigned k = 0; k < edgeb.tail_nodes_.size(); ++k) antsb[k] = &vit_result[edgeb.tail_nodes_[k]]; @@ -98,7 +98,7 @@ prob_t Viterbi(const Hypergraph& hg, struct PathLengthTraversal { typedef int Result; - void operator()(const Hypergraph::Edge& edge, + void operator()(const HG::Edge& edge, const std::vector& ants, int* result) const { (void) edge; @@ -109,7 +109,7 @@ struct PathLengthTraversal { struct ESentenceTraversal { typedef std::vector Result; - void operator()(const Hypergraph::Edge& edge, + void operator()(const HG::Edge& edge, const std::vector& ants, Result* result) const { edge.rule_->ESubstitute(ants, result); @@ -118,7 +118,7 @@ struct ESentenceTraversal { struct ELengthTraversal { typedef int Result; - void operator()(const Hypergraph::Edge& edge, + void operator()(const HG::Edge& edge, const std::vector& ants, int* result) const { *result = edge.rule_->ELength() - edge.rule_->Arity(); @@ -128,7 +128,7 @@ struct ELengthTraversal { struct FSentenceTraversal { typedef std::vector Result; - void operator()(const Hypergraph::Edge& edge, + void operator()(const HG::Edge& edge, const std::vector& ants, Result* result) const { edge.rule_->FSubstitute(ants, result); @@ -142,7 +142,7 @@ struct ETreeTraversal { const std::string space; const std::string right; typedef std::vector Result; - void operator()(const Hypergraph::Edge& edge, + void operator()(const HG::Edge& edge, const std::vector& ants, Result* result) const { Result tmp; @@ -162,7 +162,7 @@ struct FTreeTraversal { const std::string space; const std::string right; typedef std::vector Result; - void operator()(const Hypergraph::Edge& edge, + void operator()(const HG::Edge& edge, const std::vector& ants, Result* result) const { Result tmp; @@ -177,8 +177,8 @@ struct FTreeTraversal { }; struct ViterbiPathTraversal { - typedef std::vector Result; - void operator()(const Hypergraph::Edge& edge, + typedef std::vector Result; + void operator()(const HG::Edge& edge, std::vector const& ants, Result* result) const { for (unsigned i = 0; i < ants.size(); ++i) @@ -189,8 +189,8 @@ struct ViterbiPathTraversal { }; struct FeatureVectorTraversal { - typedef FeatureVector Result; - void operator()(Hypergraph::Edge const& edge, + typedef SparseVector Result; + void operator()(HG::Edge const& edge, std::vector const& ants, Result* result) const { for (unsigned i = 0; i < ants.size(); ++i) @@ -210,6 +210,6 @@ int ViterbiELength(const Hypergraph& hg); int ViterbiPathLength(const Hypergraph& hg); /// if weights supplied, assert viterbi prob = features.dot(*weights) (exception if fatal, cerr warn if not). return features (sum over all edges in viterbi derivation) -FeatureVector ViterbiFeatures(Hypergraph const& hg,WeightVector const* weights=0,bool fatal_dotprod_disagreement=false); +SparseVector ViterbiFeatures(Hypergraph const& hg,WeightVector const* weights=0,bool fatal_dotprod_disagreement=false); #endif diff --git a/example_extff/ff_example.cc b/example_extff/ff_example.cc index 51ebf364..4e478ecd 100644 --- a/example_extff/ff_example.cc +++ b/example_extff/ff_example.cc @@ -2,6 +2,8 @@ #include #include +#include "hg.h" + using namespace std; // example of a "stateful" feature made available as an external library -- cgit v1.2.3 From 1fb7bfbbe287e868522613871ed6ca74369ed2a1 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 22 Oct 2012 14:04:27 +0100 Subject: Update search, make it compile --- Makefile.am | 1 + configure.ac | 6 +- decoder/Makefile.am | 3 +- decoder/decoder.cc | 8 +- decoder/incremental.cc | 184 +++++++++++++++++++++++++++++++++++++++ decoder/incremental.h | 11 +++ decoder/lazy.cc | 178 -------------------------------------- decoder/lazy.h | 11 --- dtrain/Makefile.am | 2 +- klm/alone/Jamfile | 4 - klm/alone/assemble.cc | 76 ---------------- klm/alone/assemble.hh | 21 ----- klm/alone/graph.hh | 87 ------------------- klm/alone/just_vocab.cc | 14 --- klm/alone/labeled_edge.hh | 30 ------- klm/alone/main.cc | 85 ------------------ klm/alone/read.cc | 118 ------------------------- klm/alone/read.hh | 29 ------- klm/alone/threading.cc | 80 ----------------- klm/alone/threading.hh | 129 --------------------------- klm/alone/vocab.cc | 19 ---- klm/alone/vocab.hh | 34 -------- klm/lm/model.cc | 2 +- klm/lm/vocab.cc | 4 +- klm/lm/vocab.hh | 2 +- klm/search/Jamfile | 2 +- klm/search/Makefile.am | 11 +++ klm/search/arity.hh | 8 -- klm/search/context.hh | 10 +-- klm/search/edge.hh | 53 ++++++++---- klm/search/edge_generator.cc | 144 ++++++++++++++----------------- klm/search/edge_generator.hh | 49 ++++++----- klm/search/edge_queue.cc | 25 ------ klm/search/edge_queue.hh | 73 ---------------- klm/search/final.hh | 41 ++++----- klm/search/header.hh | 57 ++++++++++++ klm/search/source.hh | 48 ----------- klm/search/types.hh | 8 +- klm/search/vertex.cc | 10 +-- klm/search/vertex.hh | 55 ++++++------ klm/search/vertex_generator.cc | 97 ++++++++++++--------- klm/search/vertex_generator.hh | 33 +++---- klm/util/Makefile.am | 2 + klm/util/ersatz_progress.hh | 2 +- klm/util/exception.hh | 2 +- klm/util/pool.cc | 35 ++++++++ klm/util/pool.hh | 45 ++++++++++ klm/util/probing_hash_table.hh | 2 +- klm/util/string_piece.cc | 192 +++++++++++++++++++++++++++++++++++++++++ klm/util/tokenize_piece.hh | 12 +++ mira/Makefile.am | 2 +- training/Makefile.am | 38 ++++---- 52 files changed, 838 insertions(+), 1356 deletions(-) create mode 100644 decoder/incremental.cc create mode 100644 decoder/incremental.h delete mode 100644 decoder/lazy.cc delete mode 100644 decoder/lazy.h delete mode 100644 klm/alone/Jamfile delete mode 100644 klm/alone/assemble.cc delete mode 100644 klm/alone/assemble.hh delete mode 100644 klm/alone/graph.hh delete mode 100644 klm/alone/just_vocab.cc delete mode 100644 klm/alone/labeled_edge.hh delete mode 100644 klm/alone/main.cc delete mode 100644 klm/alone/read.cc delete mode 100644 klm/alone/read.hh delete mode 100644 klm/alone/threading.cc delete mode 100644 klm/alone/threading.hh delete mode 100644 klm/alone/vocab.cc delete mode 100644 klm/alone/vocab.hh create mode 100644 klm/search/Makefile.am delete mode 100644 klm/search/arity.hh delete mode 100644 klm/search/edge_queue.cc delete mode 100644 klm/search/edge_queue.hh create mode 100644 klm/search/header.hh delete mode 100644 klm/search/source.hh create mode 100644 klm/util/pool.cc create mode 100644 klm/util/pool.hh create mode 100644 klm/util/string_piece.cc (limited to 'decoder/decoder.cc') diff --git a/Makefile.am b/Makefile.am index 3e0103a8..fefc470d 100644 --- a/Makefile.am +++ b/Makefile.am @@ -6,6 +6,7 @@ SUBDIRS = \ mteval \ klm/util \ klm/lm \ + klm/search \ decoder \ training \ training/liblbfgs \ diff --git a/configure.ac b/configure.ac index 03a0ee87..cb132d66 100644 --- a/configure.ac +++ b/configure.ac @@ -12,6 +12,7 @@ AC_PROG_CXX AC_LANG_CPLUSPLUS BOOST_REQUIRE([1.44]) BOOST_PROGRAM_OPTIONS +BOOST_SYSTEM BOOST_TEST AM_PATH_PYTHON AC_CHECK_HEADER(dlfcn.h,AC_DEFINE(HAVE_DLFCN_H)) @@ -73,9 +74,9 @@ fi #BOOST_THREADS CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" -LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS" +LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_SYSTEM_LDFLAGS" # $BOOST_THREAD_LDFLAGS" -LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS" +LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_SYSTEM_LIBS" # $BOOST_THREAD_LIBS" AC_CHECK_HEADER(google/dense_hash_map, @@ -123,6 +124,7 @@ AC_CONFIG_FILES([rampion/Makefile]) AC_CONFIG_FILES([minrisk/Makefile]) AC_CONFIG_FILES([klm/util/Makefile]) AC_CONFIG_FILES([klm/lm/Makefile]) +AC_CONFIG_FILES([klm/search/Makefile]) AC_CONFIG_FILES([mira/Makefile]) AC_CONFIG_FILES([dtrain/Makefile]) AC_CONFIG_FILES([example_extff/Makefile]) diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 5c0a1964..f8f427d3 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -17,7 +17,7 @@ trule_test_SOURCES = trule_test.cc trule_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz cdec_SOURCES = cdec.cc -cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/search/libksearch.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils -I../klm @@ -73,6 +73,7 @@ libcdec_a_SOURCES = \ ff_source_syntax.cc \ ff_bleu.cc \ ff_factory.cc \ + incremental.cc \ lexalign.cc \ lextrans.cc \ tagger.cc \ diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 052823ca..fe812011 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -39,7 +39,7 @@ #include "sampler.h" #include "forest_writer.h" // TODO this section should probably be handled by an Observer -#include "lazy.h" +#include "incremental.h" #include "hg_io.h" #include "aligner.h" @@ -412,7 +412,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") ("show_target_graph", po::value(), "Directory to write the target hypergraphs to") - ("lazy_search", po::value(), "Run lazy search with this language model file") + ("incremental_search", po::value(), "Run lazy search with this language model file") ("coarse_to_fine_beam_prune", po::value(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") ("ctf_beam_widen", po::value()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") ("ctf_num_widenings", po::value()->default_value(2), "Widen coarse beam this many times before backing off to full parse") @@ -828,8 +828,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (conf.count("show_target_graph")) HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); - if (conf.count("lazy_search")) { - PassToLazy(conf["lazy_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); + if (conf.count("incremental_search")) { + PassToIncremental(conf["incremental_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); o->NotifyDecodingComplete(smeta); return true; } diff --git a/decoder/incremental.cc b/decoder/incremental.cc new file mode 100644 index 00000000..768bbd65 --- /dev/null +++ b/decoder/incremental.cc @@ -0,0 +1,184 @@ +#include "incremental.h" + +#include "hg.h" +#include "fdict.h" +#include "tdict.h" + +#include "lm/enumerate_vocab.hh" +#include "lm/model.hh" +#include "search/config.hh" +#include "search/context.hh" +#include "search/edge.hh" +#include "search/edge_generator.hh" +#include "search/rule.hh" +#include "search/vertex.hh" +#include "search/vertex_generator.hh" +#include "util/exception.hh" + +#include +#include + +#include +#include + +namespace { + +struct MapVocab : public lm::EnumerateVocab { + public: + MapVocab() {} + + // Do not call after Lookup. + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_.size()) out_.resize(cdec_id + 1); + out_[cdec_id] = index; + } + + // Assumes Add has been called and will never be called again. + lm::WordIndex FromCDec(WordID id) const { + return out_[out_.size() > id ? id : 0]; + } + + private: + std::vector out_; +}; + +class IncrementalBase { + public: + IncrementalBase(const std::vector &weights) : + cdec_weights_(weights), + weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { + std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; + } + + virtual ~IncrementalBase() {} + + virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; + + static IncrementalBase *Load(const char *model_file, const std::vector &weights); + + protected: + lm::ngram::Config GetConfig() { + lm::ngram::Config ret; + ret.enumerate_vocab = &vocab_; + return ret; + } + + MapVocab vocab_; + + const std::vector &cdec_weights_; + + const search::Weights weights_; +}; + +template class Incremental : public IncrementalBase { + public: + Incremental(const char *model_file, const std::vector &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {} + + void Search(unsigned int pop_limit, const Hypergraph &hg) const; + + private: + void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; + + const Model m_; +}; + +IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector &weights) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + return new Incremental(model_file, weights); + case lm::ngram::REST_PROBING: + return new Incremental(model_file, weights); + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); + } +} + +void PrintFinal(const Hypergraph &hg, const search::Final final) { + const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); + const search::Final *child(final.Children()); + for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { + if (*i > 0) { + std::cout << TD::Convert(*i) << ' '; + } else { + PrintFinal(hg, *child++); + } + } +} + +template void Incremental::Search(unsigned int pop_limit, const Hypergraph &hg) const { + boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); + search::Config config(weights_, pop_limit); + search::Context context(config, m_); + + for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { + search::EdgeGenerator gen; + const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; + for (unsigned int j = 0; j < down_edges.size(); ++j) { + unsigned int edge_index = down_edges[j]; + ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], gen); + } + search::VertexGenerator vertex_gen(context, out_vertices[i]); + gen.Search(context, vertex_gen); + } + const search::Final top = out_vertices[hg.nodes_.size() - 2].BestChild(); + if (top.Valid()) { + std::cout << "NO PATH FOUND" << std::endl; + } else { + PrintFinal(hg, top); + std::cout << "||| " << top.GetScore() << std::endl; + } +} + +template void Incremental::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const { + const std::vector &e = in.rule_->e(); + std::vector words; + words.reserve(e.size()); + std::vector nts; + unsigned int terminals = 0; + float score = 0.0; + for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { + if (*word <= 0) { + nts.push_back(vertices[in.tail_nodes_[-*word]].RootPartial()); + if (nts.back().Empty()) return; + score += nts.back().Bound(); + words.push_back(lm::kMaxWordIndex); + } else { + ++terminals; + words.push_back(vocab_.FromCDec(*word)); + } + } + + if (final) { + words.push_back(m_.GetVocabulary().EndSentence()); + } + + search::PartialEdge out(gen.AllocateEdge(nts.size())); + + memcpy(out.NT(), &nts[0], sizeof(search::PartialVertex) * nts.size()); + + search::Note note; + note.vp = ∈ + out.SetNote(note); + + score += in.rule_->GetFeatureValues().dot(cdec_weights_); + score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; + score += search::ScoreRule(context, words, final, out.Between()); + out.SetScore(score); + + gen.AddEdge(out); +} + +boost::scoped_ptr AwfulGlobalIncremental; + +} // namespace + +void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { + if (!AwfulGlobalIncremental.get()) { + std::cerr << "Pop limit " << pop_limit << std::endl; + AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights)); + } + AwfulGlobalIncremental->Search(pop_limit, hg); +} diff --git a/decoder/incremental.h b/decoder/incremental.h new file mode 100644 index 00000000..180383ce --- /dev/null +++ b/decoder/incremental.h @@ -0,0 +1,11 @@ +#ifndef _INCREMENTAL_H_ +#define _INCREMENTAL_H_ + +#include "weights.h" +#include + +class Hypergraph; + +void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); + +#endif // _INCREMENTAL_H_ diff --git a/decoder/lazy.cc b/decoder/lazy.cc deleted file mode 100644 index 1e6a94fe..00000000 --- a/decoder/lazy.cc +++ /dev/null @@ -1,178 +0,0 @@ -#include "hg.h" -#include "lazy.h" -#include "fdict.h" -#include "tdict.h" - -#include "lm/enumerate_vocab.hh" -#include "lm/model.hh" -#include "search/config.hh" -#include "search/context.hh" -#include "search/edge.hh" -#include "search/edge_queue.hh" -#include "search/vertex.hh" -#include "search/vertex_generator.hh" -#include "util/exception.hh" - -#include -#include - -#include -#include - -namespace { - -struct MapVocab : public lm::EnumerateVocab { - public: - MapVocab() {} - - // Do not call after Lookup. - void Add(lm::WordIndex index, const StringPiece &str) { - const WordID cdec_id = TD::Convert(str.as_string()); - if (cdec_id >= out_.size()) out_.resize(cdec_id + 1); - out_[cdec_id] = index; - } - - // Assumes Add has been called and will never be called again. - lm::WordIndex FromCDec(WordID id) const { - return out_[out_.size() > id ? id : 0]; - } - - private: - std::vector out_; -}; - -class LazyBase { - public: - LazyBase(const std::vector &weights) : - cdec_weights_(weights), - weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { - std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; - } - - virtual ~LazyBase() {} - - virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; - - static LazyBase *Load(const char *model_file, const std::vector &weights); - - protected: - lm::ngram::Config GetConfig() { - lm::ngram::Config ret; - ret.enumerate_vocab = &vocab_; - return ret; - } - - MapVocab vocab_; - - const std::vector &cdec_weights_; - - const search::Weights weights_; -}; - -template class Lazy : public LazyBase { - public: - Lazy(const char *model_file, const std::vector &weights) : LazyBase(weights), m_(model_file, GetConfig()) {} - - void Search(unsigned int pop_limit, const Hypergraph &hg) const; - - private: - unsigned char ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const; - - const Model m_; -}; - -LazyBase *LazyBase::Load(const char *model_file, const std::vector &weights) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - return new Lazy(model_file, weights); - case lm::ngram::REST_PROBING: - return new Lazy(model_file, weights); - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - -void PrintFinal(const Hypergraph &hg, const search::Final &final) { - const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); - boost::array::const_iterator child(final.Children().begin()); - for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { - if (*i > 0) { - std::cout << TD::Convert(*i) << ' '; - } else { - PrintFinal(hg, **child++); - } - } -} - -template void Lazy::Search(unsigned int pop_limit, const Hypergraph &hg) const { - boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); - search::Config config(weights_, pop_limit); - search::Context context(config, m_); - - for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { - search::EdgeQueue queue(context.PopLimit()); - const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; - for (unsigned int j = 0; j < down_edges.size(); ++j) { - unsigned int edge_index = down_edges[j]; - unsigned char arity = ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], queue.InitializeEdge()); - search::Note note; - note.vp = &hg.edges_[edge_index]; - if (arity != 255) queue.AddEdge(arity, note); - } - search::VertexGenerator vertex_gen(context, out_vertices[i]); - queue.Search(context, vertex_gen); - } - const search::Final *top = out_vertices[hg.nodes_.size() - 2].BestChild(); - if (!top) { - std::cout << "NO PATH FOUND" << std::endl; - } else { - PrintFinal(hg, *top); - std::cout << "||| " << top->Bound() << std::endl; - } -} - -template unsigned char Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const { - const std::vector &e = in.rule_->e(); - std::vector words; - unsigned int terminals = 0; - unsigned char nt = 0; - out.score = 0.0; - for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { - if (*word <= 0) { - out.nt[nt] = vertices[in.tail_nodes_[-*word]].RootPartial(); - if (out.nt[nt].Empty()) return 255; - out.score += out.nt[nt].Bound(); - ++nt; - words.push_back(lm::kMaxWordIndex); - } else { - ++terminals; - words.push_back(vocab_.FromCDec(*word)); - } - } - for (unsigned char fill = nt; fill < search::kMaxArity; ++fill) { - out.nt[fill] = search::kBlankPartialVertex; - } - - if (final) { - words.push_back(m_.GetVocabulary().EndSentence()); - } - - out.score += in.rule_->GetFeatureValues().dot(cdec_weights_); - out.score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; - out.score += search::ScoreRule(context, words, final, out.between); - return nt; -} - -boost::scoped_ptr AwfulGlobalLazy; - -} // namespace - -void PassToLazy(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { - if (!AwfulGlobalLazy.get()) { - std::cerr << "Pop limit " << pop_limit << std::endl; - AwfulGlobalLazy.reset(LazyBase::Load(model_file, weights)); - } - AwfulGlobalLazy->Search(pop_limit, hg); -} diff --git a/decoder/lazy.h b/decoder/lazy.h deleted file mode 100644 index 94895b19..00000000 --- a/decoder/lazy.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef _LAZY_H_ -#define _LAZY_H_ - -#include "weights.h" -#include - -class Hypergraph; - -void PassToLazy(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); - -#endif // _LAZY_H_ diff --git a/dtrain/Makefile.am b/dtrain/Makefile.am index 64fef489..ca9581f5 100644 --- a/dtrain/Makefile.am +++ b/dtrain/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = dtrain dtrain_SOURCES = dtrain.cc score.cc -dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/klm/alone/Jamfile b/klm/alone/Jamfile deleted file mode 100644 index 2cc90c05..00000000 --- a/klm/alone/Jamfile +++ /dev/null @@ -1,4 +0,0 @@ -lib standalone : assemble.cc read.cc threading.cc vocab.cc ../lm//kenlm ../util//kenutil ../search//search : .. : : .. ../search//search ../lm//kenlm ; - -exe decode : main.cc standalone main.cc : multi:..//boost_thread ; -exe just_vocab : just_vocab.cc standalone : multi:..//boost_thread ; diff --git a/klm/alone/assemble.cc b/klm/alone/assemble.cc deleted file mode 100644 index 2ae72ce9..00000000 --- a/klm/alone/assemble.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "alone/assemble.hh" - -#include "alone/labeled_edge.hh" -#include "search/final.hh" - -#include - -namespace alone { - -std::ostream &operator<<(std::ostream &o, const search::Final &final) { - const std::vector &words = static_cast(final.From()).Words(); - if (words.empty()) return o; - const search::Final *const *child = final.Children().data(); - std::vector::const_iterator i(words.begin()); - for (; i != words.end() - 1; ++i) { - if (*i) { - o << **i << ' '; - } else { - o << **child << ' '; - ++child; - } - } - - if (*i) { - if (**i != "") { - o << **i; - } - } else { - o << **child; - } - - return o; -} - -namespace { - -void MakeIndent(std::ostream &o, const char *indent_str, unsigned int level) { - for (unsigned int i = 0; i < level; ++i) - o << indent_str; -} - -void DetailedFinalInternal(std::ostream &o, const search::Final &final, const char *indent_str, unsigned int indent) { - o << "(\n"; - MakeIndent(o, indent_str, indent); - const std::vector &words = static_cast(final.From()).Words(); - const search::Final *const *child = final.Children().data(); - for (std::vector::const_iterator i(words.begin()); i != words.end(); ++i) { - if (*i) { - o << **i; - if (i == words.end() - 1) { - o << '\n'; - MakeIndent(o, indent_str, indent); - } else { - o << ' '; - } - } else { - // One extra indent from the line we're currently on. - o << indent_str; - DetailedFinalInternal(o, **child, indent_str, indent + 1); - for (unsigned int i = 0; i < indent; ++i) o << indent_str; - ++child; - } - } - o << ")=" << final.Bound() << '\n'; -} -} // namespace - -void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str) { - DetailedFinalInternal(o, final, indent_str, 0); -} - -void PrintFinal(const search::Final &final) { - std::cout << final << std::endl; -} - -} // namespace alone diff --git a/klm/alone/assemble.hh b/klm/alone/assemble.hh deleted file mode 100644 index e6b0ad5c..00000000 --- a/klm/alone/assemble.hh +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef ALONE_ASSEMBLE__ -#define ALONE_ASSEMBLE__ - -#include - -namespace search { -class Final; -} // namespace search - -namespace alone { - -std::ostream &operator<<(std::ostream &o, const search::Final &final); - -void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str = " "); - -// This isn't called anywhere but makes it easy to print from gdb. -void PrintFinal(const search::Final &final); - -} // namespace alone - -#endif // ALONE_ASSEMBLE__ diff --git a/klm/alone/graph.hh b/klm/alone/graph.hh deleted file mode 100644 index 788352c9..00000000 --- a/klm/alone/graph.hh +++ /dev/null @@ -1,87 +0,0 @@ -#ifndef ALONE_GRAPH__ -#define ALONE_GRAPH__ - -#include "alone/labeled_edge.hh" -#include "search/rule.hh" -#include "search/types.hh" -#include "search/vertex.hh" -#include "util/exception.hh" - -#include -#include -#include - -namespace alone { - -template class FixedAllocator : boost::noncopyable { - public: - FixedAllocator() : current_(NULL), end_(NULL) {} - - void Init(std::size_t count) { - assert(!current_); - array_.reset(new T[count]); - current_ = array_.get(); - end_ = current_ + count; - } - - T &operator[](std::size_t idx) { - return array_.get()[idx]; - } - - T *New() { - T *ret = current_++; - UTIL_THROW_IF(ret >= end_, util::Exception, "Allocating past end"); - return ret; - } - - std::size_t Size() const { - return end_ - array_.get(); - } - - private: - boost::scoped_array array_; - T *current_, *end_; -}; - -class Graph : boost::noncopyable { - public: - typedef LabeledEdge Edge; - typedef search::Vertex Vertex; - - Graph() {} - - void SetCounts(std::size_t vertices, std::size_t edges) { - vertices_.Init(vertices); - edges_.Init(edges); - } - - Vertex *NewVertex() { - return vertices_.New(); - } - - std::size_t VertexSize() const { return vertices_.Size(); } - - Vertex &MutableVertex(std::size_t index) { - return vertices_[index]; - } - - Edge *NewEdge() { - return edges_.New(); - } - - std::size_t EdgeSize() const { return edges_.Size(); } - - void SetRoot(Vertex *root) { root_ = root; } - - Vertex &Root() { return *root_; } - - private: - FixedAllocator vertices_; - FixedAllocator edges_; - - Vertex *root_; -}; - -} // namespace alone - -#endif // ALONE_GRAPH__ diff --git a/klm/alone/just_vocab.cc b/klm/alone/just_vocab.cc deleted file mode 100644 index 35aea5ed..00000000 --- a/klm/alone/just_vocab.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "alone/read.hh" -#include "util/file_piece.hh" - -#include - -int main() { - util::FilePiece f(0, "stdin", &std::cerr); - while (true) { - try { - alone::JustVocab(f, std::cout); - } catch (const util::EndOfFileException &e) { break; } - std::cout << '\n'; - } -} diff --git a/klm/alone/labeled_edge.hh b/klm/alone/labeled_edge.hh deleted file mode 100644 index 94d8cbdf..00000000 --- a/klm/alone/labeled_edge.hh +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef ALONE_LABELED_EDGE__ -#define ALONE_LABELED_EDGE__ - -#include "search/edge.hh" - -#include -#include - -namespace alone { - -class LabeledEdge : public search::Edge { - public: - LabeledEdge() {} - - void AppendWord(const std::string *word) { - words_.push_back(word); - } - - const std::vector &Words() const { - return words_; - } - - private: - // NULL for non-terminals. - std::vector words_; -}; - -} // namespace alone - -#endif // ALONE_LABELED_EDGE__ diff --git a/klm/alone/main.cc b/klm/alone/main.cc deleted file mode 100644 index e09ab01d..00000000 --- a/klm/alone/main.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "alone/threading.hh" -#include "search/config.hh" -#include "search/context.hh" -#include "util/exception.hh" -#include "util/file_piece.hh" -#include "util/usage.hh" - -#include - -#include -#include - -namespace alone { - -template void ReadLoop(const std::string &graph_prefix, Control &control) { - for (unsigned int sentence = 0; ; ++sentence) { - std::stringstream name; - name << graph_prefix << '/' << sentence; - std::auto_ptr file; - try { - file.reset(new util::FilePiece(name.str().c_str())); - } catch (const util::ErrnoException &e) { - if (e.Error() == ENOENT) return; - throw; - } - control.Add(file.release()); - } -} - -template void RunWithModelType(const char *graph_prefix, const char *model_file, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { - Model model(model_file); - search::Weights weights(weight_str); - search::Config config(weights, pop_limit); - - if (threads > 1) { -#ifdef WITH_THREADS - Controller controller(config, model, threads, std::cout); - ReadLoop(graph_prefix, controller); -#else - UTIL_THROW(util::Exception, "Threading support not compiled in."); -#endif - } else { - InThread controller(config, model, std::cout); - ReadLoop(graph_prefix, controller); - } -} - -void Run(const char *graph_prefix, const char *lm_name, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - RunWithModelType(graph_prefix, lm_name, weight_str, pop_limit, threads); - break; - case lm::ngram::REST_PROBING: - RunWithModelType(graph_prefix, lm_name, weight_str, pop_limit, threads); - break; - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - -} // namespace alone - -int main(int argc, char *argv[]) { - if (argc < 5 || argc > 6) { - std::cerr << argv[0] << " graph_prefix lm \"weights\" pop [threads]" << std::endl; - return 1; - } - -#ifdef WITH_THREADS - unsigned thread_count = boost::thread::hardware_concurrency(); -#else - unsigned thread_count = 1; -#endif - if (argc == 6) { - thread_count = boost::lexical_cast(argv[5]); - UTIL_THROW_IF(!thread_count, util::Exception, "Thread count 0"); - } - UTIL_THROW_IF(!thread_count, util::Exception, "Boost doesn't know how many threads there are. Pass it on the command line."); - alone::Run(argv[1], argv[2], argv[3], boost::lexical_cast(argv[4]), thread_count); - - util::PrintUsage(std::cerr); - return 0; -} diff --git a/klm/alone/read.cc b/klm/alone/read.cc deleted file mode 100644 index 0b20be35..00000000 --- a/klm/alone/read.cc +++ /dev/null @@ -1,118 +0,0 @@ -#include "alone/read.hh" - -#include "alone/graph.hh" -#include "alone/vocab.hh" -#include "search/arity.hh" -#include "search/context.hh" -#include "search/weights.hh" -#include "util/file_piece.hh" - -#include -#include - -#include - -namespace alone { - -namespace { - -template Graph::Edge &ReadEdge(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab, bool final) { - Graph::Edge *ret = to.NewEdge(); - - StringPiece got; - - std::vector words; - unsigned long int terminals = 0; - while ("|||" != (got = from.ReadDelimited())) { - if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) { - // non-terminal - char *end_ptr; - unsigned long int child = std::strtoul(got.data() + 1, &end_ptr, 10); - UTIL_THROW_IF(end_ptr != got.data() + got.size() - 1, FormatException, "Bad non-terminal" << got); - UTIL_THROW_IF(child >= to.VertexSize(), FormatException, "Reference to vertex " << child << " but we only have " << to.VertexSize() << " vertices. Is the file in bottom-up format?"); - ret->Add(to.MutableVertex(child)); - words.push_back(lm::kMaxWordIndex); - ret->AppendWord(NULL); - } else { - const std::pair &found = vocab.FindOrAdd(got); - words.push_back(found.second); - ret->AppendWord(&found.first); - ++terminals; - } - } - if (final) { - // This is not counted for the word penalty. - words.push_back(vocab.EndSentence().second); - ret->AppendWord(&vocab.EndSentence().first); - } - // Hard-coded word penalty. - float additive = context.GetWeights().DotNoLM(from.ReadLine()) - context.GetWeights().WordPenalty() * static_cast(terminals) / M_LN10; - ret->InitRule().Init(context, additive, words, final); - unsigned int arity = ret->GetRule().Arity(); - UTIL_THROW_IF(arity > search::kMaxArity, util::Exception, "Edit search/arity.hh and increase " << search::kMaxArity << " to at least " << arity); - return *ret; -} - -} // namespace - -// TODO: refactor -void JustVocab(util::FilePiece &from, std::ostream &out) { - boost::unordered_set seen; - unsigned long int vertices = from.ReadULong(); - from.ReadULong(); // edges - UTIL_THROW_IF(vertices == 0, FormatException, "Vertex count is zero"); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); - std::string temp; - for (unsigned long int i = 0; i < vertices; ++i) { - unsigned long int edge_count = from.ReadULong(); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); - for (unsigned long int e = 0; e < edge_count; ++e) { - StringPiece got; - while ("|||" != (got = from.ReadDelimited())) { - if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) continue; - temp.assign(got.data(), got.size()); - if (seen.insert(temp).second) out << temp << ' '; - } - from.ReadLine(); // weights - } - } - // Eat sentence - from.ReadLine(); -} - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab) { - unsigned long int vertices; - try { - vertices = from.ReadULong(); - } catch (const util::EndOfFileException &e) { return false; } - unsigned long int edges = from.ReadULong(); - UTIL_THROW_IF(vertices < 2, FormatException, "Vertex count is " << vertices); - UTIL_THROW_IF(edges == 0, FormatException, "Edge count is " << edges); - --vertices; - --edges; - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); - to.SetCounts(vertices, edges); - Graph::Vertex *vertex; - for (unsigned long int i = 0; ; ++i) { - vertex = to.NewVertex(); - unsigned long int edge_count = from.ReadULong(); - bool root = (i == vertices - 1); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); - for (unsigned long int e = 0; e < edge_count; ++e) { - vertex->Add(ReadEdge(context, from, to, vocab, root)); - } - vertex->FinishedAdding(); - if (root) break; - } - to.SetRoot(vertex); - StringPiece str = from.ReadLine(); - UTIL_THROW_IF("1" != str, FormatException, "Expected one edge to root"); - // The edge - from.ReadLine(); - return true; -} - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); - -} // namespace alone diff --git a/klm/alone/read.hh b/klm/alone/read.hh deleted file mode 100644 index 10769a86..00000000 --- a/klm/alone/read.hh +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef ALONE_READ__ -#define ALONE_READ__ - -#include "util/exception.hh" - -#include - -namespace util { class FilePiece; } - -namespace search { template class Context; } - -namespace alone { - -class Graph; -class Vocab; - -class FormatException : public util::Exception { - public: - FormatException() {} - ~FormatException() throw() {} -}; - -void JustVocab(util::FilePiece &from, std::ostream &to); - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); - -} // namespace alone - -#endif // ALONE_READ__ diff --git a/klm/alone/threading.cc b/klm/alone/threading.cc deleted file mode 100644 index 475386b6..00000000 --- a/klm/alone/threading.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include "alone/threading.hh" - -#include "alone/assemble.hh" -#include "alone/graph.hh" -#include "alone/read.hh" -#include "alone/vocab.hh" -#include "lm/model.hh" -#include "search/context.hh" -#include "search/vertex_generator.hh" - -#include -#include -#include - -#include - -namespace alone { -template void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out) { - search::Context context(config, model); - Graph graph; - Vocab vocab(model.GetVocabulary()); - { - boost::scoped_ptr in(in_ptr); - ReadCDec(context, *in, graph, vocab); - } - - for (std::size_t i = 0; i < graph.VertexSize(); ++i) { - search::VertexGenerator(context, graph.MutableVertex(i)); - } - search::PartialVertex top = graph.Root().RootPartial(); - if (top.Empty()) { - out << "NO PATH FOUND"; - } else { - search::PartialVertex continuation; - while (!top.Complete()) { - top.Split(continuation); - top = continuation; - } - out << top.End() << " ||| " << top.End().Bound() << std::endl; - } -} - -template void Decode(const search::Config &config, const lm::ngram::ProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); -template void Decode(const search::Config &config, const lm::ngram::RestProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); - -#ifdef WITH_THREADS -template void DecodeHandler::operator()(Input message) { - std::stringstream assemble; - Decode(config_, model_, message.file, assemble); - Produce(message.sentence_id, assemble.str()); -} - -template void DecodeHandler::Produce(unsigned int sentence_id, const std::string &str) { - Output out; - out.sentence_id = sentence_id; - out.str = new std::string(str); - out_.Produce(out); -} - -void PrintHandler::operator()(Output message) { - unsigned int relative = message.sentence_id - done_; - if (waiting_.size() <= relative) waiting_.resize(relative + 1); - waiting_[relative] = message.str; - for (std::string *lead; !waiting_.empty() && (lead = waiting_[0]); waiting_.pop_front(), ++done_) { - out_ << *lead; - delete lead; - } -} - -template Controller::Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to) : - sentence_id_(0), - printer_(decode_workers, 1, boost::ref(to), Output::Poison()), - decoder_(3, decode_workers, boost::in_place(boost::ref(config), boost::ref(model), boost::ref(printer_.In())), Input::Poison()) {} - -template class Controller; -template class Controller; - -#endif - -} // namespace alone diff --git a/klm/alone/threading.hh b/klm/alone/threading.hh deleted file mode 100644 index 0ab0f739..00000000 --- a/klm/alone/threading.hh +++ /dev/null @@ -1,129 +0,0 @@ -#ifndef ALONE_THREADING__ -#define ALONE_THREADING__ - -#ifdef WITH_THREADS -#include "util/pcqueue.hh" -#include "util/pool.hh" -#endif - -#include -#include -#include - -namespace util { -class FilePiece; -} // namespace util - -namespace search { -class Config; -template class Context; -} // namespace search - -namespace alone { - -template void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out); - -class Graph; - -#ifdef WITH_THREADS -struct SentenceID { - unsigned int sentence_id; - bool operator==(const SentenceID &other) const { - return sentence_id == other.sentence_id; - } -}; - -struct Input : public SentenceID { - util::FilePiece *file; - static Input Poison() { - Input ret; - ret.sentence_id = static_cast(-1); - ret.file = NULL; - return ret; - } -}; - -struct Output : public SentenceID { - std::string *str; - static Output Poison() { - Output ret; - ret.sentence_id = static_cast(-1); - ret.str = NULL; - return ret; - } -}; - -template class DecodeHandler { - public: - typedef Input Request; - - DecodeHandler(const search::Config &config, const Model &model, util::PCQueue &out) : config_(config), model_(model), out_(out) {} - - void operator()(Input message); - - private: - void Produce(unsigned int sentence_id, const std::string &str); - - const search::Config &config_; - - const Model &model_; - - util::PCQueue &out_; -}; - -class PrintHandler { - public: - typedef Output Request; - - explicit PrintHandler(std::ostream &o) : out_(o), done_(0) {} - - void operator()(Output message); - - private: - std::ostream &out_; - std::deque waiting_; - unsigned int done_; -}; - -template class Controller { - public: - // This config must remain valid. - explicit Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to); - - // Takes ownership of in. - void Add(util::FilePiece *in) { - Input input; - input.sentence_id = sentence_id_++; - input.file = in; - decoder_.Produce(input); - } - - private: - unsigned int sentence_id_; - - util::Pool printer_; - - util::Pool > decoder_; -}; -#endif - -// Same API as controller. -template class InThread { - public: - InThread(const search::Config &config, const Model &model, std::ostream &to) : config_(config), model_(model), to_(to) {} - - // Takes ownership of in. - void Add(util::FilePiece *in) { - Decode(config_, model_, in, to_); - } - - private: - const search::Config &config_; - - const Model &model_; - - std::ostream &to_; -}; - -} // namespace alone -#endif // ALONE_THREADING__ diff --git a/klm/alone/vocab.cc b/klm/alone/vocab.cc deleted file mode 100644 index ffe55301..00000000 --- a/klm/alone/vocab.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "alone/vocab.hh" - -#include "lm/virtual_interface.hh" -#include "util/string_piece.hh" - -namespace alone { - -Vocab::Vocab(const lm::base::Vocabulary &backing) : backing_(backing), end_sentence_(FindOrAdd("")) {} - -const std::pair &Vocab::FindOrAdd(const StringPiece &str) { - Map::const_iterator i(FindStringPiece(map_, str)); - if (i != map_.end()) return *i; - std::pair to_ins; - to_ins.first.assign(str.data(), str.size()); - to_ins.second = backing_.Index(str); - return *map_.insert(to_ins).first; -} - -} // namespace alone diff --git a/klm/alone/vocab.hh b/klm/alone/vocab.hh deleted file mode 100644 index 3ac0f542..00000000 --- a/klm/alone/vocab.hh +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef ALONE_VOCAB__ -#define ALONE_VOCAB__ - -#include "lm/word_index.hh" -#include "util/string_piece.hh" - -#include -#include - -#include - -namespace lm { namespace base { class Vocabulary; } } - -namespace alone { - -class Vocab { - public: - explicit Vocab(const lm::base::Vocabulary &backing); - - const std::pair &FindOrAdd(const StringPiece &str); - - const std::pair &EndSentence() const { return end_sentence_; } - - private: - typedef boost::unordered_map Map; - Map map_; - - const lm::base::Vocabulary &backing_; - - const std::pair &end_sentence_; -}; - -} // namespace alone -#endif // ALONE_VCOAB__ diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 40af8a63..2fd20481 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -87,7 +87,7 @@ template void GenericModel.. ; +lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : .. ; import testing ; diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am new file mode 100644 index 00000000..ccc5b7f6 --- /dev/null +++ b/klm/search/Makefile.am @@ -0,0 +1,11 @@ +noinst_LIBRARIES = libksearch.a + +libksearch_a_SOURCES = \ + edge_generator.cc \ + rule.cc \ + vertex.cc \ + vertex_generator.cc \ + weights.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. + diff --git a/klm/search/arity.hh b/klm/search/arity.hh deleted file mode 100644 index 09c2c671..00000000 --- a/klm/search/arity.hh +++ /dev/null @@ -1,8 +0,0 @@ -#ifndef SEARCH_ARITY__ -#define SEARCH_ARITY__ -namespace search { - -const unsigned int kMaxArity = 2; - -} // namespace search -#endif // SEARCH_ARITY__ diff --git a/klm/search/context.hh b/klm/search/context.hh index 27940053..62163144 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -7,6 +7,7 @@ #include "search/types.hh" #include "search/vertex.hh" #include "util/exception.hh" +#include "util/pool.hh" #include #include @@ -21,10 +22,8 @@ class ContextBase { public: explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - Final *NewFinal() { - Final *ret = final_pool_.construct(); - assert(ret); - return ret; + util::Pool &FinalPool() { + return final_pool_; } VertexNode *NewVertexNode() { @@ -42,7 +41,8 @@ class ContextBase { const Weights &GetWeights() const { return weights_; } private: - boost::object_pool final_pool_; + util::Pool final_pool_; + boost::object_pool vertex_node_pool_; unsigned int pop_limit_; diff --git a/klm/search/edge.hh b/klm/search/edge.hh index 77ab0ade..187904bf 100644 --- a/klm/search/edge.hh +++ b/klm/search/edge.hh @@ -2,30 +2,53 @@ #define SEARCH_EDGE__ #include "lm/state.hh" -#include "search/arity.hh" -#include "search/rule.hh" +#include "search/header.hh" #include "search/types.hh" #include "search/vertex.hh" +#include "util/pool.hh" -#include +#include + +#include namespace search { -struct PartialEdge { - Score score; - // Terminals - lm::ngram::ChartState between[kMaxArity + 1]; - // Non-terminals - PartialVertex nt[kMaxArity]; +// Copyable, but the copy will be shallow. +class PartialEdge : public Header { + public: + // Allow default construction for STL. + PartialEdge() {} + + PartialEdge(util::Pool &pool, Arity arity) + : Header(pool.Allocate(Size(arity, arity + 1)), arity) {} + + PartialEdge(util::Pool &pool, Arity arity, Arity chart_states) + : Header(pool.Allocate(Size(arity, chart_states)), arity) {} - const lm::ngram::ChartState &CompletedState() const { - return between[0]; - } + // Non-terminals + const PartialVertex *NT() const { + return reinterpret_cast(After()); + } + PartialVertex *NT() { + return reinterpret_cast(After()); + } - bool operator<(const PartialEdge &other) const { - return score < other.score; - } + const lm::ngram::ChartState &CompletedState() const { + return *Between(); + } + const lm::ngram::ChartState *Between() const { + return reinterpret_cast(After() + GetArity() * sizeof(PartialVertex)); + } + lm::ngram::ChartState *Between() { + return reinterpret_cast(After() + GetArity() * sizeof(PartialVertex)); + } + + private: + static std::size_t Size(Arity arity, Arity chart_states) { + return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState); + } }; + } // namespace search #endif // SEARCH_EDGE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index 56239dfb..260159b1 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -4,117 +4,107 @@ #include "lm/partial.hh" #include "search/context.hh" #include "search/vertex.hh" -#include "search/vertex_generator.hh" #include namespace search { -EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { -/* for (unsigned char i = 0; i < edge.Arity(); ++i) { - root.nt[i] = edge.GetVertex(i).RootPartial(); - } - for (unsigned char i = edge.Arity(); i < 2; ++i) { - root.nt[i] = kBlankPartialVertex; - }*/ - generate_.push(&root); - top_score_ = root.score; -} - namespace { -template float FastScore(const Context &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) { - memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1)); - - float ret = 0.0; - lm::ngram::ChartState *before, *after; - if (victim == 0) { - before = &update.between[0]; - after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1]; - } else { - assert(victim == 1); - assert(arity == 2); - before = &update.between[previous.nt[0].Complete() ? 0 : 1]; - after = &update.between[2]; - } - const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State(); - const PartialVertex &update_nt = update.nt[victim]; +template void FastScore(const Context &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) { + lm::ngram::ChartState *between = update.Between(); + lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1]; + + float adjustment = 0.0; + const lm::ngram::ChartState &previous_reveal = previous_vertex.State(); + const PartialVertex &update_nt = update.NT()[victim]; const lm::ngram::ChartState &update_reveal = update_nt.State(); - float just_after = 0.0; if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { - just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); + adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); } - if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) { - ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); + if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) { + adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); } if (update_nt.Complete()) { if (update_reveal.left.full) { before->left.full = true; } else { assert(update_reveal.left.length == update_reveal.right.length); - ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); + adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); } - if (victim == 0) { - update.between[0].right = after->right; - } else { - update.between[2].left = before->left; + before->right = after->right; + // Shift the others shifted one down, covering after. + for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) { + *cover = *(cover + 1); } } - return previous.score + (ret + just_after) * context.GetWeights().LM(); + update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); } } // namespace -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool) { +template PartialEdge EdgeGenerator::Pop(Context &context) { assert(!generate_.empty()); - PartialEdge &top = *generate_.top(); + PartialEdge top = generate_.top(); generate_.pop(); - unsigned int victim = 0; - unsigned char lowest_length = 255; - for (unsigned char i = 0; i != arity_; ++i) { - if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { - lowest_length = top.nt[i].Length(); - victim = i; + PartialVertex *const top_nt = top.NT(); + const Arity arity = top.GetArity(); + + Arity victim = 0; + Arity victim_completed; + Arity incomplete; + // Select victim or return if complete. + { + Arity completed = 0; + unsigned char lowest_length = 255; + for (Arity i = 0; i != arity; ++i) { + if (top_nt[i].Complete()) { + ++completed; + } else if (top_nt[i].Length() < lowest_length) { + lowest_length = top_nt[i].Length(); + victim = i; + victim_completed = completed; + } } - } - if (lowest_length == 255) { - // All states report complete. - top.between[0].right = top.between[arity_].right; - // Now top.between[0] is the full edge state. - top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; - return ⊤ + if (lowest_length == 255) { + return top; + } + incomplete = arity - completed; } - unsigned int stay = !victim; - PartialEdge &continuation = *static_cast(partial_edge_pool.malloc()); - float old_bound = top.nt[victim].Bound(); - // The alternate's score will change because alternate.nt[victim] changes. - bool split = top.nt[victim].Split(continuation.nt[victim]); - // top is now the alternate. + PartialVertex old_value(top_nt[victim]); + PartialVertex alternate_changed; + if (top_nt[victim].Split(alternate_changed)) { + PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1); + alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound()); - continuation.nt[stay] = top.nt[stay]; - continuation.score = FastScore(context, victim, arity_, top, continuation); - // TODO: dedupe? - generate_.push(&continuation); + alternate.SetNote(top.GetNote()); + + PartialVertex *alternate_nt = alternate.NT(); + for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i]; + alternate_nt[victim] = alternate_changed; + for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i]; + + memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1)); - if (split) { - // We have an alternate. - top.score += top.nt[victim].Bound() - old_bound; // TODO: dedupe? - generate_.push(&top); - } else { - partial_edge_pool.free(&top); + generate_.push(alternate); } - top_score_ = generate_.top()->score; - return NULL; + // top is now the continuation. + FastScore(context, victim, victim - victim_completed, incomplete, old_value, top); + // TODO: dedupe? + generate_.push(top); + + // Invalid indicates no new hypothesis generated. + return PartialEdge(); } -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); } // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index 875ccc5e..582c78b7 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -3,11 +3,8 @@ #include "search/edge.hh" #include "search/note.hh" +#include "search/types.hh" -#include -#include - -#include #include namespace lm { @@ -20,38 +17,40 @@ namespace search { template class Context; -class VertexGenerator; - -struct PartialEdgePointerLess : std::binary_function { - bool operator()(const PartialEdge *first, const PartialEdge *second) const { - return *first < *second; - } -}; - class EdgeGenerator { public: - EdgeGenerator(PartialEdge &root, unsigned char arity, Note note); + EdgeGenerator() {} - Score TopScore() const { - return top_score_; + PartialEdge AllocateEdge(Arity arity) { + return PartialEdge(partial_edge_pool_, arity); } - Note GetNote() const { - return note_; + void AddEdge(PartialEdge edge) { + generate_.push(edge); } - // Pop. If there's a complete hypothesis, return it. Otherwise return NULL. - template PartialEdge *Pop(Context &context, boost::pool<> &partial_edge_pool); + bool Empty() const { return generate_.empty(); } + + // Pop. If there's a complete hypothesis, return it. Otherwise return an invalid PartialEdge. + template PartialEdge Pop(Context &context); + + template void Search(Context &context, Output &output) { + unsigned to_pop = context.PopLimit(); + while (to_pop > 0 && !generate_.empty()) { + PartialEdge got(Pop(context)); + if (got.Valid()) { + output.NewHypothesis(got); + --to_pop; + } + } + output.FinishedSearch(); + } private: - Score top_score_; - - unsigned char arity_; + util::Pool partial_edge_pool_; - typedef std::priority_queue, PartialEdgePointerLess> Generate; + typedef std::priority_queue Generate; Generate generate_; - - Note note_; }; } // namespace search diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc deleted file mode 100644 index e3ae6ebf..00000000 --- a/klm/search/edge_queue.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "search/edge_queue.hh" - -#include "lm/left.hh" -#include "search/context.hh" - -#include - -namespace search { - -EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) { - take_ = static_cast(partial_edge_pool_.malloc()); -} - -/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) { - // Ignore empty edges. - for (unsigned char i = 0; i < edge.Arity(); ++i) { - PartialVertex root(edge.GetVertex(i).RootPartial()); - if (root.Empty()) return; - total_score += root.Bound(); - } - PartialEdge &allocated = *static_cast(partial_edge_pool_.malloc()); - allocated.score = total_score; -}*/ - -} // namespace search diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh deleted file mode 100644 index 187eaed7..00000000 --- a/klm/search/edge_queue.hh +++ /dev/null @@ -1,73 +0,0 @@ -#ifndef SEARCH_EDGE_QUEUE__ -#define SEARCH_EDGE_QUEUE__ - -#include "search/edge.hh" -#include "search/edge_generator.hh" -#include "search/note.hh" - -#include -#include - -#include - -namespace search { - -template class Context; - -class EdgeQueue { - public: - explicit EdgeQueue(unsigned int pop_limit_hint); - - PartialEdge &InitializeEdge() { - return *take_; - } - - void AddEdge(unsigned char arity, Note note) { - generate_.push(edge_pool_.construct(*take_, arity, note)); - take_ = static_cast(partial_edge_pool_.malloc()); - } - - bool Empty() const { return generate_.empty(); } - - /* Generate hypotheses and send them to output. Normally, output is a - * VertexGenerator, but the decoder may want to route edges to different - * vertices i.e. if they have different LHS non-terminal labels. - */ - template void Search(Context &context, Output &output) { - int to_pop = context.PopLimit(); - while (to_pop > 0 && !generate_.empty()) { - EdgeGenerator *top = generate_.top(); - generate_.pop(); - PartialEdge *ret = top->Pop(context, partial_edge_pool_); - if (ret) { - output.NewHypothesis(*ret, top->GetNote()); - --to_pop; - if (top->TopScore() != -kScoreInf) { - generate_.push(top); - } - } else { - generate_.push(top); - } - } - output.FinishedSearch(); - } - - private: - boost::object_pool edge_pool_; - - struct LessByTopScore : public std::binary_function { - bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { - return first->TopScore() < second->TopScore(); - } - }; - - typedef std::priority_queue, LessByTopScore> Generate; - Generate generate_; - - boost::pool<> partial_edge_pool_; - - PartialEdge *take_; -}; - -} // namespace search -#endif // SEARCH_EDGE_QUEUE__ diff --git a/klm/search/final.hh b/klm/search/final.hh index 1b3092ac..50e62cf2 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -1,37 +1,34 @@ #ifndef SEARCH_FINAL__ #define SEARCH_FINAL__ -#include "search/arity.hh" -#include "search/note.hh" -#include "search/types.hh" - -#include +#include "search/header.hh" +#include "util/pool.hh" namespace search { -class Final { +// A full hypothesis with pointers to children. +class Final : public Header { public: - typedef boost::array ChildArray; + Final() {} - void Reset(Score bound, Note note, const Final &left, const Final &right) { - bound_ = bound; - note_ = note; - children_[0] = &left; - children_[1] = &right; + Final(util::Pool &pool, Score score, Arity arity, Note note) + : Header(pool.Allocate(Size(arity)), arity) { + SetScore(score); + SetNote(note); } - const ChildArray &Children() const { return children_; } - - Note GetNote() const { return note_; } - - Score Bound() const { return bound_; } + // These are arrays of length GetArity(). + Final *Children() { + return reinterpret_cast(After()); + } + const Final *Children() const { + return reinterpret_cast(After()); + } private: - Score bound_; - - Note note_; - - ChildArray children_; + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Final); + } }; } // namespace search diff --git a/klm/search/header.hh b/klm/search/header.hh new file mode 100644 index 00000000..25550dbe --- /dev/null +++ b/klm/search/header.hh @@ -0,0 +1,57 @@ +#ifndef SEARCH_HEADER__ +#define SEARCH_HEADER__ + +// Header consisting of Score, Arity, and Note + +#include "search/note.hh" +#include "search/types.hh" + +#include + +namespace search { + +// Copying is shallow. +class Header { + public: + bool Valid() const { return base_; } + + Score GetScore() const { + return *reinterpret_cast(base_); + } + void SetScore(Score to) { + *reinterpret_cast(base_) = to; + } + bool operator<(const Header &other) const { + return GetScore() < other.GetScore(); + } + + Arity GetArity() const { + return *reinterpret_cast(base_ + sizeof(Score)); + } + + Note GetNote() const { + return *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)); + } + void SetNote(Note to) { + *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)) = to; + } + + protected: + Header() : base_(NULL) {} + + Header(void *base, Arity arity) : base_(static_cast(base)) { + *reinterpret_cast(base_ + sizeof(Score)) = arity; + } + + static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note); + + uint8_t *After() { return base_ + kHeaderSize; } + const uint8_t *After() const { return base_ + kHeaderSize; } + + private: + uint8_t *base_; +}; + +} // namespace search + +#endif // SEARCH_HEADER__ diff --git a/klm/search/source.hh b/klm/search/source.hh deleted file mode 100644 index 11839f7b..00000000 --- a/klm/search/source.hh +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef SEARCH_SOURCE__ -#define SEARCH_SOURCE__ - -#include "search/types.hh" - -#include -#include - -namespace search { - -template class Source { - public: - Source() : bound_(kScoreInf) {} - - Index Size() const { - return final_.size(); - } - - Score Bound() const { - return bound_; - } - - const Final &operator[](Index index) const { - return *final_[index]; - } - - Score ScoreOrBound(Index index) const { - return Size() > index ? final_[index]->Total() : Bound(); - } - - protected: - void AddFinal(const Final &store) { - final_.push_back(&store); - } - - void SetBound(Score to) { - assert(to <= bound_ + 0.001); - bound_ = to; - } - - private: - std::vector final_; - - Score bound_; -}; - -} // namespace search -#endif // SEARCH_SOURCE__ diff --git a/klm/search/types.hh b/klm/search/types.hh index 9726379f..06eb5bfa 100644 --- a/klm/search/types.hh +++ b/klm/search/types.hh @@ -1,17 +1,13 @@ #ifndef SEARCH_TYPES__ #define SEARCH_TYPES__ -#include +#include namespace search { typedef float Score; -const Score kScoreInf = INFINITY; -// This could have been an enum but gcc wants 4 bytes. -typedef bool ExtendDirection; -const ExtendDirection kExtendLeft = 0; -const ExtendDirection kExtendRight = 1; +typedef uint32_t Arity; } // namespace search diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index cc53c0dd..11f4631f 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -21,9 +21,9 @@ struct GreaterByBound : public std::binary_functionBound(); + bound_ = end_.GetScore(); return; } if (extend_.size() == 1 && parent_ptr) { @@ -39,10 +39,4 @@ void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { bound_ = extend_.front()->Bound(); } -namespace { -VertexNode kBlankVertexNode; -} // namespace - -PartialVertex kBlankPartialVertex(kBlankVertexNode); - } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index e1a9ad11..52bc1dfe 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -18,7 +18,7 @@ class ContextBase; class VertexNode { public: - VertexNode() : end_(NULL) {} + VertexNode() {} void InitRoot() { extend_.clear(); @@ -26,8 +26,7 @@ class VertexNode { state_.left.length = 0; state_.right.length = 0; right_full_ = false; - bound_ = -kScoreInf; - end_ = NULL; + end_ = Final(); } lm::ngram::ChartState &MutableState() { return state_; } @@ -37,19 +36,20 @@ class VertexNode { extend_.push_back(next); } - void SetEnd(Final *end) { end_ = end; } + void SetEnd(Final end) { + assert(!end_.Valid()); + end_ = end; + } - Final &MutableEnd() { return *end_; } - void SortAndSet(ContextBase &context, VertexNode **parent_pointer); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_ && extend_.empty(); + return !end_.Valid() && extend_.empty(); } bool Complete() const { - return end_; + return end_.Valid(); } const lm::ngram::ChartState &State() const { return state_; } @@ -63,8 +63,8 @@ class VertexNode { return state_.left.length + state_.right.length; } - // May be NULL. - const Final *End() const { return end_; } + // Will be invalid unless this is a leaf. + const Final End() const { return end_; } const VertexNode &operator[](size_t index) const { return *extend_[index]; @@ -81,7 +81,7 @@ class VertexNode { bool right_full_; Score bound_; - Final *end_; + Final end_; }; class PartialVertex { @@ -97,7 +97,7 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); } + Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } unsigned char Length() const { return back_->Length(); } @@ -105,20 +105,24 @@ class PartialVertex { return index_ + 1 < back_->Size(); } - // Split into continuation and alternative, rendering this the alternative. - bool Split(PartialVertex &continuation) { + // Split into continuation and alternative, rendering this the continuation. + bool Split(PartialVertex &alternative) { assert(!Complete()); - continuation.back_ = &((*back_)[index_]); - continuation.index_ = 0; + bool ret; if (index_ + 1 < back_->Size()) { - ++index_; - return true; + alternative.index_ = index_ + 1; + alternative.back_ = back_; + ret = true; + } else { + ret = false; } - return false; + back_ = &((*back_)[index_]); + index_ = 0; + return ret; } - const Final &End() const { - return *back_->End(); + const Final End() const { + return back_->End(); } private: @@ -126,25 +130,22 @@ class PartialVertex { unsigned int index_; }; -extern PartialVertex kBlankPartialVertex; - class Vertex { public: Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } - const Final *BestChild() const { + const Final BestChild() const { PartialVertex top(RootPartial()); if (top.Empty()) { - return NULL; + return Final(); } else { PartialVertex continuation; while (!top.Complete()) { top.Split(continuation); - top = continuation; } - return &top.End(); + return top.End(); } } diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index d94e6e06..0945fe55 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -10,74 +10,85 @@ namespace search { VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { gen.root_.InitRoot(); - root_.under = &gen.root_; } namespace { + const uint64_t kCompleteAdd = static_cast(-1); -} // namespace -void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) { - const lm::ngram::ChartState &state = partial.CompletedState(); - std::pair got(existing_.insert(std::pair(hash_value(state), NULL))); - if (!got.second) { - // Found it already. - Final &exists = *got.first->second; - if (exists.Bound() < partial.score) { - exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); - } - return; +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map extend; +}; + +Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { + Trie &next = node.extend[added]; + if (!next.under) { + next.under = context.NewVertexNode(); + lm::ngram::ChartState &writing = next.under->MutableState(); + writing = state; + writing.left.full &= left_full && state.left.full; + next.under->MutableRightFull() = right_full && state.left.full; + writing.left.length = left; + writing.right.length = right; + node.under->AddExtend(next.under); } + return next; +} + +void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { + Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); + Final *child_out = final.Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = part->End(); + + starter.under->SetEnd(final); +} + +void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + unsigned char left = 0, right = 0; - Trie *node = &root_; + Trie *node = &root; while (true) { if (left == state.left.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false); for (; right < state.right.length; ++right) { - node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false); } break; } - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false); left++; if (right == state.right.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true); for (; left < state.left.length; ++left) { - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true); } break; } - node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false); right++; } - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - got.first->second = CompleteTransition(*node, state, note, partial); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); + CompleteTransition(context, *node, partial); } -VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { - VertexGenerator::Trie &next = node.extend[added]; - if (!next.under) { - next.under = context_.NewVertexNode(); - lm::ngram::ChartState &writing = next.under->MutableState(); - writing = state; - writing.left.full &= left_full && state.left.full; - next.under->MutableRightFull() = right_full && state.left.full; - writing.left.length = left; - writing.right.length = right; - node.under->AddExtend(next.under); - } - return next; -} +} // namespace -Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) { - VertexNode &node = *starter.under; - assert(node.State().left.full == state.left.full); - assert(!node.End()); - Final *final = context_.NewFinal(); - final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); - node.SetEnd(final); - return final; +void VertexGenerator::FinishedSearch() { + Trie root; + root.under = &gen_.root_; + for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, i->second); + } + root.under->SortAndSet(context_, NULL); } } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 6b98da3e..60e86112 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -1,13 +1,11 @@ #ifndef SEARCH_VERTEX_GENERATOR__ #define SEARCH_VERTEX_GENERATOR__ -#include "search/note.hh" +#include "search/edge.hh" #include "search/vertex.hh" #include -#include - namespace lm { namespace ngram { class ChartState; @@ -18,40 +16,29 @@ namespace search { class ContextBase; class Final; -struct PartialEdge; class VertexGenerator { public: VertexGenerator(ContextBase &context, Vertex &gen); - void NewHypothesis(const PartialEdge &partial, Note note); - - void FinishedSearch() { - root_.under->SortAndSet(context_, NULL); + void NewHypothesis(PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + std::pair ret(existing_.insert(std::make_pair(hash_value(state), partial))); + if (!ret.second && ret.first->second < partial) { + ret.first->second = partial; + } } + void FinishedSearch(); + const Vertex &Generating() const { return gen_; } private: - // Parallel structure to VertexNode. - struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map extend; - }; - - Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); - - Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial); - ContextBase &context_; Vertex &gen_; - Trie root_; - - typedef boost::unordered_map Existing; + typedef boost::unordered_map Existing; Existing existing_; }; diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 5ceccf2c..5306850f 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -26,6 +26,8 @@ libklm_util_a_SOURCES = \ file_piece.cc \ mmap.cc \ murmur_hash.cc \ + pool.cc \ + string_piece.cc \ usage.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index ff4d590f..9909736d 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -4,7 +4,7 @@ #include #include -#include +#include // Ersatz version of boost::progress so core language model doesn't depend on // boost. Also adds option to print nothing. diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 83f99cd6..053a850b 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -6,7 +6,7 @@ #include #include -#include +#include namespace util { diff --git a/klm/util/pool.cc b/klm/util/pool.cc new file mode 100644 index 00000000..2dffd06f --- /dev/null +++ b/klm/util/pool.cc @@ -0,0 +1,35 @@ +#include "util/pool.hh" + +#include + +namespace util { + +Pool::Pool() { + current_ = NULL; + current_end_ = NULL; +} + +Pool::~Pool() { + FreeAll(); +} + +void Pool::FreeAll() { + for (std::vector::const_iterator i(free_list_.begin()); i != free_list_.end(); ++i) { + free(*i); + } + free_list_.clear(); + current_ = NULL; + current_end_ = NULL; +} + +void *Pool::More(std::size_t size) { + std::size_t amount = std::max(static_cast(32) << free_list_.size(), size); + uint8_t *ret = static_cast(malloc(amount)); + if (!ret) throw std::bad_alloc(); + free_list_.push_back(ret); + current_ = ret + size; + current_end_ = ret + amount; + return ret; +} + +} // namespace util diff --git a/klm/util/pool.hh b/klm/util/pool.hh new file mode 100644 index 00000000..72f8a0c8 --- /dev/null +++ b/klm/util/pool.hh @@ -0,0 +1,45 @@ +// Very simple pool. It can only allocate memory. And all of the memory it +// allocates must be freed at the same time. + +#ifndef UTIL_POOL__ +#define UTIL_POOL__ + +#include + +#include + +namespace util { + +class Pool { + public: + Pool(); + + ~Pool(); + + void *Allocate(std::size_t size) { + void *ret = current_; + current_ += size; + if (current_ < current_end_) { + return ret; + } else { + return More(size); + } + } + + void FreeAll(); + + private: + void *More(std::size_t size); + + std::vector free_list_; + + uint8_t *current_, *current_end_; + + // no copying + Pool(const Pool &); + Pool &operator=(const Pool &); +}; + +} // namespace util + +#endif // UTIL_POOL__ diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 770faa7e..4a8aff35 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -8,7 +8,7 @@ #include #include -#include +#include namespace util { diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc new file mode 100644 index 00000000..b422cefc --- /dev/null +++ b/klm/util/string_piece.cc @@ -0,0 +1,192 @@ +// Copyright 2004 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in string_piece.hh. + +#include "util/string_piece.hh" + +#include + +#include + +#ifndef HAVE_ICU + +typedef StringPiece::size_type size_type; + +void StringPiece::CopyToString(std::string* target) const { + target->assign(ptr_, length_); +} + +size_type StringPiece::find(const StringPiece& s, size_type pos) const { + if (length_ < 0 || pos > static_cast(length_)) + return npos; + + const char* result = std::search(ptr_ + pos, ptr_ + length_, + s.ptr_, s.ptr_ + s.length_); + const size_type xpos = result - ptr_; + return xpos + s.length_ <= length_ ? xpos : npos; +} + +size_type StringPiece::find(char c, size_type pos) const { + if (length_ <= 0 || pos >= static_cast(length_)) { + return npos; + } + const char* result = std::find(ptr_ + pos, ptr_ + length_, c); + return result != ptr_ + length_ ? result - ptr_ : npos; +} + +size_type StringPiece::rfind(const StringPiece& s, size_type pos) const { + if (length_ < s.length_) return npos; + const size_t ulen = length_; + if (s.length_ == 0) return std::min(ulen, pos); + + const char* last = ptr_ + std::min(ulen - s.length_, pos) + s.length_; + const char* result = std::find_end(ptr_, last, s.ptr_, s.ptr_ + s.length_); + return result != last ? result - ptr_ : npos; +} + +size_type StringPiece::rfind(char c, size_type pos) const { + if (length_ <= 0) return npos; + for (int i = std::min(pos, static_cast(length_ - 1)); + i >= 0; --i) { + if (ptr_[i] == c) { + return i; + } + } + return npos; +} + +// For each character in characters_wanted, sets the index corresponding +// to the ASCII code of that character to 1 in table. This is used by +// the find_.*_of methods below to tell whether or not a character is in +// the lookup table in constant time. +// The argument `table' must be an array that is large enough to hold all +// the possible values of an unsigned char. Thus it should be be declared +// as follows: +// bool table[UCHAR_MAX + 1] +static inline void BuildLookupTable(const StringPiece& characters_wanted, + bool* table) { + const size_type length = characters_wanted.length(); + const char* const data = characters_wanted.data(); + for (size_type i = 0; i < length; ++i) { + table[static_cast(data[i])] = true; + } +} + +size_type StringPiece::find_first_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0 || s.length_ == 0) + return npos; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_first_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = pos; i < length_; ++i) { + if (lookup[static_cast(ptr_[i])]) { + return i; + } + } + return npos; +} + +size_type StringPiece::find_first_not_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0) + return npos; + + if (s.length_ == 0) + return 0; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_first_not_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = pos; i < length_; ++i) { + if (!lookup[static_cast(ptr_[i])]) { + return i; + } + } + return npos; +} + +size_type StringPiece::find_first_not_of(char c, size_type pos) const { + if (length_ == 0) + return npos; + + for (; pos < length_; ++pos) { + if (ptr_[pos] != c) { + return pos; + } + } + return npos; +} + +size_type StringPiece::find_last_of(const StringPiece& s, size_type pos) const { + if (length_ == 0 || s.length_ == 0) + return npos; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_last_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = std::min(pos, length_ - 1); ; --i) { + if (lookup[static_cast(ptr_[i])]) + return i; + if (i == 0) + break; + } + return npos; +} + +size_type StringPiece::find_last_not_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0) + return npos; + + size_type i = std::min(pos, length_ - 1); + if (s.length_ == 0) + return i; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_last_not_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (; ; --i) { + if (!lookup[static_cast(ptr_[i])]) + return i; + if (i == 0) + break; + } + return npos; +} + +size_type StringPiece::find_last_not_of(char c, size_type pos) const { + if (length_ == 0) + return npos; + + for (size_type i = std::min(pos, length_ - 1); ; --i) { + if (ptr_[i] != c) + return i; + if (i == 0) + break; + } + return npos; +} + +StringPiece StringPiece::substr(size_type pos, size_type n) const { + if (pos > length_) pos = length_; + if (n > length_ - pos) n = length_ - pos; + return StringPiece(ptr_ + pos, n); +} + +const size_type StringPiece::npos = size_type(-1); + +#endif // !HAVE_ICU diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index c7e1c863..4a7f5460 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -54,6 +54,18 @@ class AnyCharacter { StringPiece chars_; }; +class AnyCharacterLast { + public: + explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {} + + StringPiece Find(const StringPiece &in) const { + return StringPiece(std::find_end(in.data(), in.data() + in.size(), chars_.data(), chars_.data() + chars_.size()), 1); + } + + private: + StringPiece chars_; +}; + template class TokenIter : public boost::iterator_facade, const StringPiece, boost::forward_traversal_tag> { public: TokenIter() {} diff --git a/mira/Makefile.am b/mira/Makefile.am index 7b4a4e12..3f8f17cd 100644 --- a/mira/Makefile.am +++ b/mira/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = kbest_mira kbest_mira_SOURCES = kbest_mira.cc -kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/Makefile.am b/training/Makefile.am index 5254333a..f9c25391 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -32,60 +32,60 @@ libtraining_a_SOURCES = \ risk.cc mpi_online_optimize_SOURCES = mpi_online_optimize.cc -mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_flex_optimize_SOURCES = mpi_flex_optimize.cc -mpi_flex_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_flex_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc -mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_extract_features_SOURCES = mpi_extract_features.cc -mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc cllh_observer.cc -mpi_batch_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_batch_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc cllh_observer.cc -mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz augment_grammar_SOURCES = augment_grammar.cc -augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz test_ngram_SOURCES = test_ngram.cc -test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz fast_align_SOURCES = fast_align.cc ttables.cc -fast_align_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +fast_align_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz lbl_model_SOURCES = lbl_model.cc -lbl_model_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +lbl_model_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz grammar_convert_SOURCES = grammar_convert.cc -grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz optimize_test_SOURCES = optimize_test.cc -optimize_test_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +optimize_test_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz collapse_weights_SOURCES = collapse_weights.cc -collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz lbfgs_test_SOURCES = lbfgs_test.cc -lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_optimize_reduce_SOURCES = mr_optimize_reduce.cc -mr_optimize_reduce_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_optimize_reduce_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_em_map_adapter_SOURCES = mr_em_map_adapter.cc -mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_reduce_to_weights_SOURCES = mr_reduce_to_weights.cc -mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_em_adapted_reduce_SOURCES = mr_em_adapted_reduce.cc -mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz plftools_SOURCES = plftools.cc -plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm -- cgit v1.2.3 From 20ac14856accd1532fe7cbeb3ad4cf26dbeb80b1 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 22 Oct 2012 11:35:31 -0400 Subject: Remove global variable, have decoder hold a pointer --- decoder/decoder.cc | 9 ++++++- decoder/incremental.cc | 69 +++++++++++++++++++------------------------------- decoder/incremental.h | 14 +++++++++- 3 files changed, 47 insertions(+), 45 deletions(-) (limited to 'decoder/decoder.cc') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index fe812011..b5f4b9b6 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include "program_options.h" #include "stringlib.h" @@ -325,6 +326,8 @@ struct DecoderImpl { bool feature_expectations; // TODO Observer bool output_training_vector; // TODO Observer bool remove_intersected_rule_annotations; + boost::scoped_ptr incremental; + static void ConvertSV(const SparseVector& src, SparseVector* trg) { for (SparseVector::const_iterator it = src.begin(); it != src.end(); ++it) @@ -727,6 +730,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream sent_id = -1; acc_obj = 0; // accumulate objective g_count = 0; // number of gradient pieces computed + + if (conf.count("incremental_search")) { + incremental.reset(IncrementalBase::Load(conf["incremental_search"].as().c_str(), CurrentWeightVector())); + } } Decoder::Decoder(istream* cfg) { pimpl_.reset(new DecoderImpl(conf,0,0,cfg)); } @@ -829,7 +836,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); if (conf.count("incremental_search")) { - PassToIncremental(conf["incremental_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); + incremental->Search(pop_limit, forest); o->NotifyDecodingComplete(smeta); return true; } diff --git a/decoder/incremental.cc b/decoder/incremental.cc index a9369374..46615b0b 100644 --- a/decoder/incremental.cc +++ b/decoder/incremental.cc @@ -43,21 +43,22 @@ struct MapVocab : public lm::EnumerateVocab { std::vector out_; }; -class IncrementalBase { +template class Incremental : public IncrementalBase { public: - IncrementalBase(const std::vector &weights) : - cdec_weights_(weights), - weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { + Incremental(const char *model_file, const std::vector &weights) : + IncrementalBase(weights), + m_(model_file, GetConfig()), + weights_( + weights[FD::Convert("KLanguageModel")], + weights[FD::Convert("KLanguageModel_OOV")], + weights[FD::Convert("WordPenalty")]) { std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; } + void Search(unsigned int pop_limit, const Hypergraph &hg) const; - virtual ~IncrementalBase() {} - - virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; - - static IncrementalBase *Load(const char *model_file, const std::vector &weights); + private: + void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; - protected: lm::ngram::Config GetConfig() { lm::ngram::Config ret; ret.enumerate_vocab = &vocab_; @@ -66,36 +67,11 @@ class IncrementalBase { MapVocab vocab_; - const std::vector &cdec_weights_; + const Model m_; const search::Weights weights_; }; -template class Incremental : public IncrementalBase { - public: - Incremental(const char *model_file, const std::vector &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {} - - void Search(unsigned int pop_limit, const Hypergraph &hg) const; - - private: - void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; - - const Model m_; -}; - -IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector &weights) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - return new Incremental(model_file, weights); - case lm::ngram::REST_PROBING: - return new Incremental(model_file, weights); - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - void PrintFinal(const Hypergraph &hg, const search::Final final) { const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); const search::Final *child(final.Children()); @@ -171,14 +147,21 @@ template void Incremental::ConvertEdge(const search::Contex gen.AddEdge(out); } -boost::scoped_ptr AwfulGlobalIncremental; - } // namespace -void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { - if (!AwfulGlobalIncremental.get()) { - std::cerr << "Pop limit " << pop_limit << std::endl; - AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights)); +IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector &weights) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + return new Incremental(model_file, weights); + case lm::ngram::REST_PROBING: + return new Incremental(model_file, weights); + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); } - AwfulGlobalIncremental->Search(pop_limit, hg); } + +IncrementalBase::~IncrementalBase() {} + +IncrementalBase::IncrementalBase(const std::vector &weights) : cdec_weights_(weights) {} diff --git a/decoder/incremental.h b/decoder/incremental.h index 180383ce..f791a626 100644 --- a/decoder/incremental.h +++ b/decoder/incremental.h @@ -6,6 +6,18 @@ class Hypergraph; -void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); +class IncrementalBase { + public: + static IncrementalBase *Load(const char *model_file, const std::vector &weights); + + virtual ~IncrementalBase(); + + virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; + + protected: + IncrementalBase(const std::vector &weights); + + const std::vector &cdec_weights_; +}; #endif // _INCREMENTAL_H_ -- cgit v1.2.3