summaryrefslogtreecommitdiff
path: root/training/augment_grammar.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/augment_grammar.cc')
-rw-r--r--training/augment_grammar.cc52
1 files changed, 39 insertions, 13 deletions
diff --git a/training/augment_grammar.cc b/training/augment_grammar.cc
index f1b1b355..19120d00 100644
--- a/training/augment_grammar.cc
+++ b/training/augment_grammar.cc
@@ -35,24 +35,29 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
opts.add_options()
("source_lm,l",po::value<string>(),"Source language LM (KLM)")
("collapse_weights,w",po::value<string>(), "Collapse weights into a single feature X using the coefficients from this weights file")
- ("add_shape_types,s", "Add rule shape types");
+ ("add_shape_types,s", "Add rule shape types")
+ ("replace_files,r", "Replace files with transformed variants (requires loading full grammar into memory)")
+ ("grammar,g", po::value<vector<string> >(), "Input (also output) grammar file(s)");
po::options_description clo("Command line options");
clo.add_options()
("config", po::value<string>(), "Configuration file")
("help,h", "Print this help message and exit");
po::options_description dconfig_options, dcmdline_options;
+ po::positional_options_description p;
+ p.add("grammar", -1);
+
dconfig_options.add(opts);
dcmdline_options.add(opts).add(clo);
-
- po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+
+ po::store(po::command_line_parser(argc, argv).options(dcmdline_options).positional(p).run(), *conf);
if (conf->count("config")) {
ifstream config((*conf)["config"].as<string>().c_str());
po::store(po::parse_config_file(config, dconfig_options), *conf);
}
po::notify(*conf);
- if (conf->count("help")) {
- cerr << "Usage " << argv[0] << " [OPTIONS]\n";
+ if (conf->count("help") || conf->count("grammar")==0) {
+ cerr << "Usage " << argv[0] << " [OPTIONS] file.scfg [file2.scfg...]\n";
cerr << dcmdline_options << endl;
return false;
}
@@ -82,20 +87,26 @@ template <class Model> float Score(const vector<WordID>& str, const Model &model
int kSrcLM;
vector<double> col_weights;
+bool gather_rules;
+vector<TRulePtr> rules;
static void RuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) {
static const int kSrcLM = FD::Convert("SrcLM");
static const int kPC = FD::Convert("PC");
static const int kX = FD::Convert("X");
- TRule r(*new_rule);
- if (ngram) r.scores_.set_value(kSrcLM, Score(r.f_, *ngram));
- r.scores_.set_value(kPC, 1.0);
+ TRulePtr r; r.reset(new TRule(*new_rule));
+ if (ngram) r->scores_.set_value(kSrcLM, Score(r->f_, *ngram));
+ r->scores_.set_value(kPC, 1.0);
if (col_weights.size()) {
- double score = r.scores_.dot(col_weights);
- r.scores_.clear();
- r.scores_.set_value(kX, score);
+ double score = r->scores_.dot(col_weights);
+ r->scores_.clear();
+ r->scores_.set_value(kX, score);
+ }
+ if (gather_rules) {
+ rules.push_back(r);
+ } else {
+ cout << *r << endl;
}
- cout << r << endl;
}
@@ -116,7 +127,22 @@ int main(int argc, char** argv) {
w.InitFromFile(conf["collapse_weights"].as<string>());
w.InitVector(&col_weights);
}
- RuleLexer::ReadRules(&cin, &RuleHelper, NULL);
+ gather_rules = false;
+ bool replace_files = conf.count("replace_files");
+ if (replace_files) gather_rules = true;
+ vector<string> files = conf["grammar"].as<vector<string> >();
+ for (int i=0; i < files.size(); ++i) {
+ cerr << "Processing " << files[i] << " ..." << endl;
+ if (true) {
+ ReadFile rf(files[i]);
+ rules.clear();
+ RuleLexer::ReadRules(rf.stream(), &RuleHelper, NULL);
+ }
+ if (replace_files) {
+ WriteFile wf(files[i]);
+ for (int i = 0; i < rules.size(); ++i) { (*wf.stream()) << *rules[i] << endl; }
+ }
+ }
return 0;
}