diff options
author | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-15 04:05:17 +0000 |
---|---|---|
committer | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-15 04:05:17 +0000 |
commit | 5660d720dabcc010a402129139499a8f3d3130e7 (patch) | |
tree | e884ca9a6489afb7f38436579127120a17d7479b /decoder | |
parent | 20ab64d519569d09f9e286425cdcd7ecac236bf2 (diff) |
move kbest to oracle_bleu
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@262 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/cdec.cc | 54 | ||||
-rwxr-xr-x | decoder/oracle_bleu.h | 50 |
2 files changed, 55 insertions, 49 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc index bec342ef..e616f1bb 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -261,51 +261,6 @@ void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) { } } -// TODO decoder output should probably be moved to another file - how about oracle_bleu.h -void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const&kbest_out_filename_, float doc_src_length, float tmp_src_length, const DocScorer &ds, shared_ptr<Score> doc_score) { -cerr << "In kbest\n"; - - ofstream kbest_out; - kbest_out.open(kbest_out_filename_.c_str()); - cerr << "Output kbest to " << kbest_out_filename_; - - //add length (f side) src length of this sentence to the psuedo-doc src length count - float curr_src_length = doc_src_length + tmp_src_length; - - if (unique) { - KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k); - for (int i = 0; i < k; ++i) { - const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); - if (!d) break; - //calculate score in context of psuedo-doc - Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield); - sentscore->PlusEquals(*doc_score,float(1)); - float bleu = curr_src_length * sentscore->ComputeScore(); - kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " - << d->feature_values << " ||| " << log(d->score) << " ||| " << bleu << endl; - // cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " - // << d->feature_values << " ||| " << log(d->score) << endl; - } - } else { - KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k); - for (int i = 0; i < k; ++i) { - const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); - if (!d) break; - cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " - << d->feature_values << " ||| " << log(d->score) << endl; - } - } -} - -void DumpKBest(po::variables_map const& conf,string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, float doc_src_length, float tmp_src_length, const DocScorer &ds, shared_ptr<Score> doc_score) -{ - ostringstream kbest_string_stream; - kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_"<<suffix<< "." << sent_id; - DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), doc_src_length, tmp_src_length, ds, doc_score); - -} struct ELengthWeightFunction { @@ -650,7 +605,7 @@ int main(int argc, char** argv) { if(get_oracle_forest) { Timer t("Forest Oracle rescoring:"); - DumpKBest(conf,"model",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length, oracle.ds, oracle.doc_score); + oracle.DumpKBest(conf,"model",sent_id, forest, 10, true); Translation best(forest); { @@ -664,14 +619,14 @@ int main(int argc, char** argv) { cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; oracle_trans.Print(cerr," +Oracle BLEU"); //compute kbest for oracle - DumpKBest(conf,"oracle",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length, oracle.ds, oracle.doc_score); + oracle.DumpKBest(conf,"oracle",sent_id, forest, 10, true); //reweight the model with -1 for the BLEU feature to compute k-best list for negative examples oracle.ReweightBleu(&forest,-1.0); Translation neg_trans(forest); neg_trans.Print(cerr," -Oracle BLEU"); //compute kbest for negative - DumpKBest(conf,"negative",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length,oracle.ds,oracle.doc_score); + oracle.DumpKBest(conf,"negative",sent_id, forest, 10, true); //Add 1-best translation (trans) to psuedo-doc vectors oracle.IncludeLastScore(&cerr); @@ -701,7 +656,8 @@ int main(int argc, char** argv) { MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0); } else { if (kbest) { - DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"", oracle.doc_src_length,oracle.tmp_src_length, oracle.ds,oracle.doc_score); + //TODO: does this work properly? + oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,""); } else if (csplit_output_plf) { cout << HypergraphIO::AsPLF(forest, false) << endl; } else { diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 273e14b8..32525466 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -1,6 +1,7 @@ #ifndef ORACLE_BLEU_H #define ORACLE_BLEU_H +#include <sstream> #include <iostream> #include <string> #include <vector> @@ -14,6 +15,7 @@ #include "viterbi.h" #include "sentence_metadata.h" #include "apply_models.h" +#include "kbest.h" //TODO: put function impls into .cc //TODO: disentangle @@ -148,6 +150,54 @@ struct OracleBleu { // dest_forest->SortInEdgesByEdgeWeights(); } +// TODO decoder output should probably be moved to another file - how about oracle_bleu.h + void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_) { + using namespace std; + using namespace boost; + cerr << "In kbest\n"; + + ofstream kbest_out; + kbest_out.open(kbest_out_filename_.c_str()); + cerr << "Output kbest to " << kbest_out_filename_; + + //add length (f side) src length of this sentence to the psuedo-doc src length count + float curr_src_length = doc_src_length + tmp_src_length; + + if (unique) { + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + //calculate score in context of psuedo-doc + Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield); + sentscore->PlusEquals(*doc_score,float(1)); + float bleu = curr_src_length * sentscore->ComputeScore(); + kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " + << d->feature_values << " ||| " << log(d->score) << " ||| " << bleu << endl; + // cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " + // << d->feature_values << " ||| " << log(d->score) << endl; + } + } else { + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " + << d->feature_values << " ||| " << log(d->score) << endl; + } + } + } + + void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique) + { + std::ostringstream kbest_string_stream; + kbest_string_stream << conf["forest_output"].as<std::string>() << "/kbest_"<<suffix<< "." << sent_id; + DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str()); + + } + }; |