From fe91371a77ec43cc08d284ac49f00af8baa1a298 Mon Sep 17 00:00:00 2001 From: graehl Date: Sun, 25 Jul 2010 19:32:34 +0000 Subject: fixed CreateViterbiHypergraph (old impl did not work), so --show_derivation works git-svn-id: https://ws10smt.googlecode.com/svn/trunk@408 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/ff.cc | 17 ++++++++++-- decoder/ff.h | 20 ++++++++++---- decoder/ff_from_fsa.h | 6 ++--- decoder/hg.cc | 70 +++++++++++++++++++++++++------------------------ decoder/hg.h | 16 ++++++----- decoder/indices_after.h | 19 +++++++++++++- decoder/small_vector.h | 5 +++- decoder/stringlib.h | 5 ++++ decoder/viterbi.cc | 1 - 9 files changed, 105 insertions(+), 54 deletions(-) diff --git a/decoder/ff.cc b/decoder/ff.cc index 9fc2dbd8..d21bf3fe 100644 --- a/decoder/ff.cc +++ b/decoder/ff.cc @@ -3,6 +3,7 @@ //TODO: actually score rule_feature()==true features once only, hash keyed on rule or modify TRule directly? need to keep clear in forest which features come from models vs. rules; then rescoring could drop all the old models features at once #include +#include #include "ff.h" #include "tdict.h" @@ -97,6 +98,17 @@ WordPenalty::WordPenalty(const string& param) : } } +void FeatureFunction::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + throw std::runtime_error("TraversalFeaturesImpl not implemented - override it or TraversalFeaturesLog.\n"); + abort(); +} + + void WordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, const std::vector& ant_states, @@ -189,8 +201,9 @@ void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, Hypergraph::Edge* edge, string* context, prob_t* combination_cost_estimate) const { + edge->reset_info(); context->resize(state_size_); - memset(&(*context)[0], 0, state_size_); //FIXME: only context.data() is required to be contiguous, and it become sinvalid after next string operation. use SmallVector? ValueArray? (higher performance perhaps, fixed size) + memset(&(*context)[0], 0, state_size_); //FIXME: only context.data() is required to be contiguous, and it becomes invalid after next string operation. use SmallVector? ValueArray? (higher performance perhaps, fixed size) SparseVector est_vals; // only computed if combination_cost_estimate is non-NULL if (combination_cost_estimate) *combination_cost_estimate = prob_t::One(); for (int i = 0; i < models_.size(); ++i) { @@ -214,7 +227,7 @@ void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, void ModelSet::AddFinalFeatures(const std::string& state, Hypergraph::Edge* edge,SentenceMetadata const& smeta) const { assert(1 == edge->rule_->Arity()); - + edge->reset_info(); for (int i = 0; i < models_.size(); ++i) { const FeatureFunction& ff = *models_[i]; const void* ant_state = NULL; diff --git a/decoder/ff.h b/decoder/ff.h index b8ca71c4..5c1f214f 100644 --- a/decoder/ff.h +++ b/decoder/ff.h @@ -43,12 +43,12 @@ public: // Compute the feature values and (if this applies) the estimates of the // feature values when this edge is used incorporated into a larger context inline void TraversalFeatures(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + Hypergraph::Edge& edge, const std::vector& ant_contexts, FeatureVector* features, FeatureVector* estimated_features, void* out_state) const { - TraversalFeaturesImpl(smeta, edge, ant_contexts, + TraversalFeaturesLog(smeta, edge, ant_contexts, features, estimated_features, out_state); // TODO it's easy for careless feature function developers to overwrite // the end of their state and clobber someone else's memory. These bugs @@ -66,7 +66,7 @@ protected: public: //override either this or one of above. virtual void FinalTraversalFeatures(const SentenceMetadata& /* smeta */, - const Hypergraph::Edge& /* edge */, + Hypergraph::Edge& /* edge */, // so you can log() const void* residual_state, FeatureVector* final_features) const { FinalTraversalFeatures(residual_state,final_features); @@ -81,12 +81,22 @@ public: // of the particular FeatureFunction class. There is one exception: // equality of the contents (i.e., memcmp) is required to determine whether // two states can be combined. + virtual void TraversalFeaturesLog(const SentenceMetadata& smeta, + Hypergraph::Edge& edge, // this is writable only so you can use log() + const std::vector& ant_contexts, + FeatureVector* features, + FeatureVector* estimated_features, + void* context) const { + TraversalFeaturesImpl(smeta,edge,ant_contexts,features,estimated_features,context); + } + + // override above or below. virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + Hypergraph::Edge const& edge, const std::vector& ant_contexts, FeatureVector* features, FeatureVector* estimated_features, - void* context) const = 0; + void* context) const; // !!! ONLY call this from subclass *CONSTRUCTORS* !!! void SetStateSize(size_t state_size) { diff --git a/decoder/ff_from_fsa.h b/decoder/ff_from_fsa.h index f9f707d7..adb704de 100755 --- a/decoder/ff_from_fsa.h +++ b/decoder/ff_from_fsa.h @@ -43,8 +43,8 @@ public: //TODO: add source span to Fsa FF interface, pass along //TODO: read/debug VERY CAREFULLY - void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + void TraversalFeaturesLog(const SentenceMetadata& smeta, + Hypergraph::Edge& edge, const std::vector& ant_contexts, FeatureVector* features, FeatureVector* estimated_features, @@ -144,7 +144,7 @@ public: //FIXME: it's assumed that the final rule is just a unary no-target-terminal rewrite (same as ff_lm) virtual void FinalTraversalFeatures(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, + Hypergraph::Edge& edge, const void* residual_state, FeatureVector* final_features) const { diff --git a/decoder/hg.cc b/decoder/hg.cc index 8292639b..88e95337 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -626,11 +626,10 @@ struct EdgeWeightSorter { std::string Hypergraph::show_viterbi_tree(bool indent,int show_mask,int maxdepth,int depth) const { HypergraphP v=CreateViterbiHypergraph(); - //FIXME: remove dbg print, fix. - cerr<NumberOfEdges()) { Edge const* beste=&v->edges_.back(); //FIXME: this doesn't work. check CreateViterbiHypergraph ? - return beste->derivation_tree(*this,beste,indent,show_mask,maxdepth,depth); + return beste->derivation_tree(*v,beste,indent,show_mask,maxdepth,depth); } return std::string(); } @@ -640,6 +639,30 @@ HypergraphP Hypergraph::CreateEdgeSubset(EdgeMask &keep_edges) const { return CreateEdgeSubset(keep_edges,kn); } +HypergraphP Hypergraph::CreateEdgeSubset(EdgeMask &keep_edges,NodeMask &kn) const { + kn.clear(); + kn.resize(nodes_.size()); + for (int n=0;n* edges) const } else { Viterbi(*this, &vit_edges, ViterbiPathTraversal() ,EdgeProb()); } -#if 0 +#if 1 # if 1 - check_ids(); + check_ids(); # else - set_ids(); + set_ids(); # endif - EdgeMask used(edges_.size()); - for (int i = 0; i < vit_edges.size(); ++i) - used[vit_edges[i]->id_]=true; - return CreateEdgeSubset(used); + EdgeMask used(edges_.size()); + for (int i = 0; i < vit_edges.size(); ++i) + used[vit_edges[i]->id_]=true; + return CreateEdgeSubset(used); #else map old2new_node; int num_new_nodes = 0; @@ -762,7 +764,7 @@ HypergraphP Hypergraph::CreateViterbiHypergraph(const vector* edges) const out->nodes_[new_tail_node].out_edges_.push_back(i); } } -#endif return ret; +#endif } diff --git a/decoder/hg.h b/decoder/hg.h index 90ae0935..10a24910 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -6,8 +6,8 @@ #define USE_INFO_EDGE 1 #if USE_INFO_EDGE # include -# define INFO_EDGE(e,msg) do { std::ostringstream &o=const_cast(e.info_);o<(e.info_);if (o.empty()) o<<' ';o< // STATIC_CONSTANT #include //swap +#include // iterator wrapper. inverts boolean value. template @@ -47,7 +48,8 @@ unsigned new_indices(KEEP keep,O out) { return new_indices(keep.begin(),keep.end(),out); } -// given a vector and a parallel sequence of bools where true means keep, keep only the marked elements while maintaining order +// given a vector and a parallel sequence of bools where true means keep, keep only the marked elements while maintaining order. +// this is done with a parallel sequence to the input, marked with positions the kept items would map into in a destination array, with removed items marked with the index -1. the reverse would be more compact (parallel to destination array, index of input item that goes into it) but would require the input sequence be random access. struct indices_after { BOOST_STATIC_CONSTANT(unsigned,REMOVED=(unsigned)-1); @@ -142,6 +144,21 @@ struct indices_after to[map[i]]=v[i]; } + //transform collection of indices into what we're remapping. (input/output iterators) + template + void reindex(IndexI i,IndexI const end,IndexO o) const { + for(;i + void reindex_push_back(VecI const& i,VecO &o) const { + reindex(i.begin(),i.end(),std::back_inserter(o)); + } + private: indices_after(indices_after const& o) { diff --git a/decoder/small_vector.h b/decoder/small_vector.h index 7ed99e77..25c52359 100644 --- a/decoder/small_vector.h +++ b/decoder/small_vector.h @@ -25,12 +25,15 @@ class SmallVector { typedef T const* const_iterator; typedef T* iterator; + typedef T value_type; + typedef T &reference; + typedef T const& const_reference; + T *begin() { return size_>SV_MAX?data_.ptr:data_.vals; } T const* begin() const { return const_cast(this)->begin(); } T *end() { return begin()+size_; } T const* end() const { return begin()+size_; } - explicit SmallVector(size_t s) : size_(s) { assert(s < 0xA000); if (s <= SV_MAX) { diff --git a/decoder/stringlib.h b/decoder/stringlib.h index 9efe3f36..b3097bd1 100644 --- a/decoder/stringlib.h +++ b/decoder/stringlib.h @@ -1,6 +1,10 @@ #ifndef CDEC_STRINGLIB_H_ #define CDEC_STRINGLIB_H_ +//usage: string s=MAKESTRE(1<<" "<(ostringstream()< #define SLIBDBG(x) do { std::cerr<<"DBG(stringlib): "< #include #include +#include template inline bool match_begin(Istr bstr,Istr estr,Isubstr bsub,Isubstr esub) diff --git a/decoder/viterbi.cc b/decoder/viterbi.cc index 7214c600..46b6a884 100644 --- a/decoder/viterbi.cc +++ b/decoder/viterbi.cc @@ -19,7 +19,6 @@ std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool es if (etree) { o<