summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorgraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-19 22:51:33 +0000
committergraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-19 22:51:33 +0000
commita2e4142d6a737bff040c3f2a583da6e8244db01a (patch)
treedce70b212c143f3149c8280698ee5abce7fd6cda /decoder
parent1b606343b7368aa4c61d5088b22b8916486f0073 (diff)
shared_ptr for scores. todo: intrusive.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@327 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder')
-rw-r--r--decoder/apply_models.cc3
-rw-r--r--decoder/cdec.cc2
-rw-r--r--decoder/ff_bleu.cc6
-rwxr-xr-xdecoder/oracle_bleu.h22
-rw-r--r--decoder/sparse_vector.h8
5 files changed, 28 insertions, 13 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index ba573984..0e83582f 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -409,7 +409,8 @@ void ApplyModelSet(const Hypergraph& in,
const ModelSet& models,
const IntersectionConfiguration& config,
Hypergraph* out) {
- if (models.stateless() && config.algorithm == 0) {
+ //force exhaustive if there's no state req. for model
+ if (models.stateless() || config.algorithm == 0) {
NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state
ma.Apply();
} else if (config.algorithm == 1) {
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index a9c1cb3b..be554774 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -613,7 +613,7 @@ int main(int argc, char** argv) {
/*Oracle Rescoring*/
if(get_oracle_forest) {
- Oracle o=oracle.ComputeOracle(smeta,&forest,FeatureVector(feature_weights),&cerr,10,conf["forest_output"].as<std::string>());
+ Oracle o=oracle.ComputeOracle(smeta,&forest,FeatureVector(feature_weights),10,conf["forest_output"].as<std::string>());
cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
o.hope.Print(cerr," +Oracle BLEU");
diff --git a/decoder/ff_bleu.cc b/decoder/ff_bleu.cc
index f8d62aa2..19564bd0 100644
--- a/decoder/ff_bleu.cc
+++ b/decoder/ff_bleu.cc
@@ -182,7 +182,8 @@ class BLEUModelImpl {
cerr << ")\n";
*/
- Score *node_score = smeta.GetDocScorer()[smeta.GetSentenceID()]->ScoreCCandidate(vs);
+ ScoreP node_score_p = smeta.GetDocScorer()[smeta.GetSentenceID()]->ScoreCCandidate(vs);
+ Score *node_score=node_score_p.get();
string details;
node_score->ScoreDetails(&details);
const Score *base_score= &smeta.GetScore();
@@ -194,6 +195,7 @@ class BLEUModelImpl {
//how it seems to be done in code
//TODO: might need to reverse the -1/+1 of the oracle/neg examples
+ //TO VLADIMIR: the polarity would be reversed if you switched error (1-BLEU) for BLEU.
approx_bleu = ( rule.FWords() * oracledoc_factor ) * node_score->ComputeScore();
//how I thought it was done from the paper
//approx_bleu = ( rule.FWords()+ smeta.GetDocLen() ) * node_score->ComputeScore();
@@ -277,7 +279,7 @@ void BLEUModel::TraversalFeaturesImpl(const SentenceMetadata& smeta,
const DocScorer *ds = &smeta.GetDocScorer();
*/
- cerr<< "Loading sentence " << smeta.GetSentenceID() << endl;
+// cerr<< "ff_bleu loading sentence " << smeta.GetSentenceID() << endl;
//}
features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state, smeta));
//cerr << "FID" << fid_ << " " << DebugStateToString(state) << endl;
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 4800e9c1..470d311d 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -37,9 +37,12 @@ struct Translation {
ViterbiESentence(hg,&sentence);
features=ViterbiFeatures(hg,feature_weights,true);
}
- void Print(std::ostream &out,std::string pre=" +Oracle BLEU ") const {
+ void Print(std::ostream &out,std::string pre=" +Oracle BLEU ",bool include_0_fid=true) const {
out<<pre<<"Viterbi: "<<TD::GetString(sentence)<<"\n";
- out<<pre<<"features: "<<features<<std::endl;
+ out<<pre<<"features: "<<features;
+ if (include_0_fid && features.nonzero(0))
+ out<< " dummy-feature(0)="<<features[0];
+ out<<std::endl;
}
bool is_null() {
return features.empty() /* && sentence.empty() */;
@@ -91,6 +94,7 @@ struct OracleBleu {
("references,R", value<Refs >(&refs), "Translation reference files")
("oracle_loss", value<string>(&loss_name)->default_value("IBM_BLEU_3"), "IBM_BLEU_3 (default), IBM_BLEU etc")
("bleu_weight", value<double>(&bleu_weight)->default_value(1.), "weight to give the hope/fear loss function vs. model score")
+ ("verbose",bool_switch(&verbose),"detailed logs")
;
}
int order;
@@ -122,6 +126,7 @@ struct OracleBleu {
double bleu_weight;
// you have to call notify(conf) yourself, once, in main or similar
+ bool verbose;
void UseConf(boost::program_options::variables_map const& /* conf */) {
using namespace std;
// bleu_weight=conf["bleu_weight"].as<double>();
@@ -162,12 +167,12 @@ struct OracleBleu {
return;
}
assert(refs.size());
- ds.Init(loss,refs);
+ ds.Init(loss,refs,"",verbose);
ensure_doc_score();
-// doc_score.reset();
std::cerr << "Loaded " << ds.size() << " references for scoring with " << StringFromScoreType(loss) << std::endl;
}
+ // metadata has plain pointer, not shared, so we need to exist as long as it does
SentenceMetadata MakeMetadata(Hypergraph const& forest,int sent_id) {
std::vector<WordID> srcsent;
ViterbiFSentence(forest,&srcsent);
@@ -180,7 +185,7 @@ struct OracleBleu {
}
// destroys forest (replaces it w/ rescored oracle one)
- Oracle ComputeOracle(SentenceMetadata const& smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") {
+ Oracle ComputeOracle(SentenceMetadata const& smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,unsigned kbest=0,std::string const& forest_output="") {
Hypergraph &forest=*forest_in_out;
Oracle r;
int sent_id=smeta.GetSentenceID();
@@ -189,7 +194,7 @@ struct OracleBleu {
{
Timer t("Forest Oracle rescoring:");
Hypergraph oracle_forest;
- Rescore(smeta,forest,&oracle_forest,feature_weights,bleu_weight,log);
+ Rescore(smeta,forest,&oracle_forest,feature_weights,bleu_weight);
forest.swap(oracle_forest);
}
r.hope=Translation(forest);
@@ -202,10 +207,10 @@ struct OracleBleu {
// if doc_score wasn't init, add 1 counts to ngram acc.
void ensure_doc_score() {
- if (!doc_score) { doc_score.reset(Score::GetOne(loss)); }
+ if (!doc_score) { doc_score=Score::GetOne(loss); }
}
- void Rescore(SentenceMetadata const& smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) {
+ void Rescore(SentenceMetadata const& smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0) {
// the sentence bleu stats will get added to doc only if you call IncludeLastScore
ensure_doc_score();
sentscore=GetScore(forest,smeta.GetSentenceID());
@@ -216,7 +221,6 @@ struct OracleBleu {
feature_weights_.set_value(0,bleu_weight);
feature_weights.init_vector(&w);
ModelSet oracle_models(w,vector<FeatureFunction const*>(1,pff.get()));
- if (log) *log << "Going to call Apply Model " << endl;
ApplyModelSet(forest,
smeta,
oracle_models,
diff --git a/decoder/sparse_vector.h b/decoder/sparse_vector.h
index f41bedf5..5e785210 100644
--- a/decoder/sparse_vector.h
+++ b/decoder/sparse_vector.h
@@ -1,5 +1,7 @@
#ifndef _SPARSE_VECTOR_H_
#define _SPARSE_VECTOR_H_
+/* hack: index 0 never gets printed because cdyer is creative and efficient. features which have no weight got feature dict id 0, see, and the models all clobered that value. nobody wants to see it. except that vlad is also creative and efficient and stored the oracle bleu there. */
+
// this is a modified version of code originally written
// by Phil Blunsom
@@ -54,6 +56,12 @@ public:
}
+ // warning: exploits the fact that 0 values are always removed from map. change this if you change that.
+ bool nonzero(int index) const {
+ return values_.find(index) != values_.end();
+ }
+
+
const T operator[](int index) const {
typename MapType::const_iterator found = values_.find(index);
if (found == values_.end())