diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-09-30 21:03:45 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-09-30 21:03:45 +0000 |
commit | 907ce4f93417bef73c3772cd4f7b641961f0fa25 (patch) | |
tree | fb4009f6d4c1965a82aa0a73c778ef568068f512 | |
parent | adfa44a02ea08cde8b1490258aefbd766617f447 (diff) |
grammar filter
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@666 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r-- | training/Makefile.am | 4 | ||||
-rw-r--r-- | training/cllh_filter_grammar.cc | 137 |
2 files changed, 141 insertions, 0 deletions
diff --git a/training/Makefile.am b/training/Makefile.am index 7cdf10d7..83c15ecc 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -8,6 +8,7 @@ bin_PROGRAMS = \ atools \ plftools \ collapse_weights \ + cllh_filter_grammar \ online_train noinst_PROGRAMS = \ @@ -28,6 +29,9 @@ mpi_online_optimize_SOURCES = mpi_online_optimize.cc online_optimizer.cc mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz endif +cllh_filter_grammar_SOURCES = cllh_filter_grammar.cc +cllh_filter_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz + online_train_SOURCES = online_train.cc online_optimizer.cc online_train_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz diff --git a/training/cllh_filter_grammar.cc b/training/cllh_filter_grammar.cc new file mode 100644 index 00000000..b5a4c35d --- /dev/null +++ b/training/cllh_filter_grammar.cc @@ -0,0 +1,137 @@ +#include <iostream> +#include <vector> +#include <cassert> + +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "tdict.h" +#include "ff_register.h" +#include "verbose.h" +#include "hg.h" +#include "decoder.h" +#include "filelib.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("training_data,t",po::value<string>(),"Training data corpus") + ("decoder_config,c",po::value<string>(),"Decoder configuration file"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void ReadTrainingCorpus(const string& fname, vector<string>* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + while(in) { + getline(in, line); + if (!in) break; + c->push_back(line); + } +} + +struct TrainingObserver : public DecoderObserver { + TrainingObserver() : s_lhs(-TD::Convert("S")), goal_lhs(-TD::Convert("Goal")) {} + + void Reset() { + total_complete = 0; + } + + virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { + state = 1; + used.clear(); + failed = true; + } + + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 1); + for (int i = 0; i < hg->edges_.size(); ++i) { + const TRulePtr& rule = hg->edges_[i].rule_; + if (rule->lhs_ == s_lhs || rule->lhs_ == goal_lhs) // fragile hack to filter out glue rules + continue; + if (rule->prev_i == -1) + used.insert(rule); + } + state = 2; + } + + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 2); + state = 3; + } + + virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { + if (state == 3) { + failed = false; + } else { + failed = true; + } + } + + set<TRulePtr> used; + + const int s_lhs; + const int goal_lhs; + bool failed; + int total_complete; + int state; +}; + +int main(int argc, char** argv) { + SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + + ReadFile ini_rf(conf["decoder_config"].as<string>()); + Decoder decoder(ini_rf.stream()); + if (decoder.GetConf()["input"].as<string>() != "-") { + cerr << "cdec.ini must not set an input file\n"; + abort(); + } + + vector<string> corpus; + ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus); + assert(corpus.size() > 0); + + TrainingObserver observer; + for (int i = 0; i < corpus.size(); ++i) { + int ex_num = i; + decoder.SetId(ex_num); + decoder.Decode(corpus[ex_num], &observer); + if (observer.failed) { + cerr << "*** id=" << ex_num << " is unreachable\n"; + observer.used.clear(); + } else { + cerr << corpus[ex_num] << endl; + for (set<TRulePtr>::iterator it = observer.used.begin(); it != observer.used.end(); ++it) { + cout << **it << endl; + (*it)->prev_i = 0; + } + } + } + return 0; +} |