summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore3
-rw-r--r--decoder/Makefile.am7
-rw-r--r--decoder/apply_fsa_models.h65
-rw-r--r--decoder/apply_models.cc10
-rw-r--r--decoder/cfg.cc656
-rw-r--r--decoder/cfg.h382
-rw-r--r--decoder/cfg_binarize.h90
-rw-r--r--decoder/cfg_format.h134
-rw-r--r--decoder/cfg_options.h79
-rw-r--r--decoder/cfg_test.cc98
-rw-r--r--decoder/decoder.cc99
-rw-r--r--decoder/decoder.h13
-rw-r--r--decoder/ff_factory.cc47
-rw-r--r--decoder/ff_factory.h38
-rw-r--r--decoder/ff_klm.cc3
-rw-r--r--decoder/hg_cfg.h54
-rw-r--r--decoder/sentence_metadata.h3
-rw-r--r--decoder/viterbi.cc29
-rw-r--r--extractor/Makefile.am50
-rw-r--r--extractor/compile.cc32
-rw-r--r--extractor/data_array.cc4
-rw-r--r--extractor/data_array.h3
-rw-r--r--extractor/data_array_test.cc4
-rw-r--r--extractor/extract.cc253
-rw-r--r--extractor/grammar_extractor.cc6
-rw-r--r--extractor/grammar_extractor.h5
-rw-r--r--extractor/grammar_extractor_test.cc4
-rw-r--r--extractor/mocks/mock_data_array.h1
-rw-r--r--extractor/mocks/mock_rule_factory.h6
-rw-r--r--extractor/mocks/mock_sampler.h4
-rw-r--r--extractor/precomputation.cc110
-rw-r--r--extractor/precomputation.h24
-rw-r--r--extractor/precomputation_test.cc41
-rw-r--r--extractor/rule_factory.cc7
-rw-r--r--extractor/rule_factory.h3
-rw-r--r--extractor/rule_factory_test.cc8
-rw-r--r--extractor/run_extractor.cc23
-rw-r--r--extractor/sampler.cc31
-rw-r--r--extractor/sampler.h4
-rw-r--r--extractor/sampler_test.cc30
-rw-r--r--extractor/suffix_array.cc4
-rw-r--r--extractor/suffix_array_test.cc6
-rw-r--r--extractor/translation_table.cc14
-rw-r--r--extractor/translation_table_test.cc10
-rw-r--r--extractor/vocabulary.cc11
-rw-r--r--extractor/vocabulary.h23
-rw-r--r--extractor/vocabulary_test.cc45
-rw-r--r--mteval/external_scorer.h3
-rw-r--r--mteval/ns_ext.cc2
-rw-r--r--mteval/ns_ter.cc2
-rw-r--r--mteval/scorer.h4
-rwxr-xr-xtests/run-system-tests.pl4
-rwxr-xr-xtests/tools/flex-diff.pl46
-rw-r--r--training/crf/Makefile.am4
-rw-r--r--training/crf/mpi_adagrad_optimize.cc394
-rw-r--r--training/crf/mpi_batch_optimize.cc6
-rw-r--r--utils/Makefile.am8
-rw-r--r--utils/small_vector.h12
-rw-r--r--utils/small_vector_test.cc12
-rw-r--r--utils/sv_test.cc24
-rw-r--r--utils/swap_pod.h23
-rw-r--r--utils/value_array.h9
-rw-r--r--utils/weights.cc10
63 files changed, 1108 insertions, 2031 deletions
diff --git a/.gitignore b/.gitignore
index 4acc057f..f964fa0c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -37,6 +37,7 @@ decoder/Makefile.in
decoder/bin/
decoder/cdec
decoder/dict_test
+decoder/sv_test
decoder/ff_test
decoder/grammar_test
decoder/hg_test
@@ -71,6 +72,7 @@ extools/score_grammar
extools/sg_lexer.cc
extractor/*_test
extractor/compile
+extractor/extract
extractor/run_extractor
gi/clda/src/clda
gi/markov_al/ml
@@ -163,6 +165,7 @@ training/liblbfgs/bin/
training/liblbfgs/ll_test
training/model1
training/mpi_batch_optimize
+training/mpi_adagrad_optimize
training/mpi_compute_cllh
training/mpi_em_optimize
training/mpi_extract_features
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 8280b22c..b735756d 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -32,13 +32,8 @@ EXTRA_DIST = test_data rule_lexer.ll
libcdec_a_SOURCES = \
JSON_parser.h \
aligner.h \
- apply_fsa_models.h \
apply_models.h \
bottom_up_parser.h \
- cfg.h \
- cfg_binarize.h \
- cfg_format.h \
- cfg_options.h \
csplit.h \
decoder.h \
earley_composer.h \
@@ -74,7 +69,6 @@ libcdec_a_SOURCES = \
freqdict.h \
grammar.h \
hg.h \
- hg_cfg.h \
hg_intersect.h \
hg_io.h \
hg_remove_eps.h \
@@ -105,7 +99,6 @@ libcdec_a_SOURCES = \
bottom_up_parser.cc \
cdec.cc \
cdec_ff.cc \
- cfg.cc \
csplit.cc \
decoder.cc \
earley_composer.cc \
diff --git a/decoder/apply_fsa_models.h b/decoder/apply_fsa_models.h
deleted file mode 100644
index 6561c70c..00000000
--- a/decoder/apply_fsa_models.h
+++ /dev/null
@@ -1,65 +0,0 @@
-#ifndef _APPLY_FSA_MODELS_H_
-#define _APPLY_FSA_MODELS_H_
-
-#include <string>
-#include <iostream>
-#include "feature_vector.h"
-#include "named_enum.h"
-
-struct FsaFeatureFunction;
-struct Hypergraph;
-struct SentenceMetadata;
-struct HgCFG;
-
-
-#define FSA_BY(X,t) \
- X(t,BU_CUBE,) \
- X(t,BU_FULL,) \
- X(t,EARLEY,) \
-
-#define FSA_BY_TYPE FsaBy
-
-DECLARE_NAMED_ENUM(FSA_BY)
-
-struct ApplyFsaBy {
-/*enum {
- BU_CUBE,
- BU_FULL,
- EARLEY,
- N_ALGORITHMS
- };*/
- int pop_limit; // only applies to BU_FULL so far
- bool IsBottomUp() const {
- return algorithm==BU_FULL || algorithm==BU_CUBE;
- }
- int BottomUpAlgorithm() const;
- FsaBy algorithm;
- std::string name() const;
- friend inline std::ostream &operator << (std::ostream &o,ApplyFsaBy const& c) {
- o << c.name();
- if (c.algorithm==BU_CUBE)
- o << "("<<c.pop_limit<<")";
- return o;
- }
- explicit ApplyFsaBy(FsaBy alg, int poplimit=200);
- ApplyFsaBy(std::string const& name, int poplimit=200);
- ApplyFsaBy(const ApplyFsaBy &o) : algorithm(o.algorithm) { }
- static std::string all_names(); // space separated
-};
-
-void ApplyFsaModels(HgCFG &hg_or_cfg_in,
- const SentenceMetadata& smeta,
- const FsaFeatureFunction& fsa,
- DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this)
- ApplyFsaBy const& cfg,
- Hypergraph* out);
-
-void ApplyFsaModels(Hypergraph const& ih,
- const SentenceMetadata& smeta,
- const FsaFeatureFunction& fsa,
- DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this)
- ApplyFsaBy const& cfg,
- Hypergraph* out);
-
-
-#endif
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index 4cd8b36f..9a8f60be 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -192,8 +192,6 @@ public:
assert(num_nodes >= 2);
int goal_id = num_nodes - 1;
int pregoal = goal_id - 1;
- int every = 1;
- if (num_nodes > 100) every = 10;
assert(in.nodes_[pregoal].out_edges_.size() == 1);
if (!SILENT) cerr << " ";
int has = 0;
@@ -563,12 +561,14 @@ struct NoPruningRescorer {
int num_nodes = in.nodes_.size();
int goal_id = num_nodes - 1;
int pregoal = goal_id - 1;
- int every = 1;
- if (num_nodes > 100) every = 10;
assert(in.nodes_[pregoal].out_edges_.size() == 1);
if (!SILENT) cerr << " ";
+ int has = 0;
for (int i = 0; i < in.nodes_.size(); ++i) {
- if (!SILENT && i % every == 0) cerr << '.';
+ if (!SILENT) {
+ int needs = (50 * i / in.nodes_.size());
+ while (has < needs) { cerr << '.'; ++has; }
+ }
ProcessOneNode(i, i == goal_id);
}
if (!SILENT) cerr << endl;
diff --git a/decoder/cfg.cc b/decoder/cfg.cc
deleted file mode 100644
index d6ee651a..00000000
--- a/decoder/cfg.cc
+++ /dev/null
@@ -1,656 +0,0 @@
-#include "cfg.h"
-#include "hg.h"
-#include "cfg_format.h"
-#include "cfg_binarize.h"
-#include "hash.h"
-#include "batched_append.h"
-#include <limits>
-#include "fast_lexical_cast.hpp"
-//#include "indices_after.h"
-#include "show.h"
-#include "null_traits.h"
-
-#define DUNIQ(x) x
-#define DBIN(x)
-#define DSP(x) x
-//SP:binarize by splitting.
-#define DCFG(x) IF_CFG_DEBUG(x)
-
-#undef CFG_FOR_RULES
-#define CFG_FOR_RULES(i,expr) \
- for (CFG::NTs::const_iterator n=nts.begin(),nn=nts.end();n!=nn;++n) { \
- NT const& nt=*n; \
- for (CFG::Ruleids::const_iterator ir=nt.ruleids.begin(),er=nt.ruleids.end();ir!=er;++ir) { \
- RuleHandle i=*ir; \
- expr; \
- } \
- }
-
-
-using namespace std;
-
-typedef CFG::Rule Rule;
-typedef CFG::NTOrder NTOrder;
-typedef CFG::RHS RHS;
-typedef CFG::BinRhs BinRhs;
-
-/////index ruleids:
-void CFG::UnindexRules() {
- for (NTs::iterator n=nts.begin(),nn=nts.end();n!=nn;++n)
- n->ruleids.clear();
-}
-
-void CFG::ReindexRules() {
- UnindexRules();
- for (int i=0,e=rules.size();i<e;++i)
- if (!rules[i].is_null())
- nts[rules[i].lhs].ruleids.push_back(i);
-}
-
-//////topo order:
-namespace {
-typedef std::vector<char> Seen; // 0 = unseen, 1 = seen+finished, 2 = open (for cycle detection; seen but not finished)
-enum { UNSEEN=0,SEEN,OPEN };
-
-// bottom -> top topo order (rev head->tails topo)
-template <class OutOrder>
-struct CFGTopo {
-// meaningless efficiency alternative: close over all the args except ni - so they're passed as a single pointer. also makes visiting tail_nts simpler.
- CFG const& cfg;
- OutOrder outorder;
- std::ostream *cerrp;
- CFGTopo(CFG const& cfg,OutOrder const& outorder,std::ostream *cerrp=&std::cerr)
- : cfg(cfg),outorder(outorder),cerrp(cerrp) // closure over args
- , seen(cfg.nts.size()) { }
-
- Seen seen;
- void operator()(CFG::NTHandle ni) {
- char &seenthis=seen[ni];
- if (seenthis==UNSEEN) {
- seenthis=OPEN;
-
- CFG::NT const& nt=cfg.nts[ni];
- for (CFG::Ruleids::const_iterator i=nt.ruleids.begin(),e=nt.ruleids.end();i!=e;++i) {
- Rule const& r=cfg.rules[*i];
- r.visit_rhs_nts(*this); // recurse.
- }
-
- *outorder++=ni; // dfs finishing time order = reverse topo.
- seenthis=SEEN;
- } else if (cerrp && seenthis==OPEN) {
- std::ostream &cerr=*cerrp;
- cerr<<"WARNING: CFG Topo order attempt failed: NT ";
- cfg.print_nt_name(cerr,ni);
- cerr<<" already reached from goal(top) ";
- cfg.print_nt_name(cerr,cfg.goal_nt);
- cerr<<". Continuing to reorder, but it's not fully topological.\n";
- }
- }
-
-};
-
-template <class O>
-void DoCFGTopo(CFG const& cfg,CFG::NTHandle goal,O const& o,std::ostream *w=0) {
- CFGTopo<O> ct(cfg,o,w);
- ct(goal);
-}
-
-}//ns
-
-// you would need to do this only if you didn't build from hg, or you Binarize without bin_topo option. note: this doesn't sort the list of rules; it's assumed that if you care about the topo order you'll iterate over nodes.
-void CFG::OrderNTsTopo(NTOrder *o_,std::ostream *cycle_complain) {
- NTOrder &o=*o_;
- o.resize(nts.size());
- DoCFGTopo(*this,goal_nt,o.begin(),cycle_complain);
-}
-
-
-/////sort/uniq:
-namespace {
-RHS null_rhs(1,INT_MIN);
-
-//sort
-struct ruleid_best_first {
- CFG::Rules const* rulesp;
- bool operator()(int a,int b) const { // true if a >(prob for ruleid) b
- return (*rulesp)[b].p < (*rulesp)[a].p;
- }
-};
-
-//uniq
-struct prob_pos {
- prob_pos() {}
- prob_pos(prob_t prob,int pos) : prob(prob),pos(pos) {}
- prob_t prob;
- int pos;
- bool operator <(prob_pos const& o) const { return prob<o.prob; }
-};
-}//ns
-
-int CFG::UniqRules(NTHandle ni) {
- typedef HASH_MAP<RHS,prob_pos,boost::hash<RHS> > BestRHS; // faster to use trie? maybe.
- BestRHS bestp; // once inserted, the position part (output index) never changes. but the prob may be improved (overwrite ruleid at that position).
- HASH_MAP_EMPTY(bestp,null_rhs);
- Ruleids &adj=nts[ni].ruleids;
- Ruleids oldadj=adj;
- int newpos=0;
- for (int i=0,e=oldadj.size();i!=e;++i) { // this beautiful complexity is to ensure that adj' is a subsequence of adj (without duplicates)
- int ri=oldadj[i];
- Rule const& r=rules[ri];
- prob_pos pi(r.p,newpos);
- prob_pos &oldpi=get_default(bestp,r.rhs,pi);
- if (oldpi.pos==newpos) {// newly inserted
- adj[newpos++]=ri;
- } else {
- SHOWP(DUNIQ,"Uniq duplicate: ") SHOW4(DUNIQ,oldpi.prob,pi.prob,oldpi.pos,newpos);
- SHOW(DUNIQ,ShowRule(ri));
- SHOW(DUNIQ,ShowRule(adj[oldpi.pos]));
- if (oldpi.prob<pi.prob) { // we improve prev. best (overwrite it @old pos)
- oldpi.prob=pi.prob;
- adj[oldpi.pos]=ri; // replace worse rule w/ better
- }
- }
-
- }
- // post: newpos = number of new adj
- adj.resize(newpos);
- return newpos;
-}
-
-void CFG::SortLocalBestFirst(NTHandle ni) {
- ruleid_best_first r;
- r.rulesp=&rules;
- Ruleids &adj=nts[ni].ruleids;
- std::stable_sort(adj.begin(),adj.end(),r);
-}
-
-
-/////binarization:
-namespace {
-
-BinRhs null_bin_rhs(std::numeric_limits<int>::min(),std::numeric_limits<int>::min());
-
-// index i >= N.size()? then it's in M[i-N.size()]
-//WordID first,WordID second,
-string BinStr(BinRhs const& b,CFG::NTs const& N,CFG::NTs const& M)
-{
- int nn=N.size();
- ostringstream o;
-#undef BinNameOWORD
-#define BinNameOWORD(w) \
- do { \
- int n=w; if (n>0) o << TD::Convert(n); \
- else { \
- int i=-n; \
- if (i<nn) o<<N[i].from<<i; else o<<M[i-nn].from; \
- } \
- } while(0)
-
- BinNameOWORD(b.first);
- o<<'+';
- BinNameOWORD(b.second);
- return o.str();
-}
-
-string BinStr(RHS const& r,CFG::NTs const& N,CFG::NTs const& M)
-{
- int nn=N.size();
- ostringstream o;
- for (int i=0,e=r.size();i!=e;++i) {
- if (i)
- o<<'+';
- BinNameOWORD(r[i]);
- }
- return o.str();
-}
-
-
-WordID BinName(BinRhs const& b,CFG::NTs const& N,CFG::NTs const& M)
-{
- return TD::Convert(BinStr(b,N,M));
-}
-
-WordID BinName(RHS const& b,CFG::NTs const& N,CFG::NTs const& M)
-{
- return TD::Convert(BinStr(b,N,M));
-}
-
-/*
-template <class Rhs>
-struct null_for;
-
-
-template <>
-struct null_for<BinRhs> {
- static BinRhs null;
-};
-
-template <>
-struct null_for<RHS> {
- static RHS null;
-};
-
-template <>
-BinRhs null_traits<BinRhs>::xnull(std::numeric_limits<int>::min(),std::numeric_limits<int>::min());
-
-template <>
-RHS null_traits<RHS>::xnull(1,std::numeric_limits<int>::min());
-*/
-
-template <class Rhs>
-struct add_virtual_rules {
- typedef CFG::RuleHandle RuleHandle;
- typedef CFG::NTHandle NTHandle;
- CFG::NTs &nts,new_nts;
- CFG::Rules &rules, new_rules;
-// above will be appended at the end, so we don't have to worry about iterator invalidation
- WordID newnt; //negative of NTHandle, or positive => unary lexical item (not to binarize). fit for rhs of a rule
- RuleHandle newruleid;
- typedef HASH_MAP<Rhs,WordID,boost::hash<Rhs> > R2L;
- R2L rhs2lhs; // an rhs maps to this -virtntid, or original id if length 1
- bool name_nts;
- add_virtual_rules(CFG &cfg,bool name_nts=false) : nts(cfg.nts),rules(cfg.rules),newnt(-nts.size()),newruleid(rules.size()),name_nts(name_nts) {
- HASH_MAP_EMPTY(rhs2lhs,null_traits<Rhs>::xnull);
- }
- NTHandle get_virt(Rhs const& r) {
- NTHandle nt=get_default(rhs2lhs,r,newnt);
- SHOW(DBIN,newnt) SHOWP(DBIN,"bin="<<BinStr(r,nts,new_nts)<<"=>") SHOW(DBIN,nt);
- if (newnt==nt) {
- create(r);
- }
- return nt;
- }
- inline void set_nt_name(Rhs const& r) {
- if (name_nts)
- new_nts.back().from.nt=BinName(r,nts,new_nts);
- }
- inline void create_nt(Rhs const& rhs) {
- new_nts.push_back(CFG::NT(newruleid++));
- set_nt_name(rhs);
- }
- inline void create_rule(Rhs const& rhs) {
- new_rules.push_back(CFG::Rule(-newnt,rhs));
- --newnt;
- }
- inline void create_adding(Rhs const& rhs) {
- NTHandle nt=get_default(rhs2lhs,rhs,newnt);
- assert(nt==newnt);
- create(rhs);
- }
- inline void create(Rhs const& rhs) {
- SHOWP(DSP,"Create ") SHOW3(DSP,newnt,newruleid,BinStr(rhs,nts,new_nts))
- create_nt(rhs);
- create_rule(rhs);
- assert(newruleid==rules.size()+new_rules.size());assert(-newnt==nts.size()+new_nts.size());
- }
-
- ~add_virtual_rules() {
- append_rules();
- }
- void append_rules() {
- // marginally more efficient
- batched_append_swap(nts,new_nts);
- batched_append_swap(rules,new_rules);
- }
- inline bool have(Rhs const& rhs,NTHandle &h) const {
- if (rhs.size()==1) { // stop creating virtual unary rules.
- h=rhs[0];
- return true;
- }
- typename R2L::const_iterator i=rhs2lhs.find(rhs);
- if (i==rhs2lhs.end())
- return false;
- h=i->second;
- return true;
- }
- //HACK: prevent this for instantiating for BinRhs. we need to use rule index because we'll be adding rules before we can update.
- // returns 1 per replaced NT (0,1, or 2)
- inline std::string Str(Rhs const& rhs) const {
- return BinStr(rhs,nts,new_nts);
- }
-
- template <class RHSi>
- int split_rhs(RHSi &rhs,bool only_free=false,bool only_reusing_1=false) {
- typedef WordID const* WP;
- //TODO: don't actually build substrings of rhs; define key type that stores ref to rhs in new_nts by index (because it may grow), and also allows an int* [b,e) range to serve as key (i.e. don't insert the latter type of key).
- int n=rhs.size();
- if (n<=2) return 0;
- int longest1=1; // all this other stuff is not uninitialized when used, based on checking this and other things (it's complicated, learn to prove theorems, gcc)
- int mid=n/2;
- int best_k;
- enum {HAVE_L=-1,HAVE_NONE=0,HAVE_R=1};
- int have1=HAVE_NONE; // will mean we already have some >1 length prefix or suffix as a virt. (it's free). if we have both we use it immediately and return.
- NTHandle ntr,ntl;
- NTHandle bestntr,bestntl;
- WP b=&rhs.front(),e=b+n;
- WP wk=b;
- SHOWM3(DSP,"Split",Str(rhs),only_free,only_reusing_1);
- int rlen=n;
- for (int k=1;k<n-1;++k) {
- //TODO: handle length 1 l and r parts without explicitly building Rhs?
- ++wk; assert(k==wk-b);
- --rlen; assert(rlen==n-k);
- Rhs l(b,wk);
- if (have(l,ntl)) {
- if (k>1) { SHOWM3(DSP,"Have l",k,n,Str(l)) }
- Rhs r(wk,e);
- if (have(r,ntr)) {
- SHOWM3(DSP,"Have r too",k,n,Str(r))
- rhs.resize(2);
- rhs[0]=ntl;
- rhs[1]=ntr;
- return 2;
- } else if (k>longest1) {
- longest1=k;
- have1=HAVE_L;
- bestntl=ntl;
- best_k=k;
- }
- } else if (rlen>longest1) { // > or >= favors l or r branching, maybe. who cares.
- Rhs r(wk,e);
- if (have(r,ntr)) {
- longest1=rlen;
- if (rlen>1) { SHOWM3(DSP,"Have r (only) ",k,n,Str(r)) }
- have1=HAVE_R;
- bestntr=ntr;
- best_k=k;
- }
- }
- //TODO: swap order of consideration (l first or r?) depending on pre/post midpoint? one will be useless to check for beating the longest single match so far. check that second
-
- }
- // now we know how we're going to split the rule; what follows is just doing the actual splitting:
-
- if (only_free) {
- if (have1==HAVE_NONE)
- return 0;
- if (have1==HAVE_L) {
- rhs.erase(rhs.begin()+1,rhs.begin()+best_k); //erase [1..best_k)
- rhs[0]=bestntl;
- } else {
- assert(have1==HAVE_R);
- rhs.erase(rhs.begin()+best_k+1,rhs.end()); // erase (best_k..)
- rhs[best_k]=bestntr;
- }
- return 1;
- }
- /* now we have to add some new virtual rules.
- some awkward constraints:
-
- 1. can't resize rhs until you save copy of l or r split portion
-
- 2. can't create new rule until you finished modifying rhs (this is why we store newnt then create). due to vector push_back invalidation. perhaps we could bypass this by reserving sufficient space first before a splitting pass (# rules and nts created is <= 2 * # of rules being passed over)
-
- */
- if (have1==HAVE_NONE) { // default: split down middle.
- DSP(assert(longest1==1));
- WP m=b+mid;
- if (n%2==0) {
- WP i=b;
- WP j=m;
- for (;i!=m;++i,++j)
- if (*i!=*j) goto notleqr;
- // [...mid]==[mid...]!
- RHS l(b,m); // 1. // this is equal to RHS(m,e).
- rhs.resize(2);
- rhs[0]=rhs[1]=newnt; //2.
- create_adding(l);
- return 1; // only had to create 1 total when splitting down middle when l==r
- }
- notleqr:
- if (only_reusing_1) return 0;
- best_k=mid; // rounds down
- if (mid==1) {
- RHS r(m,e); //1.
- rhs.resize(2);
- rhs[1]=newnt; //2.
- create_adding(r);
- return 1;
- } else {
- Rhs l(b,m);
- Rhs r(m,e); // 1.
- rhs.resize(2);
- rhs[0]=newnt;
- rhs[1]=newnt-1; // 2.
- create_adding(l);
- create_adding(r);
- return 2;
- }
- }
- WP best_wk=b+best_k;
- //we build these first because adding rules may invalidate the underlying pointers (we end up binarizing already split virt rules)!.
- //wow, that decision (not to use index into new_nts instead of pointer to rhs), while adding new nts to it really added some pain.
- if (have1==HAVE_L) {
- Rhs r(best_wk,e); //1.
- rhs.resize(2);
- rhs[0]=bestntl;
- DSP(assert(best_wk<e-1)); // because we would have returned having both if rhs was singleton
- rhs[1]=newnt; //2.
- create_adding(r);
- } else {
- DSP(assert(have1==HAVE_R));
- DSP(assert(best_wk>b+1)); // because we would have returned having both if lhs was singleton
- Rhs l(b,best_wk); //1.
- rhs.resize(2);
- rhs[0]=newnt; //2.
- rhs[1]=bestntr;
- create_adding(l);
- }
- return 1;
- }
-};
-
-}//ns
-
-void CFG::BinarizeSplit(CFGBinarize const& b) {
- add_virtual_rules<RHS> v(*this,b.bin_name_nts);
- CFG_FOR_RULES(i,v.split_rhs(rules[i].rhs,false,false));
- Rules &newr=v.new_rules;
-#undef CFG_FOR_VIRT
-#define CFG_FOR_VIRT(r,expr) \
- for (int i=0,e=newr.size();i<e;++i) { \
- Rule &r=newr[i];expr; } // NOTE: must use indices since we'll be adding rules as we iterate.
-
- int n_changed_total=0;
- int n_changed=0; // quiets a warning
-#define CFG_SPLIT_PASS(N,free,just1) \
- for (int pass=0;pass<b.N;++pass) { \
- n_changed=0; \
- CFG_FOR_VIRT(r,n_changed+=v.split_rhs(r.rhs,free,just1)); \
- if (!n_changed) { \
- break; \
- } n_changed_total+=n_changed; }
-
- CFG_SPLIT_PASS(split_passes,false,false)
- if (n_changed==0) return;
- CFG_SPLIT_PASS(split_share1_passes,false,true)
- CFG_SPLIT_PASS(split_free_passes,true,false)
-
-}
-
-void CFG::Binarize(CFGBinarize const& b) {
- if (!b.Binarizing()) return;
- cerr << "Binarizing "<<b<<endl;
- if (b.bin_thresh>0)
- BinarizeThresh(b);
- if (b.bin_split)
- BinarizeSplit(b);
- if (b.bin_l2r)
- BinarizeL2R(false,b.bin_name_nts);
- if (b.bin_topo) //TODO: more efficient (at least for l2r) maintenance of order?
- OrderNTsTopo();
-
-}
-
-namespace {
-}
-
-void CFG::BinarizeThresh(CFGBinarize const& b) {
- throw runtime_error("TODO: some fancy linked list thing - see NOTES.partial.binarize");
-}
-
-
-void CFG::BinarizeL2R(bool bin_unary,bool name) {
- add_virtual_rules<BinRhs> v(*this,name);
-cerr << "Binarizing left->right " << (bin_unary?"real to unary":"stop at binary") <<endl;
- HASH_MAP<BinRhs,NTHandle,boost::hash<BinRhs> > bin2lhs; // we're going to hash cons rather than build an explicit trie from right to left.
- HASH_MAP_EMPTY(bin2lhs,null_bin_rhs);
- // iterate using indices and not iterators because we'll be adding to both nts and rules list? we could instead pessimistically reserve space for both, but this is simpler. also: store original end of nts since we won't need to reprocess newly added ones.
- int rhsmin=bin_unary?0:1;
- //NTs new_nts;
- //Rules new_rules;
- //TODO: this could be factored easily into in-place (append to new_* like below) and functional (nondestructive copy) versions (copy orig to target and append to target)
-// int newnt=nts.size(); // we're going to store binary rhs with -nt to keep distinct from words (>=0)
-// int newruleid=rules.size();
- BinRhs bin;
- CFG_FOR_RULES(ruleid,
-/* for (NTs::const_iterator n=nts.begin(),nn=nts.end();n!=nn;++n) {
- NT const& nt=*n;
- for (Ruleids::const_iterator ir=nt.ruleids.begin(),er=nt.ruleids.end();ir!=er;++ir) {
- RuleHandle ruleid=*ir;*/
-// SHOW2(DBIN,ruleid,ShowRule(ruleid));
- Rule & rule=rules[ruleid];
- RHS &rhs=rule.rhs; // we're going to binarize this while adding newly created rules to new_...
- if (rhs.empty()) continue;
- int r=rhs.size()-2; // loop below: [r,r+1) is to be reduced into a (maybe new) binary NT
- if (rhsmin<=r) { // means r>=0 also
- bin.second=rhs[r+1];
- int bin_to; // the replacement for bin
-// assert(newruleid==rules.size()+new_rules.size());assert(newnt==nts.size()+new_nts.size());
- // also true at start/end of loop:
- for (;;) { // pairs from right to left (normally we leave the last pair alone)
- bin.first=rhs[r];
- bin_to=v.get_virt(bin);
-/* bin_to=get_default(bin2lhs,bin,v.newnt);
-// SHOW(DBIN,r) SHOW(DBIN,newnt) SHOWP(DBIN,"bin="<<BinStr(bin,nts,new_nts)<<"=>") SHOW(DBIN,bin_to);
- if (v.newnt==bin_to) { // it's new!
- new_nts.push_back(NT(newruleid++));
- //now newnt is the index of the last (after new_nts is appended) nt. bin is its rhs. bin_to is its lhs
- new_rules.push_back(Rule(newnt,bin));
- ++newnt;
- if (name) new_nts.back().from.nt=BinName(bin,nts,new_nts);
- }
-*/
- bin.second=bin_to;
- --r;
- if (r<rhsmin) {
- rhs[rhsmin]=bin_to;
- rhs.resize(rhsmin+1);
- break;
- }
- }
- })
- /*
- }
- }
- */
-#if 0
- // marginally more efficient
- batched_append_swap(nts,new_nts);
- batched_append_swap(rules,new_rules);
-//#else
- batched_append(nts,new_nts);
- batched_append(rules,new_rules);
-#endif
-}
-
-namespace {
-inline int nt_index(int nvar,Hypergraph::TailNodeVector const& t,bool target_side,int w) {
- assert(w<0 || (target_side&&w==0));
- return t[target_side?-w:nvar];
-}
-}
-
-void CFG::Init(Hypergraph const& hg,bool target_side,bool copy_features,bool push_weights) {
- uninit=false;
- hg_=&hg;
- Hypergraph::NodeProbs np;
- goal_inside=hg.ComputeNodeViterbi(&np);
- pushed_inside=push_weights ? goal_inside : prob_t(1);
- int nn=hg.nodes_.size(),ne=hg.edges_.size();
- nts.resize(nn);
- goal_nt=nn-1;
- rules.resize(ne);
- for (int i=0;i<nn;++i) {
- nts[i].ruleids=hg.nodes_[i].in_edges_;
- hg.SetNodeOrigin(i,nts[i].from);
- }
- for (int i=0;i<ne;++i) {
- Rule &cfgr=rules[i];
- Hypergraph::Edge const& e=hg.edges_[i];
- prob_t &crp=cfgr.p;
- crp=e.edge_prob_;
- cfgr.lhs=e.head_node_;
- IF_CFG_TRULE(cfgr.rule=e.rule_;)
- if (copy_features) cfgr.f=e.feature_values_;
- if (push_weights) crp /=np[e.head_node_];
- TRule const& er=*e.rule_;
- vector<WordID> const& rule_rhs=target_side?er.e():er.f();
- int nr=rule_rhs.size();
- RHS &rhs_out=cfgr.rhs;
- rhs_out.resize(nr);
- Hypergraph::TailNodeVector const& tails=e.tail_nodes_;
- int nvar=0;
- //split out into separate target_side, source_side loops?
- for (int j=0;j<nr;++j) {
- WordID w=rule_rhs[j];
- if (w>0)
- rhs_out[j]=w;
- else {
- int n=nt_index(nvar,tails,target_side,w);
- ++nvar;
- if (push_weights) crp*=np[n];
- rhs_out[j]=-n;
- }
- }
- assert(nvar==er.Arity());
- assert(nvar==tails.size());
- }
-}
-
-void CFG::Clear() {
- rules.clear();
- nts.clear();
- goal_nt=-1;
- hg_=0;
-}
-
-namespace {
-CFGFormat form;
-}
-
-void CFG::PrintRule(std::ostream &o,RuleHandle rulei,CFGFormat const& f) const {
- Rule const& r=rules[rulei];
- f.print_lhs(o,*this,r.lhs);
- f.print_rhs(o,*this,r.rhs.begin(),r.rhs.end());
- f.print_features(o,r.p,r.f);
- IF_CFG_TRULE(if (r.rule) o<<f.partsep<<*r.rule;)
-}
-void CFG::PrintRule(std::ostream &o,RuleHandle rulei) const {
- PrintRule(o,rulei,form);
-}
-string CFG::ShowRule(RuleHandle i) const {
- ostringstream o;PrintRule(o,i);return o.str();
-}
-
-void CFG::Print(std::ostream &o,CFGFormat const& f) const {
- assert(!uninit);
- if (!f.goal_nt_name.empty()) {
- o << '['<<f.goal_nt_name <<']';
- WordID rhs=-goal_nt;
- f.print_rhs(o,*this,&rhs,&rhs+1);
- if (pushed_inside!=prob_t::One())
- f.print_features(o,pushed_inside);
- o<<'\n';
- }
- CFG_FOR_RULES(i,PrintRule(o,i,f);o<<'\n';)
-}
-
-void CFG::Print(std::ostream &o) const {
- Print(o,form);
-}
-
-std::ostream &operator<<(std::ostream &o,CFG const &x) {
- x.Print(o);
- return o;
-}
diff --git a/decoder/cfg.h b/decoder/cfg.h
deleted file mode 100644
index aeeacb83..00000000
--- a/decoder/cfg.h
+++ /dev/null
@@ -1,382 +0,0 @@
-#ifndef CDEC_CFG_H
-#define CDEC_CFG_H
-
-#define DVISITRULEID(x)
-
-// for now, debug means remembering and printing the TRule behind each CFG rule
-#ifndef CFG_DEBUG
-# define CFG_DEBUG 1
-#endif
-#ifndef CFG_KEEP_TRULE
-# define CFG_KEEP_TRULE 0
-#endif
-
-#if CFG_DEBUG
-# define IF_CFG_DEBUG(x) x;
-#else
-# define IF_CFG_DEBUG(x)
-#endif
-
-#if CFG_KEEP_TRULE
-# define IF_CFG_TRULE(x) x;
-#else
-# define IF_CFG_TRULE(x)
-#endif
-
-/* for target FSA intersection, we want to produce a simple (feature weighted) CFG using the target projection of a hg. this is essentially isomorphic to the hypergraph, and we're copying part of the rule info (we'll maintain a pointer to the original hg edge for posterity/debugging; and perhaps avoid making a copy of the feature vector). but we may also want to support CFG read from text files (w/ features), without needing to have a backing hypergraph. so hg pointer may be null? multiple types of CFG? always copy the feature vector? especially if we choose to binarize, we won't want to rely on 1:1 alignment w/ hg
-
- question: how much does making a copy (essentially) of hg simplify things? is the space used worth it? is the node in/out edges index really that much of a waste? is the use of indices that annoying?
-
- answer: access to the source side and target side rhs is less painful - less indirection; if not a word (w>0) then -w is the NT index. also, non-synchronous ops like binarization make sense. hg is a somewhat bulky encoding of non-synchronous forest
-
- using indices to refer to NTs saves space (32 bit index vs 64 bit pointer) and allows more efficient ancillary maps for e.g. chart info (if we used pointers to actual node structures, it would be tempting to add various void * or other slots for use by mapped-during-computation ephemera)
- */
-
-#include <sstream>
-#include <string>
-#include <vector>
-#include "feature_vector.h"
-#include "small_vector.h"
-#include "wordid.h"
-#include "tdict.h"
-#include "trule.h"
-#include "prob.h"
-//#include "int_or_pointer.h"
-#include "small_vector.h"
-#include "nt_span.h"
-#include <algorithm>
-#include "indices_after.h"
-#include <boost/functional/hash.hpp>
-
-class Hypergraph;
-class CFGFormat; // #include "cfg_format.h"
-class CFGBinarize; // #include "cfg_binarize.h"
-
-#undef CFG_MUST_EQ
-#define CFG_MUST_EQ(f) if (!(o.f==f)) return false;
-
-struct CFG {
- typedef int RuleHandle;
- typedef int NTHandle;
- typedef SmallVector<WordID> RHS; // same as in trule rhs: >0 means token, <=0 means -node index (not variable index)
- typedef std::vector<RuleHandle> Ruleids;
-
- void print_nt_name(std::ostream &o,NTHandle n) const {
- o << nts[n].from << n;
- }
- std::string nt_name(NTHandle n) const {
- std::ostringstream o;
- print_nt_name(o,n);
- return o.str();
- }
- void print_rhs_name(std::ostream &o,WordID w) const {
- if (w<=0) print_nt_name(o,-w);
- else o<<TD::Convert(w);
- }
- std::string rhs_name(WordID w) const {
- if (w<=0) return nt_name(-w);
- else return TD::Convert(w);
- }
- static void static_print_nt_name(std::ostream &o,NTHandle n) {
- o<<'['<<n<<']';
- }
- static std::string static_nt_name(NTHandle w) {
- std::ostringstream o;
- static_print_nt_name(o,w);
- return o.str();
- }
- static void static_print_rhs_name(std::ostream &o,WordID w) {
- if (w<=0) static_print_nt_name(o,-w);
- else o<<TD::Convert(w);
- }
- static std::string static_rhs_name(WordID w) {
- std::ostringstream o;
- static_print_rhs_name(o,w);
- return o.str();
- }
-
- typedef std::pair<WordID,WordID> BinRhs;
-
- struct Rule {
- std::size_t hash_impl() const {
- using namespace boost;
- std::size_t h=lhs;
- hash_combine(h,rhs);
- hash_combine(h,p);
- hash_combine(h,f);
- return h;
- }
- bool operator ==(Rule const &o) const {
- CFG_MUST_EQ(lhs)
- CFG_MUST_EQ(rhs)
- CFG_MUST_EQ(p)
- CFG_MUST_EQ(f)
- return true;
- }
- inline bool operator!=(Rule const& o) const { return !(o==*this); }
-
- // for binarizing - no costs/probs
- Rule() : lhs(-1) { }
- bool is_null() const { return lhs<0; }
- void set_null() { lhs=-1; rhs.clear();f.clear(); IF_CFG_TRULE(rule.reset();) }
-
- Rule(int lhs,BinRhs const& binrhs) : lhs(lhs),rhs(2),p(1) {
- rhs[0]=binrhs.first;
- rhs[1]=binrhs.second;
- }
- Rule(int lhs,RHS const& rhs) : lhs(lhs),rhs(rhs),p(1) {
- }
-
- int lhs; // index into nts
- RHS rhs;
- prob_t p; // h unused for now (there's nothing admissable, and p is already using 1st pass inside as pushed toward top)
- SparseVector<double> f; // may be empty, unless copy_features on Init
- IF_CFG_TRULE(TRulePtr rule;)
- int size() const { // for stats only
- return rhs.size();
- }
- void Swap(Rule &o) {
- using namespace std;
- swap(lhs,o.lhs);
- swap(rhs,o.rhs);
- swap(p,o.p);
- swap(f,o.f);
- IF_CFG_TRULE(swap(rule,o.rule);)
- }
- friend inline void swap(Rule &a,Rule &b) {
- a.Swap(b);
- }
-
- template<class V>
- void visit_rhs_nts(V &v) const {
- for (RHS::const_iterator i=rhs.begin(),e=rhs.end();i!=e;++i) {
- WordID w=*i;
- if (w<=0)
- v(-w);
- }
- }
- template<class V>
- void visit_rhs_nts(V const& v) const {
- for (RHS::const_iterator i=rhs.begin(),e=rhs.end();i!=e;++i) {
- WordID w=*i;
- if (w<=0)
- v(-w);
- }
- }
-
- template<class V>
- void visit_rhs(V &v) const {
- for (RHS::const_iterator i=rhs.begin(),e=rhs.end();i!=e;++i) {
- WordID w=*i;
- if (w<=0)
- v.visit_nt(-w);
- else
- v.visit_t(w);
- }
- }
-
- // returns 0 or 1 (# of non null rules in this rule).
- template <class O>
- bool reorder_from(O &order,NTHandle removed=-1) {
- for (RHS::iterator i=rhs.begin(),e=rhs.end();i!=e;++i) {
- WordID &w=*i;
- if (w<=0) {
- int oldnt=-w;
- NTHandle newnt=(NTHandle)order[oldnt]; // e.g. unsigned to int (-1) conversion should be ok
- if (newnt==removed) {
- set_null();
- return false;
- }
- w=-newnt;
- }
- }
- return true;
- }
- };
-
- struct NT {
- NT() { }
- explicit NT(RuleHandle r) : ruleids(1,r) { }
- std::size_t hash_impl() const { using namespace boost; return hash_value(ruleids); }
- bool operator ==(NT const &o) const {
- return ruleids==o.ruleids; // don't care about from
- }
- inline bool operator!=(NT const& o) const { return !(o==*this); }
- Ruleids ruleids; // index into CFG rules with lhs = this NT. aka in_edges_
- NTSpan from; // optional name - still needs id to disambiguate
- void Swap(NT &o) {
- using namespace std;
- swap(ruleids,o.ruleids);
- swap(from,o.from);
- }
- friend inline void swap(NT &a,NT &b) {
- a.Swap(b);
- }
- };
-
- CFG() : hg_() { uninit=true; }
-
- // provided hg will have weights pushed up to root
- CFG(Hypergraph const& hg,bool target_side=true,bool copy_features=false,bool push_weights=true) {
- Init(hg,target_side,copy_features,push_weights);
- }
- bool Uninitialized() const { return uninit; }
- void Clear();
- bool Empty() const { return nts.empty(); }
- void UnindexRules(); // save some space?
- void ReindexRules(); // scan over rules and rebuild NT::ruleids (e.g. after using UniqRules)
- int UniqRules(NTHandle ni); // keep only the highest prob rule for each rhs and lhs=nt - doesn't remove from Rules; just removes from nts[ni].ruleids. keeps the same order in this sense: for a given signature (rhs), that signature's first representative in the old ruleids will become the new position of the best. as a consequence, if you SortLocalBestFirst() then UniqRules(), the result is still best first. but you may also call this on unsorted ruleids. returns number of rules kept
- inline int UniqRules() {
- int nkept=0;
- for (int i=0,e=nts.size();i!=e;++i) nkept+=UniqRules(i);
- return nkept;
- }
- int rules_size() const {
- const int sz=rules.size();
- int sum=sz;
- for (int i=0;i<sz;++i)
- sum+=rules[i].size();
- return sum;
- }
-
- void SortLocalBestFirst(NTHandle ni); // post: nts[ni].ruleids lists rules from highest p to lowest. when doing best-first earley intersection/parsing, you don't want to use the global marginal viterbi; you want to ignore outside in ordering edges for a node, so call this. stable in case of ties
- inline void SortLocalBestFirst() {
- for (int i=0,e=nts.size();i!=e;++i) SortLocalBestFirst(i);
- }
- void Init(Hypergraph const& hg,bool target_side=true,bool copy_features=false,bool push_weights=true);
- void Print(std::ostream &o,CFGFormat const& format) const; // see cfg_format.h
- void Print(std::ostream &o) const; // default format
- void PrintRule(std::ostream &o,RuleHandle rulei,CFGFormat const& format) const;
- void PrintRule(std::ostream &o,RuleHandle rulei) const;
- std::string ShowRule(RuleHandle rulei) const;
- void Swap(CFG &o) { // make sure this includes all fields (easier to see here than in .cc)
- using namespace std;
- swap(uninit,o.uninit);
- swap(hg_,o.hg_);
- swap(goal_inside,o.goal_inside);
- swap(pushed_inside,o.pushed_inside);
- swap(rules,o.rules);
- swap(nts,o.nts);
- swap(goal_nt,o.goal_nt);
- }
-
- //NOTE: this checks exact equality of data structures only. it's well known that CFG equivalence (and intersection==empty) test is undecidable.
- bool operator ==(CFG const &o) const {
- // doesn't matter: hg, goal_inside
- CFG_MUST_EQ(uninit)
- CFG_MUST_EQ(pushed_inside)
- CFG_MUST_EQ(goal_nt)
- CFG_MUST_EQ(nts)
- CFG_MUST_EQ(rules)
- return true;
- }
- inline bool operator!=(CFG const& o) const { return !(o==*this); }
-
- typedef std::vector<NTHandle> NTOrder; // a list of nts, in definition-before-use order.
-
- //perhaps: give up on templated Order and move the below to .cc (NTOrder should be fine)
-
- // post: iterating nts 0,1... the same as order[0],order[1],... ; return number of non-null rules (these aren't actually deleted)
- // pre: order is (without duplicates) a range of NTHandle
- template <class Order>
- int ReorderNTs(Order const& order) {
- using namespace std;
- int nn=nts.size();
-#if 0
- NTs newnts(order.size()); // because the (sub)permutation order may have e.g. 1<->4
- int ni=0;
- for (typename Order::const_iterator i=order.begin(),e=order.end();i!=e;++i) {
- assert(*i<nn);
- swap(newnts[ni++],nts[*i]);
- }
- swap(newnts,nts);
-#endif
- indices_after remap_nti;
- remap_nti.init_inverse_order(nn,order);
- remap_nti.do_moves_swap(nts);// (equally efficient (or more?) than the disabled nt swapping above.
- goal_nt=remap_nti.map[goal_nt]; // remap goal, of course
- // fix rule ids
- return RemapRules(remap_nti.map,(NTHandle)indices_after::REMOVED);
- }
-
- // return # of kept rules (not null)
- template <class NTHandleRemap>
- int RemapRules(NTHandleRemap const& remap_nti,NTHandle removed=-1) {
- int n_non_null=0;
- for (int i=0,e=rules.size();i<e;++i)
- n_non_null+=rules[i].reorder_from(remap_nti,removed);
- return n_non_null;
- }
-
- // call after rules are indexed.
- template <class V>
- void VisitRuleIds(V &v) {
- for (int i=0,e=nts.size();i<e;++i) {
- SHOWM(DVISITRULEID,"VisitRuleIds nt",i);
- for (Ruleids::const_iterator j=nts[i].ruleids.begin(),jj=nts[i].ruleids.end();j!=jj;++j) {
- SHOWM2(DVISITRULEID,"VisitRuleIds",i,*j);
- v(*j);
- }
- }
-
- }
- template <class V>
- void VisitRuleIds(V const& v) {
- for (int i=0,e=nts.size();i<e;++i)
- for (Ruleids::const_iterator j=nts[i].ruleids.begin(),jj=nts[i].ruleids.end();j!=jj;++j)
- v(*j);
- }
-
- // no index needed
- template <class V>
- void VisitRulesUnindexed(V const &v) {
- for (int i=0,e=rules.size();i<e;++i)
- if (!rules[i].is_null())
- v(i,rules[i]);
- }
-
-
-
- void OrderNTsTopo(NTOrder *o,std::ostream *cycle_complain=0); // places NTs in defined (completely) bottom-up before use order. this is actually reverse topo order considering edges from lhs->rhs.
- // you would need to do this only if you didn't build from hg, or you Binarize without bin_topo option.
- // note: this doesn't sort the list of rules; it's assumed that if you care about the topo order you'll iterate over nodes.
- // cycle_complain means to warn in case of back edges. it's not necessary to prevent inf. loops. you get some order that's not topo if there are loops. starts from goal_nt, of course.
-
- void OrderNTsTopo(std::ostream *cycle_complain=0) {
- NTOrder o;
- OrderNTsTopo(&o,cycle_complain);
- ReorderNTs(o);
- }
-
- void BinarizeL2R(bool bin_unary=false,bool name_nts=false);
- void Binarize(CFGBinarize const& binarize_options); // see cfg_binarize.h for docs
- void BinarizeSplit(CFGBinarize const& binarize_options);
- void BinarizeThresh(CFGBinarize const& binarize_options); // maybe unbundle opts later
-
- typedef std::vector<NT> NTs;
- NTs nts;
- typedef std::vector<Rule> Rules;
- Rules rules;
- int goal_nt;
- prob_t goal_inside,pushed_inside; // when we push viterbi weights to goal, we store the removed probability in pushed_inside
-protected:
- bool uninit;
- Hypergraph const* hg_; // shouldn't be used for anything, esp. after binarization
- // rules/nts will have same index as hg edges/nodes
-};
-
-inline std::size_t hash_value(CFG::Rule const& r) {
- return r.hash_impl();
-}
-
-inline std::size_t hash_value(CFG::NT const& r) {
- return r.hash_impl();
-}
-
-inline void swap(CFG &a,CFG &b) {
- a.Swap(b);
-}
-
-std::ostream &operator<<(std::ostream &o,CFG const &x);
-
-#endif
diff --git a/decoder/cfg_binarize.h b/decoder/cfg_binarize.h
deleted file mode 100644
index ae06f8bf..00000000
--- a/decoder/cfg_binarize.h
+++ /dev/null
@@ -1,90 +0,0 @@
-#ifndef CFG_BINARIZE_H
-#define CFG_BINARIZE_H
-
-#include <iostream>
-
-/*
- binarization: decimate rhs of original rules until their rhs have been reduced to length 2 (or 1 if bin_unary). also decimate rhs of newly binarized rules until length 2. newly created rules are all binary (never unary/nullary).
-
- bin_name_nts: nts[i].from will be initialized, including adding new names to TD
-
- bin_l2r: right-branching (a (b c)) means suffixes are shared. if requested, the only other option that matters is bin_unary
-
- otherwise, greedy binarization: the pairs that are most frequent in the rules are binarized, one at a time. this should be done efficiently: each pair has a count of and list of its left and right adjacent pair+count (or maybe a non-count collapsed list of adjacent instances). this can be efficiently updated when a pair is chosen for replacement by a new virtual NT.
- */
-
-struct CFGBinarize {
- int bin_thresh;
- bool bin_l2r;
- int bin_unary;
- bool bin_name_nts;
- bool bin_topo;
- bool bin_split;
- int split_passes,split_share1_passes,split_free_passes;
- template <class Opts> // template to support both printable_opts and boost nonprintable
- void AddOptions(Opts *opts) {
- opts->add_options()
- ("cfg_binarize_threshold", defaulted_value(&bin_thresh),"(if >0) repeatedly binarize CFG rhs bigrams which appear at least this many times, most frequent first. resulting rules may be 1,2, or >2-ary. this happens before the other types of binarization.")
-// ("cfg_binarize_unary_threshold", defaulted_value(&bin_unary),"if >0, a rule-completing production A->BC may be binarized as A->U U->BC if U->BC would be used at least this many times. this happens last.")
- ("cfg_binarize_split", defaulted_value(&bin_split),"(DeNero et al) for each rule until binarized, pick a split point k of L->r[0..n) to make rules L->V1 V2, V1->r[0..k) V2->r[k..n), to minimize the number of new rules created")
- ("cfg_split_full_passes", defaulted_value(&split_passes),"pass through the virtual rules only (up to) this many times (all real rules will have been split if not already binary)")
- ("cfg_split_share1_passes", defaulted_value(&split_share1_passes),"after the full passes, for up to this many times split when at least 1 of the items has been seen before")
- ("cfg_split_free_passes", defaulted_value(&split_free_passes),"only split off from virtual nts pre/post nts that already exist - could check for interior phrases but after a few splits everything should be tiny already.")
- ("cfg_binarize_l2r", defaulted_value(&bin_l2r),"force left to right (a (b (c d))) binarization (ignore _at threshold)")
- ("cfg_binarize_name_nts", defaulted_value(&bin_name_nts),"create named virtual NT tokens e.g. 'A12+the' when binarizing 'B->[A12] the cat'")
- ("cfg_binarize_topo", defaulted_value(&bin_topo),"reorder nonterminals after binarization to maintain definition before use (topological order). otherwise the virtual NTs will all appear after the regular NTs")
- ;
- }
- void Validate() {
- if (bin_thresh>0&&!bin_l2r) {
-// std::cerr<<"\nWARNING: greedy binarization not yet supported; using l2r (right branching) instead.\n";
-// bin_l2r=true;
- }
- if (false && bin_l2r && bin_split) { // actually, split may be slightly incomplete due to finite number of passes.
- std::cerr<<"\nWARNING: l2r and split are both complete binarization and redundant. Using split.\n";
- bin_l2r=false;
- }
-
- }
-
- bool Binarizing() const {
- return bin_split || bin_l2r || bin_thresh>0;
- }
- void set_defaults() {
- bin_split=false;
- bin_topo=false;
- bin_thresh=0;
- bin_unary=0;
- bin_name_nts=true;
- bin_l2r=false;
- split_passes=10;split_share1_passes=0;split_free_passes=10;
- }
- CFGBinarize() { set_defaults(); }
- void print(std::ostream &o) const {
- o<<'(';
- if (!Binarizing())
- o << "Unbinarized";
- else {
- if (bin_unary)
- o << "unary-sharing ";
- if (bin_thresh)
- o<<"greedy bigram count>="<<bin_thresh<<" ";
- if (bin_l2r)
- o << "left->right";
- else
- o << "DeNero greedy split";
- if (bin_name_nts)
- o << " named-NTs";
- if (bin_topo)
- o<<" preserve-topo-order";
- }
- o<<')';
- }
- friend inline std::ostream &operator<<(std::ostream &o,CFGBinarize const& me) {
- me.print(o); return o;
- }
-
-};
-
-
-#endif
diff --git a/decoder/cfg_format.h b/decoder/cfg_format.h
deleted file mode 100644
index d12da261..00000000
--- a/decoder/cfg_format.h
+++ /dev/null
@@ -1,134 +0,0 @@
-#ifndef CFG_FORMAT_H
-#define CFG_FORMAT_H
-
-#include <iostream>
-#include <string>
-#include "wordid.h"
-#include "feature_vector.h"
-#include "program_options.h"
-
-struct CFGFormat {
- bool identity_scfg;
- bool features;
- bool logprob_feat;
- bool comma_nt;
- bool nt_span;
- std::string goal_nt_name;
- std::string nt_prefix;
- std::string logprob_feat_name;
- std::string partsep;
- bool goal_nt() const { return !goal_nt_name.empty(); }
- template <class Opts> // template to support both printable_opts and boost nonprintable
- void AddOptions(Opts *opts) {
- //using namespace boost::program_options;
- //using namespace std;
- opts->add_options()
- ("identity_scfg",defaulted_value(&identity_scfg),"output an identity SCFG: add an identity target side - '[X12] ||| [X13,1] a ||| [1] a ||| feat= ...' - the redundant target '[1] a |||' is omitted otherwise.")
- ("features",defaulted_value(&features),"print the CFG feature vector")
- ("logprob_feat",defaulted_value(&logprob_feat),"print a LogProb=-1.5 feature irrespective of --features.")
- ("logprob_feat_name",defaulted_value(&logprob_feat_name),"alternate name for the LogProb feature")
- ("cfg_comma_nt",defaulted_value(&comma_nt),"if false, omit the usual [NP,1] ',1' variable index in the source side")
- ("goal_nt_name",defaulted_value(&goal_nt_name),"if nonempty, the first production will be '[goal_nt_name] ||| [x123] ||| LogProb=y' where x123 is the actual goal nt, and y is the pushed prob, if any")
- ("nt_prefix",defaulted_value(&nt_prefix),"NTs are [<nt_prefix>123] where 123 is the node number starting at 0, and the highest node (last in file) is the goal node in an acyclic hypergraph")
- ("nt_span",defaulted_value(&nt_span),"prefix A(i,j) for NT coming from hypergraph node with category A on span [i,j). this is after --nt_prefix if any")
- ;
- }
-
- void print(std::ostream &o) const {
- o<<"[";
- if (identity_scfg)
- o<<"Identity SCFG ";
- if (features)
- o<<"+Features ";
- if (logprob_feat)
- o<<logprob_feat_name<<"(logprob) ";
- if (nt_span)
- o<<"named-NTs ";
- if (comma_nt)
- o<<",N ";
- o << "CFG output format";
- o<<"]";
- }
- friend inline std::ostream &operator<<(std::ostream &o,CFGFormat const& me) {
- me.print(o); return o;
- }
-
- void Validate() { }
- template<class CFG>
- void print_source_nt(std::ostream &o,CFG const&cfg,int id,int position=1) const {
- o<<'[';
- print_nt_name(o,cfg,id);
- if (comma_nt) o<<','<<position;
- o<<']';
- }
-
- template <class CFG>
- void print_nt_name(std::ostream &o,CFG const& cfg,int id) const {
- o<<nt_prefix;
- if (nt_span)
- cfg.print_nt_name(o,id);
- else
- o<<id;
- }
-
- template <class CFG>
- void print_lhs(std::ostream &o,CFG const& cfg,int id) const {
- o<<'[';
- print_nt_name(o,cfg,id);
- o<<']';
- }
-
- template <class CFG,class Iter>
- void print_rhs(std::ostream &o,CFG const&cfg,Iter begin,Iter end) const {
- o<<partsep;
- int pos=0;
- for (Iter i=begin;i!=end;++i) {
- WordID w=*i;
- if (i!=begin) o<<' ';
- if (w>0) o << TD::Convert(w);
- else print_source_nt(o,cfg,-w,++pos);
- }
- if (identity_scfg) {
- o<<partsep;
- int pos=0;
- for (Iter i=begin;i!=end;++i) {
- WordID w=*i;
- if (i!=begin) o<<' ';
- if (w>0) o << TD::Convert(w);
- else o << '['<<++pos<<']';
- }
- }
- }
-
- void print_features(std::ostream &o,prob_t p,SparseVector<double> const& fv=SparseVector<double>()) const {
- bool logp=(logprob_feat && p!=prob_t::One());
- if (features || logp) {
- o << partsep;
- if (logp)
- o << logprob_feat_name<<'='<<log(p)<<' ';
- if (features)
- o << fv;
- }
- }
-
- //TODO: default to no nt names (nt_span=0)
- void set_defaults() {
- identity_scfg=false;
- features=true;
- logprob_feat=true;
- comma_nt=true;
- goal_nt_name="S";
- logprob_feat_name="LogProb";
- nt_prefix="";
- partsep=" ||| ";
- nt_span=true;
- }
-
- CFGFormat() {
- set_defaults();
- }
-};
-
-
-
-#endif
diff --git a/decoder/cfg_options.h b/decoder/cfg_options.h
deleted file mode 100644
index 7b59c05c..00000000
--- a/decoder/cfg_options.h
+++ /dev/null
@@ -1,79 +0,0 @@
-#ifndef CFG_OPTIONS_H
-#define CFG_OPTIONS_H
-
-#include "filelib.h"
-#include "hg_cfg.h"
-#include "cfg_format.h"
-#include "cfg_binarize.h"
-//#include "program_options.h"
-
-struct CFGOptions {
- CFGFormat format;
- CFGBinarize binarize;
- std::string out,source_out,unbin_out;
- bool uniq;
- void set_defaults() {
- format.set_defaults();
- binarize.set_defaults();
- out=source_out=unbin_out="";
- uniq=false;
- }
-
- CFGOptions() { set_defaults(); }
- template <class Opts> // template to support both printable_opts and boost nonprintable
- void AddOptions(Opts *opts) {
- opts->add_options()
- ("cfg_output", defaulted_value(&out),"write final target CFG (before FSA rescoring) to this file")
- ("source_cfg_output", defaulted_value(&source_out),"write source CFG (after prelm-scoring, prelm-prune) to this file")
- ("cfg_unbin_output", defaulted_value(&unbin_out),"write pre-binarization CFG to this file") //TODO:
- ("cfg_uniq", defaulted_value(&uniq),"in case of duplicate rules, keep only the one with highest prob")
-
- ;
- binarize.AddOptions(opts);
- format.AddOptions(opts);
- }
- void Validate() {
- format.Validate();
- binarize.Validate();
- }
- void maybe_output_source(Hypergraph const& hg) {
- if (source_out.empty()) return;
- std::cerr<<"Printing source CFG to "<<source_out<<": "<<format<<'\n';
- WriteFile o(source_out);
- CFG cfg(hg,false,format.features,format.goal_nt());
- cfg.Print(o.get(),format);
- }
- // executes all options except source_cfg_output, building target hgcfg
- void prepare(HgCFG &hgcfg) {
- if (out.empty() && unbin_out.empty()) return;
- CFG &cfg=hgcfg.GetCFG();
- maybe_print(cfg,unbin_out);
- maybe_uniq(hgcfg);
- maybe_binarize(hgcfg);
- maybe_print(cfg,out,"");
- }
-
- char const* description() const {
- return "CFG output options";
- }
- void maybe_print(CFG &cfg,std::string cfg_output,char const* desc=" unbinarized") {
- if (cfg_output.empty()) return;
- WriteFile o(cfg_output);
- std::cerr<<"Printing target"<<desc<<" CFG to "<<cfg_output<<": "<<format<<'\n';
- cfg.Print(o.get(),format);
- }
-
- void maybe_uniq(HgCFG &hgcfg) {
- if (hgcfg.uniqed) return;
- hgcfg.GetCFG().UniqRules();
- hgcfg.uniqed=true;
- }
- void maybe_binarize(HgCFG &hgcfg) {
- if (hgcfg.binarized) return;
- hgcfg.GetCFG().Binarize(binarize);
- hgcfg.binarized=true;
- }
-};
-
-
-#endif
diff --git a/decoder/cfg_test.cc b/decoder/cfg_test.cc
deleted file mode 100644
index cbe7d0be..00000000
--- a/decoder/cfg_test.cc
+++ /dev/null
@@ -1,98 +0,0 @@
-#include <boost/tuple/tuple.hpp>
-#include <gtest/gtest.h>
-#include "cfg.h"
-#include "hg_test.h"
-#include "cfg_options.h"
-#include "show.h"
-
-/* TODO: easiest way to get meaningful confirmations that things work: implement conversion back to hg, and compare viterbi/inside etc. stats for equality to original hg. or you can define CSHOW_V and see lots of output */
-
-using namespace boost;
-
-#define CSHOW_V 0
-
-#if CSHOW_V
-# define CSHOWDO(x) x;
-#else
-# define CSHOWDO(x)
-#endif
-#define CSHOW(x) CSHOWDO(cerr<<#x<<'='<<x<<endl;)
-
-typedef std::pair<string,string> HgW; // hg file,weights
-
-struct CFGTest : public TestWithParam<HgW> {
- string hgfile;
- Hypergraph hg;
- CFG cfg;
- CFGFormat form;
- SparseVector<double> weights;
-
- static void JsonFN(Hypergraph &hg,CFG &cfg,SparseVector<double> &featw,std::string file
- ,std::string const& wts="Model_0 1 EgivenF 1 f1 1")
- {
- istringstream ws(wts);
- EXPECT_TRUE(ws>>featw);
- CSHOW(featw)
- std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
- HGSetup::JsonTestFile(&hg,path,file);
- hg.Reweight(featw);
- cfg.Init(hg,true,true,false);
- }
- static void SetUpTestCase() {
- }
- static void TearDownTestCase() {
- }
- CFGTest() {
- hgfile=GetParam().first;
- JsonFN(hg,cfg,weights,hgfile,GetParam().second);
- CSHOWDO(cerr<<"\nCFG Test: ")
- CSHOW(hgfile);
- form.nt_span=true;
- form.comma_nt=false;
- }
- ~CFGTest() { }
-};
-
-TEST_P(CFGTest,Binarize) {
- CFGBinarize b;
- b.bin_name_nts=1;
- CFG cfgu=cfg;
- EXPECT_EQ(cfgu,cfg);
- int nrules=cfg.rules.size();
- CSHOWDO(cerr<<"\nUniqing: "<<nrules<<"\n");
- int nrem=cfgu.UniqRules();
- cerr<<"\nCFG "<<hgfile<<" Uniqed - remaining: "<<nrem<<" of "<<nrules<<"\n";
- if (nrem==nrules) {
- EXPECT_EQ(cfgu,cfg);
- //TODO - check that 1best is still the same (that we removed only worse edges)
- }
-
- for (int i=-1;i<8;++i) {
- bool uniq;
- if (i>=0) {
- int f=i<<1;
- b.bin_l2r=1;
- b.bin_unary=(f>>=1)&1;
- b.bin_topo=(f>>=1)&1;
- uniq=(f>>=1)&1;
- } else
- b.bin_l2r=0;
- CFG cc=uniq?cfgu:cfg;
- CSHOW("\nBinarizing "<<(uniq?"uniqued ":"")<<": "<<i<<" "<<b);
- cc.Binarize(b);
- cerr<<"Binarized "<<b<<" rules size "<<cfg.rules_size()<<" => "<<cc.rules_size()<<"\n";
- CSHOWDO(cc.Print(cerr,form);cerr<<"\n\n";);
- }
-}
-
-INSTANTIATE_TEST_CASE_P(HypergraphsWeights,CFGTest,
- Values(
- HgW(perro_json,perro_wts)
- , HgW(small_json,small_wts)
- ,HgW(urdu_json,urdu_wts)
- ));
-
-int main(int argc, char **argv) {
- testing::InitGoogleTest(&argc, argv);
- return RUN_ALL_TESTS();
-}
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index da65713a..9b41253b 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -11,7 +11,6 @@ namespace std { using std::tr1::unordered_map; }
#include <boost/make_shared.hpp>
#include <boost/scoped_ptr.hpp>
-#include "program_options.h"
#include "stringlib.h"
#include "weights.h"
#include "filelib.h"
@@ -49,13 +48,6 @@ namespace std { using std::tr1::unordered_map; }
#include "hg_io.h"
#include "aligner.h"
-#undef FSA_RESCORING
-#ifdef FSA_RESCORING
-#include "hg_cfg.h"
-#include "apply_fsa_models.h"
-#include "cfg_options.h"
-#endif
-
#ifdef CP_TIME
clock_t CpTime::time_;
void CpTime::Add(clock_t x){time_+=x;}
@@ -140,21 +132,6 @@ inline boost::shared_ptr<FeatureFunction> make_ff(string const& ffp,bool verbose
return pf;
}
-#ifdef FSA_RESCORING
-inline boost::shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") {
- string ff, param;
- SplitCommandAndParam(ffp, &ff, &param);
- cerr << "FSA Feature: " << ff;
- if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n";
- else cerr << " (no config parameters)\n";
- boost::shared_ptr<FsaFeatureFunction> pf = fsa_ff_registry.Create(ff, param);
- if (!pf) exit(1);
- if (verbose_feature_functions && !SILENT)
- cerr<<"State is "<<pf->state_bytes()<<" bytes for "<<pre<<"feature "<<ffp<<endl;
- return pf;
-}
-#endif
-
// when the translation forest is first built, it is scored by the features associated
// with the rules. To add other features (like language models, etc), cdec applies one or
// more "rescoring passes", which compute new features and optionally apply new weights
@@ -304,11 +281,6 @@ struct DecoderImpl {
boost::shared_ptr<Translator> translator;
boost::shared_ptr<vector<weight_t> > init_weights; // weights used with initial parse
vector<boost::shared_ptr<FeatureFunction> > pffs;
-#ifdef FSA_RESCORING
- CFGOptions cfg_options;
- vector<boost::shared_ptr<FsaFeatureFunction> > fsa_ffs;
- vector<string> fsa_names;
-#endif
boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;
int sample_max_trans;
bool aligner_mode;
@@ -324,7 +296,6 @@ struct DecoderImpl {
SparseVector<prob_t> acc_vec; // accumulate gradient
double acc_obj; // accumulate objective
int g_count; // number of gradient pieces computed
- int pop_limit;
bool csplit_output_plf;
bool write_gradient; // TODO Observer
bool feature_expectations; // TODO Observer
@@ -372,6 +343,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)")
("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)")
("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full, Fast_cube_pruning, Fast_cube_pruning_2")
+ ("cubepruning_pop_limit,K",po::value<unsigned>()->default_value(200), "Max number of pops from the candidate heap at each node")
("summary_feature", po::value<string>(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z")
("summary_feature_type", po::value<string>()->default_value("node_risk"), "Summary feature types: node_risk, edge_risk, edge_prob")
("density_prune", po::value<double>(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
@@ -380,6 +352,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("weights2",po::value<string>(),"Optional pass 2")
("feature_function2",po::value<vector<string> >()->composing(), "Optional pass 2")
("intersection_strategy2",po::value<string>()->default_value("cube_pruning"), "Optional pass 2")
+ ("cubepruning_pop_limit2",po::value<unsigned>()->default_value(200), "Optional pass 2")
("summary_feature2", po::value<string>(), "Optional pass 2")
("density_prune2", po::value<double>(), "Optional pass 2")
("beam_prune2", po::value<double>(), "Optional pass 2")
@@ -387,18 +360,14 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("weights3",po::value<string>(),"Optional pass 3")
("feature_function3",po::value<vector<string> >()->composing(), "Optional pass 3")
("intersection_strategy3",po::value<string>()->default_value("cube_pruning"), "Optional pass 3")
+ ("cubepruning_pop_limit3",po::value<unsigned>()->default_value(200), "Optional pass 3")
("summary_feature3", po::value<string>(), "Optional pass 3")
("density_prune3", po::value<double>(), "Optional pass 3")
("beam_prune3", po::value<double>(), "Optional pass 3")
-#ifdef FSA_RESCORING
- ("fsa_feature_function,A",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)")
- ("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names()
-#endif
("add_pass_through_rules,P","Add rules to translate OOV words as themselves")
("k_best,k",po::value<int>(),"Extract the k best derivations")
("unique_k_best,r", "Unique k-best translation list")
- ("cubepruning_pop_limit,K",po::value<int>()->default_value(200), "Max number of pops from the candidate heap at each node")
("aligner,a", "Run as a word/phrase aligner (src & ref required)")
("aligner_use_viterbi", "If run in alignment mode, compute the Viterbi (rather than MAP) alignment")
("goal",po::value<string>()->default_value("S"),"Goal symbol (SCFG & FST)")
@@ -446,10 +415,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)");
// ob.AddOptions(&opts);
-#ifdef FSA_RESCORING
- po::options_description cfgo(cfg_options.description());
- cfg_options.AddOptions(&cfgo);
-#endif
po::options_description clo("Command line options");
clo.add_options()
("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority")
@@ -459,15 +424,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
;
po::options_description dconfig_options, dcmdline_options;
-#ifdef FSA_RESCORING
- dconfig_options.add(opts).add(cfgo);
-#else
dconfig_options.add(opts);
-#endif
dcmdline_options.add(dconfig_options).add(clo);
if (argc) {
- argv_minus_to_underscore(argc,argv);
po::store(parse_command_line(argc, argv, dcmdline_options), conf);
if (conf.count("compgen")) {
print_options(cout,dcmdline_options);
@@ -511,10 +471,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
if (conf.count("list_feature_functions")) {
cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n";
ff_registry.DisplayList(); //TODO
-#ifdef FSA_RESCORING
- cerr << "Available FSA feature functions (specify with --fsa_feature_function):\n";
- fsa_ff_registry.DisplayList(); // TODO
-#endif
cerr << endl;
exit(1);
}
@@ -574,9 +530,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
if (conf.count("weights"))
Weights::InitFromFile(str("weights",conf), init_weights.get());
- // cube pruning pop-limit: we may want to configure this on a per-pass basis
- pop_limit = conf["cubepruning_pop_limit"].as<int>();
-
if (conf.count("extract_rules")) {
if (!DirectoryExists(conf["extract_rules"].as<string>()))
MkDirP(conf["extract_rules"].as<string>());
@@ -620,6 +573,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
if (conf.count(dp)) { rp.density_prune = conf[dp].as<double>(); }
int palg = (has_stateful ? 1 : 0); // if there are no stateful featueres, default to FULL
string isn = "intersection_strategy" + StringSuffixForRescoringPass(pass);
+ string spl = "cubepruning_pop_limit" + StringSuffixForRescoringPass(pass);
+ unsigned pop_limit = 200;
+ if (conf.count(spl)) { pop_limit = conf[spl].as<unsigned>(); }
if (LowercaseString(str(isn.c_str(),conf)) == "full") {
palg = 0;
}
@@ -686,21 +642,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
else
assert(!"error");
-#ifdef FSA_RESCORING
- store_conf(conf,"fsa_feature_function",&fsa_names);
- for (int i=0;i<fsa_names.size();++i)
- fsa_ffs.push_back(make_fsa_ff(fsa_names[i],verbose_feature_functions,"FSA "));
- if (fsa_ffs.size()>1) {
- //FIXME: support N fsa ffs.
- cerr<<"Only the first fsa FF will be used (FIXME).\n";
- fsa_ffs.resize(1);
- }
- if (!fsa_ffs.empty()) {
- cerr<<"FSA: ";
- show_all_features(fsa_ffs,*init_weights,cerr,cerr,true,true);
- }
-#endif
-
if (late_freeze) {
cerr << "Late freezing feature set (use --no_freeze_feature_set to prevent)." << endl;
FD::Freeze(); // this means we can't see the feature names of not-weighted features
@@ -720,10 +661,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
oracle.show_derivation=conf.count("show_derivations");
remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");
-#ifdef FSA_RESCORING
- cfg_options.Validate();
-#endif
-
if (conf.count("extract_rules")) {
stringstream ss;
ss << sent_id;
@@ -840,7 +777,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
HypergraphIO::WriteTarget(conf["show_target_graph"].as<string>(), sent_id, forest);
}
if (conf.count("incremental_search")) {
- incremental->Search(pop_limit, forest);
+ incremental->Search(conf["cubepruning_pop_limit"].as<unsigned>(), forest);
}
if (conf.count("show_target_graph") || conf.count("incremental_search")) {
o->NotifyDecodingComplete(smeta);
@@ -851,9 +788,6 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
const RescoringPass& rp = rescoring_passes[pass];
const vector<weight_t>& cur_weights = *rp.weight_vector;
if (!SILENT) cerr << endl << " RESCORING PASS #" << (pass+1) << " " << rp << endl;
-#ifdef FSA_RESCORING
- cfg_options.maybe_output_source(forest);
-#endif
string passtr = "Pass1"; passtr[4] += pass;
forest.Reweight(cur_weights);
@@ -949,25 +883,6 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
string fullbp = "beam_prune" + StringSuffixForRescoringPass(pass);
string fulldp = "density_prune" + StringSuffixForRescoringPass(pass);
maybe_prune(forest,conf,fullbp.c_str(),fulldp.c_str(),passtr,srclen);
-
-#ifdef FSA_RESCORING
- HgCFG hgcfg(forest);
- cfg_options.prepare(hgcfg);
-
- if (!fsa_ffs.empty()) {
- Timer t("Target FSA rescoring:");
- if (!has_late_models)
- forest.Reweight(pass0_weights);
- Hypergraph fsa_forest;
- assert(fsa_ffs.size()==1);
- ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit);
- if (!SILENT) cerr << "FSA rescoring with "<<cfg<<" "<<fsa_ffs[0]->describe()<<endl;
- ApplyFsaModels(hgcfg,smeta,*fsa_ffs[0],pass0_weights,cfg,&fsa_forest);
- forest.swap(fsa_forest);
- forest.Reweight(pass0_weights);
- if (!SILENT) forest_stats(forest," +FSA forest",show_tree_structure,oracle.show_derivation);
- }
-#endif
}
const vector<double>& last_weights = (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector);
diff --git a/decoder/decoder.h b/decoder/decoder.h
index 79c7a602..8039a42b 100644
--- a/decoder/decoder.h
+++ b/decoder/decoder.h
@@ -25,9 +25,10 @@ private:
class SentenceMetadata;
class Hypergraph;
-struct DecoderImpl;
+class DecoderImpl;
-struct DecoderObserver {
+class DecoderObserver {
+ public:
virtual ~DecoderObserver();
virtual void NotifyDecodingStart(const SentenceMetadata& smeta);
virtual void NotifySourceParseFailure(const SentenceMetadata& smeta);
@@ -37,9 +38,10 @@ struct DecoderObserver {
virtual void NotifyDecodingComplete(const SentenceMetadata& smeta);
};
-struct Grammar; // TODO once the decoder interface is cleaned up,
- // this should be somewhere else
-struct Decoder {
+class Grammar; // TODO once the decoder interface is cleaned up,
+ // this should be somewhere else
+class Decoder {
+ public:
Decoder(int argc, char** argv);
Decoder(std::istream* config_file);
bool Decode(const std::string& input, DecoderObserver* observer = NULL);
@@ -49,6 +51,7 @@ struct Decoder {
std::vector<weight_t>& CurrentWeightVector();
const std::vector<weight_t>& CurrentWeightVector() const;
+ // this sets the current sentence ID
void SetId(int id);
~Decoder();
const boost::program_options::variables_map& GetConf() const { return conf; }
diff --git a/decoder/ff_factory.cc b/decoder/ff_factory.cc
index 25d37648..f45d8695 100644
--- a/decoder/ff_factory.cc
+++ b/decoder/ff_factory.cc
@@ -7,26 +7,15 @@
using boost::shared_ptr;
using namespace std;
-UntypedFactory::~UntypedFactory() { }
+// global ff registry
+FFRegistry ff_registry;
-namespace {
-std::string const& debug_pre="debug";
-}
+UntypedFactory::~UntypedFactory() { }
void UntypedFactoryRegistry::clear() {
reg_.clear();
}
-bool UntypedFactoryRegistry::parse_debug(std::string & p) {
- int pl=debug_pre.size();
- bool space=false;
- bool debug=match_begin(p,debug_pre)&&
- (p.size()==pl || (space=(p[pl]==' ')));
- if (debug)
- p.erase(0,debug_pre.size()+space);
- return debug;
-}
-
bool UntypedFactoryRegistry::have(std::string const& ffname) {
return reg_.find(ffname)!=reg_.end();
}
@@ -54,34 +43,18 @@ void UntypedFactoryRegistry::Register(const string& ffname, UntypedFactory* fact
}
-void UntypedFactoryRegistry::Register(UntypedFactory* factory)
-{
+void UntypedFactoryRegistry::Register(UntypedFactory* factory) {
Register(factory->usage(false,false),factory);
}
-/*FIXME: I want these to go in ff_factory.cc, but extern etc. isn't workign right:
- ../decoder/libcdec.a(ff_factory.o): In function `~UntypedFactory':
-/nfs/topaz/graehl/ws10smt/decoder/ff_factory.cc:9: multiple definition of `global_ff_registry'
-mr_vest_generate_mapper_input.o:/nfs/topaz/graehl/ws10smt/vest/mr_vest_generate_mapper_input.cc:307: first defined here
-*/
-FsaFFRegistry fsa_ff_registry;
-FFRegistry ff_registry;
-
-/*
-#include "null_deleter.h"
-boost::shared_ptr<FsaFFRegistry> global_fsa_ff_registry(&fsa_ff_registry,null_deleter());
-boost::shared_ptr<FFRegistry> global_ff_registry(&ff_registry,null_deleter());
-*/
-
-void ff_usage(std::string const& n,std::ostream &out)
-{
+void ff_usage(std::string const& n,std::ostream &out) {
bool have=ff_registry.have(n);
if (have)
- cout<<"FF "<<ff_registry.usage(n,true,true)<<endl;
- if (fsa_ff_registry.have(n))
- cout<<"Fsa FF "<<fsa_ff_registry.usage(n,true,true)<<endl;
- else if (!have)
- throw std::runtime_error("Unknown feature "+n);
+ out << "FF " << ff_registry.usage(n,true,true) << endl;
+ else {
+ cerr << "Unknown feature: " << n << endl;
+ abort();
+ }
}
diff --git a/decoder/ff_factory.h b/decoder/ff_factory.h
index bfdd3257..1aa8e55f 100644
--- a/decoder/ff_factory.h
+++ b/decoder/ff_factory.h
@@ -1,8 +1,6 @@
#ifndef _FF_FACTORY_H_
#define _FF_FACTORY_H_
-// FsaF* vs F* (regular ff/factory).
-
//TODO: use http://www.boost.org/doc/libs/1_43_0/libs/functional/factory/doc/html/index.html ?
/*TODO: register state identity separately from feature function identity? as
@@ -25,13 +23,15 @@ class FeatureFunction;
class FsaFeatureFunction;
-struct UntypedFactory {
+class UntypedFactory {
+ public:
virtual ~UntypedFactory();
virtual std::string usage(bool params,bool verbose) const = 0;
};
template <class FF>
-struct FactoryBase : public UntypedFactory {
+class FactoryBase : public UntypedFactory {
+ public:
typedef FF F;
typedef boost::shared_ptr<F> FP;
@@ -40,7 +40,8 @@ struct FactoryBase : public UntypedFactory {
/* see cdec_ff.cc for example usage: this create concrete factories to be registered */
template<class FF>
-struct FFFactory : public FactoryBase<FeatureFunction> {
+class FFFactory : public FactoryBase<FeatureFunction> {
+ public:
FP Create(std::string param) const {
FF *ret=new FF(param);
return FP(ret);
@@ -51,18 +52,6 @@ struct FFFactory : public FactoryBase<FeatureFunction> {
};
-// same as above, but we didn't want to require a typedef e.g. Parent in FF class, and template typedef isn't available
-template<class FF>
-struct FsaFactory : public FactoryBase<FsaFeatureFunction> {
- FP Create(std::string param) const {
- FF *ret=new FF(param);
- return FP(ret);
- }
- virtual std::string usage(bool params,bool verbose) const {
- return FF::usage(params,verbose);
- }
-};
-
struct UntypedFactoryRegistry {
std::string usage(std::string const& ffname,bool params=true,bool verbose=true) const;
bool have(std::string const& ffname);
@@ -70,7 +59,6 @@ struct UntypedFactoryRegistry {
void Register(const std::string& ffname, UntypedFactory* factory);
void Register(UntypedFactory* factory);
void clear();
- static bool parse_debug(std::string & param_in_out); // returns true iff param starts w/ debug (and remove that prefix from param)
protected:
typedef boost::shared_ptr<UntypedFactory> FactoryP;
typedef std::map<std::string, FactoryP > Factmap;
@@ -92,26 +80,16 @@ struct FactoryRegistry : public UntypedFactoryRegistry {
Factmap::const_iterator it = reg_.find(ffname);
if (it == reg_.end())
throw std::runtime_error("I don't know how to create feature "+ffname);
- bool debug=parse_debug(param);
- if (debug)
- cerr<<"debug enabled for "<<ffname<< " - remaining options: '"<<param<<"'\n";
FP res = dynamic_cast<FB const&>(*it->second).Create(param);
return res;
}
};
typedef FactoryRegistry<FeatureFunction> FFRegistry;
-typedef FactoryRegistry<FsaFeatureFunction> FsaFFRegistry;
-extern FsaFFRegistry fsa_ff_registry;
-inline FsaFFRegistry & global_fsa_ff_registry() { return fsa_ff_registry; }
extern FFRegistry ff_registry;
-inline FFRegistry & global_ff_registry() { return ff_registry; }
+inline FFRegistry& global_ff_registry() { return ff_registry; }
-void ff_usage(std::string const& name,std::ostream &out=std::cout);
+void ff_usage(std::string const& name,std::ostream& out=std::cerr);
-/*
-extern boost::shared_ptr<FsaFFRegistry> global_fsa_ff_registry;
-extern boost::shared_ptr<FFRegistry> global_ff_registry;
-*/
#endif
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index c8ca917a..339a10c3 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -187,6 +187,7 @@ class KLanguageModelImpl {
// this assumes no target words on final unary -> goal rule. is that ok?
// for <s> (n-1 left words) and (n-1 right words) </s>
double FinalTraversalCost(const void* state_void, double* oovs) {
+ *oovs = 0;
const BoundaryAnnotatedState &annotated = *static_cast<const BoundaryAnnotatedState*>(state_void);
if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here
assert(!annotated.seen_bos);
@@ -344,7 +345,7 @@ void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* sme
const Hypergraph::Edge& edge,
const vector<const void*>& ant_states,
SparseVector<double>* features,
- SparseVector<double>* estimated_features,
+ SparseVector<double>* /*estimated_features*/,
void* state) const {
double est = 0;
double oovs = 0;
diff --git a/decoder/hg_cfg.h b/decoder/hg_cfg.h
deleted file mode 100644
index b90aca47..00000000
--- a/decoder/hg_cfg.h
+++ /dev/null
@@ -1,54 +0,0 @@
-#ifndef HG_CFG_H
-#define HG_CFG_H
-
-#include "cfg.h"
-
-class Hypergraph;
-
-// in case you might want the CFG whether or not you apply FSA models:
-struct HgCFG {
- void set_defaults() {
- have_cfg=binarized=have_features=uniqed=false;
- want_features=true;
- }
- HgCFG(Hypergraph const& ih) : ih(ih) {
- set_defaults();
- }
- Hypergraph const& ih;
- CFG cfg;
- bool have_cfg;
- bool have_features;
- bool want_features;
- void InitCFG(CFG &to) {
- to.Init(ih,true,want_features,true);
- }
- bool binarized;
- bool uniqed;
- CFG &GetCFG()
- {
- if (!have_cfg) {
- have_cfg=true;
- InitCFG(cfg);
- }
- return cfg;
- }
- void GiveCFG(CFG &to) {
- if (!have_cfg)
- InitCFG(to);
- else {
- cfg.VisitRuleIds(*this);
- have_cfg=false;
- to.Clear();
- swap(to,cfg);
- }
- }
- void operator()(int ri) const {
- }
- CFG const& GetCFG() const {
- assert(have_cfg);
- return cfg;
- }
-};
-
-
-#endif
diff --git a/decoder/sentence_metadata.h b/decoder/sentence_metadata.h
index 52586331..f2a779f4 100644
--- a/decoder/sentence_metadata.h
+++ b/decoder/sentence_metadata.h
@@ -9,7 +9,8 @@
struct DocScorer; // deprecated, will be removed
struct Score; // deprecated, will be removed
-struct SentenceMetadata {
+class SentenceMetadata {
+ public:
friend class DecoderImpl;
SentenceMetadata(int id, const Lattice& ref) :
sent_id_(id),
diff --git a/decoder/viterbi.cc b/decoder/viterbi.cc
index 9e381ac6..9204ad04 100644
--- a/decoder/viterbi.cc
+++ b/decoder/viterbi.cc
@@ -1,6 +1,8 @@
-#include "fast_lexical_cast.hpp"
#include "viterbi.h"
+#include <cmath>
+#include <stdexcept>
+#include "fast_lexical_cast.hpp"
#include <sstream>
#include <vector>
#include "hg.h"
@@ -110,30 +112,7 @@ string JoshuaVisualizationString(const Hypergraph& hg) {
return TD::GetString(tmp);
}
-
-//TODO: move to appropriate header if useful elsewhere
-/*
- The simple solution like abs(f1-f2) <= e does not work for very small or very big values. This floating-point comparison algorithm is based on the more confident solution presented by Knuth in [1]. For a given floating point values u and v and a tolerance e:
-
-| u - v | <= e * |u| and | u - v | <= e * |v|
-defines a "very close with tolerance e" relationship between u and v
- (1)
-
-| u - v | <= e * |u| or | u - v | <= e * |v|
-defines a "close enough with tolerance e" relationship between u and v
- (2)
-
-Both relationships are commutative but are not transitive. The relationship defined by inequations (1) is stronger that the relationship defined by inequations (2) (i.e. (1) => (2) ). Because of the multiplication in the right side of inequations, that could cause an unwanted underflow condition, the implementation is using modified version of the inequations (1) and (2) where all underflow, overflow conditions could be guarded safely:
-
-| u - v | / |u| <= e and | u - v | / |v| <= e
-| u - v | / |u| <= e or | u - v | / |v| <= e
- (1`)
-(2`)
-*/
-#include <cmath>
-#include <stdexcept>
-inline bool close_enough(double a,double b,double epsilon)
-{
+inline bool close_enough(double a,double b,double epsilon) {
using std::fabs;
double diff=fabs(a-b);
return diff<=epsilon*fabs(a) || diff<=epsilon*fabs(b);
diff --git a/extractor/Makefile.am b/extractor/Makefile.am
index 65a3d436..7825012c 100644
--- a/extractor/Makefile.am
+++ b/extractor/Makefile.am
@@ -1,5 +1,5 @@
-bin_PROGRAMS = compile run_extractor
+bin_PROGRAMS = compile run_extractor extract
if HAVE_CXX11
@@ -24,7 +24,8 @@ EXTRA_PROGRAMS = alignment_test \
scorer_test \
suffix_array_test \
target_phrase_extractor_test \
- translation_table_test
+ translation_table_test \
+ vocabulary_test
if HAVE_GTEST
RUNNABLE_TESTS = alignment_test \
@@ -48,12 +49,14 @@ if HAVE_GTEST
scorer_test \
suffix_array_test \
target_phrase_extractor_test \
- translation_table_test
+ translation_table_test \
+ vocabulary_test
endif
noinst_PROGRAMS = $(RUNNABLE_TESTS)
-TESTS = $(RUNNABLE_TESTS)
+# TESTS = $(RUNNABLE_TESTS)
+TESTS = vocabulary_test
alignment_test_SOURCES = alignment_test.cc
alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
@@ -99,44 +102,17 @@ target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc
target_phrase_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
translation_table_test_SOURCES = translation_table_test.cc
translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+vocabulary_test_SOURCES = vocabulary_test.cc
+vocabulary_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
-noinst_LIBRARIES = libextractor.a libcompile.a
+noinst_LIBRARIES = libextractor.a
compile_SOURCES = compile.cc
-compile_LDADD = libcompile.a
+compile_LDADD = libextractor.a
run_extractor_SOURCES = run_extractor.cc
run_extractor_LDADD = libextractor.a
-
-libcompile_a_SOURCES = \
- alignment.cc \
- data_array.cc \
- phrase_location.cc \
- precomputation.cc \
- suffix_array.cc \
- time_util.cc \
- translation_table.cc \
- alignment.h \
- data_array.h \
- fast_intersector.h \
- grammar.h \
- grammar_extractor.h \
- matchings_finder.h \
- matchings_trie.h \
- phrase.h \
- phrase_builder.h \
- phrase_location.h \
- precomputation.h \
- rule.h \
- rule_extractor.h \
- rule_extractor_helper.h \
- rule_factory.h \
- sampler.h \
- scorer.h \
- suffix_array.h \
- target_phrase_extractor.h \
- time_util.h \
- translation_table.h \
- vocabulary.h
+extract_SOURCES = extract.cc
+extract_LDADD = libextractor.a
libextractor_a_SOURCES = \
alignment.cc \
diff --git a/extractor/compile.cc b/extractor/compile.cc
index 0d62757e..3ee668ce 100644
--- a/extractor/compile.cc
+++ b/extractor/compile.cc
@@ -30,6 +30,8 @@ int main(int argc, char** argv) {
("bitext,b", po::value<string>(), "Parallel text (source ||| target)")
("alignment,a", po::value<string>()->required(), "Bitext word alignment")
("output,o", po::value<string>()->required(), "Output path")
+ ("config,c", po::value<string>()->required(),
+ "Path where the config file will be generated")
("frequent", po::value<int>()->default_value(100),
"Number of precomputed frequent patterns")
("super_frequent", po::value<int>()->default_value(10),
@@ -82,8 +84,12 @@ int main(int argc, char** argv) {
target_data_array = make_shared<DataArray>(vm["target"].as<string>());
}
+ ofstream config_stream(vm["config"].as<string>());
+
Clock::time_point start_write = Clock::now();
- ofstream target_fstream((output_dir / fs::path("target.bin")).string());
+ string target_path = (output_dir / fs::path("target.bin")).string();
+ config_stream << "target = " << target_path << endl;
+ ofstream target_fstream(target_path);
ar::binary_oarchive target_stream(target_fstream);
target_stream << *target_data_array;
Clock::time_point stop_write = Clock::now();
@@ -100,7 +106,9 @@ int main(int argc, char** argv) {
make_shared<SuffixArray>(source_data_array);
start_write = Clock::now();
- ofstream source_fstream((output_dir / fs::path("source.bin")).string());
+ string source_path = (output_dir / fs::path("source.bin")).string();
+ config_stream << "source = " << source_path << endl;
+ ofstream source_fstream(source_path);
ar::binary_oarchive output_stream(source_fstream);
output_stream << *source_suffix_array;
stop_write = Clock::now();
@@ -116,7 +124,9 @@ int main(int argc, char** argv) {
make_shared<Alignment>(vm["alignment"].as<string>());
start_write = Clock::now();
- ofstream alignment_fstream((output_dir / fs::path("alignment.bin")).string());
+ string alignment_path = (output_dir / fs::path("alignment.bin")).string();
+ config_stream << "alignment = " << alignment_path << endl;
+ ofstream alignment_fstream(alignment_path);
ar::binary_oarchive alignment_stream(alignment_fstream);
alignment_stream << *alignment;
stop_write = Clock::now();
@@ -126,7 +136,7 @@ int main(int argc, char** argv) {
cerr << "Reading alignment took "
<< GetDuration(start_time, stop_time) << " seconds" << endl;
- shared_ptr<Vocabulary> vocabulary;
+ shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>();
start_time = Clock::now();
cerr << "Precomputing collocations..." << endl;
@@ -142,9 +152,17 @@ int main(int argc, char** argv) {
vm["min_frequency"].as<int>());
start_write = Clock::now();
- ofstream precomp_fstream((output_dir / fs::path("precomp.bin")).string());
+ string precomputation_path = (output_dir / fs::path("precomp.bin")).string();
+ config_stream << "precomputation = " << precomputation_path << endl;
+ ofstream precomp_fstream(precomputation_path);
ar::binary_oarchive precomp_stream(precomp_fstream);
precomp_stream << precomputation;
+
+ string vocabulary_path = (output_dir / fs::path("vocab.bin")).string();
+ config_stream << "vocabulary = " << vocabulary_path << endl;
+ ofstream vocab_fstream(vocabulary_path);
+ ar::binary_oarchive vocab_stream(vocab_fstream);
+ vocab_stream << *vocabulary;
stop_write = Clock::now();
write_duration += GetDuration(start_write, stop_write);
@@ -157,7 +175,9 @@ int main(int argc, char** argv) {
TranslationTable table(source_data_array, target_data_array, alignment);
start_write = Clock::now();
- ofstream table_fstream((output_dir / fs::path("bilex.bin")).string());
+ string table_path = (output_dir / fs::path("bilex.bin")).string();
+ config_stream << "ttable = " << table_path << endl;
+ ofstream table_fstream(table_path);
ar::binary_oarchive table_stream(table_fstream);
table_stream << table;
stop_write = Clock::now();
diff --git a/extractor/data_array.cc b/extractor/data_array.cc
index dacc4283..9612aa8a 100644
--- a/extractor/data_array.cc
+++ b/extractor/data_array.cc
@@ -127,10 +127,6 @@ int DataArray::GetSentenceId(int position) const {
return sentence_id[position];
}
-bool DataArray::HasWord(const string& word) const {
- return word2id.count(word);
-}
-
int DataArray::GetWordId(const string& word) const {
auto result = word2id.find(word);
return result == word2id.end() ? -1 : result->second;
diff --git a/extractor/data_array.h b/extractor/data_array.h
index e3823d18..b96901d1 100644
--- a/extractor/data_array.h
+++ b/extractor/data_array.h
@@ -73,9 +73,6 @@ class DataArray {
// Returns the number of distinct words in the data array.
virtual int GetVocabularySize() const;
- // Returns whether a word has ever been observed in the data array.
- virtual bool HasWord(const string& word) const;
-
// Returns the word id for a given word or -1 if it the word has never been
// observed.
virtual int GetWordId(const string& word) const;
diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc
index 7b085cd9..99f79d91 100644
--- a/extractor/data_array_test.cc
+++ b/extractor/data_array_test.cc
@@ -70,16 +70,12 @@ TEST_F(DataArrayTest, TestSubstrings) {
TEST_F(DataArrayTest, TestVocabulary) {
EXPECT_EQ(9, source_data.GetVocabularySize());
- EXPECT_TRUE(source_data.HasWord("mere"));
EXPECT_EQ(4, source_data.GetWordId("mere"));
EXPECT_EQ("mere", source_data.GetWord(4));
- EXPECT_FALSE(source_data.HasWord("banane"));
EXPECT_EQ(11, target_data.GetVocabularySize());
- EXPECT_TRUE(target_data.HasWord("apples"));
EXPECT_EQ(4, target_data.GetWordId("apples"));
EXPECT_EQ("apples", target_data.GetWord(4));
- EXPECT_FALSE(target_data.HasWord("bananas"));
}
TEST_F(DataArrayTest, TestSentenceData) {
diff --git a/extractor/extract.cc b/extractor/extract.cc
new file mode 100644
index 00000000..387cbe9b
--- /dev/null
+++ b/extractor/extract.cc
@@ -0,0 +1,253 @@
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/filesystem.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+#include <omp.h>
+
+#include "alignment.h"
+#include "data_array.h"
+#include "features/count_source_target.h"
+#include "features/feature.h"
+#include "features/is_source_singleton.h"
+#include "features/is_source_target_singleton.h"
+#include "features/max_lex_source_given_target.h"
+#include "features/max_lex_target_given_source.h"
+#include "features/sample_source_count.h"
+#include "features/target_given_source_coherent.h"
+#include "grammar.h"
+#include "grammar_extractor.h"
+#include "precomputation.h"
+#include "rule.h"
+#include "scorer.h"
+#include "suffix_array.h"
+#include "time_util.h"
+#include "translation_table.h"
+#include "vocabulary.h"
+
+namespace ar = boost::archive;
+namespace fs = boost::filesystem;
+namespace po = boost::program_options;
+using namespace extractor;
+using namespace features;
+using namespace std;
+
+// Returns the file path in which a given grammar should be written.
+fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) {
+ string file_name = "grammar." + to_string(file_number);
+ return grammar_path / file_name;
+}
+
+int main(int argc, char** argv) {
+ po::options_description general_options("General options");
+ int max_threads = 1;
+ #pragma omp parallel
+ max_threads = omp_get_num_threads();
+ string threads_option = "Number of threads used for grammar extraction "
+ "max(" + to_string(max_threads) + ")";
+ general_options.add_options()
+ ("threads,t", po::value<int>()->required()->default_value(1),
+ threads_option.c_str())
+ ("grammars,g", po::value<string>()->required(), "Grammars output path")
+ ("max_rule_span", po::value<int>()->default_value(15),
+ "Maximum rule span")
+ ("max_rule_symbols", po::value<int>()->default_value(5),
+ "Maximum number of symbols (terminals + nontermals) in a rule")
+ ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size")
+ ("max_nonterminals", po::value<int>()->default_value(2),
+ "Maximum number of nonterminals in a rule")
+ ("max_samples", po::value<int>()->default_value(300),
+ "Maximum number of samples")
+ ("tight_phrases", po::value<bool>()->default_value(true),
+ "False if phrases may be loose (better, but slower)")
+ ("leave_one_out", po::value<bool>()->zero_tokens(),
+ "do leave-one-out estimation of grammars "
+ "(e.g. for extracting grammars for the training set");
+
+ po::options_description cmdline_options("Command line options");
+ cmdline_options.add_options()
+ ("help", "Show available options")
+ ("config,c", po::value<string>()->required(), "Path to config file");
+ cmdline_options.add(general_options);
+
+ po::options_description config_options("Config file options");
+ config_options.add_options()
+ ("target", po::value<string>()->required(),
+ "Path to target data file in binary format")
+ ("source", po::value<string>()->required(),
+ "Path to source suffix array file in binary format")
+ ("alignment", po::value<string>()->required(),
+ "Path to alignment file in binary format")
+ ("precomputation", po::value<string>()->required(),
+ "Path to precomputation file in binary format")
+ ("vocabulary", po::value<string>()->required(),
+ "Path to vocabulary file in binary format")
+ ("ttable", po::value<string>()->required(),
+ "Path to translation table in binary format");
+ config_options.add(general_options);
+
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
+ if (vm.count("help")) {
+ po::options_description all_options;
+ all_options.add(cmdline_options).add(config_options);
+ cout << all_options << endl;
+ return 0;
+ }
+
+ po::notify(vm);
+
+ ifstream config_stream(vm["config"].as<string>());
+ po::store(po::parse_config_file(config_stream, config_options), vm);
+ po::notify(vm);
+
+ int num_threads = vm["threads"].as<int>();
+ cerr << "Grammar extraction will use " << num_threads << " threads." << endl;
+
+ Clock::time_point read_start_time = Clock::now();
+
+ Clock::time_point start_time = Clock::now();
+ cerr << "Reading target data in binary format..." << endl;
+ shared_ptr<DataArray> target_data_array = make_shared<DataArray>();
+ ifstream target_fstream(vm["target"].as<string>());
+ ar::binary_iarchive target_stream(target_fstream);
+ target_stream >> *target_data_array;
+ Clock::time_point end_time = Clock::now();
+ cerr << "Reading target data took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading source suffix array in binary format..." << endl;
+ shared_ptr<SuffixArray> source_suffix_array = make_shared<SuffixArray>();
+ ifstream source_fstream(vm["source"].as<string>());
+ ar::binary_iarchive source_stream(source_fstream);
+ source_stream >> *source_suffix_array;
+ end_time = Clock::now();
+ cerr << "Reading source suffix array took "
+ << GetDuration(start_time, end_time) << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading alignment in binary format..." << endl;
+ shared_ptr<Alignment> alignment = make_shared<Alignment>();
+ ifstream alignment_fstream(vm["alignment"].as<string>());
+ ar::binary_iarchive alignment_stream(alignment_fstream);
+ alignment_stream >> *alignment;
+ end_time = Clock::now();
+ cerr << "Reading alignment took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading precomputation in binary format..." << endl;
+ shared_ptr<Precomputation> precomputation = make_shared<Precomputation>();
+ ifstream precomputation_fstream(vm["precomputation"].as<string>());
+ ar::binary_iarchive precomputation_stream(precomputation_fstream);
+ precomputation_stream >> *precomputation;
+ end_time = Clock::now();
+ cerr << "Reading precomputation took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading vocabulary in binary format..." << endl;
+ shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>();
+ ifstream vocabulary_fstream(vm["vocabulary"].as<string>());
+ ar::binary_iarchive vocabulary_stream(vocabulary_fstream);
+ vocabulary_stream >> *vocabulary;
+ end_time = Clock::now();
+ cerr << "Reading vocabulary took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading translation table in binary format..." << endl;
+ shared_ptr<TranslationTable> table = make_shared<TranslationTable>();
+ ifstream ttable_fstream(vm["ttable"].as<string>());
+ ar::binary_iarchive ttable_stream(ttable_fstream);
+ ttable_stream >> *table;
+ end_time = Clock::now();
+ cerr << "Reading translation table took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ Clock::time_point read_end_time = Clock::now();
+ cerr << "Total time spent loading data structures into memory: "
+ << GetDuration(read_start_time, read_end_time) << " seconds" << endl;
+
+ Clock::time_point extraction_start_time = Clock::now();
+ // Features used to score each grammar rule.
+ vector<shared_ptr<Feature>> features = {
+ make_shared<TargetGivenSourceCoherent>(),
+ make_shared<SampleSourceCount>(),
+ make_shared<CountSourceTarget>(),
+ make_shared<MaxLexSourceGivenTarget>(table),
+ make_shared<MaxLexTargetGivenSource>(table),
+ make_shared<IsSourceSingleton>(),
+ make_shared<IsSourceTargetSingleton>()
+ };
+ shared_ptr<Scorer> scorer = make_shared<Scorer>(features);
+
+ GrammarExtractor extractor(
+ source_suffix_array,
+ target_data_array,
+ alignment,
+ precomputation,
+ scorer,
+ vocabulary,
+ vm["min_gap_size"].as<int>(),
+ vm["max_rule_span"].as<int>(),
+ vm["max_nonterminals"].as<int>(),
+ vm["max_rule_symbols"].as<int>(),
+ vm["max_samples"].as<int>(),
+ vm["tight_phrases"].as<bool>());
+
+ // Creates the grammars directory if it doesn't exist.
+ fs::path grammar_path = vm["grammars"].as<string>();
+ if (!fs::is_directory(grammar_path)) {
+ fs::create_directory(grammar_path);
+ }
+
+ // Reads all sentences for which we extract grammar rules (the paralellization
+ // is simplified if we read all sentences upfront).
+ string sentence;
+ vector<string> sentences;
+ while (getline(cin, sentence)) {
+ sentences.push_back(sentence);
+ }
+
+ // Extracts the grammar for each sentence and saves it to a file.
+ vector<string> suffixes(sentences.size());
+ bool leave_one_out = vm.count("leave_one_out");
+ #pragma omp parallel for schedule(dynamic) num_threads(num_threads)
+ for (size_t i = 0; i < sentences.size(); ++i) {
+ string suffix;
+ int position = sentences[i].find("|||");
+ if (position != sentences[i].npos) {
+ suffix = sentences[i].substr(position);
+ sentences[i] = sentences[i].substr(0, position);
+ }
+ suffixes[i] = suffix;
+
+ unordered_set<int> blacklisted_sentence_ids;
+ if (leave_one_out) {
+ blacklisted_sentence_ids.insert(i);
+ }
+ Grammar grammar = extractor.GetGrammar(
+ sentences[i], blacklisted_sentence_ids);
+ ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
+ output << grammar;
+ }
+
+ for (size_t i = 0; i < sentences.size(); ++i) {
+ cout << "<seg grammar=" << GetGrammarFilePath(grammar_path, i) << " id=\""
+ << i << "\"> " << sentences[i] << " </seg> " << suffixes[i] << endl;
+ }
+
+ Clock::time_point extraction_stop_time = Clock::now();
+ cerr << "Overall extraction step took "
+ << GetDuration(extraction_start_time, extraction_stop_time)
+ << " seconds" << endl;
+
+ return 0;
+}
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
index 4d0738f7..1dc94c25 100644
--- a/extractor/grammar_extractor.cc
+++ b/extractor/grammar_extractor.cc
@@ -35,10 +35,12 @@ GrammarExtractor::GrammarExtractor(
vocabulary(vocabulary),
rule_factory(rule_factory) {}
-Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
+Grammar GrammarExtractor::GetGrammar(
+ const string& sentence,
+ const unordered_set<int>& blacklisted_sentence_ids) {
vector<string> words = TokenizeSentence(sentence);
vector<int> word_ids = AnnotateWords(words);
- return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
+ return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids);
}
vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) {
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index 8f570df2..0f3069b0 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -15,7 +15,6 @@ class DataArray;
class Grammar;
class HieroCachingRuleFactory;
class Precomputation;
-class Rule;
class Scorer;
class SuffixArray;
class Vocabulary;
@@ -46,7 +45,9 @@ class GrammarExtractor {
// Converts the sentence to a vector of word ids and uses the RuleFactory to
// extract the SCFG rules which may be used to decode the sentence.
- Grammar GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);
+ Grammar GetGrammar(
+ const string& sentence,
+ const unordered_set<int>& blacklisted_sentence_ids);
private:
// Splits the sentence in a vector of words.
diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc
index f32a9599..719e90ff 100644
--- a/extractor/grammar_extractor_test.cc
+++ b/extractor/grammar_extractor_test.cc
@@ -41,13 +41,13 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) {
Grammar grammar(rules, feature_names);
unordered_set<int> blacklisted_sentence_ids;
shared_ptr<DataArray> source_data_array;
- EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array))
+ EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids))
.WillOnce(Return(grammar));
GrammarExtractor extractor(vocabulary, factory);
string sentence = "Anna has many many apples .";
- extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array);
+ extractor.GetGrammar(sentence, blacklisted_sentence_ids);
}
} // namespace
diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h
index 4bdcf21f..98e711d2 100644
--- a/extractor/mocks/mock_data_array.h
+++ b/extractor/mocks/mock_data_array.h
@@ -13,7 +13,6 @@ class MockDataArray : public DataArray {
MOCK_CONST_METHOD2(GetWords, vector<string>(int start_index, int size));
MOCK_CONST_METHOD0(GetSize, int());
MOCK_CONST_METHOD0(GetVocabularySize, int());
- MOCK_CONST_METHOD1(HasWord, bool(const string& word));
MOCK_CONST_METHOD1(GetWordId, int(const string& word));
MOCK_CONST_METHOD1(GetWord, string(int word_id));
MOCK_CONST_METHOD1(GetSentenceLength, int(int sentence_id));
diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h
index 6b7b6586..53eb5022 100644
--- a/extractor/mocks/mock_rule_factory.h
+++ b/extractor/mocks/mock_rule_factory.h
@@ -7,9 +7,9 @@ namespace extractor {
class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {
public:
- MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const
- unordered_set<int>& blacklisted_sentence_ids,
- const shared_ptr<DataArray> source_data_array));
+ MOCK_METHOD2(GetGrammar, Grammar(
+ const vector<int>& word_ids,
+ const unordered_set<int>& blacklisted_sentence_ids));
};
} // namespace extractor
diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h
index 75c43c27..b2742f62 100644
--- a/extractor/mocks/mock_sampler.h
+++ b/extractor/mocks/mock_sampler.h
@@ -7,7 +7,9 @@ namespace extractor {
class MockSampler : public Sampler {
public:
- MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location));
+ MOCK_CONST_METHOD2(Sample, PhraseLocation(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids));
};
} // namespace extractor
diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc
index 38d8f489..b79daae3 100644
--- a/extractor/precomputation.cc
+++ b/extractor/precomputation.cc
@@ -5,60 +5,67 @@
#include "data_array.h"
#include "suffix_array.h"
+#include "time_util.h"
#include "vocabulary.h"
using namespace std;
namespace extractor {
-int Precomputation::NONTERMINAL = -1;
-
Precomputation::Precomputation(
shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array,
int num_frequent_patterns, int num_super_frequent_patterns,
int max_rule_span, int max_rule_symbols, int min_gap_size,
int max_frequent_phrase_len, int min_frequency) {
+ Clock::time_point start_time = Clock::now();
+ shared_ptr<DataArray> data_array = suffix_array->GetData();
+ vector<int> data = data_array->GetData();
vector<vector<int>> frequent_patterns = FindMostFrequentPatterns(
- suffix_array, num_frequent_patterns, max_frequent_phrase_len,
+ suffix_array, data, num_frequent_patterns, max_frequent_phrase_len,
min_frequency);
+ Clock::time_point end_time = Clock::now();
+ cerr << "Finding most frequent patterns took "
+ << GetDuration(start_time, end_time) << " seconds..." << endl;
- // Construct sets containing the frequent and superfrequent contiguous
- // collocations.
- unordered_set<vector<int>, VectorHash> frequent_patterns_set;
- unordered_set<vector<int>, VectorHash> super_frequent_patterns_set;
+ vector<vector<int>> pattern_annotations(frequent_patterns.size());
+ unordered_map<vector<int>, int, VectorHash> frequent_patterns_index;
for (size_t i = 0; i < frequent_patterns.size(); ++i) {
- frequent_patterns_set.insert(frequent_patterns[i]);
- if (i < num_super_frequent_patterns) {
- super_frequent_patterns_set.insert(frequent_patterns[i]);
- }
+ frequent_patterns_index[frequent_patterns[i]] = i;
+ pattern_annotations[i] = AnnotatePattern(vocabulary, data_array,
+ frequent_patterns[i]);
}
- shared_ptr<DataArray> data_array = suffix_array->GetData();
+ start_time = Clock::now();
vector<tuple<int, int, int>> matchings;
- for (size_t i = 0; i < data_array->GetSize(); ++i) {
+ vector<vector<int>> annotations;
+ for (size_t i = 0; i < data.size(); ++i) {
// If the sentence is over, add all the discontiguous frequent patterns to
// the index.
- if (data_array->AtIndex(i) == DataArray::END_OF_LINE) {
- UpdateIndex(data_array, vocabulary, matchings, max_rule_span,
- min_gap_size, max_rule_symbols);
+ if (data[i] == DataArray::END_OF_LINE) {
+ UpdateIndex(matchings, annotations, max_rule_span, min_gap_size,
+ max_rule_symbols);
matchings.clear();
+ annotations.clear();
continue;
}
// Find all the contiguous frequent patterns starting at position i.
vector<int> pattern;
- for (int j = 1;
- j <= max_frequent_phrase_len && i + j <= data_array->GetSize();
- ++j) {
- pattern.push_back(data_array->AtIndex(i + j - 1));
- if (!frequent_patterns_set.count(pattern)) {
+ for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) {
+ pattern.push_back(data[i + j - 1]);
+ auto it = frequent_patterns_index.find(pattern);
+ if (it == frequent_patterns_index.end()) {
// If the current pattern is not frequent, any longer pattern having the
// current pattern as prefix will not be frequent.
break;
}
- int is_super_frequent = super_frequent_patterns_set.count(pattern);
+ int is_super_frequent = it->second < num_super_frequent_patterns;
matchings.push_back(make_tuple(i, j, is_super_frequent));
+ annotations.push_back(pattern_annotations[it->second]);
}
}
+ end_time = Clock::now();
+ cerr << "Constructing collocations index took "
+ << GetDuration(start_time, end_time) << " seconds..." << endl;
}
Precomputation::Precomputation() {}
@@ -66,8 +73,8 @@ Precomputation::Precomputation() {}
Precomputation::~Precomputation() {}
vector<vector<int>> Precomputation::FindMostFrequentPatterns(
- shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
- int max_frequent_phrase_len, int min_frequency) {
+ shared_ptr<SuffixArray> suffix_array, const vector<int>& data,
+ int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) {
vector<int> lcp = suffix_array->BuildLCPArray();
vector<int> run_start(max_frequent_phrase_len);
@@ -76,9 +83,9 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
for (size_t i = 1; i < lcp.size(); ++i) {
for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) {
int frequency = i - run_start[len];
- if (frequency >= min_frequency) {
- heap.push(make_pair(frequency,
- make_pair(suffix_array->GetSuffix(run_start[len]), len + 1)));
+ int start = suffix_array->GetSuffix(run_start[len]);
+ if (frequency >= min_frequency && start + len <= data.size()) {
+ heap.push(make_pair(frequency, make_pair(start, len + 1)));
}
run_start[len] = i;
}
@@ -101,9 +108,20 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
return frequent_patterns;
}
+vector<int> Precomputation::AnnotatePattern(
+ shared_ptr<Vocabulary> vocabulary, shared_ptr<DataArray> data_array,
+ const vector<int>& pattern) const {
+ vector<int> annotation;
+ for (int word_id: pattern) {
+ annotation.push_back(vocabulary->GetTerminalIndex(
+ data_array->GetWord(word_id)));
+ }
+ return annotation;
+}
+
void Precomputation::UpdateIndex(
- shared_ptr<DataArray> data_array, shared_ptr<Vocabulary> vocabulary,
const vector<tuple<int, int, int>>& matchings,
+ const vector<vector<int>>& annotations,
int max_rule_span, int min_gap_size, int max_rule_symbols) {
// Select the leftmost subpattern.
for (size_t i = 0; i < matchings.size(); ++i) {
@@ -121,15 +139,14 @@ void Precomputation::UpdateIndex(
if (start2 - start1 - size1 >= min_gap_size
&& start2 + size2 - start1 <= max_rule_span
&& size1 + size2 + 1 <= max_rule_symbols) {
- vector<int> pattern;
- AppendSubpattern(pattern, data_array, vocabulary, start1, size1);
- pattern.push_back(Precomputation::NONTERMINAL);
- AppendSubpattern(pattern, data_array, vocabulary, start2, size2);
- AppendCollocation(index[pattern], {start1, start2});
+ vector<int> pattern = annotations[i];
+ pattern.push_back(-1);
+ AppendSubpattern(pattern, annotations[j]);
+ AppendCollocation(index[pattern], start1, start2);
// Try extending the binary collocation to a ternary collocation.
if (is_super2) {
- pattern.push_back(Precomputation::NONTERMINAL);
+ pattern.push_back(-2);
// Select the rightmost subpattern.
for (size_t k = j + 1; k < matchings.size(); ++k) {
int start3, size3, is_super3;
@@ -142,8 +159,8 @@ void Precomputation::UpdateIndex(
&& start3 + size3 - start1 <= max_rule_span
&& size1 + size2 + size3 + 2 <= max_rule_symbols
&& (is_super1 || is_super3)) {
- AppendSubpattern(pattern, data_array, vocabulary, start3, size3);
- AppendCollocation(index[pattern], {start1, start2, start3});
+ AppendSubpattern(pattern, annotations[k]);
+ AppendCollocation(index[pattern], start1, start2, start3);
pattern.erase(pattern.end() - size3);
}
}
@@ -154,17 +171,22 @@ void Precomputation::UpdateIndex(
}
void Precomputation::AppendSubpattern(
- vector<int>& pattern, shared_ptr<DataArray> data_array,
- shared_ptr<Vocabulary> vocabulary, int start, int size) {
- vector<string> words = data_array->GetWords(start, size);
- for (const string& word: words) {
- pattern.push_back(vocabulary->GetTerminalIndex(word));
- }
+ vector<int>& pattern,
+ const vector<int>& subpattern) {
+ copy(subpattern.begin(), subpattern.end(), back_inserter(pattern));
+}
+
+void Precomputation::AppendCollocation(
+ vector<int>& collocations, int pos1, int pos2) {
+ collocations.push_back(pos1);
+ collocations.push_back(pos2);
}
void Precomputation::AppendCollocation(
- vector<int>& collocations, const vector<int>& collocation) {
- copy(collocation.begin(), collocation.end(), back_inserter(collocations));
+ vector<int>& collocations, int pos1, int pos2, int pos3) {
+ collocations.push_back(pos1);
+ collocations.push_back(pos2);
+ collocations.push_back(pos3);
}
bool Precomputation::Contains(const vector<int>& pattern) const {
diff --git a/extractor/precomputation.h b/extractor/precomputation.h
index 6ade58df..2b34fc29 100644
--- a/extractor/precomputation.h
+++ b/extractor/precomputation.h
@@ -55,28 +55,32 @@ class Precomputation {
bool operator==(const Precomputation& other) const;
- static int NONTERMINAL;
-
private:
// Finds the most frequent contiguous collocations.
vector<vector<int>> FindMostFrequentPatterns(
- shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
- int max_frequent_phrase_len, int min_frequency);
+ shared_ptr<SuffixArray> suffix_array, const vector<int>& data,
+ int num_frequent_patterns, int max_frequent_phrase_len,
+ int min_frequency);
+
+ vector<int> AnnotatePattern(shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<DataArray> data_array,
+ const vector<int>& pattern) const;
// Given the locations of the frequent contiguous collocations in a sentence,
// it adds new entries to the index for each discontiguous collocation
// matching the criteria specified in the class description.
void UpdateIndex(
- shared_ptr<DataArray> data_array, shared_ptr<Vocabulary> vocabulary,
const vector<tuple<int, int, int>>& matchings,
+ const vector<vector<int>>& annotations,
int max_rule_span, int min_gap_size, int max_rule_symbols);
- void AppendSubpattern(
- vector<int>& pattern, shared_ptr<DataArray> data_array,
- shared_ptr<Vocabulary> vocabulary, int start, int size);
+ void AppendSubpattern(vector<int>& pattern, const vector<int>& subpattern);
+
+ // Adds an occurrence of a binary collocation.
+ void AppendCollocation(vector<int>& collocations, int pos1, int pos2);
- // Adds an occurrence of a collocation.
- void AppendCollocation(vector<int>& collocations, const vector<int>& collocation);
+ // Adds an occurrence of a ternary collocation.
+ void AppendCollocation(vector<int>& collocations, int pos1, int pos2, int pos3);
friend class boost::serialization::access;
diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc
index fd85fcf8..d5f5ef63 100644
--- a/extractor/precomputation_test.cc
+++ b/extractor/precomputation_test.cc
@@ -24,31 +24,12 @@ class PrecomputationTest : public Test {
virtual void SetUp() {
data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1};
data_array = make_shared<MockDataArray>();
- EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(data.size()));
+ EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data));
for (size_t i = 0; i < data.size(); ++i) {
EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));
}
- vector<pair<int, int>> expected_calls = {{8, 1}, {8, 2}, {6, 1}};
- for (const auto& call: expected_calls) {
- int start = call.first;
- int size = call.second;
- vector<int> word_ids(data.begin() + start, data.begin() + start + size);
- EXPECT_CALL(*data_array, GetWordIds(start, size))
- .WillRepeatedly(Return(word_ids));
- }
-
- expected_calls = {{1, 1}, {5, 1}, {8, 1}, {9, 1}, {5, 2},
- {6, 1}, {8, 2}, {1, 2}, {2, 1}, {11, 1}};
- for (const auto& call: expected_calls) {
- int start = call.first;
- int size = call.second;
- vector<string> words;
- for (size_t j = start; j < start + size; ++j) {
- words.push_back(to_string(data[j]));
- }
- EXPECT_CALL(*data_array, GetWords(start, size))
- .WillRepeatedly(Return(words));
- }
+ EXPECT_CALL(*data_array, GetWord(2)).WillRepeatedly(Return("2"));
+ EXPECT_CALL(*data_array, GetWord(3)).WillRepeatedly(Return("3"));
vector<int> suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13};
vector<int> lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0};
@@ -117,37 +98,37 @@ TEST_F(PrecomputationTest, TestCollocations) {
expected_value = {1, 5, 8, 5, 8, 11};
EXPECT_TRUE(precomputation.Contains(key));
EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
- key = {2, -1, 2, -1, 3};
+ key = {2, -1, 2, -2, 3};
expected_value = {1, 5, 9};
EXPECT_TRUE(precomputation.Contains(key));
EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
- key = {2, -1, 3, -1, 2};
+ key = {2, -1, 3, -2, 2};
expected_value = {1, 6, 8, 5, 9, 11};
EXPECT_TRUE(precomputation.Contains(key));
EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
- key = {2, -1, 3, -1, 3};
+ key = {2, -1, 3, -2, 3};
expected_value = {1, 6, 9};
EXPECT_TRUE(precomputation.Contains(key));
EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
- key = {3, -1, 2, -1, 2};
+ key = {3, -1, 2, -2, 2};
expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11};
EXPECT_TRUE(precomputation.Contains(key));
EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
- key = {3, -1, 2, -1, 3};
+ key = {3, -1, 2, -2, 3};
expected_value = {2, 5, 9};
EXPECT_TRUE(precomputation.Contains(key));
EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
- key = {3, -1, 3, -1, 2};
+ key = {3, -1, 3, -2, 2};
expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11};
EXPECT_TRUE(precomputation.Contains(key));
EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
- key = {3, -1, 3, -1, 3};
+ key = {3, -1, 3, -2, 3};
expected_value = {2, 6, 9};
EXPECT_TRUE(precomputation.Contains(key));
EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
// Exceeds max_rule_symbols.
- key = {2, -1, 2, -1, 2, 3};
+ key = {2, -1, 2, -2, 2, 3};
EXPECT_FALSE(precomputation.Contains(key));
// Contains non frequent pattern.
key = {2, -1, 5};
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 6ae2d792..5b66f685 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -101,7 +101,9 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {}
HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
-Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
+Grammar HieroCachingRuleFactory::GetGrammar(
+ const vector<int>& word_ids,
+ const unordered_set<int>& blacklisted_sentence_ids) {
Clock::time_point start_time = Clock::now();
double total_extract_time = 0;
double total_intersect_time = 0;
@@ -193,7 +195,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const u
Clock::time_point extract_start = Clock::now();
if (!state.starts_with_x) {
// Extract rules for the sampled set of occurrences.
- PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array);
+ PhraseLocation sample = sampler->Sample(
+ next_node->matchings, blacklisted_sentence_ids);
vector<Rule> new_rules =
rule_extractor->ExtractRules(next_phrase, sample);
rules.insert(rules.end(), new_rules.begin(), new_rules.end());
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index a1ff76e4..1a9fa2af 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -74,8 +74,7 @@ class HieroCachingRuleFactory {
// (See class description for more details.)
virtual Grammar GetGrammar(
const vector<int>& word_ids,
- const unordered_set<int>& blacklisted_sentence_ids,
- const shared_ptr<DataArray> source_data_array);
+ const unordered_set<int>& blacklisted_sentence_ids);
protected:
HieroCachingRuleFactory();
diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc
index f26cc567..332c5959 100644
--- a/extractor/rule_factory_test.cc
+++ b/extractor/rule_factory_test.cc
@@ -40,7 +40,7 @@ class RuleFactoryTest : public Test {
.WillRepeatedly(Return(feature_names));
sampler = make_shared<MockSampler>();
- EXPECT_CALL(*sampler, Sample(_))
+ EXPECT_CALL(*sampler, Sample(_, _))
.WillRepeatedly(Return(PhraseLocation(0, 1)));
Phrase phrase;
@@ -77,8 +77,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
vector<int> word_ids = {2, 3, 4};
unordered_set<int> blacklisted_sentence_ids;
- shared_ptr<DataArray> source_data_array;
- Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
+ Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(7, grammar.GetRules().size());
}
@@ -97,8 +96,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
vector<int> word_ids = {2, 3, 4, 2, 3};
unordered_set<int> blacklisted_sentence_ids;
- shared_ptr<DataArray> source_data_array;
- Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
+ Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(28, grammar.GetRules().size());
}
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 85c8a422..f1aa5e35 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -5,10 +5,10 @@
#include <string>
#include <vector>
-#include <omp.h>
#include <boost/filesystem.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include <omp.h>
#include "alignment.h"
#include "data_array.h"
@@ -78,7 +78,8 @@ int main(int argc, char** argv) {
("tight_phrases", po::value<bool>()->default_value(true),
"False if phrases may be loose (better, but slower)")
("leave_one_out", po::value<bool>()->zero_tokens(),
- "do leave-one-out estimation of grammars (e.g. for extracting grammars for the training set");
+ "do leave-one-out estimation of grammars "
+ "(e.g. for extracting grammars for the training set");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
@@ -99,11 +100,6 @@ int main(int argc, char** argv) {
return 1;
}
- bool leave_one_out = false;
- if (vm.count("leave_one_out")) {
- leave_one_out = true;
- }
-
int num_threads = vm["threads"].as<int>();
cerr << "Grammar extraction will use " << num_threads << " threads." << endl;
@@ -178,8 +174,8 @@ int main(int argc, char** argv) {
<< GetDuration(preprocess_start_time, preprocess_stop_time)
<< " seconds" << endl;
- // Features used to score each grammar rule.
Clock::time_point extraction_start_time = Clock::now();
+ // Features used to score each grammar rule.
vector<shared_ptr<Feature>> features = {
make_shared<TargetGivenSourceCoherent>(),
make_shared<SampleSourceCount>(),
@@ -206,9 +202,6 @@ int main(int argc, char** argv) {
vm["max_samples"].as<int>(),
vm["tight_phrases"].as<bool>());
- // Releases extra memory used by the initial precomputation.
- precomputation.reset();
-
// Creates the grammars directory if it doesn't exist.
fs::path grammar_path = vm["grammars"].as<string>();
if (!fs::is_directory(grammar_path)) {
@@ -224,6 +217,7 @@ int main(int argc, char** argv) {
}
// Extracts the grammar for each sentence and saves it to a file.
+ bool leave_one_out = vm.count("leave_one_out");
vector<string> suffixes(sentences.size());
#pragma omp parallel for schedule(dynamic) num_threads(num_threads)
for (size_t i = 0; i < sentences.size(); ++i) {
@@ -236,8 +230,11 @@ int main(int argc, char** argv) {
suffixes[i] = suffix;
unordered_set<int> blacklisted_sentence_ids;
- if (leave_one_out) blacklisted_sentence_ids.insert(i);
- Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array);
+ if (leave_one_out) {
+ blacklisted_sentence_ids.insert(i);
+ }
+ Grammar grammar = extractor.GetGrammar(
+ sentences[i], blacklisted_sentence_ids);
ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
output << grammar;
}
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
index 963afa7a..887aaec1 100644
--- a/extractor/sampler.cc
+++ b/extractor/sampler.cc
@@ -12,7 +12,10 @@ Sampler::Sampler() {}
Sampler::~Sampler() {}
-PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const {
+PhraseLocation Sampler::Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const {
+ shared_ptr<DataArray> source_data_array = suffix_array->GetData();
vector<int> sample;
int num_subpatterns;
if (location.matchings == NULL) {
@@ -20,30 +23,30 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s
num_subpatterns = 1;
int low = location.sa_low, high = location.sa_high;
double step = max(1.0, (double) (high - low) / max_samples);
- double i = low, last = i;
- bool found;
+ double i = low, last = i - 1;
while (sample.size() < max_samples && i < high) {
int x = suffix_array->GetSuffix(Round(i));
int id = source_data_array->GetSentenceId(x);
- if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) {
- found = false;
- double backoff_step = 1;
- while (true) {
- if ((double)backoff_step >= step) break;
+ bool found = false;
+ if (blacklisted_sentence_ids.count(id)) {
+ for (int backoff_step = 1; backoff_step <= step; ++backoff_step) {
double j = i - backoff_step;
x = suffix_array->GetSuffix(Round(j));
id = source_data_array->GetSentenceId(x);
- if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
- found = true; last = i; break;
+ if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) {
+ found = true;
+ last = i;
+ break;
}
double k = i + backoff_step;
x = suffix_array->GetSuffix(Round(k));
id = source_data_array->GetSentenceId(x);
- if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
- found = true; last = k; break;
+ if (k < min(i+step, (double) high) &&
+ !blacklisted_sentence_ids.count(id)) {
+ found = true;
+ last = k;
+ break;
}
- if (j <= last && k >= high) break;
- backoff_step++;
}
} else {
found = true;
diff --git a/extractor/sampler.h b/extractor/sampler.h
index de450c48..bd8a5876 100644
--- a/extractor/sampler.h
+++ b/extractor/sampler.h
@@ -23,7 +23,9 @@ class Sampler {
virtual ~Sampler();
// Samples uniformly at most max_samples phrase occurrences.
- virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const;
+ virtual PhraseLocation Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const;
protected:
Sampler();
diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc
index 965567ba..14e72780 100644
--- a/extractor/sampler_test.cc
+++ b/extractor/sampler_test.cc
@@ -19,6 +19,8 @@ class SamplerTest : public Test {
source_data_array = make_shared<MockDataArray>();
EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999));
suffix_array = make_shared<MockSuffixArray>();
+ EXPECT_CALL(*suffix_array, GetData())
+ .WillRepeatedly(Return(source_data_array));
for (int i = 0; i < 10; ++i) {
EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
}
@@ -35,23 +37,29 @@ TEST_F(SamplerTest, TestSuffixArrayRange) {
sampler = make_shared<Sampler>(suffix_array, 1);
vector<int> expected_locations = {0};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler->Sample(location, blacklist));
+ return;
sampler = make_shared<Sampler>(suffix_array, 2);
expected_locations = {0, 5};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler->Sample(location, blacklist));
sampler = make_shared<Sampler>(suffix_array, 3);
expected_locations = {0, 3, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler->Sample(location, blacklist));
sampler = make_shared<Sampler>(suffix_array, 4);
expected_locations = {0, 3, 5, 8};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler->Sample(location, blacklist));
sampler = make_shared<Sampler>(suffix_array, 100);
expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler->Sample(location, blacklist));
}
TEST_F(SamplerTest, TestSubstringsSample) {
@@ -61,19 +69,23 @@ TEST_F(SamplerTest, TestSubstringsSample) {
sampler = make_shared<Sampler>(suffix_array, 1);
vector<int> expected_locations = {0, 1};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklist));
sampler = make_shared<Sampler>(suffix_array, 2);
expected_locations = {0, 1, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklist));
sampler = make_shared<Sampler>(suffix_array, 3);
expected_locations = {0, 1, 4, 5, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklist));
sampler = make_shared<Sampler>(suffix_array, 7);
expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklist));
}
} // namespace
diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc
index ac230d13..4a514b12 100644
--- a/extractor/suffix_array.cc
+++ b/extractor/suffix_array.cc
@@ -187,12 +187,12 @@ shared_ptr<DataArray> SuffixArray::GetData() const {
PhraseLocation SuffixArray::Lookup(int low, int high, const string& word,
int offset) const {
- if (!data_array->HasWord(word)) {
+ int word_id = data_array->GetWordId(word);
+ if (word_id == -1) {
// Return empty phrase location.
return PhraseLocation(0, 0);
}
- int word_id = data_array->GetWordId(word);
if (offset == 0) {
return PhraseLocation(word_start[word_id], word_start[word_id + 1]);
}
diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc
index a9fd1eab..161edbc0 100644
--- a/extractor/suffix_array_test.cc
+++ b/extractor/suffix_array_test.cc
@@ -55,22 +55,18 @@ TEST_F(SuffixArrayTest, TestLookup) {
EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));
}
- EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6));
EXPECT_EQ(PhraseLocation(11, 14), suffix_array.Lookup(0, 14, "word1", 0));
- EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false));
+ EXPECT_CALL(*data_array, GetWordId("word2")).WillRepeatedly(Return(-1));
EXPECT_EQ(PhraseLocation(0, 0), suffix_array.Lookup(0, 14, "word2", 0));
- EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4));
EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 14, "word3", 1));
- EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1));
EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word4", 2));
- EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2));
EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word5", 3));
diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc
index 1b1ba112..11e29e1e 100644
--- a/extractor/translation_table.cc
+++ b/extractor/translation_table.cc
@@ -90,13 +90,12 @@ void TranslationTable::IncrementLinksCount(
double TranslationTable::GetTargetGivenSourceScore(
const string& source_word, const string& target_word) {
- if (!source_data_array->HasWord(source_word) ||
- !target_data_array->HasWord(target_word)) {
+ int source_id = source_data_array->GetWordId(source_word);
+ int target_id = target_data_array->GetWordId(target_word);
+ if (source_id == -1 || target_id == -1) {
return -1;
}
- int source_id = source_data_array->GetWordId(source_word);
- int target_id = target_data_array->GetWordId(target_word);
auto entry = make_pair(source_id, target_id);
auto it = translation_probabilities.find(entry);
if (it == translation_probabilities.end()) {
@@ -107,13 +106,12 @@ double TranslationTable::GetTargetGivenSourceScore(
double TranslationTable::GetSourceGivenTargetScore(
const string& source_word, const string& target_word) {
- if (!source_data_array->HasWord(source_word) ||
- !target_data_array->HasWord(target_word)) {
+ int source_id = source_data_array->GetWordId(source_word);
+ int target_id = target_data_array->GetWordId(target_word);
+ if (source_id == -1 || target_id == -1) {
return -1;
}
- int source_id = source_data_array->GetWordId(source_word);
- int target_id = target_data_array->GetWordId(target_word);
auto entry = make_pair(source_id, target_id);
auto it = translation_probabilities.find(entry);
if (it == translation_probabilities.end()) {
diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc
index 72551a12..3cfc0011 100644
--- a/extractor/translation_table_test.cc
+++ b/extractor/translation_table_test.cc
@@ -36,13 +36,10 @@ class TranslationTableTest : public Test {
.WillRepeatedly(Return(source_sentence_start[i]));
}
for (size_t i = 0; i < words.size(); ++i) {
- EXPECT_CALL(*source_data_array, HasWord(words[i]))
- .WillRepeatedly(Return(true));
EXPECT_CALL(*source_data_array, GetWordId(words[i]))
.WillRepeatedly(Return(i + 2));
}
- EXPECT_CALL(*source_data_array, HasWord("d"))
- .WillRepeatedly(Return(false));
+ EXPECT_CALL(*source_data_array, GetWordId("d")).WillRepeatedly(Return(-1));
vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0};
vector<int> target_sentence_start = {0, 7, 10, 13};
@@ -54,13 +51,10 @@ class TranslationTableTest : public Test {
.WillRepeatedly(Return(target_sentence_start[i]));
}
for (size_t i = 0; i < words.size(); ++i) {
- EXPECT_CALL(*target_data_array, HasWord(words[i]))
- .WillRepeatedly(Return(true));
EXPECT_CALL(*target_data_array, GetWordId(words[i]))
.WillRepeatedly(Return(i + 2));
}
- EXPECT_CALL(*target_data_array, HasWord("d"))
- .WillRepeatedly(Return(false));
+ EXPECT_CALL(*target_data_array, GetWordId("d")).WillRepeatedly(Return(-1));
vector<pair<int, int>> links1 = {
make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3),
diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc
index 15795d1e..c9c2d6f4 100644
--- a/extractor/vocabulary.cc
+++ b/extractor/vocabulary.cc
@@ -8,12 +8,13 @@ int Vocabulary::GetTerminalIndex(const string& word) {
int word_id = -1;
#pragma omp critical (vocabulary)
{
- if (!dictionary.count(word)) {
+ auto it = dictionary.find(word);
+ if (it != dictionary.end()) {
+ word_id = it->second;
+ } else {
word_id = words.size();
dictionary[word] = word_id;
words.push_back(word);
- } else {
- word_id = dictionary[word];
}
}
return word_id;
@@ -34,4 +35,8 @@ string Vocabulary::GetTerminalValue(int symbol) {
return word;
}
+bool Vocabulary::operator==(const Vocabulary& other) const {
+ return words == other.words && dictionary == other.dictionary;
+}
+
} // namespace extractor
diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h
index c8fd9411..db092e99 100644
--- a/extractor/vocabulary.h
+++ b/extractor/vocabulary.h
@@ -5,6 +5,10 @@
#include <unordered_map>
#include <vector>
+#include <boost/serialization/serialization.hpp>
+#include <boost/serialization/string.hpp>
+#include <boost/serialization/vector.hpp>
+
using namespace std;
namespace extractor {
@@ -14,7 +18,7 @@ namespace extractor {
*
* This strucure contains words located in the frequent collocations and words
* encountered during the grammar extraction time. This dictionary is
- * considerably smaller than the dictionaries in the data arrays (and so is the
+ * considerably smaller than the dictionaries in the data arays (and so is the
* query time). Note that this is the single data structure that changes state
* and needs to have thread safe read/write operations.
*
@@ -38,7 +42,24 @@ class Vocabulary {
// Returns the word corresponding to the given word id.
virtual string GetTerminalValue(int symbol);
+ bool operator==(const Vocabulary& vocabulary) const;
+
private:
+ friend class boost::serialization::access;
+
+ template<class Archive> void save(Archive& ar, unsigned int) const {
+ ar << words;
+ }
+
+ template<class Archive> void load(Archive& ar, unsigned int) {
+ ar >> words;
+ for (size_t i = 0; i < words.size(); ++i) {
+ dictionary[words[i]] = i;
+ }
+ }
+
+ BOOST_SERIALIZATION_SPLIT_MEMBER();
+
unordered_map<string, int> dictionary;
vector<string> words;
};
diff --git a/extractor/vocabulary_test.cc b/extractor/vocabulary_test.cc
new file mode 100644
index 00000000..cf5e3e36
--- /dev/null
+++ b/extractor/vocabulary_test.cc
@@ -0,0 +1,45 @@
+#include <gtest/gtest.h>
+
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/archive/text_oarchive.hpp>
+
+#include "vocabulary.h"
+
+using namespace std;
+using namespace ::testing;
+namespace ar = boost::archive;
+
+namespace extractor {
+namespace {
+
+TEST(VocabularyTest, TestIndexes) {
+ Vocabulary vocabulary;
+ EXPECT_EQ(0, vocabulary.GetTerminalIndex("zero"));
+ EXPECT_EQ("zero", vocabulary.GetTerminalValue(0));
+
+ EXPECT_EQ(1, vocabulary.GetTerminalIndex("one"));
+ EXPECT_EQ("one", vocabulary.GetTerminalValue(1));
+}
+
+TEST(VocabularyTest, TestSerialization) {
+ Vocabulary vocabulary;
+ EXPECT_EQ(0, vocabulary.GetTerminalIndex("zero"));
+ EXPECT_EQ("zero", vocabulary.GetTerminalValue(0));
+
+ stringstream stream(ios_base::out | ios_base::in);
+ ar::text_oarchive output_stream(stream, ar::no_header);
+ output_stream << vocabulary;
+
+ Vocabulary vocabulary_copy;
+ ar::text_iarchive input_stream(stream, ar::no_header);
+ input_stream >> vocabulary_copy;
+
+ EXPECT_EQ(vocabulary, vocabulary_copy);
+}
+
+} // namespace
+} // namespace extractor
diff --git a/mteval/external_scorer.h b/mteval/external_scorer.h
index a28fb920..85535655 100644
--- a/mteval/external_scorer.h
+++ b/mteval/external_scorer.h
@@ -24,7 +24,8 @@ class ScoreServer {
int c2p[2];
};
-struct ScoreServerManager {
+class ScoreServerManager {
+ public:
static ScoreServer* Instance(const std::string& score_type);
private:
static std::map<std::string, boost::shared_ptr<ScoreServer> > servers_;
diff --git a/mteval/ns_ext.cc b/mteval/ns_ext.cc
index 956708af..1e7e2bc1 100644
--- a/mteval/ns_ext.cc
+++ b/mteval/ns_ext.cc
@@ -118,7 +118,7 @@ void ExternalMetric::ComputeSufficientStatistics(const std::vector<WordID>& hyp,
}
float ExternalMetric::ComputeScore(const SufficientStats& stats) const {
- eval_server->ComputeScore(stats.fields);
+ return eval_server->ComputeScore(stats.fields);
}
ExternalMetric::ExternalMetric(const string& metric_name, const std::string& command) :
diff --git a/mteval/ns_ter.cc b/mteval/ns_ter.cc
index 680fb7b4..00b6eb01 100644
--- a/mteval/ns_ter.cc
+++ b/mteval/ns_ter.cc
@@ -298,7 +298,7 @@ class TERScorerImpl {
}
bool CalculateBestShift(const vector<WordID>& cur,
- const vector<WordID>& hyp,
+ const vector<WordID>& /*hyp*/,
float curerr,
const vector<TransType>& path,
vector<WordID>* new_hyp,
diff --git a/mteval/scorer.h b/mteval/scorer.h
index bb1e89ae..8d986612 100644
--- a/mteval/scorer.h
+++ b/mteval/scorer.h
@@ -103,7 +103,7 @@ class DocScorer {
virtual int size() const { return scorers_.size(); }
virtual ScorerP operator[](size_t i) const { return scorers_[i]; }
- virtual void update(const std::string& ref) {}
+ virtual void update(const std::string& /*ref*/) {}
private:
std::vector<ScorerP> scorers_;
};
@@ -124,7 +124,7 @@ class DocStreamScorer : public DocScorer {
{
Init(type,ref_files,src_file,verbose);
}
- ScorerP operator[](size_t i) const { return scorer; }
+ ScorerP operator[](size_t /*i*/) const { return scorer; }
int size() const { return 1; }
void update(const std::string& ref);
private:
diff --git a/tests/run-system-tests.pl b/tests/run-system-tests.pl
index 8555ef78..324763ae 100755
--- a/tests/run-system-tests.pl
+++ b/tests/run-system-tests.pl
@@ -11,6 +11,7 @@ my $TEMP_DIR = tempdir( CLEANUP => 1 );
my $DECODER = "$script_dir/../decoder/cdec";
my $FILTER = "$script_dir/tools/filter-stderr.pl";
my $COMPARE_STATS = "$script_dir/tools/compare-statistics.pl";
+my $XDIFF = "$script_dir/tools/flex-diff.pl";
die "Can't find $DECODER" unless -f $DECODER;
die "Can't execute $DECODER" unless -x $DECODER;
@@ -18,6 +19,7 @@ die "Can't find $FILTER" unless -f $FILTER;
die "Can't execute $FILTER" unless -x $FILTER;
die "Can't find $COMPARE_STATS" unless -f $COMPARE_STATS;
die "Can't execute $COMPARE_STATS" unless -x $COMPARE_STATS;
+die "Can't execute $XDIFF" unless -x $XDIFF;
my $TEST_DIR = "$script_dir/system_tests";
opendir DIR, $TEST_DIR or die "Can't open $TEST_DIR: $!";
@@ -62,7 +64,7 @@ for my $test (@tests) {
} else {
die unless -f "$TEMP_DIR/stdout";
my $failed = 0;
- run3 "diff gold.stdout $TEMP_DIR/stdout";
+ run3 "$XDIFF gold.stdout $TEMP_DIR/stdout";
if ($? != 0) {
print STDERR " FAILED differences in output!\n";
$failed = 1;
diff --git a/tests/tools/flex-diff.pl b/tests/tools/flex-diff.pl
new file mode 100755
index 00000000..30f73c4d
--- /dev/null
+++ b/tests/tools/flex-diff.pl
@@ -0,0 +1,46 @@
+#!/usr/bin/perl -w
+use strict;
+
+my $script_dir; BEGIN { use Cwd qw/ abs_path cwd /; use File::Basename; $script_dir = dirname(abs_path($0)); push @INC, "$script_dir/.."; }
+
+use IPC::Run3;
+
+# this file abstracts away from differences due to different hash
+# functions that lead to different orders of features, n-best entries,
+# etc.
+
+die "Usage: $0 file1.txt file2.txt\n" unless scalar @ARGV == 2;
+my $tmpa = "tmp.$$.a";
+my $tmpb = "tmp.$$.b";
+create_sorted($ARGV[0], $tmpa);
+create_sorted($ARGV[1], $tmpb);
+
+my $failed = 0;
+run3 "diff $tmpa $tmpb";
+if ($? != 0) {
+ run3 "diff $ARGV[0] $ARGV[1]";
+ $failed = 1;
+}
+
+unlink $tmpa;
+unlink $tmpb;
+
+exit $failed;
+
+sub create_sorted {
+ my ($in, $out) = @_;
+ open A, "sort $in|" or die "Can't read $in: $!";
+ open AA, ">$out" or die "Can't write $out: $!";
+ while(<A>) {
+ chomp;
+ s/^\s*//;
+ s/\s*$//;
+ my @cs = split //;
+ @cs = sort @cs;
+ my $o = join('', @cs);
+ print AA "$o\n";
+ }
+ close AA;
+ close A;
+}
+
diff --git a/training/crf/Makefile.am b/training/crf/Makefile.am
index 4a8c30fd..cd82161f 100644
--- a/training/crf/Makefile.am
+++ b/training/crf/Makefile.am
@@ -1,5 +1,6 @@
bin_PROGRAMS = \
mpi_batch_optimize \
+ mpi_adagrad_optimize \
mpi_compute_cllh \
mpi_extract_features \
mpi_extract_reachable \
@@ -10,6 +11,9 @@ bin_PROGRAMS = \
mpi_baum_welch_SOURCES = mpi_baum_welch.cc
mpi_baum_welch_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz
+mpi_adagrad_optimize_SOURCES = mpi_adagrad_optimize.cc cllh_observer.cc cllh_observer.h
+mpi_adagrad_optimize_LDADD = ../../training/utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz
+
mpi_online_optimize_SOURCES = mpi_online_optimize.cc
mpi_online_optimize_LDADD = ../../training/utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz
diff --git a/training/crf/mpi_adagrad_optimize.cc b/training/crf/mpi_adagrad_optimize.cc
new file mode 100644
index 00000000..af963e3a
--- /dev/null
+++ b/training/crf/mpi_adagrad_optimize.cc
@@ -0,0 +1,394 @@
+#include <sstream>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <cassert>
+#include <cmath>
+#include <ctime>
+
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+#include <boost/shared_ptr.hpp>
+
+#include "config.h"
+#include "stringlib.h"
+#include "verbose.h"
+#include "cllh_observer.h"
+#include "hg.h"
+#include "prob.h"
+#include "inside_outside.h"
+#include "ff_register.h"
+#include "decoder.h"
+#include "filelib.h"
+#include "online_optimizer.h"
+#include "fdict.h"
+#include "weights.h"
+#include "sparse_vector.h"
+#include "sampler.h"
+
+#ifdef HAVE_MPI
+#include <boost/mpi/timer.hpp>
+#include <boost/mpi.hpp>
+namespace mpi = boost::mpi;
+#endif
+
+using namespace std;
+namespace po = boost::program_options;
+
+bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("weights,w",po::value<string>(), "Initial feature weights")
+ ("training_data,d",po::value<string>(), "Training data corpus")
+ ("test_data,t",po::value<string>(), "(optional) Test data")
+ ("decoder_config,c",po::value<string>(), "Decoder configuration file")
+ ("minibatch_size_per_proc,s", po::value<unsigned>()->default_value(8),
+ "Number of training instances evaluated per processor in each minibatch")
+ ("max_passes", po::value<double>()->default_value(20.0), "Maximum number of passes through the data")
+ ("max_walltime", po::value<unsigned>(), "Walltime to run (in minutes)")
+ ("write_every_n_minibatches", po::value<unsigned>()->default_value(100), "Write weights every N minibatches processed")
+ ("random_seed,S", po::value<uint32_t>(), "Random seed")
+ ("regularization,r", po::value<string>()->default_value("none"),
+ "Regularization 'none', 'l1', or 'l2'")
+ ("regularization_strength,C", po::value<double>(), "Regularization strength")
+ ("eta,e", po::value<double>()->default_value(1.0), "Initial learning rate (eta)");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ ifstream config((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) {
+ cerr << dcmdline_options << endl;
+ return false;
+ }
+ return true;
+}
+
+void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c, vector<int>* order) {
+ ReadFile rf(fname);
+ istream& in = *rf.stream();
+ string line;
+ int id = 0;
+ while(getline(in, line)) {
+ if (id % size == rank) {
+ c->push_back(line);
+ order->push_back(id);
+ }
+ ++id;
+ }
+}
+
+static const double kMINUS_EPSILON = -1e-6;
+
+struct TrainingObserver : public DecoderObserver {
+ void Reset() {
+ acc_grad.clear();
+ acc_obj = 0;
+ total_complete = 0;
+ }
+
+ virtual void NotifyDecodingStart(const SentenceMetadata&) {
+ cur_model_exp.clear();
+ cur_obj = 0;
+ state = 1;
+ }
+
+ // compute model expectations, denominator of objective
+ virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) {
+ assert(state == 1);
+ state = 2;
+ const prob_t z = InsideOutside<prob_t,
+ EdgeProb,
+ SparseVector<prob_t>,
+ EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp);
+ cur_obj = log(z);
+ cur_model_exp /= z;
+ }
+
+ // compute "empirical" expectations, numerator of objective
+ virtual void NotifyAlignmentForest(const SentenceMetadata&, Hypergraph* hg) {
+ assert(state == 2);
+ state = 3;
+ SparseVector<prob_t> ref_exp;
+ const prob_t ref_z = InsideOutside<prob_t,
+ EdgeProb,
+ SparseVector<prob_t>,
+ EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp);
+ ref_exp /= ref_z;
+
+ double log_ref_z;
+#if 0
+ if (crf_uniform_empirical) {
+ log_ref_z = ref_exp.dot(feature_weights);
+ } else {
+ log_ref_z = log(ref_z);
+ }
+#else
+ log_ref_z = log(ref_z);
+#endif
+
+ // rounding errors means that <0 is too strict
+ if ((cur_obj - log_ref_z) < kMINUS_EPSILON) {
+ cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl;
+ exit(1);
+ }
+ assert(!std::isnan(log_ref_z));
+ ref_exp -= cur_model_exp;
+ acc_grad += ref_exp;
+ acc_obj += (cur_obj - log_ref_z);
+ }
+
+ virtual void NotifyDecodingComplete(const SentenceMetadata&) {
+ if (state == 3) {
+ ++total_complete;
+ } else {
+ }
+ }
+
+ void GetGradient(SparseVector<double>* g) const {
+ g->clear();
+#if HAVE_CXX11
+ for (auto& gi : acc_grad) {
+#else
+ for (FastSparseVector<prob_t>::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) {
+ pair<unsigned, double>& gi = *it;
+#endif
+ g->set_value(gi.first, -gi.second.as_float());
+ }
+ }
+
+ int total_complete;
+ SparseVector<prob_t> cur_model_exp;
+ SparseVector<prob_t> acc_grad;
+ double acc_obj;
+ double cur_obj;
+ int state;
+};
+
+#ifdef HAVE_MPI
+namespace boost { namespace mpi {
+ template<>
+ struct is_commutative<std::plus<SparseVector<double> >, SparseVector<double> >
+ : mpl::true_ { };
+} } // end namespace boost::mpi
+#endif
+
+class AdaGradOptimizer {
+ public:
+ explicit AdaGradOptimizer(double e) :
+ eta(e),
+ G() {}
+ void update(const SparseVector<double>& g, vector<double>* x) {
+ if (x->size() > G.size()) G.resize(x->size(), 0.0);
+#if HAVE_CXX11
+ for (auto& gi : g) {
+#else
+ for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) {
+ const pair<unsigned,double>& gi = *it;
+#endif
+ if (gi.second) {
+ G[gi.first] += gi.second * gi.second;
+ (*x)[gi.first] -= eta / sqrt(G[gi.first]) * gi.second;
+ }
+ }
+ }
+ const double eta;
+ vector<double> G;
+};
+
+class AdaGradL1Optimizer {
+ public:
+ explicit AdaGradL1Optimizer(double e, double l) :
+ t(),
+ eta(e),
+ lambda(l),
+ G() {}
+ void update(const SparseVector<double>& g, vector<double>* x) {
+ t += 1.0;
+ if (x->size() > G.size()) {
+ G.resize(x->size(), 0.0);
+ u.resize(x->size(), 0.0);
+ }
+#if HAVE_CXX11
+ for (auto& gi : g) {
+#else
+ for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) {
+ const pair<unsigned,double>& gi = *it;
+#endif
+ if (gi.second) {
+ u[gi.first] += gi.second;
+ G[gi.first] += gi.second * gi.second;
+ double z = fabs(u[gi.first] / t) - lambda;
+ double s = 1;
+ if (u[gi.first] > 0) s = -1;
+ if (z > 0 && G[gi.first])
+ (*x)[gi.first] = eta * s * z * t / sqrt(G[gi.first]);
+ else
+ (*x)[gi.first] = 0.0;
+ }
+ }
+ }
+ double t;
+ const double eta;
+ const double lambda;
+ vector<double> G, u;
+};
+
+unsigned non_zeros(const vector<double>& x) {
+ unsigned nz = 0;
+ for (unsigned i = 0; i < x.size(); ++i)
+ if (x[i]) ++nz;
+ return nz;
+}
+
+int main(int argc, char** argv) {
+#ifdef HAVE_MPI
+ mpi::environment env(argc, argv);
+ mpi::communicator world;
+ const int size = world.size();
+ const int rank = world.rank();
+#else
+ const int size = 1;
+ const int rank = 0;
+#endif
+ if (size > 1) SetSilent(true); // turn off verbose decoder output
+ register_feature_functions();
+
+ po::variables_map conf;
+ if (!InitCommandLine(argc, argv, &conf))
+ return 1;
+
+ ReadFile ini_rf(conf["decoder_config"].as<string>());
+ Decoder decoder(ini_rf.stream());
+
+ // load initial weights
+ vector<weight_t> init_weights;
+ if (conf.count("input_weights"))
+ Weights::InitFromFile(conf["input_weights"].as<string>(), &init_weights);
+
+ vector<string> corpus, test_corpus;
+ vector<int> ids;
+ ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus, &ids);
+ assert(corpus.size() > 0);
+ if (conf.count("test_data"))
+ ReadTrainingCorpus(conf["test_data"].as<string>(), rank, size, &corpus, &ids);
+
+ const unsigned size_per_proc = conf["minibatch_size_per_proc"].as<unsigned>();
+ if (size_per_proc > corpus.size()) {
+ cerr << "Minibatch size must be smaller than corpus size!\n";
+ return 1;
+ }
+ const double minibatch_size = size_per_proc * size;
+
+ size_t total_corpus_size = 0;
+#ifdef HAVE_MPI
+ reduce(world, corpus.size(), total_corpus_size, std::plus<size_t>(), 0);
+#else
+ total_corpus_size = corpus.size();
+#endif
+
+ if (rank == 0)
+ cerr << "Total corpus size: " << total_corpus_size << endl;
+
+ boost::shared_ptr<MT19937> rng;
+ if (conf.count("random_seed"))
+ rng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
+ else
+ rng.reset(new MT19937);
+
+ double passes_per_minibatch = static_cast<double>(size_per_proc) / total_corpus_size;
+
+ int write_weights_every_ith = conf["write_every_n_minibatches"].as<unsigned>();
+
+ unsigned max_iteration = conf["max_passes"].as<double>() / passes_per_minibatch;
+ ++max_iteration;
+ if (rank == 0)
+ cerr << "Max passes through data = " << conf["max_passes"].as<double>() << endl
+ << " --> max minibatches = " << max_iteration << endl;
+ unsigned timeout = 0;
+ if (conf.count("max_walltime"))
+ timeout = 60 * conf["max_walltime"].as<unsigned>();
+ vector<weight_t>& lambdas = decoder.CurrentWeightVector();
+ if (init_weights.size()) {
+ lambdas.swap(init_weights);
+ init_weights.clear();
+ }
+
+ //AdaGradOptimizer adagrad(conf["eta"].as<double>());
+ AdaGradL1Optimizer adagrad(conf["eta"].as<double>(), conf["regularization_strength"].as<double>());
+ int iter = -1;
+ bool converged = false;
+
+ TrainingObserver observer;
+ ConditionalLikelihoodObserver cllh_observer;
+
+ const time_t start_time = time(NULL);
+ while (!converged) {
+#ifdef HAVE_MPI
+ mpi::timer timer;
+#endif
+ ++iter;
+ observer.Reset();
+ if (rank == 0) {
+ converged = (iter == max_iteration);
+ string fname = "weights.cur.gz";
+ if (iter % write_weights_every_ith == 0) {
+ ostringstream o; o << "weights." << iter << ".gz";
+ fname = o.str();
+ }
+ const time_t cur_time = time(NULL);
+ if (timeout && ((cur_time - start_time) > timeout)) {
+ converged = true;
+ fname = "weights.final.gz";
+ }
+ ostringstream vv;
+ double minutes = (cur_time - start_time) / 60.0;
+ vv << "total walltime=" << minutes << "min iter=" << iter << " minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << non_zeros(lambdas) << '/' << FD::NumFeats() << " passes_thru_data=" << (iter * size_per_proc / static_cast<double>(corpus.size()));
+ const string svv = vv.str();
+ cerr << svv << endl;
+ Weights::WriteToFile(fname, lambdas, true, &svv);
+ }
+
+ for (int i = 0; i < size_per_proc; ++i) {
+ int ei = corpus.size() * rng->next();
+ int id = ids[ei];
+ decoder.SetId(id);
+ decoder.Decode(corpus[ei], &observer);
+ }
+ SparseVector<double> local_grad, g;
+ observer.GetGradient(&local_grad);
+#ifdef HAVE_MPI
+ reduce(world, local_grad, g, std::plus<SparseVector<double> >(), 0);
+#else
+ g.swap(local_grad);
+#endif
+ local_grad.clear();
+ if (rank == 0) {
+ g /= minibatch_size;
+ lambdas.resize(FD::NumFeats(), 0.0); // might have seen new features
+ adagrad.update(g, &lambdas);
+ Weights::SanityCheck(lambdas);
+ Weights::ShowLargestFeatures(lambdas);
+ }
+#ifdef HAVE_MPI
+ broadcast(world, lambdas, 0);
+ broadcast(world, converged, 0);
+ world.barrier();
+ if (rank == 0) { cerr << " ELAPSED TIME THIS ITERATION=" << timer.elapsed() << endl; }
+#endif
+ }
+ cerr << "CONVERGED = " << converged << endl;
+ cerr << "EXITING...\n";
+ return 0;
+}
+
diff --git a/training/crf/mpi_batch_optimize.cc b/training/crf/mpi_batch_optimize.cc
index 2eff07e4..da1845b1 100644
--- a/training/crf/mpi_batch_optimize.cc
+++ b/training/crf/mpi_batch_optimize.cc
@@ -97,14 +97,14 @@ struct TrainingObserver : public DecoderObserver {
(*g)[it->first] = it->second.as_float();
}
- virtual void NotifyDecodingStart(const SentenceMetadata& smeta) {
+ virtual void NotifyDecodingStart(const SentenceMetadata&) {
cur_model_exp.clear();
cur_obj = 0;
state = 1;
}
// compute model expectations, denominator of objective
- virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) {
+ virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) {
assert(state == 1);
state = 2;
const prob_t z = InsideOutside<prob_t,
@@ -149,7 +149,7 @@ struct TrainingObserver : public DecoderObserver {
trg_words += smeta.GetReference().size();
}
- virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) {
+ virtual void NotifyDecodingComplete(const SentenceMetadata&) {
if (state == 3) {
++total_complete;
} else {
diff --git a/utils/Makefile.am b/utils/Makefile.am
index c5fedb78..a22b6727 100644
--- a/utils/Makefile.am
+++ b/utils/Makefile.am
@@ -7,9 +7,10 @@ noinst_PROGRAMS = \
m_test \
weights_test \
logval_test \
- small_vector_test
+ small_vector_test \
+ sv_test
-TESTS = ts small_vector_test logval_test weights_test dict_test m_test
+TESTS = ts small_vector_test logval_test weights_test dict_test m_test sv_test
noinst_LIBRARIES = libutils.a
@@ -50,7 +51,6 @@ libutils_a_SOURCES = \
sparse_vector.h \
static_utoa.h \
stringlib.h \
- swap_pod.h \
tdict.h \
timing_stats.h \
utoa.h \
@@ -101,6 +101,8 @@ logval_test_SOURCES = logval_test.cc
logval_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS)
small_vector_test_SOURCES = small_vector_test.cc
small_vector_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS)
+sv_test_SOURCES = sv_test.cc
+sv_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS)
################################################################
# do NOT NOT NOT add any other -I includes NO NO NO NO NO ######
diff --git a/utils/small_vector.h b/utils/small_vector.h
index c8a69927..280ab72c 100644
--- a/utils/small_vector.h
+++ b/utils/small_vector.h
@@ -14,7 +14,6 @@
#include <stdint.h>
#include <new>
#include <stdint.h>
-#include "swap_pod.h"
#include <boost/functional/hash.hpp>
//sizeof(T)/sizeof(T*)>1?sizeof(T)/sizeof(T*):1
@@ -278,8 +277,15 @@ public:
return !(a==b);
}
- void swap(Self& o) {
- swap_pod(*this,o);
+ inline void swap(Self& o) {
+ const unsigned s=sizeof(SmallVector<T,SV_MAX>);
+ char tmp[s];
+ void *pt=static_cast<void*>(tmp);
+ void *pa=static_cast<void*>(this);
+ void *pb=static_cast<void*>(&o);
+ std::memcpy(pt,pa,s);
+ std::memcpy(pa,pb,s);
+ std::memcpy(pb,pt,s);
}
inline std::size_t hash_impl() const {
diff --git a/utils/small_vector_test.cc b/utils/small_vector_test.cc
index cded4619..a4eb89ae 100644
--- a/utils/small_vector_test.cc
+++ b/utils/small_vector_test.cc
@@ -82,6 +82,18 @@ BOOST_AUTO_TEST_CASE(LargerThan2) {
BOOST_CHECK(cc.size() == 0);
}
+BOOST_AUTO_TEST_CASE(SwapSV) {
+ SmallVectorInt v;
+ SmallVectorInt v2(2, 10);
+ SmallVectorInt v3(2, 10);
+ BOOST_CHECK(v2 == v3);
+ BOOST_CHECK(v != v3);
+ v.swap(v2);
+ BOOST_CHECK(v == v3);
+ SmallVectorInt v4;
+ BOOST_CHECK(v4 == v2);
+}
+
BOOST_AUTO_TEST_CASE(Small) {
SmallVectorInt v;
SmallVectorInt v1(1,0);
diff --git a/utils/sv_test.cc b/utils/sv_test.cc
new file mode 100644
index 00000000..c7ac9e54
--- /dev/null
+++ b/utils/sv_test.cc
@@ -0,0 +1,24 @@
+#define BOOST_TEST_MODULE WeightsTest
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+#include "sparse_vector.h"
+
+using namespace std;
+
+BOOST_AUTO_TEST_CASE(Equality) {
+ SparseVector<double> x;
+ SparseVector<double> y;
+ x.set_value(1,-1);
+ y.set_value(1,-1);
+ BOOST_CHECK(x == y);
+}
+
+BOOST_AUTO_TEST_CASE(Division) {
+ SparseVector<double> x;
+ SparseVector<double> y;
+ x.set_value(1,1);
+ y.set_value(1,-1);
+ BOOST_CHECK(!(x == y));
+ x /= -1;
+ BOOST_CHECK(x == y);
+}
diff --git a/utils/swap_pod.h b/utils/swap_pod.h
deleted file mode 100644
index bb9a830d..00000000
--- a/utils/swap_pod.h
+++ /dev/null
@@ -1,23 +0,0 @@
-#ifndef SWAP_POD_H
-#define SWAP_POD_H
-
-//for swapping objects of the same concrete type where just swapping their bytes will work. will at least work on plain old data.
-
-#include <algorithm> // not used, but people who use this will want to bring std::swap in anyway
-#include <cstring>
-
-template <class T>
-inline void swap_pod(T &a,T &b) {
- using namespace std;
- const unsigned s=sizeof(T);
- char tmp[s];
- void *pt=(void*)tmp;
- void *pa=(void*)&a;
- void *pb=(void*)&b;
- memcpy(pt,pa,s);
- memcpy(pa,pb,s);
- memcpy(pb,pt,s);
-}
-
-
-#endif
diff --git a/utils/value_array.h b/utils/value_array.h
index 12fc9d87..e59349b5 100644
--- a/utils/value_array.h
+++ b/utils/value_array.h
@@ -1,8 +1,6 @@
#ifndef VALUE_ARRAY_H
#define VALUE_ARRAY_H
-//TODO: option for non-constructed version (type_traits pod?), option for small array optimization (if sz < N, store inline in union, see small_vector.h)
-
#define DBVALUEARRAY(x) x
#include <cstdlib>
@@ -30,8 +28,7 @@
// valarray like in that size is fixed (so saves space compared to vector), but same interface as vector (less resize/push_back/insert, of course)
template <class T, class A = std::allocator<T> >
-class ValueArray : A // private inheritance so stateless allocator adds no size.
-{
+class ValueArray : A { // private inheritance so stateless allocator adds no size.
typedef ValueArray<T,A> Self;
public:
#if VALUE_ARRAY_OSTREAM
@@ -323,14 +320,14 @@ private:
//friend class boost::serialization::access;
public:
template <class Archive>
- void save(Archive& ar, unsigned int version) const
+ void save(Archive& ar, unsigned int /*version*/) const
{
ar << sz;
for (size_type i = 0; i != sz; ++i) ar << at(i);
}
template <class Archive>
- void load(Archive& ar, unsigned int version)
+ void load(Archive& ar, unsigned int /*version*/)
{
size_type s;
ar >> s;
diff --git a/utils/weights.cc b/utils/weights.cc
index effdfc5e..1f66c441 100644
--- a/utils/weights.cc
+++ b/utils/weights.cc
@@ -127,9 +127,11 @@ void Weights::InitSparseVector(const vector<weight_t>& dv,
}
void Weights::SanityCheck(const vector<weight_t>& w) {
- for (unsigned i = 0; i < w.size(); ++i) {
- assert(!std::isnan(w[i]));
- assert(!std::isinf(w[i]));
+ for (unsigned i = 1; i < w.size(); ++i) {
+ if (std::isnan(w[i]) || std::isinf(w[i])) {
+ cerr << FD::Convert(i) << " has bad weight: " << w[i] << endl;
+ abort();
+ }
}
}
@@ -161,7 +163,7 @@ string Weights::GetString(const vector<weight_t>& w,
bool hide_zero_value_features) {
ostringstream os;
os.precision(17);
- int nf = FD::NumFeats();
+ const unsigned nf = FD::NumFeats();
for (unsigned i = 1; i < nf; i++) {
weight_t val = (i < w.size() ? w[i] : 0.0);
if (hide_zero_value_features && val == 0.0) {