summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authormjdenkowski <michael.j.denkowski@gmail.com>2013-08-28 18:07:42 -0400
committermjdenkowski <michael.j.denkowski@gmail.com>2013-08-28 18:07:42 -0400
commit0bc21f0fbcf5e060c1a9b249527e094436a383d8 (patch)
tree070c863b4a6734f2c8072e2cb610ac878a96bb31 /training
parentca9b58716214148eeaeaa3076e1a1dc8f8bb5892 (diff)
Stream support for MIRA (part of realtime)
Diffstat (limited to 'training')
-rw-r--r--training/mira/kbest_cut_mira.cc122
1 files changed, 88 insertions, 34 deletions
diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc
index 7df9a18f..d8c42db7 100644
--- a/training/mira/kbest_cut_mira.cc
+++ b/training/mira/kbest_cut_mira.cc
@@ -48,6 +48,7 @@ int hope_select;
bool pseudo_doc;
bool sent_approx;
bool checkloss;
+bool stream;
void SanityCheck(const vector<double>& w) {
for (int i = 0; i < w.size(); ++i) {
@@ -99,6 +100,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("k_best_size,k", po::value<int>()->default_value(250), "Size of hypothesis list to search for oracles")
("update_k_best,b", po::value<int>()->default_value(1), "Size of good, bad lists to perform update with")
("unique_k_best,u", "Unique k-best translation list")
+ ("stream,t", "Stream mode (used for realtime)")
("weights_output,O",po::value<string>(),"Directory to write weights to")
("output_dir,D",po::value<string>(),"Directory to place output in")
("decoder_config,c",po::value<string>(),"Decoder configuration file");
@@ -117,7 +119,11 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
}
po::notify(*conf);
- if (conf->count("help") || !conf->count("input_weights") || !conf->count("decoder_config") || !conf->count("reference")) {
+ if (conf->count("help")
+ || !conf->count("input_weights")
+ || !conf->count("decoder_config")
+ || (!conf->count("stream") && (!conf->count("reference") || !conf->count("weights_output") || !conf->count("output_dir")))
+ ) {
cerr << dcmdline_options << endl;
return false;
}
@@ -321,6 +327,25 @@ struct GoodBadOracle {
vector<shared_ptr<HypothesisInfo> > bad;
};
+struct BasicObserver: public DecoderObserver {
+ Hypergraph* hypergraph;
+ BasicObserver() : hypergraph(NULL) {}
+ ~BasicObserver() {
+ if(hypergraph != NULL) delete hypergraph;
+ }
+ void NotifyDecodingStart(const SentenceMetadata& smeta) {}
+ void NotifySourceParseFailure(const SentenceMetadata& smeta) {}
+ void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) {
+ if(hypergraph != NULL) delete hypergraph;
+ hypergraph = new Hypergraph(*hg);
+ }
+ void NotifyAlignmentFailure(const SentenceMetadata& semta) {
+ if(hypergraph != NULL) delete hypergraph;
+ }
+ void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) {}
+ void NotifyDecodingComplete(const SentenceMetadata& smeta) {}
+};
+
struct TrainingObserver : public DecoderObserver {
TrainingObserver(const int k, const DocScorer& d, vector<GoodBadOracle>* o, vector<ScoreP>* cbs) : ds(d), oracles(*o), corpus_bleu_sent_stats(*cbs), kbest_size(k) {
@@ -619,14 +644,15 @@ int main(int argc, char** argv) {
no_select = conf.count("no_select");
update_list_size = conf["update_k_best"].as<int>();
unique_kbest = conf.count("unique_k_best");
+ stream = conf.count("stream");
pseudo_doc = conf.count("pseudo_doc");
sent_approx = conf.count("sent_approx");
cerr << "Using pseudo-doc:" << pseudo_doc << " Sent:" << sent_approx << endl;
if(pseudo_doc)
mt_metric_scale=1;
- const string weights_dir = conf["weights_output"].as<string>();
- const string output_dir = conf["output_dir"].as<string>();
+ const string weights_dir = stream ? "-" : conf["weights_output"].as<string>();
+ const string output_dir = stream ? "-" : conf["output_dir"].as<string>();
ScoreType type = ScoreTypeFromString(metric_name);
//establish metric used for tuning
@@ -636,16 +662,22 @@ int main(int argc, char** argv) {
invert_score = false;
}
- //load references
- DocScorer ds(type, conf["reference"].as<vector<string> >(), "");
- cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl;
+ shared_ptr<DocScorer> ds;
+ //normal: load references, stream: start stream scorer
+ if (stream) {
+ ds = shared_ptr<DocScorer>(new DocStreamScorer(type, vector<string>(0), ""));
+ cerr << "Scoring doc stream with " << metric_name << endl;
+ } else {
+ ds = shared_ptr<DocScorer>(new DocScorer(type, conf["reference"].as<vector<string> >(), ""));
+ cerr << "Loaded " << ds->size() << " references for scoring with " << metric_name << endl;
+ }
vector<ScoreP> corpus_bleu_sent_stats;
//check training pass,if >0, then use previous iterations corpus bleu stats
- cur_pass = conf["pass"].as<int>();
+ cur_pass = stream ? 0 : conf["pass"].as<int>();
if(cur_pass > 0)
{
- ReadPastTranslationForScore(cur_pass, &corpus_bleu_sent_stats, ds, output_dir);
+ ReadPastTranslationForScore(cur_pass, &corpus_bleu_sent_stats, *ds, output_dir);
}
cerr << "Using optimizer:" << optimizer << endl;
@@ -659,7 +691,7 @@ int main(int argc, char** argv) {
Weights::InitFromFile(conf["input_weights"].as<string>(), &dense_weights);
Weights::InitSparseVector(dense_weights, &lambdas);
- const string input = decoder.GetConf()["input"].as<string>();
+ const string input = stream ? "-" : decoder.GetConf()["input"].as<string>();
if (!SILENT) cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl;
ReadFile in_read(input);
istream *in = in_read.stream();
@@ -668,9 +700,10 @@ int main(int argc, char** argv) {
const double max_step_size = conf["max_step_size"].as<double>();
- vector<GoodBadOracle> oracles(ds.size());
+ vector<GoodBadOracle> oracles(ds->size());
- TrainingObserver observer(conf["k_best_size"].as<int>(), ds, &oracles, &corpus_bleu_sent_stats);
+ BasicObserver bobs;
+ TrainingObserver observer(conf["k_best_size"].as<int>(), *ds, &oracles, &corpus_bleu_sent_stats);
int cur_sent = 0;
int lcount = 0;
@@ -689,12 +722,30 @@ int main(int argc, char** argv) {
while(*in) {
getline(*in, buf);
if (buf.empty()) continue;
+ if (stream) {
+ int delim = buf.find(" ||| ");
+ // Translate only
+ if (delim == -1) {
+ cur_sent = 0;
+ decoder.SetId(cur_sent);
+ decoder.Decode(buf, &bobs);
+ vector<WordID> trans;
+ ViterbiESentence(bobs.hypergraph[0], &trans);
+ cout << TD::GetString(trans) << endl;
+ continue;
+ // Translate and update (normal MIRA)
+ } else {
+ cur_sent = 1;
+ ds->update(buf.substr(delim + 5));
+ buf = buf.substr(0, delim);
+ }
+ }
//TODO: allow batch updating
lambdas.init_vector(&dense_weights);
dense_w_local = dense_weights;
decoder.SetId(cur_sent);
decoder.Decode(buf, &observer); // decode the sentence, calling Notify to get the hope,fear, and model best hyps.
-
+
cur_sent = observer.GetCurrentSent();
cerr << "SENT: " << cur_sent << endl;
const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis();
@@ -708,15 +759,15 @@ int main(int argc, char** argv) {
tot_loss += cur_hyp.mt_metric;
//score hyps to be able to compute corpus level bleu after we finish this iteration through the corpus
- ScoreP sentscore = ds[cur_sent]->ScoreCandidate(cur_hyp.hyp);
+ ScoreP sentscore = (*ds)[cur_sent]->ScoreCandidate(cur_hyp.hyp);
if (!acc) { acc = sentscore->GetZero(); }
acc->PlusEquals(*sentscore);
- ScoreP hope_sentscore = ds[cur_sent]->ScoreCandidate(cur_good.hyp);
+ ScoreP hope_sentscore = (*ds)[cur_sent]->ScoreCandidate(cur_good.hyp);
if (!acc_h) { acc_h = hope_sentscore->GetZero(); }
acc_h->PlusEquals(*hope_sentscore);
- ScoreP fear_sentscore = ds[cur_sent]->ScoreCandidate(cur_bad.hyp);
+ ScoreP fear_sentscore = (*ds)[cur_sent]->ScoreCandidate(cur_bad.hyp);
if (!acc_f) { acc_f = fear_sentscore->GetZero(); }
acc_f->PlusEquals(*fear_sentscore);
@@ -915,11 +966,11 @@ int main(int argc, char** argv) {
}
- if ((cur_sent * 40 / ds.size()) > dots) { ++dots; cerr << '.'; }
+ if ((cur_sent * 40 / ds->size()) > dots) { ++dots; cerr << '.'; }
tot += lambdas;
++lcount;
cur_sent++;
-
+
cout << TD::GetString(cur_good_v[0]->hyp) << " ||| " << TD::GetString(cur_best_v[0]->hyp) << " ||| " << TD::GetString(cur_bad_v[0]->hyp) << endl;
}
@@ -929,24 +980,27 @@ int main(int argc, char** argv) {
cerr << "Translated " << lcount << " sentences " << endl;
cerr << " [AVG METRIC LAST PASS=" << (tot_loss / lcount) << "]\n";
tot_loss = 0;
+
+ // Write weights unless streaming
+ if (!stream) {
+ int node_id = rng->next() * 100000;
+ cerr << " Writing weights to " << node_id << endl;
+ Weights::ShowLargestFeatures(dense_weights);
+ dots = 0;
+ ostringstream os;
+ os << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << ".gz";
+ string msg = "# MIRA tuned weights ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount);
+ lambdas.init_vector(&dense_weights);
+ Weights::WriteToFile(os.str(), dense_weights, true, &msg);
- int node_id = rng->next() * 100000;
- cerr << " Writing weights to " << node_id << endl;
- Weights::ShowLargestFeatures(dense_weights);
- dots = 0;
- ostringstream os;
- os << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << ".gz";
- string msg = "# MIRA tuned weights ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount);
- lambdas.init_vector(&dense_weights);
- Weights::WriteToFile(os.str(), dense_weights, true, &msg);
-
- SparseVector<double> x = tot;
- x /= lcount+1;
- ostringstream sa;
- string msga = "# MIRA tuned weights AVERAGED ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount);
- sa << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << "-avg.gz";
- x.init_vector(&dense_weights);
- Weights::WriteToFile(sa.str(), dense_weights, true, &msga);
+ SparseVector<double> x = tot;
+ x /= lcount+1;
+ ostringstream sa;
+ string msga = "# MIRA tuned weights AVERAGED ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount);
+ sa << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << "-avg.gz";
+ x.init_vector(&dense_weights);
+ Weights::WriteToFile(sa.str(), dense_weights, true, &msga);
+ }
cerr << "Optimization complete.\n";
return 0;