From bffa08c6a335ac03bc43ab178724ec176921aa5a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 28 Feb 2011 22:10:50 -0500 Subject: src language LM phrase scoring --- training/Makefile.am | 8 +++- training/augment_grammar.cc | 93 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 training/augment_grammar.cc (limited to 'training') diff --git a/training/Makefile.am b/training/Makefile.am index 8218ff0a..b046c698 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -10,7 +10,8 @@ bin_PROGRAMS = \ collapse_weights \ cllh_filter_grammar \ mpi_online_optimize \ - mpi_batch_optimize + mpi_batch_optimize \ + augment_grammar noinst_PROGRAMS = \ lbfgs_test \ @@ -34,6 +35,9 @@ endif cllh_filter_grammar_SOURCES = cllh_filter_grammar.cc cllh_filter_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +augment_grammar_SOURCES = augment_grammar.cc +augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + atools_SOURCES = atools.cc atools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz @@ -67,4 +71,4 @@ mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils plftools_SOURCES = plftools.cc plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz -AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm diff --git a/training/augment_grammar.cc b/training/augment_grammar.cc new file mode 100644 index 00000000..9b7fc7be --- /dev/null +++ b/training/augment_grammar.cc @@ -0,0 +1,93 @@ +#include +#include + +#include +#include + +#include "rule_lexer.h" +#include "trule.h" +#include "filelib.h" +#include "tdict.h" +#include "lm/model.hh" +#include "lm/enumerate_vocab.hh" +#include "wordid.h" + +namespace po = boost::program_options; +using namespace std; + +vector word_map; +lm::ngram::ProbingModel* ngram; +struct VMapper : public lm::ngram::EnumerateVocab { + VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_->size()) + out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); + (*out_)[cdec_id] = index; + } + vector* out_; + const lm::WordIndex kLM_UNKNOWN_TOKEN; +}; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("source_lm,l",po::value(),"Source language LM (KLM)") + ("add_shape_types,s", "Add rule shape types"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help")) { + cerr << "Usage " << argv[0] << " [OPTIONS]\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +template float Score(const vector& str, const Model &model) { + typename Model::State state, out; + lm::FullScoreReturn ret; + float total = 0.0f; + state = model.NullContextState(); + + for (int i = 0; i < str.size(); ++i) { + lm::WordIndex vocab = ((str[i] < word_map.size() && str[i] > 0) ? word_map[str[i]] : 0); + ret = model.FullScore(state, vocab, out); + total += ret.prob; + state = out; + } + return total; +} + +static void RuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) { + cout << *new_rule << " SrcLM=" << Score(new_rule->f_, *ngram) << endl; +} + +int main(int argc, char** argv) { + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + if (conf.count("source_lm")) { + lm::ngram::Config kconf; + VMapper vm(&word_map); + kconf.enumerate_vocab = &vm; + ngram = new lm::ngram::ProbingModel(conf["source_lm"].as().c_str(), kconf); + cerr << "Loaded " << (int)ngram->Order() << "-gram KenLM (MapSize=" << word_map.size() << ")\n"; + } else { ngram = NULL; } + assert(ngram); + RuleLexer::ReadRules(&cin, &RuleHelper, NULL); + return 0; +} + -- cgit v1.2.3