summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-09-30 21:03:45 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-09-30 21:03:45 +0000
commit6c00ca8e04e1a1398ca753d41e5cd474a37626b6 (patch)
tree9246b1a9ef2bd0495698acdd9adccb1b8274c34c /training
parent446fde8f67d4ad8c2699a8e9327a8988c3380723 (diff)
grammar filter
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@666 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training')
-rw-r--r--training/Makefile.am4
-rw-r--r--training/cllh_filter_grammar.cc137
2 files changed, 141 insertions, 0 deletions
diff --git a/training/Makefile.am b/training/Makefile.am
index 7cdf10d7..83c15ecc 100644
--- a/training/Makefile.am
+++ b/training/Makefile.am
@@ -8,6 +8,7 @@ bin_PROGRAMS = \
atools \
plftools \
collapse_weights \
+ cllh_filter_grammar \
online_train
noinst_PROGRAMS = \
@@ -28,6 +29,9 @@ mpi_online_optimize_SOURCES = mpi_online_optimize.cc online_optimizer.cc
mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz
endif
+cllh_filter_grammar_SOURCES = cllh_filter_grammar.cc
+cllh_filter_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz
+
online_train_SOURCES = online_train.cc online_optimizer.cc
online_train_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz
diff --git a/training/cllh_filter_grammar.cc b/training/cllh_filter_grammar.cc
new file mode 100644
index 00000000..b5a4c35d
--- /dev/null
+++ b/training/cllh_filter_grammar.cc
@@ -0,0 +1,137 @@
+#include <iostream>
+#include <vector>
+#include <cassert>
+
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "tdict.h"
+#include "ff_register.h"
+#include "verbose.h"
+#include "hg.h"
+#include "decoder.h"
+#include "filelib.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("training_data,t",po::value<string>(),"Training data corpus")
+ ("decoder_config,c",po::value<string>(),"Decoder configuration file");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ ifstream config((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+void ReadTrainingCorpus(const string& fname, vector<string>* c) {
+ ReadFile rf(fname);
+ istream& in = *rf.stream();
+ string line;
+ while(in) {
+ getline(in, line);
+ if (!in) break;
+ c->push_back(line);
+ }
+}
+
+struct TrainingObserver : public DecoderObserver {
+ TrainingObserver() : s_lhs(-TD::Convert("S")), goal_lhs(-TD::Convert("Goal")) {}
+
+ void Reset() {
+ total_complete = 0;
+ }
+
+ virtual void NotifyDecodingStart(const SentenceMetadata& smeta) {
+ state = 1;
+ used.clear();
+ failed = true;
+ }
+
+ virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) {
+ assert(state == 1);
+ for (int i = 0; i < hg->edges_.size(); ++i) {
+ const TRulePtr& rule = hg->edges_[i].rule_;
+ if (rule->lhs_ == s_lhs || rule->lhs_ == goal_lhs) // fragile hack to filter out glue rules
+ continue;
+ if (rule->prev_i == -1)
+ used.insert(rule);
+ }
+ state = 2;
+ }
+
+ virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) {
+ assert(state == 2);
+ state = 3;
+ }
+
+ virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) {
+ if (state == 3) {
+ failed = false;
+ } else {
+ failed = true;
+ }
+ }
+
+ set<TRulePtr> used;
+
+ const int s_lhs;
+ const int goal_lhs;
+ bool failed;
+ int total_complete;
+ int state;
+};
+
+int main(int argc, char** argv) {
+ SetSilent(true); // turn off verbose decoder output
+ register_feature_functions();
+
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+
+ ReadFile ini_rf(conf["decoder_config"].as<string>());
+ Decoder decoder(ini_rf.stream());
+ if (decoder.GetConf()["input"].as<string>() != "-") {
+ cerr << "cdec.ini must not set an input file\n";
+ abort();
+ }
+
+ vector<string> corpus;
+ ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus);
+ assert(corpus.size() > 0);
+
+ TrainingObserver observer;
+ for (int i = 0; i < corpus.size(); ++i) {
+ int ex_num = i;
+ decoder.SetId(ex_num);
+ decoder.Decode(corpus[ex_num], &observer);
+ if (observer.failed) {
+ cerr << "*** id=" << ex_num << " is unreachable\n";
+ observer.used.clear();
+ } else {
+ cerr << corpus[ex_num] << endl;
+ for (set<TRulePtr>::iterator it = observer.used.begin(); it != observer.used.end(); ++it) {
+ cout << **it << endl;
+ (*it)->prev_i = 0;
+ }
+ }
+ }
+ return 0;
+}