summaryrefslogtreecommitdiff
path: root/decoder/oracle_bleu.h
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/oracle_bleu.h')
-rwxr-xr-xdecoder/oracle_bleu.h59
1 files changed, 31 insertions, 28 deletions
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 4dc86bc7..4a2cbbe5 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -94,6 +94,7 @@ struct OracleBleu {
("references,R", value<Refs >(&refs), "Translation reference files")
("oracle_loss", value<string>(&loss_name)->default_value("IBM_BLEU_3"), "IBM_BLEU_3 (default), IBM_BLEU etc")
("bleu_weight", value<double>(&bleu_weight)->default_value(1.), "weight to give the hope/fear loss function vs. model score")
+ ("show_derivation", bool_switch(&show_derivation), "show derivation tree in kbest")
("verbose",bool_switch(&verbose),"detailed logs")
;
}
@@ -248,46 +249,48 @@ struct OracleBleu {
// dest_forest->SortInEdgesByEdgeWeights();
}
-// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
- void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_) {
+ bool show_derivation;
+ template <class Filter>
+ void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) {
using namespace std;
using namespace boost;
- cerr << "In kbest\n";
-
- ofstream kbest_out;
- kbest_out.open(kbest_out_filename_.c_str());
- cerr << "Output kbest to " << kbest_out_filename_;
-
+ typedef KBest::KBestDerivations<Sentence, ESentenceTraversal,Filter> K;
+ K kbest(forest,k);
//add length (f side) src length of this sentence to the psuedo-doc src length count
float curr_src_length = doc_src_length + tmp_src_length;
-
- if (unique) {
- KBest::KBestDerivations<Sentence, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);
- for (int i = 0; i < k; ++i) {
- const KBest::KBestDerivations<Sentence, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =
- kbest.LazyKthBest(forest.nodes_.size() - 1, i);
- if (!d) break;
- //calculate score in context of psuedo-doc
+ for (int i = 0; i < k; ++i) {
+ typename K::Derivation *d = kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+ if (!d) break;
+ kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
+ << d->feature_values << " ||| " << log(d->score);
+ if (!refs.empty()) {
ScoreP sentscore = GetScore(d->yield,sent_id);
sentscore->PlusEquals(*doc_score,float(1));
float bleu = curr_src_length * sentscore->ComputeScore();
- kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- << d->feature_values << " ||| " << log(d->score) << " ||| " << bleu << endl;
- // cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- // << d->feature_values << " ||| " << log(d->score) << endl;
+ kbest_out << " ||| " << bleu;
}
- } else {
- KBest::KBestDerivations<Sentence, ESentenceTraversal> kbest(forest, k);
- for (int i = 0; i < k; ++i) {
- const KBest::KBestDerivations<Sentence, ESentenceTraversal>::Derivation* d =
- kbest.LazyKthBest(forest.nodes_.size() - 1, i);
- if (!d) break;
- cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- << d->feature_values << " ||| " << log(d->score) << endl;
+ kbest_out<<endl<<flush;
+ if (show_derivation) {
+ deriv_out<<"\nsent_id="<<sent_id<<"\n";
+ forest.show_tree(cerr,*d->edge);
+ deriv_out<<flush;
}
}
}
+// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
+ void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_) {
+
+ WriteFile ko(kbest_out_filename_);
+ std::cerr << "Output kbest to " << kbest_out_filename_;
+
+ if (unique)
+ kbest<KBest::NoFilter>(sent_id,forest,k,ko.get(),std::cerr);
+ else {
+ kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),std::cerr);
+ }
+ }
+
void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const& forest_output)
{
std::ostringstream kbest_string_stream;