From a046645ca3e2ac1ac8839ba2856c49bd771be62f Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Thu, 5 Nov 2015 20:17:40 +0100 Subject: dtrain_net_interface output rules too --- training/dtrain/dtrain_net_interface.cc | 4 ++++ training/dtrain/dtrain_net_interface.h | 1 + training/dtrain/sample_net_interface.h | 6 +++++- 3 files changed, 10 insertions(+), 1 deletion(-) (limited to 'training') diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc index 01b110b4..f16b9304 100644 --- a/training/dtrain/dtrain_net_interface.cc +++ b/training/dtrain/dtrain_net_interface.cc @@ -28,6 +28,7 @@ main(int argc, char** argv) boost::split(dense_features, conf["dense_features"].as(), boost::is_any_of(" ")); const bool output_derivation = conf["output_derivation"].as(); + const bool output_rules = conf["output_rules"].as(); // setup decoder register_feature_functions(); @@ -132,6 +133,9 @@ main(int argc, char** argv) } else { PrintWordIDVec((*samples)[0].w, os); } + if (output_rules) { + os << observer->GetViterbiRules() << endl; + } sock.send(os.str().c_str(), os.str().size()+1, 0); cerr << "[dtrain] done translating, looping again" << endl; continue; diff --git a/training/dtrain/dtrain_net_interface.h b/training/dtrain/dtrain_net_interface.h index 3c7665a2..816237c3 100644 --- a/training/dtrain/dtrain_net_interface.h +++ b/training/dtrain/dtrain_net_interface.h @@ -65,6 +65,7 @@ dtrain_net_init(int argc, char** argv, po::variables_map* conf) ("learning_rate,l", po::value()->default_value(1.0), "learning rate") ("learning_rate_sparse,l", po::value()->default_value(1.0), "learning rate for sparse features") ("output_derivation,E", po::bool_switch()->default_value(false), "output derivation, not viterbi str") + ("output_rules,R", po::bool_switch()->default_value(false), "also output rules") ("dense_features,D", po::value()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV Shape_S01111_T11011 Shape_S11110_T11011 Shape_S11100_T11000 Shape_S01110_T01110 Shape_S01111_T01111 Shape_S01100_T11000 Shape_S10000_T10000 Shape_S11100_T11100 Shape_S11110_T11110 Shape_S11110_T11010 Shape_S01100_T11100 Shape_S01000_T01000 Shape_S01010_T01010 Shape_S01111_T01011 Shape_S01100_T01100 Shape_S01110_T11010 Shape_S11000_T11000 Shape_S11000_T01100 IsSupportedOnline ForceRule"), "dense features") ("debug_output,d", po::value()->default_value(""), "file for debug output"); diff --git a/training/dtrain/sample_net_interface.h b/training/dtrain/sample_net_interface.h index a2b5f87d..6d00e5d5 100644 --- a/training/dtrain/sample_net_interface.h +++ b/training/dtrain/sample_net_interface.h @@ -17,7 +17,7 @@ struct ScoredKbest : public DecoderObserver vector* ref_ngs_; vector* ref_ls_; bool dont_score; - string viterbiTreeStr_; + string viterbiTreeStr_, viterbiRules_; ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) : k_(k), scorer_(scorer), dont_score(false) {} @@ -44,6 +44,9 @@ struct ScoredKbest : public DecoderObserver effective_sz_++; feature_count_ += h.f.size(); viterbiTreeStr_ = hg->show_viterbi_tree(false); + ostringstream ss; + ViterbiRules(*hg, &ss); + viterbiRules_ = ss.str(); } } @@ -56,6 +59,7 @@ struct ScoredKbest : public DecoderObserver inline size_t GetFeatureCount() { return feature_count_; } inline size_t GetSize() { return effective_sz_; } inline string GetViterbiTreeStr() { return viterbiTreeStr_; } + inline string GetViterbiRules() { return viterbiRules_; } }; } // namespace -- cgit v1.2.3