diff options
Diffstat (limited to 'training/mira')
| -rw-r--r-- | training/mira/kbest_cut_mira.cc | 122 | 
1 files changed, 88 insertions, 34 deletions
| diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc index 7df9a18f..d8c42db7 100644 --- a/training/mira/kbest_cut_mira.cc +++ b/training/mira/kbest_cut_mira.cc @@ -48,6 +48,7 @@ int hope_select;  bool pseudo_doc;  bool sent_approx;  bool checkloss; +bool stream;  void SanityCheck(const vector<double>& w) {    for (int i = 0; i < w.size(); ++i) { @@ -99,6 +100,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {      ("k_best_size,k", po::value<int>()->default_value(250), "Size of hypothesis list to search for oracles")      ("update_k_best,b", po::value<int>()->default_value(1), "Size of good, bad lists to perform update with")      ("unique_k_best,u", "Unique k-best translation list") +    ("stream,t", "Stream mode (used for realtime)")      ("weights_output,O",po::value<string>(),"Directory to write weights to")      ("output_dir,D",po::value<string>(),"Directory to place output in")      ("decoder_config,c",po::value<string>(),"Decoder configuration file"); @@ -117,7 +119,11 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {    }    po::notify(*conf); -  if (conf->count("help") || !conf->count("input_weights") || !conf->count("decoder_config") || !conf->count("reference")) { +  if (conf->count("help") +          || !conf->count("input_weights") +          || !conf->count("decoder_config") +          || (!conf->count("stream") && (!conf->count("reference") || !conf->count("weights_output") || !conf->count("output_dir"))) +          ) {      cerr << dcmdline_options << endl;      return false;    } @@ -321,6 +327,25 @@ struct GoodBadOracle {    vector<shared_ptr<HypothesisInfo> > bad;  }; +struct BasicObserver: public DecoderObserver { +    Hypergraph* hypergraph; +    BasicObserver() : hypergraph(NULL) {} +    ~BasicObserver() { +        if(hypergraph != NULL) delete hypergraph; +    } +    void NotifyDecodingStart(const SentenceMetadata& smeta) {} +    void NotifySourceParseFailure(const SentenceMetadata& smeta) {} +    void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { +        if(hypergraph != NULL) delete hypergraph; +        hypergraph = new Hypergraph(*hg); +    } +    void NotifyAlignmentFailure(const SentenceMetadata& semta) { +        if(hypergraph != NULL) delete hypergraph; +    } +    void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) {} +    void NotifyDecodingComplete(const SentenceMetadata& smeta) {} +}; +  struct TrainingObserver : public DecoderObserver {    TrainingObserver(const int k, const DocScorer& d, vector<GoodBadOracle>* o, vector<ScoreP>* cbs) : ds(d), oracles(*o), corpus_bleu_sent_stats(*cbs), kbest_size(k) { @@ -619,14 +644,15 @@ int main(int argc, char** argv) {    no_select = conf.count("no_select");    update_list_size = conf["update_k_best"].as<int>();    unique_kbest = conf.count("unique_k_best"); +  stream = conf.count("stream");    pseudo_doc = conf.count("pseudo_doc");    sent_approx = conf.count("sent_approx");    cerr << "Using pseudo-doc:" << pseudo_doc << " Sent:" << sent_approx << endl;    if(pseudo_doc)      mt_metric_scale=1; -  const string weights_dir = conf["weights_output"].as<string>(); -  const string output_dir = conf["output_dir"].as<string>(); +  const string weights_dir = stream ? "-" : conf["weights_output"].as<string>(); +  const string output_dir = stream ? "-" : conf["output_dir"].as<string>();    ScoreType type = ScoreTypeFromString(metric_name);    //establish metric used for tuning @@ -636,16 +662,22 @@ int main(int argc, char** argv) {      invert_score = false;    } -  //load references -  DocScorer ds(type, conf["reference"].as<vector<string> >(), ""); -  cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl; +  shared_ptr<DocScorer> ds; +  //normal: load references, stream: start stream scorer +  if (stream) { +	  ds = shared_ptr<DocScorer>(new DocStreamScorer(type, vector<string>(0), "")); +	  cerr << "Scoring doc stream with " << metric_name << endl; +  } else { +      ds = shared_ptr<DocScorer>(new DocScorer(type, conf["reference"].as<vector<string> >(), "")); +      cerr << "Loaded " << ds->size() << " references for scoring with " << metric_name << endl; +  }    vector<ScoreP> corpus_bleu_sent_stats;    //check training pass,if >0, then use previous iterations corpus bleu stats -  cur_pass = conf["pass"].as<int>(); +  cur_pass = stream ? 0 : conf["pass"].as<int>();    if(cur_pass > 0)      { -      ReadPastTranslationForScore(cur_pass, &corpus_bleu_sent_stats, ds, output_dir); +      ReadPastTranslationForScore(cur_pass, &corpus_bleu_sent_stats, *ds, output_dir);      }    cerr << "Using optimizer:" << optimizer << endl; @@ -659,7 +691,7 @@ int main(int argc, char** argv) {    Weights::InitFromFile(conf["input_weights"].as<string>(), &dense_weights);    Weights::InitSparseVector(dense_weights, &lambdas); -  const string input = decoder.GetConf()["input"].as<string>(); +  const string input = stream ? "-" : decoder.GetConf()["input"].as<string>();    if (!SILENT) cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl;    ReadFile in_read(input);    istream *in = in_read.stream(); @@ -668,9 +700,10 @@ int main(int argc, char** argv) {    const double max_step_size = conf["max_step_size"].as<double>(); -  vector<GoodBadOracle> oracles(ds.size()); +  vector<GoodBadOracle> oracles(ds->size()); -  TrainingObserver observer(conf["k_best_size"].as<int>(), ds, &oracles, &corpus_bleu_sent_stats); +  BasicObserver bobs; +  TrainingObserver observer(conf["k_best_size"].as<int>(), *ds, &oracles, &corpus_bleu_sent_stats);    int cur_sent = 0;    int lcount = 0; @@ -689,12 +722,30 @@ int main(int argc, char** argv) {    while(*in) {        getline(*in, buf);        if (buf.empty()) continue; +      if (stream) { +    	  int delim = buf.find(" ||| "); +    	  // Translate only +    	  if (delim == -1) { +    		  cur_sent = 0; +    		  decoder.SetId(cur_sent); +    		  decoder.Decode(buf, &bobs); +    		  vector<WordID> trans; +    		  ViterbiESentence(bobs.hypergraph[0], &trans); +    		  cout << TD::GetString(trans) << endl; +    		  continue; +    	  // Translate and update (normal MIRA) +    	  } else { +    		  cur_sent = 1; +    		  ds->update(buf.substr(delim + 5)); +    		  buf = buf.substr(0, delim); +    	  } +      }        //TODO: allow batch updating        lambdas.init_vector(&dense_weights);        dense_w_local = dense_weights;        decoder.SetId(cur_sent);        decoder.Decode(buf, &observer);  // decode the sentence, calling Notify to get the hope,fear, and model best hyps.  -       +        cur_sent = observer.GetCurrentSent();        cerr << "SENT: " << cur_sent << endl;        const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis(); @@ -708,15 +759,15 @@ int main(int argc, char** argv) {        tot_loss += cur_hyp.mt_metric;        //score hyps to be able to compute corpus level bleu after we finish this iteration through the corpus -      ScoreP sentscore = ds[cur_sent]->ScoreCandidate(cur_hyp.hyp); +      ScoreP sentscore = (*ds)[cur_sent]->ScoreCandidate(cur_hyp.hyp);        if (!acc) { acc = sentscore->GetZero(); }        acc->PlusEquals(*sentscore); -      ScoreP hope_sentscore = ds[cur_sent]->ScoreCandidate(cur_good.hyp); +      ScoreP hope_sentscore = (*ds)[cur_sent]->ScoreCandidate(cur_good.hyp);        if (!acc_h) { acc_h = hope_sentscore->GetZero(); }        acc_h->PlusEquals(*hope_sentscore); -      ScoreP fear_sentscore = ds[cur_sent]->ScoreCandidate(cur_bad.hyp); +      ScoreP fear_sentscore = (*ds)[cur_sent]->ScoreCandidate(cur_bad.hyp);        if (!acc_f) { acc_f = fear_sentscore->GetZero(); }        acc_f->PlusEquals(*fear_sentscore); @@ -915,11 +966,11 @@ int main(int argc, char** argv) {  	  } -      if ((cur_sent * 40 / ds.size()) > dots) { ++dots; cerr << '.'; } +      if ((cur_sent * 40 / ds->size()) > dots) { ++dots; cerr << '.'; }        tot += lambdas;        ++lcount;        cur_sent++; -       +        cout << TD::GetString(cur_good_v[0]->hyp) << " ||| " << TD::GetString(cur_best_v[0]->hyp) << " ||| " << TD::GetString(cur_bad_v[0]->hyp) << endl;      } @@ -929,24 +980,27 @@ int main(int argc, char** argv) {      cerr << "Translated " << lcount << " sentences " << endl;      cerr << " [AVG METRIC LAST PASS=" << (tot_loss / lcount) << "]\n";      tot_loss = 0; + +    // Write weights unless streaming +    if (!stream) { +		int node_id = rng->next() * 100000; +		cerr << " Writing weights to " << node_id << endl; +		Weights::ShowLargestFeatures(dense_weights); +		dots = 0; +		ostringstream os; +		os << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << ".gz"; +		string msg = "# MIRA tuned weights ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount); +		lambdas.init_vector(&dense_weights); +		Weights::WriteToFile(os.str(), dense_weights, true, &msg); -    int node_id = rng->next() * 100000; -    cerr << " Writing weights to " << node_id << endl; -    Weights::ShowLargestFeatures(dense_weights); -    dots = 0; -    ostringstream os; -    os << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << ".gz"; -    string msg = "# MIRA tuned weights ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount); -    lambdas.init_vector(&dense_weights); -    Weights::WriteToFile(os.str(), dense_weights, true, &msg); - -    SparseVector<double> x = tot; -    x /= lcount+1; -    ostringstream sa; -    string msga = "# MIRA tuned weights AVERAGED ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount); -    sa << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << "-avg.gz"; -    x.init_vector(&dense_weights); -    Weights::WriteToFile(sa.str(), dense_weights, true, &msga); +		SparseVector<double> x = tot; +		x /= lcount+1; +		ostringstream sa; +		string msga = "# MIRA tuned weights AVERAGED ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount); +		sa << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << "-avg.gz"; +		x.init_vector(&dense_weights); +		Weights::WriteToFile(sa.str(), dense_weights, true, &msga); +    }      cerr << "Optimization complete.\n";      return 0; | 
