From e4c5e87db2139aa0f8655b063da7d8b5199cb46d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 20 Dec 2011 18:34:14 -0500 Subject: migrate fast_score to the new API --- pro-train/dist-pro.pl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'pro-train') diff --git a/pro-train/dist-pro.pl b/pro-train/dist-pro.pl index 5db053de..ba9cdc06 100755 --- a/pro-train/dist-pro.pl +++ b/pro-train/dist-pro.pl @@ -288,7 +288,7 @@ while (1){ $retries++; } die "Dev set contains $devSize sentences, but we don't have topbest and hypergraphs for all these! Decoder failure? Check $decoderLog\n" if ($devSize != $num_hgs || $devSize != $num_topbest); - my $dec_score = check_output("cat $runFile | $SCORER $refs_comma_sep -l $metric"); + my $dec_score = check_output("cat $runFile | $SCORER $refs_comma_sep -m $metric"); chomp $dec_score; print STDERR "DECODER SCORE: $dec_score\n"; -- cgit v1.2.3 From dbf367e0fc9d3faf906340d1f51f2dbda1892081 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 3 Feb 2012 17:19:16 -0500 Subject: make pro use new interface --- .gitignore | 77 ++++++++++++++++++++++++++++++++++++++++--------- mteval/ns.cc | 4 +++ mteval/ns.h | 4 +++ mteval/ns_ter.h | 1 + pro-train/dist-pro.pl | 4 +-- pro-train/mr_pro_map.cc | 37 +++++++++++++++--------- 6 files changed, 98 insertions(+), 29 deletions(-) (limited to 'pro-train') diff --git a/.gitignore b/.gitignore index 5efe37b0..ab8bf2c7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,46 @@ +mira/kbest_mira +sa-extract/calignment.c +sa-extract/calignment.so +sa-extract/cdat.c +sa-extract/cdat.so +sa-extract/cfloatlist.c +sa-extract/cfloatlist.so +sa-extract/cintlist.c +sa-extract/cintlist.so +sa-extract/clex.c +sa-extract/clex.so +sa-extract/cn.pyc +sa-extract/context_model.pyc +sa-extract/cstrmap.c +sa-extract/cstrmap.so +sa-extract/csuf.c +sa-extract/csuf.so +sa-extract/cveb.c +sa-extract/cveb.so +sa-extract/lcp.c +sa-extract/lcp.so +sa-extract/log.pyc +sa-extract/manager.pyc +sa-extract/model.pyc +sa-extract/monitor.pyc +sa-extract/precomputation.c +sa-extract/precomputation.so +sa-extract/rule.c +sa-extract/rule.so +sa-extract/rulefactory.c +sa-extract/rulefactory.so +sa-extract/sgml.pyc +sa-extract/sym.c +sa-extract/sym.so +training/mpi_flex_optimize +training/test_ngram +utils/dict_test +utils/logval_test +utils/mfcr_test +utils/phmt +utils/small_vector_test +utils/ts +utils/weights_test pro-train/.deps pro-train/mr_pro_map pro-train/mr_pro_reduce @@ -38,8 +81,8 @@ utils/.deps/ utils/libutils.a *swp *.o -vest/sentserver -vest/sentclient +dpmert/sentserver +dpmert/sentclient gi/pyp-topics/src/contexts_lexer.cc config.guess config.sub @@ -61,12 +104,12 @@ training/mr_em_map_adapter training/mr_reduce_to_weights training/optimize_test training/plftools -vest/fast_score -vest/lo_test -vest/mr_vest_map -vest/mr_vest_reduce -vest/scorer_test -vest/union_forests +dpmert/fast_score +dpmert/lo_test +dpmert/mr_dpmert_map +dpmert/mr_dpmert_reduce +dpmert/scorer_test +dpmert/union_forests Makefile Makefile.in aclocal.m4 @@ -99,11 +142,11 @@ training/Makefile.in training/*.o training/grammar_convert training/model1 -vest/.deps/ -vest/Makefile -vest/Makefile.in -vest/mr_vest_generate_mapper_input -vest/*.o +dpmert/.deps/ +dpmert/Makefile +dpmert/Makefile.in +dpmert/mr_dpmert_generate_mapper_input +dpmert/*.o decoder/logval_test extools/build_lexical_translation extools/filter_grammar @@ -124,7 +167,6 @@ m4/ltoptions.m4 m4/ltsugar.m4 m4/ltversion.m4 m4/lt~obsolete.m4 -vest/mbr_kbest extools/featurize_grammar extools/filter_score_grammar gi/posterior-regularisation/prjava/build/ @@ -143,3 +185,10 @@ gi/posterior-regularisation/prjava/lib/prjava-20100715.jar *.ps *.toc *~ +gi/pf/align-lexonly +gi/pf/align-lexonly-pyp +gi/pf/condnaive +mteval/scorer_test +phrasinator/gibbs_train_plm +phrasinator/gibbs_train_plm_notables +.* diff --git a/mteval/ns.cc b/mteval/ns.cc index da678b84..788f809a 100644 --- a/mteval/ns.cc +++ b/mteval/ns.cc @@ -21,6 +21,10 @@ map EvaluationMetric::instances_; SegmentEvaluator::~SegmentEvaluator() {} EvaluationMetric::~EvaluationMetric() {} +bool EvaluationMetric::IsErrorMetric() const { + return false; +} + struct DefaultSegmentEvaluator : public SegmentEvaluator { DefaultSegmentEvaluator(const vector >& refs, const EvaluationMetric* em) : refs_(refs), em_(em) {} void Evaluate(const vector& hyp, SufficientStats* out) const { diff --git a/mteval/ns.h b/mteval/ns.h index d88c263b..4e4c6975 100644 --- a/mteval/ns.h +++ b/mteval/ns.h @@ -94,6 +94,10 @@ class EvaluationMetric { public: const std::string& MetricId() const { return name_; } + // returns true for metrics like WER and TER where lower scores are better + // false for metrics like BLEU and METEOR where higher scores are better + virtual bool IsErrorMetric() const; + virtual unsigned SufficientStatisticsVectorSize() const; virtual float ComputeScore(const SufficientStats& stats) const = 0; virtual std::string DetailedScore(const SufficientStats& stats) const; diff --git a/mteval/ns_ter.h b/mteval/ns_ter.h index 3190fc1b..c5c25413 100644 --- a/mteval/ns_ter.h +++ b/mteval/ns_ter.h @@ -9,6 +9,7 @@ class TERMetric : public EvaluationMetric { TERMetric() : EvaluationMetric("TER") {} public: + virtual bool IsErrorMetric() const; virtual unsigned SufficientStatisticsVectorSize() const; virtual std::string DetailedScore(const SufficientStats& stats) const; virtual void ComputeSufficientStatistics(const std::vector& hyp, diff --git a/pro-train/dist-pro.pl b/pro-train/dist-pro.pl index ba9cdc06..31258fa6 100755 --- a/pro-train/dist-pro.pl +++ b/pro-train/dist-pro.pl @@ -12,7 +12,7 @@ use POSIX ":sys_wait_h"; my $QSUB_CMD = qsub_args(mert_memory()); my $default_jobs = env_default_jobs(); -my $VEST_DIR="$SCRIPT_DIR/../vest"; +my $VEST_DIR="$SCRIPT_DIR/../dpmert"; require "$VEST_DIR/libcall.pl"; # Default settings @@ -338,7 +338,7 @@ while (1){ $mapoutput =~ s/mapinput/mapoutput/; push @mapoutputs, "$dir/splag.$im1/$mapoutput"; $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard"; - my $script = "$MAPPER -s $srcFile -l $metric $refs_comma_sep -w $inweights -K $dir/kbest < $dir/splag.$im1/$shard > $dir/splag.$im1/$mapoutput"; + my $script = "$MAPPER -s $srcFile -m $metric $refs_comma_sep -w $inweights -K $dir/kbest < $dir/splag.$im1/$shard > $dir/splag.$im1/$mapoutput"; if ($use_make) { my $script_file = "$dir/scripts/map.$shard"; open F, ">$script_file" or die "Can't write $script_file: $!"; diff --git a/pro-train/mr_pro_map.cc b/pro-train/mr_pro_map.cc index 0a9b75d7..52b67f32 100644 --- a/pro-train/mr_pro_map.cc +++ b/pro-train/mr_pro_map.cc @@ -13,11 +13,12 @@ #include "filelib.h" #include "stringlib.h" #include "weights.h" -#include "scorer.h" #include "inside_outside.h" #include "hg_io.h" #include "kbest.h" #include "viterbi.h" +#include "ns.h" +#include "ns_docscorer.h" // This is Figure 4 (Algorithm Sampler) from Hopkins&May (2011) @@ -80,7 +81,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("kbest_repository,K",po::value()->default_value("./kbest"),"K-best list repository (directory)") ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") ("source,s",po::value()->default_value(""), "Source file (ignored, except for AER)") - ("loss_function,l",po::value()->default_value("ibm_bleu"), "Loss function being optimized") + ("evaluation_metric,m",po::value()->default_value("IBM_BLEU"), "Evaluation metric (ibm_bleu, koehn_bleu, nist_bleu, ter, meteor, etc.)") ("kbest_size,k",po::value()->default_value(1500u), "Top k-hypotheses to extract") ("candidate_pairs,G", po::value()->default_value(5000u), "Number of pairs to sample per hypothesis (Gamma)") ("best_pairs,X", po::value()->default_value(50u), "Number of pairs, ranked by magnitude of objective delta, to retain (Xi)") @@ -109,9 +110,12 @@ struct HypInfo { HypInfo(const vector& h, const SparseVector& feats) : hyp(h), g_(-100.0f), x(feats) {} // lazy evaluation - double g(const SentenceScorer& scorer) const { - if (g_ == -100.0f) - g_ = scorer.ScoreCandidate(hyp)->ComputeScore(); + double g(const SegmentEvaluator& scorer, const EvaluationMetric* metric) const { + if (g_ == -100.0f) { + SufficientStats ss; + scorer.Evaluate(hyp, &ss); + g_ = metric->ComputeScore(ss); + } return g_; } vector hyp; @@ -233,15 +237,21 @@ struct DiffOrder { } }; -void Sample(const unsigned gamma, const unsigned xi, const vector& J_i, const SentenceScorer& scorer, const bool invert_score, vector* pv) { +void Sample(const unsigned gamma, + const unsigned xi, + const vector& J_i, + const SegmentEvaluator& scorer, + const EvaluationMetric* metric, + vector* pv) { + const bool invert_score = metric->IsErrorMetric(); vector v1, v2; float avg_diff = 0; for (unsigned i = 0; i < gamma; ++i) { const size_t a = rng->inclusive(0, J_i.size() - 1)(); const size_t b = rng->inclusive(0, J_i.size() - 1)(); if (a == b) continue; - float ga = J_i[a].g(scorer); - float gb = J_i[b].g(scorer); + float ga = J_i[a].g(scorer, metric); + float gb = J_i[b].g(scorer, metric); bool positive = gb < ga; if (invert_score) positive = !positive; const float gdiff = fabs(ga - gb); @@ -288,11 +298,12 @@ int main(int argc, char** argv) { rng.reset(new MT19937(conf["random_seed"].as())); else rng.reset(new MT19937); - const string loss_function = conf["loss_function"].as(); + const string evaluation_metric = conf["evaluation_metric"].as(); + + EvaluationMetric* metric = EvaluationMetric::Instance(evaluation_metric); + DocumentScorer ds(metric, conf["reference"].as >()); + cerr << "Loaded " << ds.size() << " references for scoring with " << evaluation_metric << endl; - ScoreType type = ScoreTypeFromString(loss_function); - DocScorer ds(type, conf["reference"].as >(), conf["source"].as()); - cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function << endl; Hypergraph hg; string last_file; ReadFile in_read(conf["input"].as()); @@ -335,7 +346,7 @@ int main(int argc, char** argv) { Dedup(&J_i); WriteKBest(kbest_file, J_i); - Sample(gamma, xi, J_i, *ds[sent_id], (type == TER), &v); + Sample(gamma, xi, J_i, *ds[sent_id], metric, &v); for (unsigned i = 0; i < v.size(); ++i) { const TrainingInstance& vi = v[i]; cout << vi.y << "\t" << vi.x << endl; -- cgit v1.2.3