diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-03 21:41:40 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-03 21:41:40 -0500 |
commit | e40c7341549126a045eeca412c9910edf1eacef8 (patch) | |
tree | ee26f494c1be6aa26eeece6006fbf621c90f8211 /training | |
parent | 9bef5114bb2fc923ad02b3245517c97289f40f83 (diff) |
more options for augment grammar
Diffstat (limited to 'training')
-rw-r--r-- | training/augment_grammar.cc | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/training/augment_grammar.cc b/training/augment_grammar.cc index 48ef23fc..f1b1b355 100644 --- a/training/augment_grammar.cc +++ b/training/augment_grammar.cc @@ -4,6 +4,7 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> +#include "weights.h" #include "rule_lexer.h" #include "trule.h" #include "filelib.h" @@ -33,6 +34,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("source_lm,l",po::value<string>(),"Source language LM (KLM)") + ("collapse_weights,w",po::value<string>(), "Collapse weights into a single feature X using the coefficients from this weights file") ("add_shape_types,s", "Add rule shape types"); po::options_description clo("Command line options"); clo.add_options() @@ -78,8 +80,22 @@ template <class Model> float Score(const vector<WordID>& str, const Model &model return total; } +int kSrcLM; +vector<double> col_weights; + static void RuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) { - cout << *new_rule << " SrcLM=" << Score(new_rule->f_, *ngram) << endl; + static const int kSrcLM = FD::Convert("SrcLM"); + static const int kPC = FD::Convert("PC"); + static const int kX = FD::Convert("X"); + TRule r(*new_rule); + if (ngram) r.scores_.set_value(kSrcLM, Score(r.f_, *ngram)); + r.scores_.set_value(kPC, 1.0); + if (col_weights.size()) { + double score = r.scores_.dot(col_weights); + r.scores_.clear(); + r.scores_.set_value(kX, score); + } + cout << r << endl; } @@ -95,7 +111,11 @@ int main(int argc, char** argv) { cerr << "Loaded " << (int)ngram->Order() << "-gram KenLM (MapSize=" << word_map.size() << ")\n"; cerr << " <s> = " << kSOS << endl; } else { ngram = NULL; } - assert(ngram); + if (conf.count("collapse_weights")) { + Weights w; + w.InitFromFile(conf["collapse_weights"].as<string>()); + w.InitVector(&col_weights); + } RuleLexer::ReadRules(&cin, &RuleHelper, NULL); return 0; } |