summaryrefslogtreecommitdiff
path: root/decoder/viterbi.h
diff options
context:
space:
mode:
authorgraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-08 16:44:20 +0000
committergraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-08 16:44:20 +0000
commit337b647baac15609a0a493902d58c473d25d2ed8 (patch)
tree3827f59a180164878ee217f27f445c259c9e5bab /decoder/viterbi.h
parent7386574a2c70c7ed6e937eb94e3add8023cd7327 (diff)
--show_features
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@184 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder/viterbi.h')
-rw-r--r--decoder/viterbi.h102
1 files changed, 82 insertions, 20 deletions
diff --git a/decoder/viterbi.h b/decoder/viterbi.h
index 6b8bbed1..7e1e2c0e 100644
--- a/decoder/viterbi.h
+++ b/decoder/viterbi.h
@@ -8,13 +8,21 @@
std::string viterbi_stats(Hypergraph const& hg, std::string const& name="forest", bool estring=true, bool etree=false);
-// V must implement:
-// void operator()(const vector<const T*>& ants, T* result);
-template<typename T, typename Traversal, typename WeightType, typename WeightFunction>
-WeightType Viterbi(const Hypergraph& hg,
- T* result,
- const Traversal& traverse = Traversal(),
- const WeightFunction& weight = WeightFunction()) {
+/// computes for each hg node the best (according to WeightType/WeightFunction) derivation, and some homomorphism (bottom up expression tree applied through Traversal) of it. T is the "return type" of Traversal, which is called only once for the best edge for a node's result (i.e. result will start default constructed)
+//TODO: make T a typename inside Traversal and WeightType a typename inside WeightFunction?
+// Traversal must implement:
+// typedef T Result;
+// void operator()(Hypergraph::Edge const& e,const vector<const Result*>& ants, Result* result) const;
+// WeightFunction must implement:
+// typedef prob_t Weight;
+// Weight operator()(Hypergraph::Edge const& e) const;
+template<typename Traversal,typename WeightFunction>
+typename WeightFunction::Weight Viterbi(const Hypergraph& hg,
+ typename Traversal::Result* result,
+ const Traversal& traverse,
+ const WeightFunction& weight) {
+ typedef typename Traversal::Result T;
+ typedef typename WeightFunction::Weight WeightType;
const int num_nodes = hg.nodes_.size();
std::vector<T> vit_result(num_nodes);
std::vector<WeightType> vit_weight(num_nodes, WeightType::Zero());
@@ -51,7 +59,41 @@ WeightType Viterbi(const Hypergraph& hg,
return vit_weight.back();
}
+
+/*
+template<typename Traversal,typename WeightFunction>
+typename WeightFunction::Weight Viterbi(const Hypergraph& hg,
+ typename Traversal::Result* result)
+{
+ Traversal traverse;
+ WeightFunction weight;
+ return Viterbi(hg,result,traverse,weight);
+}
+
+template<class Traversal,class WeightFunction=EdgeProb>
+typename WeightFunction::Weight Viterbi(const Hypergraph& hg,
+ typename Traversal::Result* result,
+ Traversal const& traverse=Traversal()
+ )
+{
+ WeightFunction weight;
+ return Viterbi(hg,result,traverse,weight);
+}
+*/
+
+//spec for EdgeProb
+template<class Traversal>
+prob_t Viterbi(const Hypergraph& hg,
+ typename Traversal::Result* result,
+ Traversal const& traverse=Traversal()
+ )
+{
+ EdgeProb weight;
+ return Viterbi(hg,result,traverse,weight);
+}
+
struct PathLengthTraversal {
+ typedef int Result;
void operator()(const Hypergraph::Edge& edge,
const std::vector<const int*>& ants,
int* result) const {
@@ -62,14 +104,16 @@ struct PathLengthTraversal {
};
struct ESentenceTraversal {
+ typedef std::vector<WordID> Result;
void operator()(const Hypergraph::Edge& edge,
- const std::vector<const std::vector<WordID>*>& ants,
- std::vector<WordID>* result) const {
+ const std::vector<const Result*>& ants,
+ Result* result) const {
edge.rule_->ESubstitute(ants, result);
}
};
struct ELengthTraversal {
+ typedef int Result;
void operator()(const Hypergraph::Edge& edge,
const std::vector<const int*>& ants,
int* result) const {
@@ -79,9 +123,10 @@ struct ELengthTraversal {
};
struct FSentenceTraversal {
+ typedef std::vector<WordID> Result;
void operator()(const Hypergraph::Edge& edge,
- const std::vector<const std::vector<WordID>*>& ants,
- std::vector<WordID>* result) const {
+ const std::vector<const Result*>& ants,
+ Result* result) const {
edge.rule_->FSubstitute(ants, result);
}
};
@@ -92,10 +137,11 @@ struct ETreeTraversal {
const std::string left;
const std::string space;
const std::string right;
+ typedef std::vector<WordID> Result;
void operator()(const Hypergraph::Edge& edge,
- const std::vector<const std::vector<WordID>*>& ants,
- std::vector<WordID>* result) const {
- std::vector<WordID> tmp;
+ const std::vector<const Result*>& ants,
+ Result* result) const {
+ Result tmp;
edge.rule_->ESubstitute(ants, &tmp);
const std::string cat = TD::Convert(edge.rule_->GetLHS() * -1);
if (cat == "Goal")
@@ -111,10 +157,11 @@ struct FTreeTraversal {
const std::string left;
const std::string space;
const std::string right;
+ typedef std::vector<WordID> Result;
void operator()(const Hypergraph::Edge& edge,
- const std::vector<const std::vector<WordID>*>& ants,
- std::vector<WordID>* result) const {
- std::vector<WordID> tmp;
+ const std::vector<const Result*>& ants,
+ Result* result) const {
+ Result tmp;
edge.rule_->FSubstitute(ants, &tmp);
const std::string cat = TD::Convert(edge.rule_->GetLHS() * -1);
if (cat == "Goal")
@@ -126,10 +173,10 @@ struct FTreeTraversal {
};
struct ViterbiPathTraversal {
+ typedef std::vector<Hypergraph::Edge const*> Result;
void operator()(const Hypergraph::Edge& edge,
- const std::vector<const std::vector<const Hypergraph::Edge*>* >& ants,
- std::vector<const Hypergraph::Edge*>* result) const {
- result->clear();
+ std::vector<Result const*> const& ants,
+ Result* result) const {
for (int i = 0; i < ants.size(); ++i)
for (int j = 0; j < ants[i]->size(); ++j)
result->push_back((*ants[i])[j]);
@@ -137,6 +184,18 @@ struct ViterbiPathTraversal {
}
};
+struct FeatureVectorTraversal {
+ typedef FeatureVector Result;
+ void operator()(Hypergraph::Edge const& edge,
+ std::vector<Result const*> const& ants,
+ Result* result) const {
+ for (int i = 0; i < ants.size(); ++i)
+ *result+=*ants[i];
+ *result+=edge.feature_values_;
+ }
+};
+
+
std::string JoshuaVisualizationString(const Hypergraph& hg);
prob_t ViterbiESentence(const Hypergraph& hg, std::vector<WordID>* result);
std::string ViterbiETree(const Hypergraph& hg);
@@ -145,4 +204,7 @@ std::string ViterbiFTree(const Hypergraph& hg);
int ViterbiELength(const Hypergraph& hg);
int ViterbiPathLength(const Hypergraph& hg);
+/// if weights supplied, assert viterbi prob = features.dot(*weights). return features (sum over all edges in viterbi derivation)
+FeatureVector ViterbiFeatures(Hypergraph const& hg,FeatureWeights const* weights=0);
+
#endif