summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-03-03 21:41:40 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-03-03 21:41:40 -0500
commite40c7341549126a045eeca412c9910edf1eacef8 (patch)
treeee26f494c1be6aa26eeece6006fbf621c90f8211 /training
parent9bef5114bb2fc923ad02b3245517c97289f40f83 (diff)
more options for augment grammar
Diffstat (limited to 'training')
-rw-r--r--training/augment_grammar.cc24
1 files changed, 22 insertions, 2 deletions
diff --git a/training/augment_grammar.cc b/training/augment_grammar.cc
index 48ef23fc..f1b1b355 100644
--- a/training/augment_grammar.cc
+++ b/training/augment_grammar.cc
@@ -4,6 +4,7 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include "weights.h"
#include "rule_lexer.h"
#include "trule.h"
#include "filelib.h"
@@ -33,6 +34,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
("source_lm,l",po::value<string>(),"Source language LM (KLM)")
+ ("collapse_weights,w",po::value<string>(), "Collapse weights into a single feature X using the coefficients from this weights file")
("add_shape_types,s", "Add rule shape types");
po::options_description clo("Command line options");
clo.add_options()
@@ -78,8 +80,22 @@ template <class Model> float Score(const vector<WordID>& str, const Model &model
return total;
}
+int kSrcLM;
+vector<double> col_weights;
+
static void RuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) {
- cout << *new_rule << " SrcLM=" << Score(new_rule->f_, *ngram) << endl;
+ static const int kSrcLM = FD::Convert("SrcLM");
+ static const int kPC = FD::Convert("PC");
+ static const int kX = FD::Convert("X");
+ TRule r(*new_rule);
+ if (ngram) r.scores_.set_value(kSrcLM, Score(r.f_, *ngram));
+ r.scores_.set_value(kPC, 1.0);
+ if (col_weights.size()) {
+ double score = r.scores_.dot(col_weights);
+ r.scores_.clear();
+ r.scores_.set_value(kX, score);
+ }
+ cout << r << endl;
}
@@ -95,7 +111,11 @@ int main(int argc, char** argv) {
cerr << "Loaded " << (int)ngram->Order() << "-gram KenLM (MapSize=" << word_map.size() << ")\n";
cerr << " <s> = " << kSOS << endl;
} else { ngram = NULL; }
- assert(ngram);
+ if (conf.count("collapse_weights")) {
+ Weights w;
+ w.InitFromFile(conf["collapse_weights"].as<string>());
+ w.InitVector(&col_weights);
+ }
RuleLexer::ReadRules(&cin, &RuleHelper, NULL);
return 0;
}