summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/Makefile.am1
-rw-r--r--decoder/cdec_ff.cc2
-rw-r--r--decoder/decoder.cc2
-rw-r--r--decoder/ff_spans.cc50
-rw-r--r--decoder/ff_spans.h26
5 files changed, 80 insertions, 1 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index f0c5f73e..9cf4c3c4 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -63,6 +63,7 @@ libcdec_a_SOURCES = \
ff_charset.cc \
ff_lm.cc \
ff_klm.cc \
+ ff_spans.cc \
ff_ruleshape.cc \
ff_wordalign.cc \
ff_csplit.cc \
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index c87396a6..7bcee6b8 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -1,6 +1,7 @@
#include <boost/shared_ptr.hpp>
#include "ff.h"
+#include "ff_spans.h"
#include "ff_lm.h"
#include "ff_klm.h"
#include "ff_csplit.h"
@@ -49,6 +50,7 @@ void register_feature_functions() {
#ifdef HAVE_RANDLM
ff_registry.Register("RandLM", new FFFactory<LanguageModelRandLM>);
#endif
+ ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>());
ff_registry.Register("KLanguageModel", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >());
ff_registry.Register("KLanguageModel_Sorted", new FFFactory<KLanguageModel<lm::ngram::SortedModel> >());
ff_registry.Register("KLanguageModel_Trie", new FFFactory<KLanguageModel<lm::ngram::TrieModel> >());
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index e28080aa..f37e8a37 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -838,12 +838,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
HypergraphIO::WriteAsCFG(forest);
if (has_ref) {
if (HG::Intersect(ref, &forest)) {
- if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
// if (crf_uniform_empirical) {
// if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n";
// for (int i = 0; i < forest.edges_.size(); ++i)
// forest.edges_[i].edge_prob_=prob_t::One(); }
forest.Reweight(feature_weights);
+ if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
if (conf.count("show_partition")) {
const prob_t z = Inside<prob_t, EdgeProb>(forest);
diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc
new file mode 100644
index 00000000..6fa49d45
--- /dev/null
+++ b/decoder/ff_spans.cc
@@ -0,0 +1,50 @@
+#include "ff_spans.h"
+
+#include <sstream>
+#include <cassert>
+
+#include "sentence_metadata.h"
+#include "lattice.h"
+#include "fdict.h"
+
+using namespace std;
+
+SpanFeatures::SpanFeatures(const string& param) :
+ kS(TD::Convert("S") * -1),
+ kX(TD::Convert("X") * -1) {}
+
+void SpanFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+// char& res = *static_cast<char*>(context);
+// res = edge.j_ - edge.i_;
+// assert(res >= 0);
+ assert(edge.j_ < end_span_ids_.size());
+ assert(edge.j_ >= 0);
+ features->set_value(end_span_ids_[edge.j_], 1);
+
+ if (edge.Arity() == 2) {
+ const TRule& rule = *edge.rule_;
+ if (rule.f_[0] == kS && rule.f_[1] == kX) {
+// char x_width = *static_cast<const char*>(ant_contexts[1]);
+ }
+ }
+}
+
+void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+ const Lattice& lattice = smeta.GetSourceLattice();
+ WordID eos = TD::Convert("</s>");
+ end_span_ids_.resize(lattice.size() + 1);
+ for (int i = 0; i <= lattice.size(); ++i) {
+ WordID word = eos;
+ if (i < lattice.size())
+ word = lattice[i][0].label; // rather arbitrary for lattices
+ ostringstream sfid;
+ sfid << "ES:" << TD::Convert(word);
+ end_span_ids_[i] = FD::Convert(sfid.str());
+ }
+}
+
diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h
new file mode 100644
index 00000000..588956c9
--- /dev/null
+++ b/decoder/ff_spans.h
@@ -0,0 +1,26 @@
+#ifndef _FF_SPANS_H_
+#define _FF_SPANS_H_
+
+#include <vector>
+#include "ff.h"
+#include "array2d.h"
+
+class SpanFeatures : public FeatureFunction {
+ public:
+ SpanFeatures(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ private:
+ const int kS;
+ const int kX;
+ Array2D<int> span_feats_;
+ std::vector<int> end_span_ids_;
+};
+
+#endif