summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/apply_models.h2
-rw-r--r--decoder/cdec.cc18
-rwxr-xr-xdecoder/oracle_bleu.h28
-rwxr-xr-xdecoder/sentences.h53
-rw-r--r--decoder/sparse_vector.h20
-rw-r--r--decoder/stringlib.h5
6 files changed, 103 insertions, 23 deletions
diff --git a/decoder/apply_models.h b/decoder/apply_models.h
index 5c220afd..61a5b8f7 100644
--- a/decoder/apply_models.h
+++ b/decoder/apply_models.h
@@ -11,7 +11,7 @@ struct IntersectionConfiguration {
const int algorithm; // 0 = full intersection, 1 = cube pruning
const int pop_limit; // max number of pops off the heap at each node
IntersectionConfiguration(int alg, int k) : algorithm(alg), pop_limit(k) {}
- IntersectionConfiguration(exhaustive_t t) : algorithm(0), pop_limit() {(void)t;}
+ IntersectionConfiguration(exhaustive_t /* t */) : algorithm(0), pop_limit() {}
};
void ApplyModelSet(const Hypergraph& in,
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index 77179948..8827cce3 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -323,6 +323,12 @@ void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_featur
}
}
+void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,DenseWeightVector const& feature_weights) {
+ WeightVector fw(feature_weights);
+ forest_stats(forest,name,show_tree,show_features,&fw);
+}
+
+
void maybe_prune(Hypergraph &forest,po::variables_map const& conf,string nbeam,string ndensity,string forestname,double srclen) {
double beam_prune=0,density_prune=0;
bool use_beam_prune=beam_param(conf,nbeam,&beam_prune,conf.count("scale_prune_srclen"),srclen);
@@ -390,9 +396,9 @@ int main(int argc, char** argv) {
prelm_w.InitFromFile(plmw);
prelm_feature_weights.resize(FD::NumFeats());
prelm_w.InitVector(&prelm_feature_weights);
-// cerr << "prelm_weights: " << FeatureVector(prelm_feature_weights)<<endl;
+// cerr << "prelm_weights: " << WeightVector(prelm_feature_weights)<<endl;
}
-// cerr << "+LM weights: " << FeatureVector(feature_weights)<<endl;
+// cerr << "+LM weights: " << WeightVector(feature_weights)<<endl;
}
bool warn0=conf.count("warn_0_weight");
bool freeze=!conf.count("no_freeze_feature_set");
@@ -548,7 +554,7 @@ int main(int argc, char** argv) {
}
const bool show_tree_structure=conf.count("show_tree_structure");
const bool show_features=conf.count("show_features");
- forest_stats(forest," -LM forest",show_tree_structure,show_features,&feature_weights);
+ forest_stats(forest," -LM forest",show_tree_structure,show_features,feature_weights);
if (conf.count("show_expected_length")) {
const PRPair<double, double> res =
Inside<PRPair<double, double>,
@@ -574,7 +580,7 @@ int main(int argc, char** argv) {
&prelm_forest);
forest.swap(prelm_forest);
forest.Reweight(prelm_feature_weights);
- forest_stats(forest," prelm forest",show_tree_structure,show_features,&prelm_feature_weights);
+ forest_stats(forest," prelm forest",show_tree_structure,show_features,prelm_feature_weights);
}
maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen);
@@ -593,7 +599,7 @@ int main(int argc, char** argv) {
&lm_forest);
forest.swap(lm_forest);
forest.Reweight(feature_weights);
- forest_stats(forest," +LM forest",show_tree_structure,show_features,&feature_weights);
+ forest_stats(forest," +LM forest",show_tree_structure,show_features,feature_weights);
}
maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen);
@@ -604,7 +610,7 @@ int main(int argc, char** argv) {
/*Oracle Rescoring*/
if(get_oracle_forest) {
- Oracles o=oracles.ComputeOracles(smeta,&forest,feature_weights,&cerr,10,conf["forest_output"].as<std::string>());
+ Oracle o=oracle.ComputeOracle(smeta,&forest,FeatureVector(feature_weights),&cerr,10,conf["forest_output"].as<std::string>());
cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
o.hope.Print(cerr," +Oracle BLEU");
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index b58117c1..cc19fbca 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -17,6 +17,7 @@
#include "apply_models.h"
#include "kbest.h"
#include "timing_stats.h"
+#include "sentences.h"
//TODO: put function impls into .cc
//TODO: disentangle
@@ -44,7 +45,7 @@ struct Translation {
};
-struct Oracles {
+struct Oracle {
bool is_null() {
return model.is_null() /* && fear.is_null() && hope.is_null() */;
}
@@ -52,13 +53,13 @@ struct Oracles {
Translation model,fear,hope;
// feature 0 will be the error rate in fear and hope
// move toward hope
- FeatureVector ModelHopeGradient() {
+ FeatureVector ModelHopeGradient() const {
FeatureVector r=hope.features-model.features;
r.set_value(0,0);
return r;
}
// move toward hope from fear
- FeatureVector FearHopeGradient() {
+ FeatureVector FearHopeGradient() const {
FeatureVector r=hope.features-fear.features;
r.set_value(0,0);
return r;
@@ -150,9 +151,9 @@ struct OracleBleu {
}
// destroys forest (replaces it w/ rescored oracle one)
- Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") {
+ Oracle ComputeOracle(SentenceMetadata const& smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") {
Hypergraph &forest=*forest_in_out;
- Oracles r;
+ Oracle r;
int sent_id=smeta.GetSentenceID();
r.model=Translation(forest);
if (kbest) DumpKBest("model",sent_id, forest, kbest, true, forest_output);
@@ -169,23 +170,24 @@ struct OracleBleu {
if (kbest) DumpKBest("negative",sent_id, forest, kbest, true, forest_output);
return r;
}
- typedef std::vector<WordID> Sentence;
- void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) {
+ void Rescore(SentenceMetadata const& smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) {
// the sentence bleu stats will get added to doc only if you call IncludeLastScore
sentscore=GetScore(forest,smeta.GetSentenceID());
if (!doc_score) { doc_score.reset(sentscore->GetOne()); }
tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from?
using namespace std;
- ModelSet oracle_models(WeightVector(bleu_weight,1),vector<FeatureFunction const*>(1,pff.get()));
- const IntersectionConfiguration inter_conf_oracle(0, 0);
+ DenseWeightVector w;
+ feature_weights_=feature_weights;
+ feature_weights_.set_value(0,bleu_weight);
+ feature_weights.init_vector(&w);
+ ModelSet oracle_models(w,vector<FeatureFunction const*>(1,pff.get()));
if (log) *log << "Going to call Apply Model " << endl;
ApplyModelSet(forest,
smeta,
oracle_models,
- inter_conf_oracle,
+ IntersectionConfiguration(exhaustive_t()),
dest_forest);
- feature_weights_=feature_weights;
ReweightBleu(dest_forest,bleu_weight);
}
@@ -202,7 +204,7 @@ struct OracleBleu {
}
void ReweightBleu(Hypergraph *dest_forest,double bleu_weight=-1.) {
- feature_weights_[0]=bleu_weight;
+ feature_weights_.set_value(0,bleu_weight);
dest_forest->Reweight(feature_weights_);
// dest_forest->SortInEdgesByEdgeWeights();
}
@@ -227,7 +229,7 @@ struct OracleBleu {
kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
//calculate score in context of psuedo-doc
- Score* sentscore = GetScore(d->yield,sent_id);
+ ScoreP sentscore = GetScore(d->yield,sent_id);
sentscore->PlusEquals(*doc_score,float(1));
float bleu = curr_src_length * sentscore->ComputeScore();
kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
diff --git a/decoder/sentences.h b/decoder/sentences.h
new file mode 100755
index 00000000..842072b9
--- /dev/null
+++ b/decoder/sentences.h
@@ -0,0 +1,53 @@
+#ifndef CDEC_SENTENCES_H
+#define CDEC_SENTENCES_H
+
+#include <algorithm>
+#include <vector>
+#include <iostream>
+#include "filelib.h"
+#include "tdict.h"
+#include "stringlib.h"
+typedef std::vector<WordID> Sentence;
+
+inline void StringToSentence(std::string const& str,Sentence &s) {
+ using namespace std;
+ vector<string> ss=SplitOnWhitespace(str);
+ s.clear();
+ transform(ss.begin(),ss.end(),back_inserter(s),ToTD());
+}
+
+inline Sentence StringToSentence(std::string const& str) {
+ Sentence s;
+ StringToSentence(str,s);
+ return s;
+}
+
+inline std::istream& operator >> (std::istream &in,Sentence &s) {
+ using namespace std;
+ string str;
+ if (getline(in,str)) {
+ StringToSentence(str,s);
+ }
+ return in;
+}
+
+
+class Sentences : public std::vector<Sentence> {
+ typedef std::vector<Sentence> VS;
+public:
+ Sentences() { }
+ Sentences(unsigned n,Sentence const& sentence) : VS(n,sentence) { }
+ Sentences(unsigned n,std::string const& sentence) : VS(n,StringToSentence(sentence)) { }
+ void Load(std::string file) {
+ ReadFile r(file);
+ Load(*r.stream());
+ }
+ void Load(std::istream &in) {
+ this->push_back(Sentence());
+ while(in>>this->back()) ;
+ this->pop_back();
+ }
+};
+
+
+#endif
diff --git a/decoder/sparse_vector.h b/decoder/sparse_vector.h
index 9c7c9c79..43880014 100644
--- a/decoder/sparse_vector.h
+++ b/decoder/sparse_vector.h
@@ -12,6 +12,13 @@
#include "fdict.h"
+template <class T>
+inline T & extend_vector(std::vector<T> &v,int i) {
+ if (i>=v.size())
+ v.resize(i+1);
+ return v[i];
+}
+
template <typename T>
class SparseVector {
public:
@@ -29,6 +36,17 @@ public:
}
+ void init_vector(std::vector<T> *vp) const {
+ init_vector(*vp);
+ }
+
+ void init_vector(std::vector<T> &v) const {
+ v.clear();
+ for (const_iterator i=values_.begin(),e=values_.end();i!=e;++i)
+ extend_vector(v,i->first)=i->second;
+ }
+
+
void set_new_value(int index, T const& val) {
assert(values_.find(index)==values_.end());
values_[index]=val;
@@ -312,7 +330,7 @@ private:
typedef SparseVector<double> FeatureVector;
typedef SparseVector<double> WeightVector;
-
+typedef std::vector<double> DenseWeightVector;
template <typename T>
SparseVector<T> operator+(const SparseVector<T>& a, const SparseVector<T>& b) {
SparseVector<T> result = a;
diff --git a/decoder/stringlib.h b/decoder/stringlib.h
index eac1dce6..6bb8cff0 100644
--- a/decoder/stringlib.h
+++ b/decoder/stringlib.h
@@ -1,4 +1,5 @@
-#ifndef _STRINGLIB_H_
+#ifndef CDEC_STRINGLIB_H_
+#define CDEC_STRINGLIB_H_
#include <map>
#include <vector>
@@ -14,7 +15,7 @@ void ParseTranslatorInput(const std::string& line, std::string* input, std::stri
struct Lattice;
void ParseTranslatorInputLattice(const std::string& line, std::string* input, Lattice* ref);
-inline const std::string Trim(const std::string& str, const std::string& dropChars = " \t") {
+inline std::string Trim(const std::string& str, const std::string& dropChars = " \t") {
std::string res = str;
res.erase(str.find_last_not_of(dropChars)+1);
return res.erase(0, res.find_first_not_of(dropChars));