diff options
author | mjdenkowski <michael.j.denkowski@gmail.com> | 2013-08-28 18:07:42 -0400 |
---|---|---|
committer | mjdenkowski <michael.j.denkowski@gmail.com> | 2013-08-28 18:07:42 -0400 |
commit | 0bc21f0fbcf5e060c1a9b249527e094436a383d8 (patch) | |
tree | 070c863b4a6734f2c8072e2cb610ac878a96bb31 /training | |
parent | ca9b58716214148eeaeaa3076e1a1dc8f8bb5892 (diff) |
Stream support for MIRA (part of realtime)
Diffstat (limited to 'training')
-rw-r--r-- | training/mira/kbest_cut_mira.cc | 122 |
1 files changed, 88 insertions, 34 deletions
diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc index 7df9a18f..d8c42db7 100644 --- a/training/mira/kbest_cut_mira.cc +++ b/training/mira/kbest_cut_mira.cc @@ -48,6 +48,7 @@ int hope_select; bool pseudo_doc; bool sent_approx; bool checkloss; +bool stream; void SanityCheck(const vector<double>& w) { for (int i = 0; i < w.size(); ++i) { @@ -99,6 +100,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("k_best_size,k", po::value<int>()->default_value(250), "Size of hypothesis list to search for oracles") ("update_k_best,b", po::value<int>()->default_value(1), "Size of good, bad lists to perform update with") ("unique_k_best,u", "Unique k-best translation list") + ("stream,t", "Stream mode (used for realtime)") ("weights_output,O",po::value<string>(),"Directory to write weights to") ("output_dir,D",po::value<string>(),"Directory to place output in") ("decoder_config,c",po::value<string>(),"Decoder configuration file"); @@ -117,7 +119,11 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { } po::notify(*conf); - if (conf->count("help") || !conf->count("input_weights") || !conf->count("decoder_config") || !conf->count("reference")) { + if (conf->count("help") + || !conf->count("input_weights") + || !conf->count("decoder_config") + || (!conf->count("stream") && (!conf->count("reference") || !conf->count("weights_output") || !conf->count("output_dir"))) + ) { cerr << dcmdline_options << endl; return false; } @@ -321,6 +327,25 @@ struct GoodBadOracle { vector<shared_ptr<HypothesisInfo> > bad; }; +struct BasicObserver: public DecoderObserver { + Hypergraph* hypergraph; + BasicObserver() : hypergraph(NULL) {} + ~BasicObserver() { + if(hypergraph != NULL) delete hypergraph; + } + void NotifyDecodingStart(const SentenceMetadata& smeta) {} + void NotifySourceParseFailure(const SentenceMetadata& smeta) {} + void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + if(hypergraph != NULL) delete hypergraph; + hypergraph = new Hypergraph(*hg); + } + void NotifyAlignmentFailure(const SentenceMetadata& semta) { + if(hypergraph != NULL) delete hypergraph; + } + void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) {} + void NotifyDecodingComplete(const SentenceMetadata& smeta) {} +}; + struct TrainingObserver : public DecoderObserver { TrainingObserver(const int k, const DocScorer& d, vector<GoodBadOracle>* o, vector<ScoreP>* cbs) : ds(d), oracles(*o), corpus_bleu_sent_stats(*cbs), kbest_size(k) { @@ -619,14 +644,15 @@ int main(int argc, char** argv) { no_select = conf.count("no_select"); update_list_size = conf["update_k_best"].as<int>(); unique_kbest = conf.count("unique_k_best"); + stream = conf.count("stream"); pseudo_doc = conf.count("pseudo_doc"); sent_approx = conf.count("sent_approx"); cerr << "Using pseudo-doc:" << pseudo_doc << " Sent:" << sent_approx << endl; if(pseudo_doc) mt_metric_scale=1; - const string weights_dir = conf["weights_output"].as<string>(); - const string output_dir = conf["output_dir"].as<string>(); + const string weights_dir = stream ? "-" : conf["weights_output"].as<string>(); + const string output_dir = stream ? "-" : conf["output_dir"].as<string>(); ScoreType type = ScoreTypeFromString(metric_name); //establish metric used for tuning @@ -636,16 +662,22 @@ int main(int argc, char** argv) { invert_score = false; } - //load references - DocScorer ds(type, conf["reference"].as<vector<string> >(), ""); - cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl; + shared_ptr<DocScorer> ds; + //normal: load references, stream: start stream scorer + if (stream) { + ds = shared_ptr<DocScorer>(new DocStreamScorer(type, vector<string>(0), "")); + cerr << "Scoring doc stream with " << metric_name << endl; + } else { + ds = shared_ptr<DocScorer>(new DocScorer(type, conf["reference"].as<vector<string> >(), "")); + cerr << "Loaded " << ds->size() << " references for scoring with " << metric_name << endl; + } vector<ScoreP> corpus_bleu_sent_stats; //check training pass,if >0, then use previous iterations corpus bleu stats - cur_pass = conf["pass"].as<int>(); + cur_pass = stream ? 0 : conf["pass"].as<int>(); if(cur_pass > 0) { - ReadPastTranslationForScore(cur_pass, &corpus_bleu_sent_stats, ds, output_dir); + ReadPastTranslationForScore(cur_pass, &corpus_bleu_sent_stats, *ds, output_dir); } cerr << "Using optimizer:" << optimizer << endl; @@ -659,7 +691,7 @@ int main(int argc, char** argv) { Weights::InitFromFile(conf["input_weights"].as<string>(), &dense_weights); Weights::InitSparseVector(dense_weights, &lambdas); - const string input = decoder.GetConf()["input"].as<string>(); + const string input = stream ? "-" : decoder.GetConf()["input"].as<string>(); if (!SILENT) cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl; ReadFile in_read(input); istream *in = in_read.stream(); @@ -668,9 +700,10 @@ int main(int argc, char** argv) { const double max_step_size = conf["max_step_size"].as<double>(); - vector<GoodBadOracle> oracles(ds.size()); + vector<GoodBadOracle> oracles(ds->size()); - TrainingObserver observer(conf["k_best_size"].as<int>(), ds, &oracles, &corpus_bleu_sent_stats); + BasicObserver bobs; + TrainingObserver observer(conf["k_best_size"].as<int>(), *ds, &oracles, &corpus_bleu_sent_stats); int cur_sent = 0; int lcount = 0; @@ -689,12 +722,30 @@ int main(int argc, char** argv) { while(*in) { getline(*in, buf); if (buf.empty()) continue; + if (stream) { + int delim = buf.find(" ||| "); + // Translate only + if (delim == -1) { + cur_sent = 0; + decoder.SetId(cur_sent); + decoder.Decode(buf, &bobs); + vector<WordID> trans; + ViterbiESentence(bobs.hypergraph[0], &trans); + cout << TD::GetString(trans) << endl; + continue; + // Translate and update (normal MIRA) + } else { + cur_sent = 1; + ds->update(buf.substr(delim + 5)); + buf = buf.substr(0, delim); + } + } //TODO: allow batch updating lambdas.init_vector(&dense_weights); dense_w_local = dense_weights; decoder.SetId(cur_sent); decoder.Decode(buf, &observer); // decode the sentence, calling Notify to get the hope,fear, and model best hyps. - + cur_sent = observer.GetCurrentSent(); cerr << "SENT: " << cur_sent << endl; const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis(); @@ -708,15 +759,15 @@ int main(int argc, char** argv) { tot_loss += cur_hyp.mt_metric; //score hyps to be able to compute corpus level bleu after we finish this iteration through the corpus - ScoreP sentscore = ds[cur_sent]->ScoreCandidate(cur_hyp.hyp); + ScoreP sentscore = (*ds)[cur_sent]->ScoreCandidate(cur_hyp.hyp); if (!acc) { acc = sentscore->GetZero(); } acc->PlusEquals(*sentscore); - ScoreP hope_sentscore = ds[cur_sent]->ScoreCandidate(cur_good.hyp); + ScoreP hope_sentscore = (*ds)[cur_sent]->ScoreCandidate(cur_good.hyp); if (!acc_h) { acc_h = hope_sentscore->GetZero(); } acc_h->PlusEquals(*hope_sentscore); - ScoreP fear_sentscore = ds[cur_sent]->ScoreCandidate(cur_bad.hyp); + ScoreP fear_sentscore = (*ds)[cur_sent]->ScoreCandidate(cur_bad.hyp); if (!acc_f) { acc_f = fear_sentscore->GetZero(); } acc_f->PlusEquals(*fear_sentscore); @@ -915,11 +966,11 @@ int main(int argc, char** argv) { } - if ((cur_sent * 40 / ds.size()) > dots) { ++dots; cerr << '.'; } + if ((cur_sent * 40 / ds->size()) > dots) { ++dots; cerr << '.'; } tot += lambdas; ++lcount; cur_sent++; - + cout << TD::GetString(cur_good_v[0]->hyp) << " ||| " << TD::GetString(cur_best_v[0]->hyp) << " ||| " << TD::GetString(cur_bad_v[0]->hyp) << endl; } @@ -929,24 +980,27 @@ int main(int argc, char** argv) { cerr << "Translated " << lcount << " sentences " << endl; cerr << " [AVG METRIC LAST PASS=" << (tot_loss / lcount) << "]\n"; tot_loss = 0; + + // Write weights unless streaming + if (!stream) { + int node_id = rng->next() * 100000; + cerr << " Writing weights to " << node_id << endl; + Weights::ShowLargestFeatures(dense_weights); + dots = 0; + ostringstream os; + os << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << ".gz"; + string msg = "# MIRA tuned weights ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount); + lambdas.init_vector(&dense_weights); + Weights::WriteToFile(os.str(), dense_weights, true, &msg); - int node_id = rng->next() * 100000; - cerr << " Writing weights to " << node_id << endl; - Weights::ShowLargestFeatures(dense_weights); - dots = 0; - ostringstream os; - os << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << ".gz"; - string msg = "# MIRA tuned weights ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount); - lambdas.init_vector(&dense_weights); - Weights::WriteToFile(os.str(), dense_weights, true, &msg); - - SparseVector<double> x = tot; - x /= lcount+1; - ostringstream sa; - string msga = "# MIRA tuned weights AVERAGED ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount); - sa << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << "-avg.gz"; - x.init_vector(&dense_weights); - Weights::WriteToFile(sa.str(), dense_weights, true, &msga); + SparseVector<double> x = tot; + x /= lcount+1; + ostringstream sa; + string msga = "# MIRA tuned weights AVERAGED ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount); + sa << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << "-avg.gz"; + x.init_vector(&dense_weights); + Weights::WriteToFile(sa.str(), dense_weights, true, &msga); + } cerr << "Optimization complete.\n"; return 0; |