diff options
author | Patrick Simianer <p@simianer.de> | 2014-02-16 00:13:17 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2014-02-16 00:13:17 +0100 |
commit | bb5b6464826c765f4795381830acae158987f46b (patch) | |
tree | f29658c1278f2b450a81fa4207c8d809558e95cf | |
parent | 7bfe96c2a706d375362c054619f28dd40c7c33e8 (diff) | |
parent | 8015250ddd3983320b6e54ca7f1914a465bc8a59 (diff) |
Merge remote-tracking branch 'upstream/master'
-rwxr-xr-x | compound-split/compound-split.pl | 29 | ||||
-rwxr-xr-x | compound-split/install-data-deps.sh | 9 | ||||
-rw-r--r-- | decoder/bottom_up_parser.cc | 80 | ||||
-rw-r--r-- | mteval/ns.h | 1 | ||||
-rw-r--r-- | mteval/ns_docscorer.cc | 34 | ||||
-rw-r--r-- | mteval/scorer_test.cc | 9 | ||||
-rw-r--r-- | mteval/test_data/devset.txt | 2 | ||||
-rw-r--r-- | tests/system_tests/multigram/cdec.ini | 1 | ||||
-rw-r--r-- | tests/system_tests/multigram/g1.scfg | 4 | ||||
-rw-r--r-- | tests/system_tests/multigram/g2.scfg | 1 | ||||
-rw-r--r-- | tests/system_tests/multigram/gold.statistics | 3 | ||||
-rw-r--r-- | tests/system_tests/multigram/gold.stdout | 1 | ||||
-rw-r--r-- | tests/system_tests/multigram/input.txt | 1 | ||||
-rw-r--r-- | tests/system_tests/multigram/weights | 1 | ||||
-rw-r--r-- | training/mira/Makefile.am | 13 | ||||
-rw-r--r-- | training/mira/ada_opt_sm.cc | 198 | ||||
-rw-r--r-- | training/utils/candidate_set.cc | 15 | ||||
-rw-r--r-- | training/utils/candidate_set.h | 2 | ||||
-rw-r--r-- | utils/stringlib.h | 15 | ||||
-rw-r--r-- | utils/tdict.cc | 4 | ||||
-rw-r--r-- | utils/tdict.h | 2 |
21 files changed, 406 insertions, 19 deletions
diff --git a/compound-split/compound-split.pl b/compound-split/compound-split.pl index 62259146..93ac3b20 100755 --- a/compound-split/compound-split.pl +++ b/compound-split/compound-split.pl @@ -35,6 +35,7 @@ die "Don't know about language: $LANG\n" unless -d "./$LANG"; my $CONFIG="cdec-$LANG.ini"; die "Can't find $CONFIG" unless -f $CONFIG; die "--output must be '1best' or 'plf'\n" unless ($OUTPUT =~ /^(plf|1best)$/); +check_dependencies($CONFIG, $LANG); print STDERR "(Run with --help for options)\n"; print STDERR "LANGUAGE: $LANG\n"; print STDERR " OUTPUT: $OUTPUT\n"; @@ -146,3 +147,31 @@ Usage: $0 [OPTIONS] < file.txt EOT exit(1); } + +sub check_dependencies { + my ($conf, $lang) = @_; + my @files = (); + open F, "<$conf" or die "Can't read $conf: $!"; + while(<F>){ + chomp; + my @x = split /\s+/; + for my $f (@x) { + push @files, $f if ($f =~ /\.gz$/); + } + } + close F; + my $c = 0; + for my $file (@files) { + $c++ if -f $file; + } + if ($c != scalar @files) { + print STDERR <<EOT; +Missing data dependencies; to install, please run: + + $script_dir/install-data-deps.sh + +EOT + exit(1); + } +} + diff --git a/compound-split/install-data-deps.sh b/compound-split/install-data-deps.sh new file mode 100755 index 00000000..942bfdcd --- /dev/null +++ b/compound-split/install-data-deps.sh @@ -0,0 +1,9 @@ +#!/bin/sh +set -e + +data_version=csplit-data-01.tar.gz + +curl -f http://demo.clab.cs.cmu.edu/cdec/$data_version -o $data_version + +tar xzf $data_version + diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index 606b8d7e..8738c8f1 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -45,6 +45,7 @@ class PassiveChart { const float lattice_cost); void ApplyUnaryRules(const int i, const int j); + void TopoSortUnaries(); const vector<GrammarPtr>& grammars_; const Lattice& input_; @@ -57,6 +58,7 @@ class PassiveChart { TRulePtr goal_rule_; int goal_idx_; // index of goal node, if found const int lc_fid_; + vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars static WordID kGOAL; // [Goal] }; @@ -159,21 +161,78 @@ PassiveChart::PassiveChart(const string& goal, goal_cat_(TD::Convert(goal) * -1), goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")), goal_idx_(-1), - lc_fid_(FD::Convert("LatticeCost")) { + lc_fid_(FD::Convert("LatticeCost")), + unaries_() { act_chart_.resize(grammars_.size()); - for (unsigned i = 0; i < grammars_.size(); ++i) + for (unsigned i = 0; i < grammars_.size(); ++i) { act_chart_[i] = new ActiveChart(forest, *this); + const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules(); + for (unsigned j = 0; j < u.size(); ++j) + unaries_.push_back(u[j]); + } + TopoSortUnaries(); if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl; } +static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) { + if (mark[node] == 1) { + cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n"; + return false; // cycle detected + } else if (mark[node] == 2) { + return true; // already been + } + mark[node] = 1; + const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node); + if (nit != g.end()) { + const vector<TRulePtr>& edges = nit->second; + vector<bool> okay(edges.size(), true); + for (unsigned i = 0; i < edges.size(); ++i) { + okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark); + if (!okay[i]) { + cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl; + } + } + for (unsigned i = 0; i < edges.size(); ++i) { + if (okay[i]) u.push_back(edges[i]); + //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl; + } + } + mark[node] = 2; + return true; +} + +void PassiveChart::TopoSortUnaries() { + vector<TRulePtr> u(unaries_.size()); u.clear(); + map<int, vector<TRulePtr> > g; + map<int, int> mark; + //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl; + mark[goal_cat_] = 2; + for (unsigned i = 0; i < unaries_.size(); ++i) { + //cerr << "Adding: " << unaries_[i]->AsString() << endl; + g[unaries_[i]->f()[0]].push_back(unaries_[i]); + } + //m[unaries_[i]->lhs_].push_back(unaries_[i]); + for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) { + //cerr << "PROC: " << TD::Convert(-it->first) << endl; + if (mark[it->first] > 0) { + //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n"; + } else { + TopoSortVisit(it->first, u, g, mark); + } + } + unaries_.clear(); + for (int i = u.size() - 1; i >= 0; --i) + unaries_.push_back(u[i]); +} + void PassiveChart::ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, const float lattice_cost) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); - //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; + // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; new_edge->prev_i_ = r->prev_i; new_edge->prev_j_ = r->prev_j; new_edge->i_ = i; @@ -215,15 +274,14 @@ void PassiveChart::ApplyRules(const int i, void PassiveChart::ApplyUnaryRules(const int i, const int j) { const vector<int>& nodes = chart_(i,j); // reference is important! - for (unsigned gi = 0; gi < grammars_.size(); ++gi) { - if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue; - for (unsigned di = 0; di < nodes.size(); ++di) { - const WordID& cat = forest_->nodes_[nodes[di]].cat_; - const vector<TRulePtr>& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat); - for (unsigned ri = 0; ri < unaries.size(); ++ri) { - // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl; + for (unsigned di = 0; di < nodes.size(); ++di) { + const WordID& cat = forest_->nodes_[nodes[di]].cat_; + for (unsigned ri = 0; ri < unaries_.size(); ++ri) { + //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; + if (unaries_[ri]->f()[0] == cat) { + //cerr << " --MATCH\n"; const Hypergraph::TailNodeVector ant(1, nodes[di]); - ApplyRule(i, j, unaries[ri], ant, 0); // may update nodes + ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes } } } diff --git a/mteval/ns.h b/mteval/ns.h index ac7b0a23..153bf0b8 100644 --- a/mteval/ns.h +++ b/mteval/ns.h @@ -78,6 +78,7 @@ inline const SufficientStats operator-(const SufficientStats& a, const Sufficien struct SegmentEvaluator { virtual ~SegmentEvaluator(); virtual void Evaluate(const std::vector<WordID>& hyp, SufficientStats* out) const = 0; + std::string src; // this may not always be available }; // Instructions for implementing a new metric diff --git a/mteval/ns_docscorer.cc b/mteval/ns_docscorer.cc index 83bd1a29..242f134a 100644 --- a/mteval/ns_docscorer.cc +++ b/mteval/ns_docscorer.cc @@ -13,6 +13,40 @@ DocumentScorer::~DocumentScorer() {} DocumentScorer::DocumentScorer() {} +DocumentScorer::DocumentScorer(const EvaluationMetric* metric, + const string& src_ref_file) { + const WordID kDIV = TD::Convert("|||"); + assert(!src_ref_file.empty()); + cerr << "Loading source and references from " << src_ref_file << "...\n"; + ReadFile rf(src_ref_file); + istream& in = *rf.stream(); + unsigned lc = 0; + string src_ref; + vector<WordID> tmp; + vector<vector<WordID> > refs; + while(getline(in, src_ref)) { + ++lc; + size_t end_src = src_ref.find(" ||| "); + if (end_src == string::npos) { + cerr << "Expected SRC ||| REF [||| REF2 ||| REF3 ...] in line " << lc << endl; + abort(); + } + refs.clear(); + tmp.clear(); + TD::ConvertSentence(src_ref, &tmp, end_src + 5); + unsigned last = 0; + for (unsigned j = 0; j < tmp.size(); ++j) { + if (tmp[j] == kDIV) { + refs.push_back(vector<WordID>(tmp.begin() + last, tmp.begin() + j)); + last = j + 1; + } + } + refs.push_back(vector<WordID>(tmp.begin() + last, tmp.end())); + scorers_.push_back(metric->CreateSegmentEvaluator(refs)); + scorers_.back()->src = src_ref.substr(0, end_src); + } +} + void DocumentScorer::Init(const EvaluationMetric* metric, const vector<string>& ref_files, const string& src_file, diff --git a/mteval/scorer_test.cc b/mteval/scorer_test.cc index da07f154..cd27f020 100644 --- a/mteval/scorer_test.cc +++ b/mteval/scorer_test.cc @@ -3,6 +3,7 @@ #include <boost/test/unit_test.hpp> #include <boost/test/floating_point_comparison.hpp> +#include "ns_docscorer.h" #include "ns.h" #include "tdict.h" #include "scorer.h" @@ -223,4 +224,12 @@ BOOST_AUTO_TEST_CASE(NewScoreAPI) { //cerr << metric->ComputeScore(statse) << endl; } +BOOST_AUTO_TEST_CASE(HybridSourceReferenceFileFormat) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); + EvaluationMetric* metric = EvaluationMetric::Instance("IBM_BLEU"); + DocumentScorer ds(metric, path + "/devset.txt"); + BOOST_CHECK_EQUAL(2, ds.size()); + BOOST_CHECK_EQUAL("Quelltext hier .", ds[0]->src); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/mteval/test_data/devset.txt b/mteval/test_data/devset.txt new file mode 100644 index 00000000..f9135d98 --- /dev/null +++ b/mteval/test_data/devset.txt @@ -0,0 +1,2 @@ +Quelltext hier . ||| source text here . ||| original text . ||| some source text . +ein anderer Satz . ||| another sentence . ||| a different sentece . diff --git a/tests/system_tests/multigram/cdec.ini b/tests/system_tests/multigram/cdec.ini new file mode 100644 index 00000000..f31becb8 --- /dev/null +++ b/tests/system_tests/multigram/cdec.ini @@ -0,0 +1 @@ +formalism=scfg diff --git a/tests/system_tests/multigram/g1.scfg b/tests/system_tests/multigram/g1.scfg new file mode 100644 index 00000000..a3a59699 --- /dev/null +++ b/tests/system_tests/multigram/g1.scfg @@ -0,0 +1,4 @@ +[X] ||| [Z] ||| [1] ||| Top=1 +[Y] ||| foo ||| foo ||| F1=1 +[Z] ||| [Z] [Y] ||| [1] [2] ||| W1=1 +[Z] ||| [Y] ||| [1] ||| W2=1 diff --git a/tests/system_tests/multigram/g2.scfg b/tests/system_tests/multigram/g2.scfg new file mode 100644 index 00000000..40962517 --- /dev/null +++ b/tests/system_tests/multigram/g2.scfg @@ -0,0 +1 @@ +[Y] ||| bar ||| bar ||| F2=1 diff --git a/tests/system_tests/multigram/gold.statistics b/tests/system_tests/multigram/gold.statistics new file mode 100644 index 00000000..ef23a685 --- /dev/null +++ b/tests/system_tests/multigram/gold.statistics @@ -0,0 +1,3 @@ +-lm_nodes 11 +-lm_edges 12 +-lm_paths 2 diff --git a/tests/system_tests/multigram/gold.stdout b/tests/system_tests/multigram/gold.stdout new file mode 100644 index 00000000..d675fa44 --- /dev/null +++ b/tests/system_tests/multigram/gold.stdout @@ -0,0 +1 @@ +foo bar diff --git a/tests/system_tests/multigram/input.txt b/tests/system_tests/multigram/input.txt new file mode 100644 index 00000000..2aef01a0 --- /dev/null +++ b/tests/system_tests/multigram/input.txt @@ -0,0 +1 @@ +<seg id="0" grammar="g1.scfg" grammar1="g2.scfg">foo bar</seg> diff --git a/tests/system_tests/multigram/weights b/tests/system_tests/multigram/weights new file mode 100644 index 00000000..a6b6698a --- /dev/null +++ b/tests/system_tests/multigram/weights @@ -0,0 +1 @@ +Glue -1 diff --git a/training/mira/Makefile.am b/training/mira/Makefile.am index 44bf1063..a318cf6e 100644 --- a/training/mira/Makefile.am +++ b/training/mira/Makefile.am @@ -1,15 +1,20 @@ -bin_PROGRAMS = kbest_mira \ - kbest_cut_mira +bin_PROGRAMS = \ + kbest_mira \ + kbest_cut_mira \ + ada_opt_sm EXTRA_DIST = mira.py +ada_opt_sm_SOURCES = ada_opt_sm.cc +ada_opt_sm_LDFLAGS= -rdynamic +ada_opt_sm_LDADD = ../utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a + kbest_mira_SOURCES = kbest_mira.cc kbest_mira_LDFLAGS= -rdynamic kbest_mira_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a - kbest_cut_mira_SOURCES = kbest_cut_mira.cc kbest_cut_mira_LDFLAGS= -rdynamic kbest_cut_mira_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training/utils diff --git a/training/mira/ada_opt_sm.cc b/training/mira/ada_opt_sm.cc new file mode 100644 index 00000000..18ddbf8f --- /dev/null +++ b/training/mira/ada_opt_sm.cc @@ -0,0 +1,198 @@ +#include "config.h" + +#include <boost/container/flat_map.hpp> +#include <boost/shared_ptr.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "filelib.h" +#include "stringlib.h" +#include "weights.h" +#include "sparse_vector.h" +#include "candidate_set.h" +#include "sentence_metadata.h" +#include "ns.h" +#include "ns_docscorer.h" +#include "verbose.h" +#include "hg.h" +#include "ff_register.h" +#include "decoder.h" +#include "fdict.h" +#include "sampler.h" + +using namespace std; +namespace po = boost::program_options; + +boost::shared_ptr<MT19937> rng; +vector<training::CandidateSet> kbests; +SparseVector<weight_t> G, u, lambdas; +double pseudo_doc_decay = 0.9; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("decoder_config,c",po::value<string>(),"[REQ] Decoder configuration file") + ("devset,d",po::value<string>(),"[REQ] Source/reference development set") + ("weights,w",po::value<string>(),"Initial feature weights file") + ("mt_metric,m",po::value<string>()->default_value("ibm_bleu"), "Scoring metric (ibm_bleu, nist_bleu, koehn_bleu, ter, combi)") + ("size",po::value<unsigned>()->default_value(0), "Process rank (for multiprocess mode)") + ("rank",po::value<unsigned>()->default_value(1), "Number of processes (for multiprocess mode)") + ("optimizer,o",po::value<unsigned>()->default_value(1), "Optimizer (Adaptive MIRA=1)") + ("fear,f",po::value<unsigned>()->default_value(1), "Fear selection (model-cost=1, maxcost=2, maxscore=3)") + ("hope,h",po::value<unsigned>()->default_value(1), "Hope selection (model+cost=1, mincost=2)") + ("eta0", po::value<double>()->default_value(0.1), "Initial step size") + ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)") + ("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Scale MT loss function by this amount") + ("pseudo_doc,e", "Use pseudo-documents for approximate scoring") + ("k_best_size,k", po::value<unsigned>()->default_value(500), "Size of hypothesis list to search for oracles"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,H", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") + || !conf->count("decoder_config") + || !conf->count("devset")) { + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +struct TrainingObserver : public DecoderObserver { + explicit TrainingObserver(const EvaluationMetric& m, const int k) : metric(m), kbest_size(k), cur_eval() {} + + const EvaluationMetric& metric; + const int kbest_size; + const SegmentEvaluator* cur_eval; + SufficientStats pdoc; + unsigned hi, vi, fi; // hope, viterbi, fear + + void SetSegmentEvaluator(const SegmentEvaluator* eval) { + cur_eval = eval; + } + + virtual void NotifySourceParseFailure(const SentenceMetadata& smeta) { + cerr << "Failed to translate sentence with ID = " << smeta.GetSentenceID() << endl; + abort(); + } + + unsigned CostAugmentedDecode(const training::CandidateSet& cs, + const SparseVector<double>& w, + double alpha = 0) { + unsigned best_i = 0; + double best = -numeric_limits<double>::infinity(); + for (unsigned i = 0; i < cs.size(); ++i) { + double s = cs[i].fmap.dot(w); + if (alpha) + s += alpha * metric.ComputeScore(cs[i].eval_feats + pdoc); + if (s > best) { + best = s; + best_i = i; + } + } + return best_i; + } + + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + pdoc *= pseudo_doc_decay; + const unsigned sent_id = smeta.GetSentenceID(); + kbests[sent_id].AddUniqueKBestCandidates(*hg, kbest_size, cur_eval); + vi = CostAugmentedDecode(kbests[sent_id], lambdas); + hi = CostAugmentedDecode(kbests[sent_id], lambdas, 1.0); + fi = CostAugmentedDecode(kbests[sent_id], lambdas, -1.0); + cerr << sent_id << " ||| " << TD::GetString(kbests[sent_id][vi].ewords) << " ||| " << metric.ComputeScore(kbests[sent_id][vi].eval_feats + pdoc) << endl; + pdoc += kbests[sent_id][vi].eval_feats; // update pseudodoc stats + } +}; + +int main(int argc, char** argv) { + SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + + if (conf.count("random_seed")) + rng.reset(new MT19937(conf["random_seed"].as<uint32_t>())); + else + rng.reset(new MT19937); + + string metric_name = UppercaseString(conf["mt_metric"].as<string>()); + if (metric_name == "COMBI") { + cerr << "WARNING: 'combi' metric is no longer supported, switching to 'COMB:TER=-0.5;IBM_BLEU=0.5'\n"; + metric_name = "COMB:TER=-0.5;IBM_BLEU=0.5"; + } else if (metric_name == "BLEU") { + cerr << "WARNING: 'BLEU' is ambiguous, assuming 'IBM_BLEU'\n"; + metric_name = "IBM_BLEU"; + } + EvaluationMetric* metric = EvaluationMetric::Instance(metric_name); + DocumentScorer ds(metric, conf["devset"].as<string>()); + cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl; + kbests.resize(ds.size()); + double eta = 0.001; + + ReadFile ini_rf(conf["decoder_config"].as<string>()); + Decoder decoder(ini_rf.stream()); + + vector<weight_t>& dense_weights = decoder.CurrentWeightVector(); + if (conf.count("weights")) { + Weights::InitFromFile(conf["weights"].as<string>(), &dense_weights); + Weights::InitSparseVector(dense_weights, &lambdas); + } + + TrainingObserver observer(*metric, conf["k_best_size"].as<unsigned>()); + + unsigned num = 200; + for (unsigned iter = 1; iter < num; ++iter) { + lambdas.init_vector(&dense_weights); + unsigned sent_id = rng->next() * ds.size(); + cerr << "Learning from sentence id: " << sent_id << endl; + observer.SetSegmentEvaluator(ds[sent_id]); + decoder.SetId(sent_id); + decoder.Decode(ds[sent_id]->src, &observer); + if (observer.vi != observer.hi) { // viterbi != hope + SparseVector<double> grad = kbests[sent_id][observer.fi].fmap; + grad -= kbests[sent_id][observer.hi].fmap; + cerr << "GRAD: " << grad << endl; + const SparseVector<double>& g = grad; +#if HAVE_CXX11 && (__GNUC_MINOR__ > 4 || __GNUC__ > 4) + for (auto& gi : g) { +#else + for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) { + const pair<unsigned,double>& gi = *it; +#endif + if (gi.second) { + u[gi.first] += gi.second; + G[gi.first] += gi.second * gi.second; + lambdas.set_value(gi.first, 1.0); // this is a dummy value to trigger recomputation + } + } + for (SparseVector<double>::iterator it = lambdas.begin(); it != lambdas.end(); ++it) { + const pair<unsigned,double>& xi = *it; + double z = fabs(u[xi.first] / iter) - 0.0; + double s = 1; + if (u[xi.first] > 0) s = -1; + if (z > 0 && G[xi.first]) { + lambdas.set_value(xi.first, eta * s * z * iter / sqrt(G[xi.first])); + } else { + lambdas.set_value(xi.first, 0.0); + } + } + } + } + cerr << "Optimization complete.\n"; + Weights::WriteToFile("-", dense_weights, true); + return 0; +} + diff --git a/training/utils/candidate_set.cc b/training/utils/candidate_set.cc index 33dae9a3..36f5b271 100644 --- a/training/utils/candidate_set.cc +++ b/training/utils/candidate_set.cc @@ -171,4 +171,19 @@ void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, c Dedup(); } +void CandidateSet::AddUniqueKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) { + typedef KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> K; + K kbest(hg, kbest_size); + + for (unsigned i = 0; i < kbest_size; ++i) { + const K::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cs.push_back(Candidate(d->yield, d->feature_values)); + if (scorer) + scorer->Evaluate(d->yield, &cs.back().eval_feats); + } + Dedup(); +} + } diff --git a/training/utils/candidate_set.h b/training/utils/candidate_set.h index 9d326ed0..17a650f5 100644 --- a/training/utils/candidate_set.h +++ b/training/utils/candidate_set.h @@ -47,7 +47,7 @@ class CandidateSet { void ReadFromFile(const std::string& file); void WriteToFile(const std::string& file) const; void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL); - // TODO add code to do unique k-best + void AddUniqueKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL); // TODO add code to draw k samples private: diff --git a/utils/stringlib.h b/utils/stringlib.h index f60b7867..2fdbfff8 100644 --- a/utils/stringlib.h +++ b/utils/stringlib.h @@ -242,6 +242,21 @@ void VisitTokens(std::string const& s,F f) { VisitTokens(mp.p,mp.p+s.size(),f); } +template <class F> +void VisitTokens(std::string const& s,F f, unsigned start) { + if (0) { + std::vector<std::string> ss=SplitOnWhitespace(s); + for (unsigned i=0;i<ss.size();++i) + f(ss[i]); + return; + } + //FIXME: + if (s.empty()) return; + mutable_c_str mp(s); + SLIBDBG("mp="<<mp.p); + VisitTokens(mp.p+start,mp.p+s.size(),f); +} + inline void SplitCommandAndParam(const std::string& in, std::string* cmd, std::string* param) { cmd->clear(); param->clear(); diff --git a/utils/tdict.cc b/utils/tdict.cc index fd2b76cb..c99f1697 100644 --- a/utils/tdict.cc +++ b/utils/tdict.cc @@ -70,7 +70,7 @@ struct add_wordids { } -void TD::ConvertSentence(std::string const& s, std::vector<WordID>* ids) { +void TD::ConvertSentence(std::string const& s, std::vector<WordID>* ids, unsigned start) { ids->clear(); - VisitTokens(s,add_wordids(ids)); + VisitTokens(s,add_wordids(ids),start); } diff --git a/utils/tdict.h b/utils/tdict.h index 03afc2e6..bb19ecd5 100644 --- a/utils/tdict.h +++ b/utils/tdict.h @@ -9,7 +9,7 @@ struct TD { static WordID end(); // next id to be assigned; [begin,end) give the non-reserved tokens seen so far - static void ConvertSentence(std::string const& sent, std::vector<WordID>* ids); + static void ConvertSentence(std::string const& sent, std::vector<WordID>* ids, unsigned start=0); static void GetWordIDs(const std::vector<std::string>& strings, std::vector<WordID>* ids); static std::string GetString(const std::vector<WordID>& str); static std::string GetString(WordID const* i,WordID const* e); |