summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2014-02-16 00:13:17 +0100
committerPatrick Simianer <p@simianer.de>2014-02-16 00:13:17 +0100
commitab71c44e61d00c788e84b44156d0be16191e267d (patch)
tree27d1c7e74e8b07276312766a5908853465a3ed18
parent4494c2cae3bed81f9d2d24d749e99bf66a734bc5 (diff)
parent9e2f7fcfa76213f5e41abb4f4c9a264ebe8f9d8c (diff)
Merge remote-tracking branch 'upstream/master'
-rwxr-xr-xcompound-split/compound-split.pl29
-rwxr-xr-xcompound-split/install-data-deps.sh9
-rw-r--r--decoder/bottom_up_parser.cc80
-rw-r--r--mteval/ns.h1
-rw-r--r--mteval/ns_docscorer.cc34
-rw-r--r--mteval/scorer_test.cc9
-rw-r--r--mteval/test_data/devset.txt2
-rw-r--r--tests/system_tests/multigram/cdec.ini1
-rw-r--r--tests/system_tests/multigram/g1.scfg4
-rw-r--r--tests/system_tests/multigram/g2.scfg1
-rw-r--r--tests/system_tests/multigram/gold.statistics3
-rw-r--r--tests/system_tests/multigram/gold.stdout1
-rw-r--r--tests/system_tests/multigram/input.txt1
-rw-r--r--tests/system_tests/multigram/weights1
-rw-r--r--training/mira/Makefile.am13
-rw-r--r--training/mira/ada_opt_sm.cc198
-rw-r--r--training/utils/candidate_set.cc15
-rw-r--r--training/utils/candidate_set.h2
-rw-r--r--utils/stringlib.h15
-rw-r--r--utils/tdict.cc4
-rw-r--r--utils/tdict.h2
21 files changed, 406 insertions, 19 deletions
diff --git a/compound-split/compound-split.pl b/compound-split/compound-split.pl
index 62259146..93ac3b20 100755
--- a/compound-split/compound-split.pl
+++ b/compound-split/compound-split.pl
@@ -35,6 +35,7 @@ die "Don't know about language: $LANG\n" unless -d "./$LANG";
my $CONFIG="cdec-$LANG.ini";
die "Can't find $CONFIG" unless -f $CONFIG;
die "--output must be '1best' or 'plf'\n" unless ($OUTPUT =~ /^(plf|1best)$/);
+check_dependencies($CONFIG, $LANG);
print STDERR "(Run with --help for options)\n";
print STDERR "LANGUAGE: $LANG\n";
print STDERR " OUTPUT: $OUTPUT\n";
@@ -146,3 +147,31 @@ Usage: $0 [OPTIONS] < file.txt
EOT
exit(1);
}
+
+sub check_dependencies {
+ my ($conf, $lang) = @_;
+ my @files = ();
+ open F, "<$conf" or die "Can't read $conf: $!";
+ while(<F>){
+ chomp;
+ my @x = split /\s+/;
+ for my $f (@x) {
+ push @files, $f if ($f =~ /\.gz$/);
+ }
+ }
+ close F;
+ my $c = 0;
+ for my $file (@files) {
+ $c++ if -f $file;
+ }
+ if ($c != scalar @files) {
+ print STDERR <<EOT;
+Missing data dependencies; to install, please run:
+
+ $script_dir/install-data-deps.sh
+
+EOT
+ exit(1);
+ }
+}
+
diff --git a/compound-split/install-data-deps.sh b/compound-split/install-data-deps.sh
new file mode 100755
index 00000000..942bfdcd
--- /dev/null
+++ b/compound-split/install-data-deps.sh
@@ -0,0 +1,9 @@
+#!/bin/sh
+set -e
+
+data_version=csplit-data-01.tar.gz
+
+curl -f http://demo.clab.cs.cmu.edu/cdec/$data_version -o $data_version
+
+tar xzf $data_version
+
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc
index 606b8d7e..8738c8f1 100644
--- a/decoder/bottom_up_parser.cc
+++ b/decoder/bottom_up_parser.cc
@@ -45,6 +45,7 @@ class PassiveChart {
const float lattice_cost);
void ApplyUnaryRules(const int i, const int j);
+ void TopoSortUnaries();
const vector<GrammarPtr>& grammars_;
const Lattice& input_;
@@ -57,6 +58,7 @@ class PassiveChart {
TRulePtr goal_rule_;
int goal_idx_; // index of goal node, if found
const int lc_fid_;
+ vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars
static WordID kGOAL; // [Goal]
};
@@ -159,21 +161,78 @@ PassiveChart::PassiveChart(const string& goal,
goal_cat_(TD::Convert(goal) * -1),
goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")),
goal_idx_(-1),
- lc_fid_(FD::Convert("LatticeCost")) {
+ lc_fid_(FD::Convert("LatticeCost")),
+ unaries_() {
act_chart_.resize(grammars_.size());
- for (unsigned i = 0; i < grammars_.size(); ++i)
+ for (unsigned i = 0; i < grammars_.size(); ++i) {
act_chart_[i] = new ActiveChart(forest, *this);
+ const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules();
+ for (unsigned j = 0; j < u.size(); ++j)
+ unaries_.push_back(u[j]);
+ }
+ TopoSortUnaries();
if (!kGOAL) kGOAL = TD::Convert("Goal") * -1;
if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl;
}
+static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) {
+ if (mark[node] == 1) {
+ cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n";
+ return false; // cycle detected
+ } else if (mark[node] == 2) {
+ return true; // already been
+ }
+ mark[node] = 1;
+ const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node);
+ if (nit != g.end()) {
+ const vector<TRulePtr>& edges = nit->second;
+ vector<bool> okay(edges.size(), true);
+ for (unsigned i = 0; i < edges.size(); ++i) {
+ okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark);
+ if (!okay[i]) {
+ cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl;
+ }
+ }
+ for (unsigned i = 0; i < edges.size(); ++i) {
+ if (okay[i]) u.push_back(edges[i]);
+ //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl;
+ }
+ }
+ mark[node] = 2;
+ return true;
+}
+
+void PassiveChart::TopoSortUnaries() {
+ vector<TRulePtr> u(unaries_.size()); u.clear();
+ map<int, vector<TRulePtr> > g;
+ map<int, int> mark;
+ //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl;
+ mark[goal_cat_] = 2;
+ for (unsigned i = 0; i < unaries_.size(); ++i) {
+ //cerr << "Adding: " << unaries_[i]->AsString() << endl;
+ g[unaries_[i]->f()[0]].push_back(unaries_[i]);
+ }
+ //m[unaries_[i]->lhs_].push_back(unaries_[i]);
+ for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) {
+ //cerr << "PROC: " << TD::Convert(-it->first) << endl;
+ if (mark[it->first] > 0) {
+ //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n";
+ } else {
+ TopoSortVisit(it->first, u, g, mark);
+ }
+ }
+ unaries_.clear();
+ for (int i = u.size() - 1; i >= 0; --i)
+ unaries_.push_back(u[i]);
+}
+
void PassiveChart::ApplyRule(const int i,
const int j,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
const float lattice_cost) {
Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
- //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
+ // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
new_edge->prev_i_ = r->prev_i;
new_edge->prev_j_ = r->prev_j;
new_edge->i_ = i;
@@ -215,15 +274,14 @@ void PassiveChart::ApplyRules(const int i,
void PassiveChart::ApplyUnaryRules(const int i, const int j) {
const vector<int>& nodes = chart_(i,j); // reference is important!
- for (unsigned gi = 0; gi < grammars_.size(); ++gi) {
- if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue;
- for (unsigned di = 0; di < nodes.size(); ++di) {
- const WordID& cat = forest_->nodes_[nodes[di]].cat_;
- const vector<TRulePtr>& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat);
- for (unsigned ri = 0; ri < unaries.size(); ++ri) {
- // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl;
+ for (unsigned di = 0; di < nodes.size(); ++di) {
+ const WordID& cat = forest_->nodes_[nodes[di]].cat_;
+ for (unsigned ri = 0; ri < unaries_.size(); ++ri) {
+ //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl;
+ if (unaries_[ri]->f()[0] == cat) {
+ //cerr << " --MATCH\n";
const Hypergraph::TailNodeVector ant(1, nodes[di]);
- ApplyRule(i, j, unaries[ri], ant, 0); // may update nodes
+ ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes
}
}
}
diff --git a/mteval/ns.h b/mteval/ns.h
index ac7b0a23..153bf0b8 100644
--- a/mteval/ns.h
+++ b/mteval/ns.h
@@ -78,6 +78,7 @@ inline const SufficientStats operator-(const SufficientStats& a, const Sufficien
struct SegmentEvaluator {
virtual ~SegmentEvaluator();
virtual void Evaluate(const std::vector<WordID>& hyp, SufficientStats* out) const = 0;
+ std::string src; // this may not always be available
};
// Instructions for implementing a new metric
diff --git a/mteval/ns_docscorer.cc b/mteval/ns_docscorer.cc
index 83bd1a29..242f134a 100644
--- a/mteval/ns_docscorer.cc
+++ b/mteval/ns_docscorer.cc
@@ -13,6 +13,40 @@ DocumentScorer::~DocumentScorer() {}
DocumentScorer::DocumentScorer() {}
+DocumentScorer::DocumentScorer(const EvaluationMetric* metric,
+ const string& src_ref_file) {
+ const WordID kDIV = TD::Convert("|||");
+ assert(!src_ref_file.empty());
+ cerr << "Loading source and references from " << src_ref_file << "...\n";
+ ReadFile rf(src_ref_file);
+ istream& in = *rf.stream();
+ unsigned lc = 0;
+ string src_ref;
+ vector<WordID> tmp;
+ vector<vector<WordID> > refs;
+ while(getline(in, src_ref)) {
+ ++lc;
+ size_t end_src = src_ref.find(" ||| ");
+ if (end_src == string::npos) {
+ cerr << "Expected SRC ||| REF [||| REF2 ||| REF3 ...] in line " << lc << endl;
+ abort();
+ }
+ refs.clear();
+ tmp.clear();
+ TD::ConvertSentence(src_ref, &tmp, end_src + 5);
+ unsigned last = 0;
+ for (unsigned j = 0; j < tmp.size(); ++j) {
+ if (tmp[j] == kDIV) {
+ refs.push_back(vector<WordID>(tmp.begin() + last, tmp.begin() + j));
+ last = j + 1;
+ }
+ }
+ refs.push_back(vector<WordID>(tmp.begin() + last, tmp.end()));
+ scorers_.push_back(metric->CreateSegmentEvaluator(refs));
+ scorers_.back()->src = src_ref.substr(0, end_src);
+ }
+}
+
void DocumentScorer::Init(const EvaluationMetric* metric,
const vector<string>& ref_files,
const string& src_file,
diff --git a/mteval/scorer_test.cc b/mteval/scorer_test.cc
index da07f154..cd27f020 100644
--- a/mteval/scorer_test.cc
+++ b/mteval/scorer_test.cc
@@ -3,6 +3,7 @@
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
+#include "ns_docscorer.h"
#include "ns.h"
#include "tdict.h"
#include "scorer.h"
@@ -223,4 +224,12 @@ BOOST_AUTO_TEST_CASE(NewScoreAPI) {
//cerr << metric->ComputeScore(statse) << endl;
}
+BOOST_AUTO_TEST_CASE(HybridSourceReferenceFileFormat) {
+ std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
+ EvaluationMetric* metric = EvaluationMetric::Instance("IBM_BLEU");
+ DocumentScorer ds(metric, path + "/devset.txt");
+ BOOST_CHECK_EQUAL(2, ds.size());
+ BOOST_CHECK_EQUAL("Quelltext hier .", ds[0]->src);
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/mteval/test_data/devset.txt b/mteval/test_data/devset.txt
new file mode 100644
index 00000000..f9135d98
--- /dev/null
+++ b/mteval/test_data/devset.txt
@@ -0,0 +1,2 @@
+Quelltext hier . ||| source text here . ||| original text . ||| some source text .
+ein anderer Satz . ||| another sentence . ||| a different sentece .
diff --git a/tests/system_tests/multigram/cdec.ini b/tests/system_tests/multigram/cdec.ini
new file mode 100644
index 00000000..f31becb8
--- /dev/null
+++ b/tests/system_tests/multigram/cdec.ini
@@ -0,0 +1 @@
+formalism=scfg
diff --git a/tests/system_tests/multigram/g1.scfg b/tests/system_tests/multigram/g1.scfg
new file mode 100644
index 00000000..a3a59699
--- /dev/null
+++ b/tests/system_tests/multigram/g1.scfg
@@ -0,0 +1,4 @@
+[X] ||| [Z] ||| [1] ||| Top=1
+[Y] ||| foo ||| foo ||| F1=1
+[Z] ||| [Z] [Y] ||| [1] [2] ||| W1=1
+[Z] ||| [Y] ||| [1] ||| W2=1
diff --git a/tests/system_tests/multigram/g2.scfg b/tests/system_tests/multigram/g2.scfg
new file mode 100644
index 00000000..40962517
--- /dev/null
+++ b/tests/system_tests/multigram/g2.scfg
@@ -0,0 +1 @@
+[Y] ||| bar ||| bar ||| F2=1
diff --git a/tests/system_tests/multigram/gold.statistics b/tests/system_tests/multigram/gold.statistics
new file mode 100644
index 00000000..ef23a685
--- /dev/null
+++ b/tests/system_tests/multigram/gold.statistics
@@ -0,0 +1,3 @@
+-lm_nodes 11
+-lm_edges 12
+-lm_paths 2
diff --git a/tests/system_tests/multigram/gold.stdout b/tests/system_tests/multigram/gold.stdout
new file mode 100644
index 00000000..d675fa44
--- /dev/null
+++ b/tests/system_tests/multigram/gold.stdout
@@ -0,0 +1 @@
+foo bar
diff --git a/tests/system_tests/multigram/input.txt b/tests/system_tests/multigram/input.txt
new file mode 100644
index 00000000..2aef01a0
--- /dev/null
+++ b/tests/system_tests/multigram/input.txt
@@ -0,0 +1 @@
+<seg id="0" grammar="g1.scfg" grammar1="g2.scfg">foo bar</seg>
diff --git a/tests/system_tests/multigram/weights b/tests/system_tests/multigram/weights
new file mode 100644
index 00000000..a6b6698a
--- /dev/null
+++ b/tests/system_tests/multigram/weights
@@ -0,0 +1 @@
+Glue -1
diff --git a/training/mira/Makefile.am b/training/mira/Makefile.am
index 44bf1063..a318cf6e 100644
--- a/training/mira/Makefile.am
+++ b/training/mira/Makefile.am
@@ -1,15 +1,20 @@
-bin_PROGRAMS = kbest_mira \
- kbest_cut_mira
+bin_PROGRAMS = \
+ kbest_mira \
+ kbest_cut_mira \
+ ada_opt_sm
EXTRA_DIST = mira.py
+ada_opt_sm_SOURCES = ada_opt_sm.cc
+ada_opt_sm_LDFLAGS= -rdynamic
+ada_opt_sm_LDADD = ../utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a
+
kbest_mira_SOURCES = kbest_mira.cc
kbest_mira_LDFLAGS= -rdynamic
kbest_mira_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a
-
kbest_cut_mira_SOURCES = kbest_cut_mira.cc
kbest_cut_mira_LDFLAGS= -rdynamic
kbest_cut_mira_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a
-AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training/utils
diff --git a/training/mira/ada_opt_sm.cc b/training/mira/ada_opt_sm.cc
new file mode 100644
index 00000000..18ddbf8f
--- /dev/null
+++ b/training/mira/ada_opt_sm.cc
@@ -0,0 +1,198 @@
+#include "config.h"
+
+#include <boost/container/flat_map.hpp>
+#include <boost/shared_ptr.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "weights.h"
+#include "sparse_vector.h"
+#include "candidate_set.h"
+#include "sentence_metadata.h"
+#include "ns.h"
+#include "ns_docscorer.h"
+#include "verbose.h"
+#include "hg.h"
+#include "ff_register.h"
+#include "decoder.h"
+#include "fdict.h"
+#include "sampler.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+boost::shared_ptr<MT19937> rng;
+vector<training::CandidateSet> kbests;
+SparseVector<weight_t> G, u, lambdas;
+double pseudo_doc_decay = 0.9;
+
+bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("decoder_config,c",po::value<string>(),"[REQ] Decoder configuration file")
+ ("devset,d",po::value<string>(),"[REQ] Source/reference development set")
+ ("weights,w",po::value<string>(),"Initial feature weights file")
+ ("mt_metric,m",po::value<string>()->default_value("ibm_bleu"), "Scoring metric (ibm_bleu, nist_bleu, koehn_bleu, ter, combi)")
+ ("size",po::value<unsigned>()->default_value(0), "Process rank (for multiprocess mode)")
+ ("rank",po::value<unsigned>()->default_value(1), "Number of processes (for multiprocess mode)")
+ ("optimizer,o",po::value<unsigned>()->default_value(1), "Optimizer (Adaptive MIRA=1)")
+ ("fear,f",po::value<unsigned>()->default_value(1), "Fear selection (model-cost=1, maxcost=2, maxscore=3)")
+ ("hope,h",po::value<unsigned>()->default_value(1), "Hope selection (model+cost=1, mincost=2)")
+ ("eta0", po::value<double>()->default_value(0.1), "Initial step size")
+ ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)")
+ ("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Scale MT loss function by this amount")
+ ("pseudo_doc,e", "Use pseudo-documents for approximate scoring")
+ ("k_best_size,k", po::value<unsigned>()->default_value(500), "Size of hypothesis list to search for oracles");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config", po::value<string>(), "Configuration file")
+ ("help,H", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ ifstream config((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help")
+ || !conf->count("decoder_config")
+ || !conf->count("devset")) {
+ cerr << dcmdline_options << endl;
+ return false;
+ }
+ return true;
+}
+
+struct TrainingObserver : public DecoderObserver {
+ explicit TrainingObserver(const EvaluationMetric& m, const int k) : metric(m), kbest_size(k), cur_eval() {}
+
+ const EvaluationMetric& metric;
+ const int kbest_size;
+ const SegmentEvaluator* cur_eval;
+ SufficientStats pdoc;
+ unsigned hi, vi, fi; // hope, viterbi, fear
+
+ void SetSegmentEvaluator(const SegmentEvaluator* eval) {
+ cur_eval = eval;
+ }
+
+ virtual void NotifySourceParseFailure(const SentenceMetadata& smeta) {
+ cerr << "Failed to translate sentence with ID = " << smeta.GetSentenceID() << endl;
+ abort();
+ }
+
+ unsigned CostAugmentedDecode(const training::CandidateSet& cs,
+ const SparseVector<double>& w,
+ double alpha = 0) {
+ unsigned best_i = 0;
+ double best = -numeric_limits<double>::infinity();
+ for (unsigned i = 0; i < cs.size(); ++i) {
+ double s = cs[i].fmap.dot(w);
+ if (alpha)
+ s += alpha * metric.ComputeScore(cs[i].eval_feats + pdoc);
+ if (s > best) {
+ best = s;
+ best_i = i;
+ }
+ }
+ return best_i;
+ }
+
+ virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) {
+ pdoc *= pseudo_doc_decay;
+ const unsigned sent_id = smeta.GetSentenceID();
+ kbests[sent_id].AddUniqueKBestCandidates(*hg, kbest_size, cur_eval);
+ vi = CostAugmentedDecode(kbests[sent_id], lambdas);
+ hi = CostAugmentedDecode(kbests[sent_id], lambdas, 1.0);
+ fi = CostAugmentedDecode(kbests[sent_id], lambdas, -1.0);
+ cerr << sent_id << " ||| " << TD::GetString(kbests[sent_id][vi].ewords) << " ||| " << metric.ComputeScore(kbests[sent_id][vi].eval_feats + pdoc) << endl;
+ pdoc += kbests[sent_id][vi].eval_feats; // update pseudodoc stats
+ }
+};
+
+int main(int argc, char** argv) {
+ SetSilent(true); // turn off verbose decoder output
+ register_feature_functions();
+
+ po::variables_map conf;
+ if (!InitCommandLine(argc, argv, &conf)) return 1;
+
+ if (conf.count("random_seed"))
+ rng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
+ else
+ rng.reset(new MT19937);
+
+ string metric_name = UppercaseString(conf["mt_metric"].as<string>());
+ if (metric_name == "COMBI") {
+ cerr << "WARNING: 'combi' metric is no longer supported, switching to 'COMB:TER=-0.5;IBM_BLEU=0.5'\n";
+ metric_name = "COMB:TER=-0.5;IBM_BLEU=0.5";
+ } else if (metric_name == "BLEU") {
+ cerr << "WARNING: 'BLEU' is ambiguous, assuming 'IBM_BLEU'\n";
+ metric_name = "IBM_BLEU";
+ }
+ EvaluationMetric* metric = EvaluationMetric::Instance(metric_name);
+ DocumentScorer ds(metric, conf["devset"].as<string>());
+ cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl;
+ kbests.resize(ds.size());
+ double eta = 0.001;
+
+ ReadFile ini_rf(conf["decoder_config"].as<string>());
+ Decoder decoder(ini_rf.stream());
+
+ vector<weight_t>& dense_weights = decoder.CurrentWeightVector();
+ if (conf.count("weights")) {
+ Weights::InitFromFile(conf["weights"].as<string>(), &dense_weights);
+ Weights::InitSparseVector(dense_weights, &lambdas);
+ }
+
+ TrainingObserver observer(*metric, conf["k_best_size"].as<unsigned>());
+
+ unsigned num = 200;
+ for (unsigned iter = 1; iter < num; ++iter) {
+ lambdas.init_vector(&dense_weights);
+ unsigned sent_id = rng->next() * ds.size();
+ cerr << "Learning from sentence id: " << sent_id << endl;
+ observer.SetSegmentEvaluator(ds[sent_id]);
+ decoder.SetId(sent_id);
+ decoder.Decode(ds[sent_id]->src, &observer);
+ if (observer.vi != observer.hi) { // viterbi != hope
+ SparseVector<double> grad = kbests[sent_id][observer.fi].fmap;
+ grad -= kbests[sent_id][observer.hi].fmap;
+ cerr << "GRAD: " << grad << endl;
+ const SparseVector<double>& g = grad;
+#if HAVE_CXX11 && (__GNUC_MINOR__ > 4 || __GNUC__ > 4)
+ for (auto& gi : g) {
+#else
+ for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) {
+ const pair<unsigned,double>& gi = *it;
+#endif
+ if (gi.second) {
+ u[gi.first] += gi.second;
+ G[gi.first] += gi.second * gi.second;
+ lambdas.set_value(gi.first, 1.0); // this is a dummy value to trigger recomputation
+ }
+ }
+ for (SparseVector<double>::iterator it = lambdas.begin(); it != lambdas.end(); ++it) {
+ const pair<unsigned,double>& xi = *it;
+ double z = fabs(u[xi.first] / iter) - 0.0;
+ double s = 1;
+ if (u[xi.first] > 0) s = -1;
+ if (z > 0 && G[xi.first]) {
+ lambdas.set_value(xi.first, eta * s * z * iter / sqrt(G[xi.first]));
+ } else {
+ lambdas.set_value(xi.first, 0.0);
+ }
+ }
+ }
+ }
+ cerr << "Optimization complete.\n";
+ Weights::WriteToFile("-", dense_weights, true);
+ return 0;
+}
+
diff --git a/training/utils/candidate_set.cc b/training/utils/candidate_set.cc
index 33dae9a3..36f5b271 100644
--- a/training/utils/candidate_set.cc
+++ b/training/utils/candidate_set.cc
@@ -171,4 +171,19 @@ void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, c
Dedup();
}
+void CandidateSet::AddUniqueKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) {
+ typedef KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> K;
+ K kbest(hg, kbest_size);
+
+ for (unsigned i = 0; i < kbest_size; ++i) {
+ const K::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cs.push_back(Candidate(d->yield, d->feature_values));
+ if (scorer)
+ scorer->Evaluate(d->yield, &cs.back().eval_feats);
+ }
+ Dedup();
+}
+
}
diff --git a/training/utils/candidate_set.h b/training/utils/candidate_set.h
index 9d326ed0..17a650f5 100644
--- a/training/utils/candidate_set.h
+++ b/training/utils/candidate_set.h
@@ -47,7 +47,7 @@ class CandidateSet {
void ReadFromFile(const std::string& file);
void WriteToFile(const std::string& file) const;
void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL);
- // TODO add code to do unique k-best
+ void AddUniqueKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL);
// TODO add code to draw k samples
private:
diff --git a/utils/stringlib.h b/utils/stringlib.h
index f60b7867..2fdbfff8 100644
--- a/utils/stringlib.h
+++ b/utils/stringlib.h
@@ -242,6 +242,21 @@ void VisitTokens(std::string const& s,F f) {
VisitTokens(mp.p,mp.p+s.size(),f);
}
+template <class F>
+void VisitTokens(std::string const& s,F f, unsigned start) {
+ if (0) {
+ std::vector<std::string> ss=SplitOnWhitespace(s);
+ for (unsigned i=0;i<ss.size();++i)
+ f(ss[i]);
+ return;
+ }
+ //FIXME:
+ if (s.empty()) return;
+ mutable_c_str mp(s);
+ SLIBDBG("mp="<<mp.p);
+ VisitTokens(mp.p+start,mp.p+s.size(),f);
+}
+
inline void SplitCommandAndParam(const std::string& in, std::string* cmd, std::string* param) {
cmd->clear();
param->clear();
diff --git a/utils/tdict.cc b/utils/tdict.cc
index fd2b76cb..c99f1697 100644
--- a/utils/tdict.cc
+++ b/utils/tdict.cc
@@ -70,7 +70,7 @@ struct add_wordids {
}
-void TD::ConvertSentence(std::string const& s, std::vector<WordID>* ids) {
+void TD::ConvertSentence(std::string const& s, std::vector<WordID>* ids, unsigned start) {
ids->clear();
- VisitTokens(s,add_wordids(ids));
+ VisitTokens(s,add_wordids(ids),start);
}
diff --git a/utils/tdict.h b/utils/tdict.h
index 03afc2e6..bb19ecd5 100644
--- a/utils/tdict.h
+++ b/utils/tdict.h
@@ -9,7 +9,7 @@
struct TD {
static WordID end(); // next id to be assigned; [begin,end) give the non-reserved tokens seen so far
- static void ConvertSentence(std::string const& sent, std::vector<WordID>* ids);
+ static void ConvertSentence(std::string const& sent, std::vector<WordID>* ids, unsigned start=0);
static void GetWordIDs(const std::vector<std::string>& strings, std::vector<WordID>* ids);
static std::string GetString(const std::vector<WordID>& str);
static std::string GetString(WordID const* i,WordID const* e);