summaryrefslogtreecommitdiff
path: root/training/augment_grammar.cc
blob: 48ef23fc64917a5d86ae233d08135e0c0d0544af (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <iostream>
#include <vector>

#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>

#include "rule_lexer.h"
#include "trule.h"
#include "filelib.h"
#include "tdict.h"
#include "lm/model.hh"
#include "lm/enumerate_vocab.hh"
#include "wordid.h"

namespace po = boost::program_options;
using namespace std;

vector<lm::WordIndex> word_map;
lm::ngram::ProbingModel* ngram;
struct VMapper : public lm::ngram::EnumerateVocab {
  VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); }
  void Add(lm::WordIndex index, const StringPiece &str) {
    const WordID cdec_id = TD::Convert(str.as_string());
    if (cdec_id >= out_->size())
      out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN);
    (*out_)[cdec_id] = index;
  }
  vector<lm::WordIndex>* out_;
  const lm::WordIndex kLM_UNKNOWN_TOKEN;
};

bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
  po::options_description opts("Configuration options");
  opts.add_options()
        ("source_lm,l",po::value<string>(),"Source language LM (KLM)")
        ("add_shape_types,s", "Add rule shape types");
  po::options_description clo("Command line options");
  clo.add_options()
        ("config", po::value<string>(), "Configuration file")
        ("help,h", "Print this help message and exit");
  po::options_description dconfig_options, dcmdline_options;
  dconfig_options.add(opts);
  dcmdline_options.add(opts).add(clo);
  
  po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
  if (conf->count("config")) {
    ifstream config((*conf)["config"].as<string>().c_str());
    po::store(po::parse_config_file(config, dconfig_options), *conf);
  }
  po::notify(*conf);

  if (conf->count("help")) {
    cerr << "Usage " << argv[0] << " [OPTIONS]\n";
    cerr << dcmdline_options << endl;
    return false;
  }
  return true;
}

lm::WordIndex kSOS;

template <class Model> float Score(const vector<WordID>& str, const Model &model) {
  typename Model::State state, out;
  lm::FullScoreReturn ret;
  float total = 0.0f;
  state = model.NullContextState();

  for (int i = 0; i < str.size(); ++i) {
    lm::WordIndex vocab = ((str[i] < word_map.size() && str[i] > 0) ? word_map[str[i]] : 0);
    if (vocab == kSOS) {
      state = model.BeginSentenceState();
    } else {
      ret = model.FullScore(state, vocab, out);
      total += ret.prob;
      state = out;
    }
  }
  return total;
}

static void RuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) {
  cout << *new_rule << " SrcLM=" << Score(new_rule->f_, *ngram) << endl;
}


int main(int argc, char** argv) {
  po::variables_map conf;
  if (!InitCommandLine(argc, argv, &conf)) return 1;
  if (conf.count("source_lm")) {
    lm::ngram::Config kconf;
    VMapper vm(&word_map);
    kconf.enumerate_vocab = &vm; 
    ngram = new lm::ngram::ProbingModel(conf["source_lm"].as<string>().c_str(), kconf);
    kSOS = word_map[TD::Convert("<s>")];
    cerr << "Loaded " << (int)ngram->Order() << "-gram KenLM (MapSize=" << word_map.size() << ")\n";
    cerr << "  <s> = " << kSOS << endl;
  } else { ngram = NULL; }
  assert(ngram);
  RuleLexer::ReadRules(&cin, &RuleHelper, NULL);
  return 0;
}