diff options
-rw-r--r-- | decoder/Makefile.am | 1 | ||||
-rw-r--r-- | decoder/cdec_ff.cc | 2 | ||||
-rw-r--r-- | decoder/decoder.cc | 2 | ||||
-rw-r--r-- | decoder/ff_spans.cc | 50 | ||||
-rw-r--r-- | decoder/ff_spans.h | 26 |
5 files changed, 80 insertions, 1 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index f0c5f73e..9cf4c3c4 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -63,6 +63,7 @@ libcdec_a_SOURCES = \ ff_charset.cc \ ff_lm.cc \ ff_klm.cc \ + ff_spans.cc \ ff_ruleshape.cc \ ff_wordalign.cc \ ff_csplit.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index c87396a6..7bcee6b8 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -1,6 +1,7 @@ #include <boost/shared_ptr.hpp> #include "ff.h" +#include "ff_spans.h" #include "ff_lm.h" #include "ff_klm.h" #include "ff_csplit.h" @@ -49,6 +50,7 @@ void register_feature_functions() { #ifdef HAVE_RANDLM ff_registry.Register("RandLM", new FFFactory<LanguageModelRandLM>); #endif + ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>()); ff_registry.Register("KLanguageModel", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >()); ff_registry.Register("KLanguageModel_Sorted", new FFFactory<KLanguageModel<lm::ngram::SortedModel> >()); ff_registry.Register("KLanguageModel_Trie", new FFFactory<KLanguageModel<lm::ngram::TrieModel> >()); diff --git a/decoder/decoder.cc b/decoder/decoder.cc index e28080aa..f37e8a37 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -838,12 +838,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { HypergraphIO::WriteAsCFG(forest); if (has_ref) { if (HG::Intersect(ref, &forest)) { - if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); // if (crf_uniform_empirical) { // if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n"; // for (int i = 0; i < forest.edges_.size(); ++i) // forest.edges_[i].edge_prob_=prob_t::One(); } forest.Reweight(feature_weights); + if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; if (conf.count("show_partition")) { const prob_t z = Inside<prob_t, EdgeProb>(forest); diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc new file mode 100644 index 00000000..6fa49d45 --- /dev/null +++ b/decoder/ff_spans.cc @@ -0,0 +1,50 @@ +#include "ff_spans.h" + +#include <sstream> +#include <cassert> + +#include "sentence_metadata.h" +#include "lattice.h" +#include "fdict.h" + +using namespace std; + +SpanFeatures::SpanFeatures(const string& param) : + kS(TD::Convert("S") * -1), + kX(TD::Convert("X") * -1) {} + +void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const { +// char& res = *static_cast<char*>(context); +// res = edge.j_ - edge.i_; +// assert(res >= 0); + assert(edge.j_ < end_span_ids_.size()); + assert(edge.j_ >= 0); + features->set_value(end_span_ids_[edge.j_], 1); + + if (edge.Arity() == 2) { + const TRule& rule = *edge.rule_; + if (rule.f_[0] == kS && rule.f_[1] == kX) { +// char x_width = *static_cast<const char*>(ant_contexts[1]); + } + } +} + +void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { + const Lattice& lattice = smeta.GetSourceLattice(); + WordID eos = TD::Convert("</s>"); + end_span_ids_.resize(lattice.size() + 1); + for (int i = 0; i <= lattice.size(); ++i) { + WordID word = eos; + if (i < lattice.size()) + word = lattice[i][0].label; // rather arbitrary for lattices + ostringstream sfid; + sfid << "ES:" << TD::Convert(word); + end_span_ids_[i] = FD::Convert(sfid.str()); + } +} + diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h new file mode 100644 index 00000000..588956c9 --- /dev/null +++ b/decoder/ff_spans.h @@ -0,0 +1,26 @@ +#ifndef _FF_SPANS_H_ +#define _FF_SPANS_H_ + +#include <vector> +#include "ff.h" +#include "array2d.h" + +class SpanFeatures : public FeatureFunction { + public: + SpanFeatures(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + const int kS; + const int kX; + Array2D<int> span_feats_; + std::vector<int> end_span_ids_; +}; + +#endif |