diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/Makefile.am | 8 | ||||
-rw-r--r-- | training/augment_grammar.cc | 93 |
2 files changed, 99 insertions, 2 deletions
diff --git a/training/Makefile.am b/training/Makefile.am index 8218ff0a..b046c698 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -10,7 +10,8 @@ bin_PROGRAMS = \ collapse_weights \ cllh_filter_grammar \ mpi_online_optimize \ - mpi_batch_optimize + mpi_batch_optimize \ + augment_grammar noinst_PROGRAMS = \ lbfgs_test \ @@ -34,6 +35,9 @@ endif cllh_filter_grammar_SOURCES = cllh_filter_grammar.cc cllh_filter_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +augment_grammar_SOURCES = augment_grammar.cc +augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + atools_SOURCES = atools.cc atools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz @@ -67,4 +71,4 @@ mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils plftools_SOURCES = plftools.cc plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz -AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm diff --git a/training/augment_grammar.cc b/training/augment_grammar.cc new file mode 100644 index 00000000..9b7fc7be --- /dev/null +++ b/training/augment_grammar.cc @@ -0,0 +1,93 @@ +#include <iostream> +#include <vector> + +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "rule_lexer.h" +#include "trule.h" +#include "filelib.h" +#include "tdict.h" +#include "lm/model.hh" +#include "lm/enumerate_vocab.hh" +#include "wordid.h" + +namespace po = boost::program_options; +using namespace std; + +vector<lm::WordIndex> word_map; +lm::ngram::ProbingModel* ngram; +struct VMapper : public lm::ngram::EnumerateVocab { + VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_->size()) + out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); + (*out_)[cdec_id] = index; + } + vector<lm::WordIndex>* out_; + const lm::WordIndex kLM_UNKNOWN_TOKEN; +}; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("source_lm,l",po::value<string>(),"Source language LM (KLM)") + ("add_shape_types,s", "Add rule shape types"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help")) { + cerr << "Usage " << argv[0] << " [OPTIONS]\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +template <class Model> float Score(const vector<WordID>& str, const Model &model) { + typename Model::State state, out; + lm::FullScoreReturn ret; + float total = 0.0f; + state = model.NullContextState(); + + for (int i = 0; i < str.size(); ++i) { + lm::WordIndex vocab = ((str[i] < word_map.size() && str[i] > 0) ? word_map[str[i]] : 0); + ret = model.FullScore(state, vocab, out); + total += ret.prob; + state = out; + } + return total; +} + +static void RuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) { + cout << *new_rule << " SrcLM=" << Score(new_rule->f_, *ngram) << endl; +} + +int main(int argc, char** argv) { + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + if (conf.count("source_lm")) { + lm::ngram::Config kconf; + VMapper vm(&word_map); + kconf.enumerate_vocab = &vm; + ngram = new lm::ngram::ProbingModel(conf["source_lm"].as<string>().c_str(), kconf); + cerr << "Loaded " << (int)ngram->Order() << "-gram KenLM (MapSize=" << word_map.size() << ")\n"; + } else { ngram = NULL; } + assert(ngram); + RuleLexer::ReadRules(&cin, &RuleHelper, NULL); + return 0; +} + |