summaryrefslogtreecommitdiff
path: root/decoder/cdec.cc
diff options
context:
space:
mode:
authorgraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-15 04:05:17 +0000
committergraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-15 04:05:17 +0000
commit03b8a687ff9537a75931d79657d9bb57aefb4bc6 (patch)
treef87c2ee29c2a590085d1dc395cc49f13f7089e31 /decoder/cdec.cc
parent289873e898104ed56d88f819cc5559b69d3f1f2d (diff)
move kbest to oracle_bleu
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@262 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder/cdec.cc')
-rw-r--r--decoder/cdec.cc54
1 files changed, 5 insertions, 49 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index bec342ef..e616f1bb 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -261,51 +261,6 @@ void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) {
}
}
-// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
-void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const&kbest_out_filename_, float doc_src_length, float tmp_src_length, const DocScorer &ds, shared_ptr<Score> doc_score) {
-cerr << "In kbest\n";
-
- ofstream kbest_out;
- kbest_out.open(kbest_out_filename_.c_str());
- cerr << "Output kbest to " << kbest_out_filename_;
-
- //add length (f side) src length of this sentence to the psuedo-doc src length count
- float curr_src_length = doc_src_length + tmp_src_length;
-
- if (unique) {
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);
- for (int i = 0; i < k; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =
- kbest.LazyKthBest(forest.nodes_.size() - 1, i);
- if (!d) break;
- //calculate score in context of psuedo-doc
- Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield);
- sentscore->PlusEquals(*doc_score,float(1));
- float bleu = curr_src_length * sentscore->ComputeScore();
- kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- << d->feature_values << " ||| " << log(d->score) << " ||| " << bleu << endl;
- // cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- // << d->feature_values << " ||| " << log(d->score) << endl;
- }
- } else {
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k);
- for (int i = 0; i < k; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
- kbest.LazyKthBest(forest.nodes_.size() - 1, i);
- if (!d) break;
- cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- << d->feature_values << " ||| " << log(d->score) << endl;
- }
- }
-}
-
-void DumpKBest(po::variables_map const& conf,string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, float doc_src_length, float tmp_src_length, const DocScorer &ds, shared_ptr<Score> doc_score)
-{
- ostringstream kbest_string_stream;
- kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_"<<suffix<< "." << sent_id;
- DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), doc_src_length, tmp_src_length, ds, doc_score);
-
-}
struct ELengthWeightFunction {
@@ -650,7 +605,7 @@ int main(int argc, char** argv) {
if(get_oracle_forest) {
Timer t("Forest Oracle rescoring:");
- DumpKBest(conf,"model",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length, oracle.ds, oracle.doc_score);
+ oracle.DumpKBest(conf,"model",sent_id, forest, 10, true);
Translation best(forest);
{
@@ -664,14 +619,14 @@ int main(int argc, char** argv) {
cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
oracle_trans.Print(cerr," +Oracle BLEU");
//compute kbest for oracle
- DumpKBest(conf,"oracle",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length, oracle.ds, oracle.doc_score);
+ oracle.DumpKBest(conf,"oracle",sent_id, forest, 10, true);
//reweight the model with -1 for the BLEU feature to compute k-best list for negative examples
oracle.ReweightBleu(&forest,-1.0);
Translation neg_trans(forest);
neg_trans.Print(cerr," -Oracle BLEU");
//compute kbest for negative
- DumpKBest(conf,"negative",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length,oracle.ds,oracle.doc_score);
+ oracle.DumpKBest(conf,"negative",sent_id, forest, 10, true);
//Add 1-best translation (trans) to psuedo-doc vectors
oracle.IncludeLastScore(&cerr);
@@ -701,7 +656,8 @@ int main(int argc, char** argv) {
MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0);
} else {
if (kbest) {
- DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"", oracle.doc_src_length,oracle.tmp_src_length, oracle.ds,oracle.doc_score);
+ //TODO: does this work properly?
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"");
} else if (csplit_output_plf) {
cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {