summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-15 04:05:17 +0000
committergraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-15 04:05:17 +0000
commit5660d720dabcc010a402129139499a8f3d3130e7 (patch)
treee884ca9a6489afb7f38436579127120a17d7479b
parent20ab64d519569d09f9e286425cdcd7ecac236bf2 (diff)
move kbest to oracle_bleu
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@262 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r--decoder/cdec.cc54
-rwxr-xr-xdecoder/oracle_bleu.h50
2 files changed, 55 insertions, 49 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index bec342ef..e616f1bb 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -261,51 +261,6 @@ void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) {
}
}
-// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
-void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const&kbest_out_filename_, float doc_src_length, float tmp_src_length, const DocScorer &ds, shared_ptr<Score> doc_score) {
-cerr << "In kbest\n";
-
- ofstream kbest_out;
- kbest_out.open(kbest_out_filename_.c_str());
- cerr << "Output kbest to " << kbest_out_filename_;
-
- //add length (f side) src length of this sentence to the psuedo-doc src length count
- float curr_src_length = doc_src_length + tmp_src_length;
-
- if (unique) {
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);
- for (int i = 0; i < k; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =
- kbest.LazyKthBest(forest.nodes_.size() - 1, i);
- if (!d) break;
- //calculate score in context of psuedo-doc
- Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield);
- sentscore->PlusEquals(*doc_score,float(1));
- float bleu = curr_src_length * sentscore->ComputeScore();
- kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- << d->feature_values << " ||| " << log(d->score) << " ||| " << bleu << endl;
- // cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- // << d->feature_values << " ||| " << log(d->score) << endl;
- }
- } else {
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k);
- for (int i = 0; i < k; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
- kbest.LazyKthBest(forest.nodes_.size() - 1, i);
- if (!d) break;
- cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- << d->feature_values << " ||| " << log(d->score) << endl;
- }
- }
-}
-
-void DumpKBest(po::variables_map const& conf,string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, float doc_src_length, float tmp_src_length, const DocScorer &ds, shared_ptr<Score> doc_score)
-{
- ostringstream kbest_string_stream;
- kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_"<<suffix<< "." << sent_id;
- DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), doc_src_length, tmp_src_length, ds, doc_score);
-
-}
struct ELengthWeightFunction {
@@ -650,7 +605,7 @@ int main(int argc, char** argv) {
if(get_oracle_forest) {
Timer t("Forest Oracle rescoring:");
- DumpKBest(conf,"model",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length, oracle.ds, oracle.doc_score);
+ oracle.DumpKBest(conf,"model",sent_id, forest, 10, true);
Translation best(forest);
{
@@ -664,14 +619,14 @@ int main(int argc, char** argv) {
cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
oracle_trans.Print(cerr," +Oracle BLEU");
//compute kbest for oracle
- DumpKBest(conf,"oracle",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length, oracle.ds, oracle.doc_score);
+ oracle.DumpKBest(conf,"oracle",sent_id, forest, 10, true);
//reweight the model with -1 for the BLEU feature to compute k-best list for negative examples
oracle.ReweightBleu(&forest,-1.0);
Translation neg_trans(forest);
neg_trans.Print(cerr," -Oracle BLEU");
//compute kbest for negative
- DumpKBest(conf,"negative",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length,oracle.ds,oracle.doc_score);
+ oracle.DumpKBest(conf,"negative",sent_id, forest, 10, true);
//Add 1-best translation (trans) to psuedo-doc vectors
oracle.IncludeLastScore(&cerr);
@@ -701,7 +656,8 @@ int main(int argc, char** argv) {
MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0);
} else {
if (kbest) {
- DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"", oracle.doc_src_length,oracle.tmp_src_length, oracle.ds,oracle.doc_score);
+ //TODO: does this work properly?
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"");
} else if (csplit_output_plf) {
cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 273e14b8..32525466 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -1,6 +1,7 @@
#ifndef ORACLE_BLEU_H
#define ORACLE_BLEU_H
+#include <sstream>
#include <iostream>
#include <string>
#include <vector>
@@ -14,6 +15,7 @@
#include "viterbi.h"
#include "sentence_metadata.h"
#include "apply_models.h"
+#include "kbest.h"
//TODO: put function impls into .cc
//TODO: disentangle
@@ -148,6 +150,54 @@ struct OracleBleu {
// dest_forest->SortInEdgesByEdgeWeights();
}
+// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
+ void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_) {
+ using namespace std;
+ using namespace boost;
+ cerr << "In kbest\n";
+
+ ofstream kbest_out;
+ kbest_out.open(kbest_out_filename_.c_str());
+ cerr << "Output kbest to " << kbest_out_filename_;
+
+ //add length (f side) src length of this sentence to the psuedo-doc src length count
+ float curr_src_length = doc_src_length + tmp_src_length;
+
+ if (unique) {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);
+ for (int i = 0; i < k; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =
+ kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+ if (!d) break;
+ //calculate score in context of psuedo-doc
+ Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield);
+ sentscore->PlusEquals(*doc_score,float(1));
+ float bleu = curr_src_length * sentscore->ComputeScore();
+ kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
+ << d->feature_values << " ||| " << log(d->score) << " ||| " << bleu << endl;
+ // cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
+ // << d->feature_values << " ||| " << log(d->score) << endl;
+ }
+ } else {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k);
+ for (int i = 0; i < k; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+ if (!d) break;
+ cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
+ << d->feature_values << " ||| " << log(d->score) << endl;
+ }
+ }
+ }
+
+ void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique)
+ {
+ std::ostringstream kbest_string_stream;
+ kbest_string_stream << conf["forest_output"].as<std::string>() << "/kbest_"<<suffix<< "." << sent_id;
+ DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str());
+
+ }
+
};