summaryrefslogtreecommitdiff
path: root/decoder/cdec.cc
diff options
context:
space:
mode:
authorgraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-15 03:50:05 +0000
committergraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-15 03:50:05 +0000
commit27ed3c0fecde089a761ccf718748413bb572a3a4 (patch)
tree69e84990a9c4842ccbb7783f76e73b2dc1e3a7fa /decoder/cdec.cc
parent12b09cc1069ee4401074e0e0f6d8f9c120318aa0 (diff)
oracle bleu refactor
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@259 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder/cdec.cc')
-rw-r--r--decoder/cdec.cc167
1 files changed, 47 insertions, 120 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index c15408b5..bec342ef 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -7,6 +7,7 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include "oracle_bleu.h"
#include "timing_stats.h"
#include "translator.h"
#include "phrasebased_translator.h"
@@ -146,11 +147,11 @@ void InitCommandLine(int argc, char** argv, po::variables_map* confp) {
("crf_uniform_empirical", "If there are multple references use (i.e., lattice) a uniform distribution rather than posterior weighting a la EM")
("get_oracle_forest,o", "Calculate rescored hypregraph using approximate BLEU scoring of rules")
("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
- ("references,R", po::value<vector<string> >(), "Translation reference files")
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
("forest_output,O",po::value<string>(),"Directory to write forests to")
("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc.");
+ OracleBleu::AddOptions(&opts);
po::options_description clo("Command line options");
clo.add_options()
("config,c", po::value<string>(), "Configuration file")
@@ -260,14 +261,14 @@ void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) {
}
}
-// TODO decoder output should probably be moved to another file
-void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, const char *kbest_out_filename_, float doc_src_length, float tmp_src_length, const DocScorer &ds, Score* doc_score) {
+// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
+void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const&kbest_out_filename_, float doc_src_length, float tmp_src_length, const DocScorer &ds, shared_ptr<Score> doc_score) {
cerr << "In kbest\n";
ofstream kbest_out;
- kbest_out.open(kbest_out_filename_);
+ kbest_out.open(kbest_out_filename_.c_str());
cerr << "Output kbest to " << kbest_out_filename_;
-
+
//add length (f side) src length of this sentence to the psuedo-doc src length count
float curr_src_length = doc_src_length + tmp_src_length;
@@ -298,6 +299,15 @@ cerr << "In kbest\n";
}
}
+void DumpKBest(po::variables_map const& conf,string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, float doc_src_length, float tmp_src_length, const DocScorer &ds, shared_ptr<Score> doc_score)
+{
+ ostringstream kbest_string_stream;
+ kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_"<<suffix<< "." << sent_id;
+ DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), doc_src_length, tmp_src_length, ds, doc_score);
+
+}
+
+
struct ELengthWeightFunction {
double operator()(const Hypergraph::Edge& e) const {
return e.rule_->ELength() - e.rule_->Arity();
@@ -517,45 +527,9 @@ int main(int argc, char** argv) {
const bool crf_uniform_empirical = conf.count("crf_uniform_empirical");
const bool get_oracle_forest = conf.count("get_oracle_forest");
- /*Oracle Extraction Prep*/
- vector<const FeatureFunction*> oracle_model_ffs;
- vector<double> oracle_feature_weights;
- shared_ptr<FeatureFunction> oracle_pff;
- if(get_oracle_forest) {
-
- /*Add feature for oracle rescoring */
- string ff, param;
- ff="BLEUModel";
- //pass the location of the references file via param to BLEUModel
- for(int kk=0;kk < conf["references"].as<vector<string> >().size();kk++)
- {
- param = param + " " + conf["references"].as<vector<string> >()[kk];
- }
- cerr << "Feature: " << ff << "->" << param << endl;
- oracle_pff = global_ff_registry->Create(ff,param);
- if (!oracle_pff) { exit(1); }
- oracle_model_ffs.push_back(oracle_pff.get());
- oracle_feature_weights.push_back(1.0);
-
- }
-
- ModelSet oracle_models(oracle_feature_weights, oracle_model_ffs);
-
- const string loss_function3 = "IBM_BLEU_3";
- ScoreType type3 = ScoreTypeFromString(loss_function3);
- const DocScorer ds(type3, conf["references"].as<vector<string> >(), "");
- cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function3 << endl;
-
-
- std::ostringstream kbest_string_stream;
- Score* doc_score=NULL;
- float doc_src_length=0;
- float tmp_src_length=0;
- int oracle_doc_size= 10; //used for scaling/weighting oracle doc
- float scale_oracle= 1-float(1)/oracle_doc_size;
-
- /*End Oracle Extraction Prep*/
-
+ OracleBleu oracle;
+ if (get_oracle_forest)
+ oracle.UseConf(conf);
shared_ptr<WriteFile> extract_file;
if (conf.count("extract_rules"))
@@ -671,83 +645,37 @@ int main(int argc, char** argv) {
vector<WordID> trans;
ViterbiESentence(forest, &trans);
-
+
/*Oracle Rescoring*/
- if(get_oracle_forest)
+ if(get_oracle_forest) {
+ Timer t("Forest Oracle rescoring:");
+
+ DumpKBest(conf,"model",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length, oracle.ds, oracle.doc_score);
+
+ Translation best(forest);
{
- Timer t("Forest Oracle rescoring:");
- vector<WordID> model_trans;
- model_trans = trans;
-
- trans=model_trans;
- Score* sentscore = ds[sent_id]->ScoreCandidate(model_trans);
- //initilize psuedo-doc vector to 1 counts
- if (!doc_score) { doc_score = sentscore->GetOne(); }
- double bleu_scale_ = doc_src_length * doc_score->ComputeScore();
- tmp_src_length = smeta.GetSourceLength();
- smeta.SetScore(doc_score);
- smeta.SetDocLen(doc_src_length);
- smeta.SetDocScorer(&ds);
-
- feature_weights[0]=1.0;
-
- kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_model" << "." << sent_id;
- DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length, ds, doc_score);
- kbest_string_stream.str("");
-
-
- forest.SortInEdgesByEdgeWeights();
- Hypergraph lm_forest;
- const IntersectionConfiguration inter_conf_oracle(0, 0);
- cerr << "Going to call Apply Model " << endl;
- ApplyModelSet(forest,
- smeta,
- oracle_models,
- inter_conf_oracle,
- &lm_forest);
-
- forest.swap(lm_forest);
- forest.Reweight(feature_weights);
- forest.SortInEdgesByEdgeWeights();
- vector<WordID> oracle_trans;
-
- ViterbiESentence(forest, &oracle_trans);
- cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
- cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
- cerr << " +Oracle BLEU Viterbi: " << TD::GetString(oracle_trans) << endl;
-
- //compute kbest for oracle
- kbest_string_stream << conf["forest_output"].as<string>() <<"/kbest_oracle" << "." << sent_id;
- DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length, ds, doc_score);
- kbest_string_stream.str("");
-
-
- //reweight the model with -1 for the BLEU feature to compute k-best list for negative examples
- feature_weights[0]=-1.0;
- forest.Reweight(feature_weights);
- forest.SortInEdgesByEdgeWeights();
- vector<WordID> neg_trans;
- ViterbiESentence(forest, &neg_trans);
- cerr << " -Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
- cerr << " -Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
- cerr << " -Oracle BLEU Viterbi: " << TD::GetString(neg_trans) << endl;
-
- //compute kbest for negative
- kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_negative" << "." << sent_id;
- DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length,ds, doc_score);
- kbest_string_stream.str("");
-
- //Add 1-best translation (trans) to psuedo-doc vectors
- doc_score->PlusEquals(*sentscore, scale_oracle);
- delete sentscore;
-
- doc_src_length = (doc_src_length + tmp_src_length) * scale_oracle;
-
-
- string details;
- doc_score->ScoreDetails(&details);
- cerr << "SCALED SCORE: " << bleu_scale_ << "DOC BLEU " << doc_score->ComputeScore() << " " <<details << endl;
+ Hypergraph oracle_forest;
+ oracle.Rescore(smeta,forest,&oracle_forest,feature_weights,1.0);
+ forest.swap(oracle_forest);
}
+ Translation oracle_trans(forest);
+
+ cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
+ cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
+ oracle_trans.Print(cerr," +Oracle BLEU");
+ //compute kbest for oracle
+ DumpKBest(conf,"oracle",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length, oracle.ds, oracle.doc_score);
+
+ //reweight the model with -1 for the BLEU feature to compute k-best list for negative examples
+ oracle.ReweightBleu(&forest,-1.0);
+ Translation neg_trans(forest);
+ neg_trans.Print(cerr," -Oracle BLEU");
+ //compute kbest for negative
+ DumpKBest(conf,"negative",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length,oracle.ds,oracle.doc_score);
+
+ //Add 1-best translation (trans) to psuedo-doc vectors
+ oracle.IncludeLastScore(&cerr);
+ }
if (conf.count("forest_output") && !has_ref) {
@@ -772,9 +700,8 @@ int main(int argc, char** argv) {
if (sample_max_trans) {
MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0);
} else {
-
if (kbest) {
- DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"", doc_src_length, tmp_src_length, ds, doc_score);
+ DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"", oracle.doc_src_length,oracle.tmp_src_length, oracle.ds,oracle.doc_score);
} else if (csplit_output_plf) {
cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {