summaryrefslogtreecommitdiff
path: root/training/dtrain
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain')
-rw-r--r--training/dtrain/dtrain_net_interface.cc7
-rw-r--r--training/dtrain/dtrain_net_interface.h1
-rw-r--r--training/dtrain/sample.h3
3 files changed, 10 insertions, 1 deletions
diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc
index 38fad160..e21920d0 100644
--- a/training/dtrain/dtrain_net_interface.cc
+++ b/training/dtrain/dtrain_net_interface.cc
@@ -27,6 +27,7 @@ main(int argc, char** argv)
vector<string> dense_features;
boost::split(dense_features, conf["dense_features"].as<string>(),
boost::is_any_of(" "));
+ const bool output_derivation = conf["output_derivation"].as<bool>();
// setup decoder
register_feature_functions();
@@ -125,7 +126,11 @@ main(int argc, char** argv)
vector<ScoredHyp>* samples = observer->GetSamples();
ostringstream os;
cerr << "[dtrain] 1best features " << (*samples)[0].f << endl;
- PrintWordIDVec((*samples)[0].w, os);
+ if (output_derivation) {
+ os << observer->GetViterbiTreeString() << endl;
+ } else {
+ PrintWordIDVec((*samples)[0].w, os);
+ }
sock.send(os.str().c_str(), os.str().size()+1, 0);
cerr << "[dtrain] done translating, looping again" << endl;
continue;
diff --git a/training/dtrain/dtrain_net_interface.h b/training/dtrain/dtrain_net_interface.h
index eb0aa668..3c7665a2 100644
--- a/training/dtrain/dtrain_net_interface.h
+++ b/training/dtrain/dtrain_net_interface.h
@@ -64,6 +64,7 @@ dtrain_net_init(int argc, char** argv, po::variables_map* conf)
("input_weights,w", po::value<string>(), "input weights file")
("learning_rate,l", po::value<weight_t>()->default_value(1.0), "learning rate")
("learning_rate_sparse,l", po::value<weight_t>()->default_value(1.0), "learning rate for sparse features")
+ ("output_derivation,E", po::bool_switch()->default_value(false), "output derivation, not viterbi str")
("dense_features,D", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV Shape_S01111_T11011 Shape_S11110_T11011 Shape_S11100_T11000 Shape_S01110_T01110 Shape_S01111_T01111 Shape_S01100_T11000 Shape_S10000_T10000 Shape_S11100_T11100 Shape_S11110_T11110 Shape_S11110_T11010 Shape_S01100_T11100 Shape_S01000_T01000 Shape_S01010_T01010 Shape_S01111_T01011 Shape_S01100_T01100 Shape_S01110_T11010 Shape_S11000_T11000 Shape_S11000_T01100 IsSupportedOnline ForceRule"),
"dense features")
("debug_output,d", po::value<string>()->default_value(""), "file for debug output");
diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h
index 03cc82c3..e24b65cf 100644
--- a/training/dtrain/sample.h
+++ b/training/dtrain/sample.h
@@ -16,6 +16,7 @@ struct ScoredKbest : public DecoderObserver
PerSentenceBleuScorer* scorer_;
vector<Ngrams>* ref_ngs_;
vector<size_t>* ref_ls_;
+ string viterbi_tree_str;
ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) :
k_(k), scorer_(scorer) {}
@@ -40,6 +41,7 @@ struct ScoredKbest : public DecoderObserver
samples_.push_back(h);
effective_sz_++;
feature_count_ += h.f.size();
+ viterbi_tree_str = hg->show_viterbi_tree(false);
}
}
@@ -51,6 +53,7 @@ struct ScoredKbest : public DecoderObserver
}
inline size_t GetFeatureCount() { return feature_count_; }
inline size_t GetSize() { return effective_sz_; }
+ inline string GetViterbiTreeString() { return viterbi_tree_str; }
};
} // namespace