summaryrefslogtreecommitdiff
path: root/decoder/cdec.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/cdec.cc')
-rw-r--r--decoder/cdec.cc154
1 files changed, 146 insertions, 8 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index b6cc6f66..5f06b0c8 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -32,6 +32,7 @@
#include "inside_outside.h"
#include "exp_semiring.h"
#include "sentence_metadata.h"
+#include "../vest/scorer.h"
using namespace std;
using namespace std::tr1;
@@ -143,7 +144,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* confp) {
("pb_max_distortion,D", po::value<int>()->default_value(4), "Phrase-based decoder: maximum distortion")
("cll_gradient,G","Compute conditional log-likelihood gradient and write to STDOUT (src & ref required)")
("crf_uniform_empirical", "If there are multple references use (i.e., lattice) a uniform distribution rather than posterior weighting a la EM")
- ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
+ ("get_oracle_forest,OO", "Calculate rescored hypregraph using approximate BLEU scoring of rules")
+ ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
+ ("references,R", po::value<vector<string> >(), "Translation reference files")
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
("forest_output,O",po::value<string>(),"Directory to write forests to")
@@ -258,16 +261,30 @@ void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) {
}
// TODO decoder output should probably be moved to another file
-void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique) {
+void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, const char *kbest_out_filename_, float doc_src_length, float tmp_src_length, const DocScorer &ds, Score* doc_score) {
cerr << "In kbest\n";
+
+ ofstream kbest_out;
+ kbest_out.open(kbest_out_filename_);
+ cerr << "Output kbest to " << kbest_out_filename_;
+
+ //add length (f side) src length of this sentence to the psuedo-doc src length count
+ float curr_src_length = doc_src_length + tmp_src_length;
+
if (unique) {
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);
for (int i = 0; i < k; ++i) {
const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =
kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
- cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- << d->feature_values << " ||| " << log(d->score) << endl;
+ //calculate score in context of psuedo-doc
+ Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield);
+ sentscore->PlusEquals(*doc_score,float(1));
+ float bleu = curr_src_length * sentscore->ComputeScore();
+ kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
+ << d->feature_values << " ||| " << log(d->score) << " ||| " << bleu << endl;
+ // cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
+ // << d->feature_values << " ||| " << log(d->score) << endl;
}
} else {
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k);
@@ -498,6 +515,48 @@ int main(int argc, char** argv) {
const bool kbest = conf.count("k_best");
const bool unique_kbest = conf.count("unique_k_best");
const bool crf_uniform_empirical = conf.count("crf_uniform_empirical");
+ const bool get_oracle_forest = conf.count("get_oracle_forest");
+
+ /*Oracle Extraction Prep*/
+ vector<const FeatureFunction*> oracle_model_ffs;
+ vector<double> oracle_feature_weights;
+ shared_ptr<FeatureFunction> oracle_pff;
+ if(get_oracle_forest) {
+
+ /*Add feature for oracle rescoring */
+ string ff, param;
+ ff="BLEUModel";
+ //pass the location of the references file via param to BLEUModel
+ for(int kk=0;kk < conf["references"].as<vector<string> >().size();kk++)
+ {
+ param = param + " " + conf["references"].as<vector<string> >()[kk];
+ }
+ cerr << "Feature: " << ff << "->" << param << endl;
+ oracle_pff = global_ff_registry->Create(ff,param);
+ if (!oracle_pff) { exit(1); }
+ oracle_model_ffs.push_back(oracle_pff.get());
+ oracle_feature_weights.push_back(1.0);
+
+ }
+
+ ModelSet oracle_models(oracle_feature_weights, oracle_model_ffs);
+
+ const string loss_function3 = "IBM_BLEU_3";
+ ScoreType type3 = ScoreTypeFromString(loss_function3);
+ const DocScorer ds(type3, conf["references"].as<vector<string> >(), "");
+ cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function3 << endl;
+
+
+ std::ostringstream kbest_string_stream;
+ Score* doc_score=NULL;
+ float doc_src_length=0;
+ float tmp_src_length=0;
+ int oracle_doc_size= 10; //used for scaling/weighting oracle doc
+ float scale_oracle= 1-float(1)/oracle_doc_size;
+
+ /*End Oracle Extraction Prep*/
+
+
shared_ptr<WriteFile> extract_file;
if (conf.count("extract_rules"))
extract_file.reset(new WriteFile(str("extract_rules",conf)));
@@ -610,6 +669,87 @@ int main(int argc, char** argv) {
maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen);
+ vector<WordID> trans;
+ ViterbiESentence(forest, &trans);
+
+ /*Oracle Rescoring*/
+ if(get_oracle_forest)
+ {
+ Timer t("Forest Oracle rescoring:");
+ vector<WordID> model_trans;
+ model_trans = trans;
+
+ trans=model_trans;
+ Score* sentscore = ds[sent_id]->ScoreCandidate(model_trans);
+ //initilize psuedo-doc vector to 1 counts
+ if (!doc_score) { doc_score = sentscore->GetOne(); }
+ double bleu_scale_ = doc_src_length * doc_score->ComputeScore();
+ tmp_src_length = smeta.GetSourceLength();
+ smeta.SetScore(doc_score);
+ smeta.SetDocLen(doc_src_length);
+ smeta.SetDocScorer(&ds);
+
+ feature_weights[0]=1.0;
+
+ kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_model" << "." << sent_id;
+ DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length, ds, doc_score);
+ kbest_string_stream.str("");
+
+
+ forest.SortInEdgesByEdgeWeights();
+ Hypergraph lm_forest;
+ const IntersectionConfiguration inter_conf_oracle(0, 0);
+ cerr << "Going to call Apply Model " << endl;
+ ApplyModelSet(forest,
+ smeta,
+ oracle_models,
+ inter_conf_oracle,
+ &lm_forest);
+
+ forest.swap(lm_forest);
+ forest.Reweight(feature_weights);
+ forest.SortInEdgesByEdgeWeights();
+ vector<WordID> oracle_trans;
+
+ ViterbiESentence(forest, &oracle_trans);
+ cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
+ cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
+ cerr << " +Oracle BLEU Viterbi: " << TD::GetString(oracle_trans) << endl;
+
+ //compute kbest for oracle
+ kbest_string_stream << conf["forest_output"].as<string>() <<"/kbest_oracle" << "." << sent_id;
+ DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length, ds, doc_score);
+ kbest_string_stream.str("");
+
+
+ //reweight the model with -1 for the BLEU feature to compute k-best list for negative examples
+ feature_weights[0]=-1.0;
+ forest.Reweight(feature_weights);
+ forest.SortInEdgesByEdgeWeights();
+ vector<WordID> neg_trans;
+ ViterbiESentence(forest, &neg_trans);
+ cerr << " -Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
+ cerr << " -Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
+ cerr << " -Oracle BLEU Viterbi: " << TD::GetString(neg_trans) << endl;
+
+ //compute kbest for negative
+ kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_negative" << "." << sent_id;
+ DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length,ds, doc_score);
+ kbest_string_stream.str("");
+
+ //Add 1-best translation (trans) to psuedo-doc vectors
+ doc_score->PlusEquals(*sentscore, scale_oracle);
+ delete sentscore;
+
+ doc_src_length = (doc_src_length + tmp_src_length) * scale_oracle;
+
+
+ string details;
+ doc_score->ScoreDetails(&details);
+ cerr << "SCALED SCORE: " << bleu_scale_ << "DOC BLEU " << doc_score->ComputeScore() << " " <<details << endl;
+ }
+
+
if (conf.count("forest_output") && !has_ref) {
ForestWriter writer(str("forest_output",conf), sent_id);
if (FileExists(writer.fname_)) {
@@ -632,11 +772,9 @@ int main(int argc, char** argv) {
if (sample_max_trans) {
MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0);
} else {
- vector<WordID> trans;
- ViterbiESentence(forest, &trans);
-
+
if (kbest) {
- DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest);
+ DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"", doc_src_length, tmp_src_length, ds, doc_score);
} else if (csplit_output_plf) {
cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {