summaryrefslogtreecommitdiff
path: root/gi/pf/pfbrat.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pf/pfbrat.cc')
-rw-r--r--gi/pf/pfbrat.cc554
1 files changed, 554 insertions, 0 deletions
diff --git a/gi/pf/pfbrat.cc b/gi/pf/pfbrat.cc
new file mode 100644
index 00000000..4c6ba3ef
--- /dev/null
+++ b/gi/pf/pfbrat.cc
@@ -0,0 +1,554 @@
+#include <iostream>
+#include <tr1/memory>
+#include <queue>
+
+#include <boost/functional.hpp>
+#include <boost/multi_array.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "viterbi.h"
+#include "hg.h"
+#include "trule.h"
+#include "tdict.h"
+#include "filelib.h"
+#include "dict.h"
+#include "sampler.h"
+#include "ccrp_nt.h"
+#include "cfg_wfst_composer.h"
+
+using namespace std;
+using namespace tr1;
+namespace po = boost::program_options;
+
+static unsigned kMAX_SRC_PHRASE;
+static unsigned kMAX_TRG_PHRASE;
+struct FSTState;
+
+size_t hash_value(const TRule& r) {
+ size_t h = 2 - r.lhs_;
+ boost::hash_combine(h, boost::hash_value(r.e_));
+ boost::hash_combine(h, boost::hash_value(r.f_));
+ return h;
+}
+
+bool operator==(const TRule& a, const TRule& b) {
+ return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_);
+}
+
+double log_poisson(unsigned x, const double& lambda) {
+ assert(lambda > 0.0);
+ return log(lambda) * x - lgamma(x + 1) - lambda;
+}
+
+struct ConditionalBase {
+ explicit ConditionalBase(const double m1mixture, const unsigned vocab_e_size, const string& model1fname) :
+ kM1MIXTURE(m1mixture),
+ kUNIFORM_MIXTURE(1.0 - m1mixture),
+ kUNIFORM_TARGET(1.0 / vocab_e_size),
+ kNULL(TD::Convert("<eps>")) {
+ assert(m1mixture >= 0.0 && m1mixture <= 1.0);
+ assert(vocab_e_size > 0);
+ LoadModel1(model1fname);
+ }
+
+ void LoadModel1(const string& fname) {
+ cerr << "Loading Model 1 parameters from " << fname << " ..." << endl;
+ ReadFile rf(fname);
+ istream& in = *rf.stream();
+ string line;
+ unsigned lc = 0;
+ while(getline(in, line)) {
+ ++lc;
+ int cur = 0;
+ int start = 0;
+ while(cur < line.size() && line[cur] != ' ') { ++cur; }
+ assert(cur != line.size());
+ line[cur] = 0;
+ const WordID src = TD::Convert(&line[0]);
+ ++cur;
+ start = cur;
+ while(cur < line.size() && line[cur] != ' ') { ++cur; }
+ assert(cur != line.size());
+ line[cur] = 0;
+ WordID trg = TD::Convert(&line[start]);
+ const double logprob = strtod(&line[cur + 1], NULL);
+ if (src >= ttable.size()) ttable.resize(src + 1);
+ ttable[src][trg].logeq(logprob);
+ }
+ cerr << " read " << lc << " parameters.\n";
+ }
+
+ // return logp0 of rule.e_ | rule.f_
+ prob_t operator()(const TRule& rule) const {
+ const int flen = rule.f_.size();
+ const int elen = rule.e_.size();
+ prob_t uniform_src_alignment; uniform_src_alignment.logeq(-log(flen + 1));
+ prob_t p;
+ p.logeq(log_poisson(elen, flen + 0.01)); // elen | flen ~Pois(flen + 0.01)
+ for (int i = 0; i < elen; ++i) { // for each position i in e-RHS
+ const WordID trg = rule.e_[i];
+ prob_t tp = prob_t::Zero();
+ for (int j = -1; j < flen; ++j) {
+ const WordID src = j < 0 ? kNULL : rule.f_[j];
+ const map<WordID, prob_t>::const_iterator it = ttable[src].find(trg);
+ if (it != ttable[src].end()) {
+ tp += kM1MIXTURE * it->second;
+ }
+ tp += kUNIFORM_MIXTURE * kUNIFORM_TARGET;
+ }
+ tp *= uniform_src_alignment; // draw a_i ~uniform
+ p *= tp; // draw e_i ~Model1(f_a_i) / uniform
+ }
+ return p;
+ }
+
+ const prob_t kM1MIXTURE; // Model 1 mixture component
+ const prob_t kUNIFORM_MIXTURE; // uniform mixture component
+ const prob_t kUNIFORM_TARGET;
+ const WordID kNULL;
+ vector<map<WordID, prob_t> > ttable;
+};
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("samples,s",po::value<unsigned>()->default_value(1000),"Number of samples")
+ ("input,i",po::value<string>(),"Read parallel data from")
+ ("max_src_phrase",po::value<unsigned>()->default_value(3),"Maximum length of source language phrases")
+ ("max_trg_phrase",po::value<unsigned>()->default_value(3),"Maximum length of target language phrases")
+ ("model1,m",po::value<string>(),"Model 1 parameters (used in base distribution)")
+ ("model1_interpolation_weight",po::value<double>()->default_value(0.95),"Mixing proportion of model 1 with uniform target distribution")
+ ("random_seed,S",po::value<uint32_t>(), "Random seed");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ ifstream config((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help") || (conf->count("input") == 0)) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+void ReadParallelCorpus(const string& filename,
+ vector<vector<WordID> >* f,
+ vector<vector<int> >* e,
+ set<int>* vocab_f,
+ set<int>* vocab_e) {
+ f->clear();
+ e->clear();
+ vocab_f->clear();
+ vocab_e->clear();
+ istream* in;
+ if (filename == "-")
+ in = &cin;
+ else
+ in = new ifstream(filename.c_str());
+ assert(*in);
+ string line;
+ const WordID kDIV = TD::Convert("|||");
+ vector<WordID> tmp;
+ while(*in) {
+ getline(*in, line);
+ if (line.empty() && !*in) break;
+ e->push_back(vector<int>());
+ f->push_back(vector<int>());
+ vector<int>& le = e->back();
+ vector<int>& lf = f->back();
+ tmp.clear();
+ TD::ConvertSentence(line, &tmp);
+ bool isf = true;
+ for (unsigned i = 0; i < tmp.size(); ++i) {
+ const int cur = tmp[i];
+ if (isf) {
+ if (kDIV == cur) { isf = false; } else {
+ lf.push_back(cur);
+ vocab_f->insert(cur);
+ }
+ } else {
+ assert(cur != kDIV);
+ le.push_back(cur);
+ vocab_e->insert(cur);
+ }
+ }
+ assert(isf == false);
+ }
+ if (in != &cin) delete in;
+}
+
+struct UniphraseLM {
+ UniphraseLM(const vector<vector<int> >& corpus,
+ const set<int>& vocab,
+ const po::variables_map& conf) :
+ phrases_(1,1),
+ gen_(1,1),
+ corpus_(corpus),
+ uniform_word_(1.0 / vocab.size()),
+ gen_p0_(0.5),
+ p_end_(0.5),
+ use_poisson_(conf.count("poisson_length") > 0) {}
+
+ void ResampleHyperparameters(MT19937* rng) {
+ phrases_.resample_hyperparameters(rng);
+ gen_.resample_hyperparameters(rng);
+ cerr << " " << phrases_.concentration();
+ }
+
+ CCRP_NoTable<vector<int> > phrases_;
+ CCRP_NoTable<bool> gen_;
+ vector<vector<bool> > z_; // z_[i] is there a phrase boundary after the ith word
+ const vector<vector<int> >& corpus_;
+ const double uniform_word_;
+ const double gen_p0_;
+ const double p_end_; // in base length distribution, p of the end of a phrase
+ const bool use_poisson_;
+};
+
+struct Reachability {
+ boost::multi_array<bool, 4> edges; // edges[src_covered][trg_covered][x][trg_delta] is this edge worth exploring?
+ boost::multi_array<short, 2> max_src_delta; // msd[src_covered][trg_covered] -- the largest src delta that's valid
+
+ Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) :
+ edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]),
+ max_src_delta(boost::extents[srclen][trglen]) {
+ ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len);
+ }
+
+ private:
+ struct SState {
+ SState() : prev_src_covered(), prev_trg_covered() {}
+ SState(int i, int j) : prev_src_covered(i), prev_trg_covered(j) {}
+ int prev_src_covered;
+ int prev_trg_covered;
+ };
+
+ struct NState {
+ NState() : next_src_covered(), next_trg_covered() {}
+ NState(int i, int j) : next_src_covered(i), next_trg_covered(j) {}
+ int next_src_covered;
+ int next_trg_covered;
+ };
+
+ void ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) {
+ typedef boost::multi_array<vector<SState>, 2> array_type;
+ array_type a(boost::extents[srclen + 1][trglen + 1]);
+ a[0][0].push_back(SState());
+ for (int i = 0; i < srclen; ++i) {
+ for (int j = 0; j < trglen; ++j) {
+ if (a[i][j].size() == 0) continue;
+ const SState prev(i,j);
+ for (int k = 1; k <= src_max_phrase_len; ++k) {
+ if ((i + k) > srclen) continue;
+ for (int l = 1; l <= trg_max_phrase_len; ++l) {
+ if ((j + l) > trglen) continue;
+ a[i + k][j + l].push_back(prev);
+ }
+ }
+ }
+ }
+ a[0][0].clear();
+ cerr << "Final cell contains " << a[srclen][trglen].size() << " back pointers\n";
+ assert(a[srclen][trglen].size() > 0);
+
+ typedef boost::multi_array<bool, 2> rarray_type;
+ rarray_type r(boost::extents[srclen + 1][trglen + 1]);
+// typedef boost::multi_array<vector<NState>, 2> narray_type;
+// narray_type b(boost::extents[srclen + 1][trglen + 1]);
+ r[srclen][trglen] = true;
+ for (int i = srclen; i >= 0; --i) {
+ for (int j = trglen; j >= 0; --j) {
+ vector<SState>& prevs = a[i][j];
+ if (!r[i][j]) { prevs.clear(); }
+// const NState nstate(i,j);
+ for (int k = 0; k < prevs.size(); ++k) {
+ r[prevs[k].prev_src_covered][prevs[k].prev_trg_covered] = true;
+ int src_delta = i - prevs[k].prev_src_covered;
+ edges[prevs[k].prev_src_covered][prevs[k].prev_trg_covered][src_delta][j - prevs[k].prev_trg_covered] = true;
+ short &msd = max_src_delta[prevs[k].prev_src_covered][prevs[k].prev_trg_covered];
+ if (src_delta > msd) msd = src_delta;
+// b[prevs[k].prev_src_covered][prevs[k].prev_trg_covered].push_back(nstate);
+ }
+ }
+ }
+ assert(!edges[0][0][1][0]);
+ assert(!edges[0][0][0][1]);
+ assert(!edges[0][0][0][0]);
+ cerr << " MAX SRC DELTA[0][0] = " << max_src_delta[0][0] << endl;
+ assert(max_src_delta[0][0] > 0);
+ //cerr << "First cell contains " << b[0][0].size() << " forward pointers\n";
+ //for (int i = 0; i < b[0][0].size(); ++i) {
+ // cerr << " -> (" << b[0][0][i].next_src_covered << "," << b[0][0][i].next_trg_covered << ")\n";
+ //}
+ }
+};
+
+ostream& operator<<(ostream& os, const FSTState& q);
+struct FSTState {
+ explicit FSTState(int src_size) :
+ trg_covered_(),
+ src_covered_(),
+ src_coverage_(src_size) {}
+
+ FSTState(short trg_covered, short src_covered, const vector<bool>& src_coverage, const vector<short>& src_prefix) :
+ trg_covered_(trg_covered),
+ src_covered_(src_covered),
+ src_coverage_(src_coverage),
+ src_prefix_(src_prefix) {
+ if (src_coverage_.size() == src_covered) {
+ assert(src_prefix.size() == 0);
+ }
+ }
+
+ // if we extend by the word at src_position, what are
+ // the next states that are reachable and lie on a valid
+ // path to the final state?
+ vector<FSTState> Extensions(int src_position, int src_len, int trg_len, const Reachability& r) const {
+ assert(src_position < src_coverage_.size());
+ if (src_coverage_[src_position]) {
+ cerr << "Trying to extend " << *this << " with position " << src_position << endl;
+ abort();
+ }
+ vector<bool> ncvg = src_coverage_;
+ ncvg[src_position] = true;
+
+ vector<FSTState> res;
+ const int trg_remaining = trg_len - trg_covered_;
+ if (trg_remaining <= 0) {
+ cerr << "Target appears to have been covered: " << *this << " (trg_len=" << trg_len << ",trg_covered=" << trg_covered_ << ")" << endl;
+ abort();
+ }
+ const int src_remaining = src_len - src_covered_;
+ if (src_remaining <= 0) {
+ cerr << "Source appears to have been covered: " << *this << endl;
+ abort();
+ }
+
+ for (int tc = 1; tc <= kMAX_TRG_PHRASE; ++tc) {
+ if (r.edges[src_covered_][trg_covered_][src_prefix_.size() + 1][tc]) {
+ int nc = src_prefix_.size() + 1 + src_covered_;
+ res.push_back(FSTState(trg_covered_ + tc, nc, ncvg, vector<short>()));
+ }
+ }
+
+ if ((src_prefix_.size() + 1) < r.max_src_delta[src_covered_][trg_covered_]) {
+ vector<short> nsp = src_prefix_;
+ nsp.push_back(src_position);
+ res.push_back(FSTState(trg_covered_, src_covered_, ncvg, nsp));
+ }
+
+ if (res.size() == 0) {
+ cerr << *this << " can't be extended!\n";
+ abort();
+ }
+ return res;
+ }
+
+ short trg_covered_, src_covered_;
+ vector<bool> src_coverage_;
+ vector<short> src_prefix_;
+};
+bool operator<(const FSTState& q, const FSTState& r) {
+ if (q.trg_covered_ != r.trg_covered_) return q.trg_covered_ < r.trg_covered_;
+ if (q.src_covered_!= r.src_covered_) return q.src_covered_ < r.src_covered_;
+ if (q.src_coverage_ != r.src_coverage_) return q.src_coverage_ < r.src_coverage_;
+ return q.src_prefix_ < r.src_prefix_;
+}
+
+ostream& operator<<(ostream& os, const FSTState& q) {
+ os << "[" << q.trg_covered_ << " : ";
+ for (int i = 0; i < q.src_coverage_.size(); ++i)
+ os << q.src_coverage_[i];
+ os << " : <";
+ for (int i = 0; i < q.src_prefix_.size(); ++i) {
+ if (i != 0) os << ' ';
+ os << q.src_prefix_[i];
+ }
+ return os << ">]";
+}
+
+struct MyModel {
+ MyModel(ConditionalBase& rcp0) : rp0(rcp0) {}
+ typedef unordered_map<vector<WordID>, CCRP_NoTable<TRule>, boost::hash<vector<WordID> > > SrcToRuleCRPMap;
+
+ void DecrementRule(const TRule& rule) {
+ SrcToRuleCRPMap::iterator it = rules.find(rule.f_);
+ assert(it != rules.end());
+ it->second.decrement(rule);
+ if (it->second.num_customers() == 0) rules.erase(it);
+ }
+
+ void IncrementRule(const TRule& rule) {
+ SrcToRuleCRPMap::iterator it = rules.find(rule.f_);
+ if (it == rules.end()) {
+ CCRP_NoTable<TRule> crp(1,1);
+ it = rules.insert(make_pair(rule.f_, crp)).first;
+ }
+ it->second.increment(rule);
+ }
+
+ // conditioned on rule.f_
+ prob_t RuleConditionalProbability(const TRule& rule) const {
+ const prob_t base = rp0(rule);
+ SrcToRuleCRPMap::const_iterator it = rules.find(rule.f_);
+ if (it == rules.end()) {
+ return base;
+ } else {
+ const double lp = it->second.logprob(rule, log(base));
+ prob_t q; q.logeq(lp);
+ return q;
+ }
+ }
+
+ const ConditionalBase& rp0;
+ SrcToRuleCRPMap rules;
+};
+
+struct MyFST : public WFST {
+ MyFST(const vector<WordID>& ssrc, const vector<WordID>& strg, MyModel* m) :
+ src(ssrc), trg(strg),
+ r(src.size(),trg.size(),kMAX_SRC_PHRASE, kMAX_TRG_PHRASE),
+ model(m) {
+ FSTState in(src.size());
+ cerr << " INIT: " << in << endl;
+ init = GetNode(in);
+ for (int i = 0; i < in.src_coverage_.size(); ++i) in.src_coverage_[i] = true;
+ in.src_covered_ = src.size();
+ in.trg_covered_ = trg.size();
+ cerr << "FINAL: " << in << endl;
+ final = GetNode(in);
+ }
+ virtual const WFSTNode* Final() const;
+ virtual const WFSTNode* Initial() const;
+
+ const WFSTNode* GetNode(const FSTState& q);
+ map<FSTState, boost::shared_ptr<WFSTNode> > m;
+ const vector<WordID>& src;
+ const vector<WordID>& trg;
+ Reachability r;
+ const WFSTNode* init;
+ const WFSTNode* final;
+ MyModel* model;
+};
+
+struct MyNode : public WFSTNode {
+ MyNode(const FSTState& q, MyFST* fst) : state(q), container(fst) {}
+ virtual vector<pair<const WFSTNode*, TRulePtr> > ExtendInput(unsigned srcindex) const;
+ const FSTState state;
+ mutable MyFST* container;
+};
+
+vector<pair<const WFSTNode*, TRulePtr> > MyNode::ExtendInput(unsigned srcindex) const {
+ cerr << "EXTEND " << state << " with " << srcindex << endl;
+ vector<FSTState> ext = state.Extensions(srcindex, container->src.size(), container->trg.size(), container->r);
+ vector<pair<const WFSTNode*,TRulePtr> > res(ext.size());
+ for (unsigned i = 0; i < ext.size(); ++i) {
+ res[i].first = container->GetNode(ext[i]);
+ if (ext[i].src_prefix_.size() == 0) {
+ const unsigned trg_from = state.trg_covered_;
+ const unsigned trg_to = ext[i].trg_covered_;
+ const unsigned prev_prfx_size = state.src_prefix_.size();
+ res[i].second.reset(new TRule);
+ res[i].second->lhs_ = -TD::Convert("X");
+ vector<WordID>& src = res[i].second->f_;
+ vector<WordID>& trg = res[i].second->e_;
+ src.resize(prev_prfx_size + 1);
+ for (unsigned j = 0; j < prev_prfx_size; ++j)
+ src[j] = container->src[state.src_prefix_[j]];
+ src[prev_prfx_size] = container->src[srcindex];
+ for (unsigned j = trg_from; j < trg_to; ++j)
+ trg.push_back(container->trg[j]);
+ res[i].second->scores_.set_value(FD::Convert("Proposal"), log(container->model->RuleConditionalProbability(*res[i].second)));
+ }
+ }
+ return res;
+}
+
+const WFSTNode* MyFST::GetNode(const FSTState& q) {
+ boost::shared_ptr<WFSTNode>& res = m[q];
+ if (!res) {
+ res.reset(new MyNode(q, this));
+ }
+ return &*res;
+}
+
+const WFSTNode* MyFST::Final() const {
+ return final;
+}
+
+const WFSTNode* MyFST::Initial() const {
+ return init;
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ kMAX_TRG_PHRASE = conf["max_trg_phrase"].as<unsigned>();
+ kMAX_SRC_PHRASE = conf["max_src_phrase"].as<unsigned>();
+
+ if (!conf.count("model1")) {
+ cerr << argv[0] << "Please use --model1 to specify model 1 parameters\n";
+ return 1;
+ }
+ shared_ptr<MT19937> prng;
+ if (conf.count("random_seed"))
+ prng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
+ else
+ prng.reset(new MT19937);
+ MT19937& rng = *prng;
+
+ vector<vector<int> > corpuse, corpusf;
+ set<int> vocabe, vocabf;
+ ReadParallelCorpus(conf["input"].as<string>(), &corpusf, &corpuse, &vocabf, &vocabe);
+ cerr << "f-Corpus size: " << corpusf.size() << " sentences\n";
+ cerr << "f-Vocabulary size: " << vocabf.size() << " types\n";
+ cerr << "f-Corpus size: " << corpuse.size() << " sentences\n";
+ cerr << "f-Vocabulary size: " << vocabe.size() << " types\n";
+ assert(corpusf.size() == corpuse.size());
+
+ ConditionalBase lp0(conf["model1_interpolation_weight"].as<double>(),
+ vocabe.size(),
+ conf["model1"].as<string>());
+ MyModel m(lp0);
+
+ TRule x("[X] ||| kAnwntR myN ||| at the convent ||| 0");
+ m.IncrementRule(x);
+ TRule y("[X] ||| nY dyN ||| gave ||| 0");
+ m.IncrementRule(y);
+
+
+ MyFST fst(corpusf[0], corpuse[0], &m);
+ ifstream in("./kimura.g");
+ assert(in);
+ CFG_WFSTComposer comp(fst);
+ Hypergraph hg;
+ bool succeed = comp.Compose(&in, &hg);
+ hg.PrintGraphviz();
+ if (succeed) { cerr << "SUCCESS.\n"; } else { cerr << "FAILURE REPORTED.\n"; }
+
+#if 0
+ ifstream in2("./amnabooks.g");
+ assert(in2);
+ MyFST fst2(corpusf[1], corpuse[1], &m);
+ CFG_WFSTComposer comp2(fst2);
+ Hypergraph hg2;
+ bool succeed2 = comp2.Compose(&in2, &hg2);
+ if (succeed2) { cerr << "SUCCESS.\n"; } else { cerr << "FAILURE REPORTED.\n"; }
+#endif
+
+ SparseVector<double> w; w.set_value(FD::Convert("Proposal"), 1.0);
+ hg.Reweight(w);
+ cerr << ViterbiFTree(hg) << endl;
+ return 0;
+}
+