summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-10-19 14:02:34 +0200
committerPatrick Simianer <p@simianer.de>2011-10-19 14:02:34 +0200
commiteb14e36d0b29f19321d44dd7dfa73cc703838d86 (patch)
tree1285e9e56959bc3a4b506e36bbc3b49f4e938fa0 /decoder
parent68f158b11df9f4072699fe6a4c8022ea54102b28 (diff)
parent04e38a57b19ea012895ac2efb39382c2e77833a9 (diff)
merge upstream/master
Diffstat (limited to 'decoder')
-rw-r--r--decoder/Makefile.am2
-rw-r--r--decoder/aligner.cc2
-rwxr-xr-xdecoder/apply_fsa_models.cc798
-rw-r--r--decoder/apply_models.cc205
-rw-r--r--decoder/apply_models.h6
-rwxr-xr-xdecoder/cdec-fsa.ini10
-rw-r--r--decoder/cdec.cc8
-rw-r--r--decoder/cdec_ff.cc16
-rwxr-xr-xdecoder/cfg.cc2
-rwxr-xr-xdecoder/cfg_format.h2
-rw-r--r--decoder/decoder.cc124
-rw-r--r--decoder/decoder.h23
-rwxr-xr-xdecoder/feature_accum.h129
-rw-r--r--decoder/ff_factory.h2
-rwxr-xr-xdecoder/ff_from_fsa.h304
-rwxr-xr-xdecoder/ff_fsa.h401
-rwxr-xr-xdecoder/ff_fsa_data.h131
-rwxr-xr-xdecoder/ff_fsa_dynamic.h208
-rw-r--r--decoder/ff_klm.cc308
-rw-r--r--decoder/ff_lm.cc48
-rwxr-xr-xdecoder/ff_lm_fsa.h140
-rwxr-xr-xdecoder/ff_register.h38
-rw-r--r--decoder/ff_source_syntax.cc232
-rw-r--r--decoder/ff_source_syntax.h41
-rw-r--r--decoder/grammar_test.cc4
-rw-r--r--decoder/hg.cc4
-rw-r--r--decoder/hg.h8
-rw-r--r--decoder/hg_test.cc16
-rwxr-xr-xdecoder/oracle_bleu.h22
-rw-r--r--decoder/rule_lexer.l2
-rw-r--r--decoder/trule.h15
31 files changed, 766 insertions, 2485 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index e5f7505f..6b9360d8 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -42,7 +42,6 @@ libcdec_a_SOURCES = \
cfg.cc \
dwarf.cc \
ff_dwarf.cc \
- apply_fsa_models.cc \
rule_lexer.cc \
fst_translator.cc \
csplit.cc \
@@ -72,6 +71,7 @@ libcdec_a_SOURCES = \
ff_wordalign.cc \
ff_csplit.cc \
ff_tagger.cc \
+ ff_source_syntax.cc \
ff_bleu.cc \
ff_factory.cc \
freqdict.cc \
diff --git a/decoder/aligner.cc b/decoder/aligner.cc
index 292ee123..53e059fb 100644
--- a/decoder/aligner.cc
+++ b/decoder/aligner.cc
@@ -165,7 +165,7 @@ inline void WriteProbGrid(const Array2D<prob_t>& m, ostream* pos) {
if (m(i,j) == prob_t::Zero()) {
os << "\t---X---";
} else {
- snprintf(b, 1024, "%0.5f", static_cast<double>(m(i,j)));
+ snprintf(b, 1024, "%0.5f", m(i,j).as_float());
os << '\t' << b;
}
}
diff --git a/decoder/apply_fsa_models.cc b/decoder/apply_fsa_models.cc
deleted file mode 100755
index 3e93cadd..00000000
--- a/decoder/apply_fsa_models.cc
+++ /dev/null
@@ -1,798 +0,0 @@
-//see apply_fsa_models.README for notes on the l2r earley fsa+cfg intersection
-//implementation in this file (also some comments in this file)
-#define SAFE_VALGRIND 1
-
-#include "apply_fsa_models.h"
-#include <stdexcept>
-#include <cassert>
-#include <queue>
-#include <stdint.h>
-
-#include "writer.h"
-#include "hg.h"
-#include "ff_fsa_dynamic.h"
-#include "ff_from_fsa.h"
-#include "feature_vector.h"
-#include "stringlib.h"
-#include "apply_models.h"
-#include "cfg.h"
-#include "hg_cfg.h"
-#include "utoa.h"
-#include "hash.h"
-#include "value_array.h"
-#include "d_ary_heap.h"
-#include "agenda.h"
-#include "show.h"
-#include "string_to.h"
-
-
-#define DFSA(x) x
-//fsa earley chart
-
-#define DPFSA(x) x
-//prefix trie
-
-#define DBUILDTRIE(x)
-
-#define PRINT_PREFIX 1
-#if PRINT_PREFIX
-# define IF_PRINT_PREFIX(x) x
-#else
-# define IF_PRINT_PREFIX(x)
-#endif
-// keep backpointers in prefix trie so you can print a meaningful node id
-
-static const unsigned FSA_AGENDA_RESERVE=10; // TODO: increase to 1<<24 (16M)
-
-using namespace std;
-
-//impl details (not exported). flat namespace for my ease.
-
-typedef CFG::RHS RHS;
-typedef CFG::BinRhs BinRhs;
-typedef CFG::NTs NTs;
-typedef CFG::NT NT;
-typedef CFG::NTHandle NTHandle;
-typedef CFG::Rules Rules;
-typedef CFG::Rule Rule;
-typedef CFG::RuleHandle RuleHandle;
-
-namespace {
-
-/*
-
-1) A -> x . * (trie)
-
-this is somewhat nice. cost pushed for best first, of course. similar benefit as left-branching binarization without the explicit predict/complete steps?
-
-vs. just
-
-2) * -> x . y
-
-here you have to potentially list out all A -> . x y as items * -> . x y immediately, and shared rhs seqs won't be shared except at the usual single-NT predict/complete. of course, the prediction of items -> . x y can occur lazy best-first.
-
-vs.
-
-3) * -> x . *
-
-with 3, we predict all sorts of useless items - that won't give us our goal A and may not partcipate in any parse. this is not a good option at all.
-
-I'm using option 1.
-*/
-
-// if we don't greedy-binarize, we want to encode recognized prefixes p (X -> p . rest) efficiently. if we're doing this, we may as well also push costs so we can best-first select rules in a lazy fashion. this is effectively left-branching binarization, of course.
-
-template <class K,class V,class Hash>
-struct fsa_map_type {
- typedef std::map<K,V> type; // change to HASH_MAP ?
-};
-//template typedef - and macro to make it less painful
-#define FSA_MAP(k,v) fsa_map_type<k,v,boost::hash<k> >::type
-
-struct PrefixTrieNode;
-typedef PrefixTrieNode *NodeP;
-typedef PrefixTrieNode const *NodePc;
-
-// for debugging prints only
-struct TrieBackP {
- WordID w;
- NodePc from;
- TrieBackP(WordID w=0,NodePc from=0) : w(w),from(from) { }
-};
-
-FsaFeatureFunction const* print_fsa=0;
-CFG const* print_cfg=0;
-inline ostream& print_cfg_rhs(std::ostream &o,WordID w,CFG const*pcfg=print_cfg) {
- if (pcfg)
- pcfg->print_rhs_name(o,w);
- else
- CFG::static_print_rhs_name(o,w);
- return o;
-}
-
-inline std::string nt_name(WordID n,CFG const*pcfg=print_cfg) {
- if (pcfg) return pcfg->nt_name(n);
- return CFG::static_nt_name(n);
-}
-
-template <class V>
-ostream& print_by_nt(std::ostream &o,V const& v,CFG const*pcfg=print_cfg,char const* header="\nNT -> X\n") {
- o<<header;
- for (int i=0;i<v.size();++i)
- o << nt_name(i,pcfg) << " -> "<<v[i]<<"\n";
- return o;
-}
-
-template <class V>
-ostream& print_map_by_nt(std::ostream &o,V const& v,CFG const*pcfg=print_cfg,char const* header="\nNT -> X\n") {
- o<<header;
- for (typename V::const_iterator i=v.begin(),e=v.end();i!=e;++i) {
- print_cfg_rhs(o,i->first,pcfg) << " -> "<<i->second<<"\n";
- }
- return o;
-}
-
-struct PrefixTrieEdge {
- PrefixTrieEdge()
- // : dest(0),w(TD::max_wordid)
- {}
- PrefixTrieEdge(WordID w,NodeP dest)
- : dest(dest),w(w)
- {}
-// explicit PrefixTrieEdge(best_t p) : p(p),dest(0) { }
-
- best_t p;// viterbi additional prob, i.e. product over path incl. p_final = total rule prob. note: for final edge, set this.
- //DPFSA()
- // we can probably just store deltas, but for debugging remember the full p
- // best_t delta; //
- NodeP dest;
- bool is_final() const { return dest==0; }
- best_t p_dest() const;
- WordID w; // for root and and is_final(), this will be (negated) NTHandle.
-
- // for sorting most probable first in adj; actually >(p)
- inline bool operator <(PrefixTrieEdge const& o) const {
- return o.p<p;
- }
- PRINT_SELF(PrefixTrieEdge)
- void print(std::ostream &o) const {
- print_cfg_rhs(o,w);
- o<<"{"<<p<<"}->"<<dest;
- }
-};
-
-//note: ending a rule is handled with a special final edge, so that possibility can be explored in best-first order along with the rest (alternative: always finish a final rule by putting it on queue). this edge has no symbol on it.
-struct PrefixTrieNode {
- best_t p; // viterbi (max prob) of rule this node leads to - when building. telescope later onto edges for best-first.
-// bool final; // may also have successors, of course. we don't really need to track this; a null dest edge in the adj list lets us encounter the fact in best first order.
- void p_delta(int next,best_t &p) const {
- p*=adj[next].p;
- }
- void inc_adj(int &next,best_t &p) const {
- p/=adj[next].p; //TODO: cache deltas
- ++next;
- p*=adj[next].p;
- }
-
-
- typedef TrieBackP BP;
- typedef std::vector<BP> BPs;
- void back_vec(BPs &ns) const {
- IF_PRINT_PREFIX(if(backp.from) { ns.push_back(backp); backp.from->back_vec(ns); })
- }
-
- BPs back_vec() const {
- BPs ret;
- back_vec(ret);
- return ret;
- }
-
- unsigned size() const {
- unsigned a=adj.size();
- unsigned e=edge_for.size();
- return a>e?a:e;
- }
-
- void print_back_str(std::ostream &o) const {
- BPs back=back_vec();
- unsigned i=back.size();
- if (!i) {
- o<<"PrefixTrieNode@"<<(uintptr_t)this;
- return;
- }
- bool first=true;
- while (i--<=0) {
- if (!first) o<<',';
- first=false;
- WordID w=back[i].w;
- print_cfg_rhs(o,w);
- }
- }
- std::string back_str() const {
- std::ostringstream o;
- print_back_str(o);
- return o.str();
- }
-
-// best_t p_final; // additional prob beyond what we already paid. while building, this is the total prob
-// instead of storing final, we'll say that an edge with a NULL dest is a final edge. this way it gets sorted into the list of adj.
-
- // instead of completed map, we have trie start w/ lhs.
- NTHandle lhs; // nonneg. - instead of storing this in Item.
- IF_PRINT_PREFIX(BP backp;)
-
- enum { ROOT=-1 };
- explicit PrefixTrieNode(NTHandle lhs=ROOT,best_t p=1) : p(p),lhs(lhs),IF_PRINT_PREFIX(backp()) {
- //final=false;
- }
- bool is_root() const { return lhs==ROOT; } // means adj are the nonneg lhs indices, and we have the index edge_for still available
-
- // outgoing edges will be ordered highest p to worst p
-
- typedef FSA_MAP(WordID,PrefixTrieEdge) PrefixTrieEdgeFor;
-public:
- PrefixTrieEdgeFor edge_for; //TODO: move builder elsewhere? then need 2nd hash or edge include pointer to builder. just clear this later
- bool have_adj() const {
- return adj.size()>=edge_for.size();
- }
- bool no_adj() const {
- return adj.empty();
- }
-
- void index_adj() {
- index_adj(edge_for);
- }
- template <class M>
- void index_adj(M &m) {
- assert(have_adj());
- m.clear();
- for (int i=0;i<adj.size();++i) {
- PrefixTrieEdge const& e=adj[i];
- SHOWM2(DPFSA,"index_adj",i,e);
- m[e.w]=e;
- }
- }
- template <class PV>
- void index_lhs(PV &v) {
- for (int i=0,e=adj.size();i!=e;++i) {
- PrefixTrieEdge const& edge=adj[i];
- // assert(edge.p.is_1()); // actually, after done_building, e will have telescoped dest->p/p.
- NTHandle n=-edge.w;
- assert(n>=0);
-// SHOWM3(DPFSA,"index_lhs",i,edge,n);
- v[n]=edge.dest;
- }
- }
-
- template <class PV>
- void done_root(PV &v) {
- assert(is_root());
- SHOWM1(DBUILDTRIE,"done_root",OSTRF1(print_map_by_nt,edge_for));
- done_building_r(); //sets adj
- SHOWM1(DBUILDTRIE,"done_root",OSTRF1(print_by_nt,adj));
-// SHOWM1(DBUILDTRIE,done_root,adj);
-// index_adj(); // we want an index for the root node?. don't think so - index_lhs handles it. also we stopped clearing edge_for.
- index_lhs(v); // uses adj
- }
-
- // call only once.
- void done_building_r() {
- done_building();
- for (int i=0;i<adj.size();++i)
- if (adj[i].dest) // skip final edge
- adj[i].dest->done_building_r();
- }
-
- // for done_building; compute incremental (telescoped) edge p
- PrefixTrieEdge /*const&*/ operator()(PrefixTrieEdgeFor::value_type & pair) const {
- PrefixTrieEdge &e=pair.second;//const_cast<PrefixTrieEdge&>(pair.second);
- e.p=e.p_dest()/p;
- return e;
- }
-
- // call only once.
- void done_building() {
- SHOWM3(DBUILDTRIE,"done_building",edge_for.size(),adj.size(),1);
-#if 1
- adj.reinit_map(edge_for,*this);
-#else
- adj.reinit(edge_for.size());
- SHOWM3(DBUILDTRIE,"done_building_reinit",edge_for.size(),adj.size(),2);
- Adj::iterator o=adj.begin();
- for (PrefixTrieEdgeFor::iterator i=edge_for.begin(),e=edge_for.end();i!=e;++i) {
- SHOWM3(DBUILDTRIE,"edge_for",o-adj.begin(),i->first,i->second);
- PrefixTrieEdge &edge=i->second;
- edge.p=(edge.dest->p)/p;
- *o++=edge;
-// (*this)(*i);
- }
-#endif
- SHOWM1(DBUILDTRIE,"done building adj",prange(adj.begin(),adj.end(),true));
- assert(adj.size()==edge_for.size());
-// if (final) p_final/=p;
- std::sort(adj.begin(),adj.end());
- //TODO: store adjacent differences on edges (compared to
- }
-
- typedef ValueArray<PrefixTrieEdge> Adj;
-// typedef vector<PrefixTrieEdge> Adj;
- Adj adj;
-
- typedef WordID W;
-
- // let's compute p_min so that every rule reachable from the created node has p at least this low.
- NodeP improve_edge(PrefixTrieEdge const& e,best_t rulep) {
- NodeP d=e.dest;
- maybe_improve(d->p,rulep);
- return d;
- }
-
- inline NodeP build(W w,best_t rulep) {
- return build(lhs,w,rulep);
- }
- inline NodeP build_lhs(NTHandle n,best_t rulep) {
- return build(n,-n,rulep);
- }
-
- NodeP build(NTHandle lhs_,W w,best_t rulep) {
- PrefixTrieEdgeFor::iterator i=edge_for.find(w);
- if (i!=edge_for.end())
- return improve_edge(i->second,rulep);
- NodeP r=new PrefixTrieNode(lhs_,rulep);
- IF_PRINT_PREFIX(r->backp=BP(w,this));
-// edge_for.insert(i,PrefixTrieEdgeFor::value_type(w,PrefixTrieEdge(w,r)));
- add(edge_for,w,PrefixTrieEdge(w,r));
- SHOWM4(DBUILDTRIE,"built node",this,w,*r,r);
- return r;
- }
-
- void set_final(NTHandle lhs_,best_t pf) {
- assert(no_adj());
-// final=true;
- PrefixTrieEdge &e=edge_for[null_wordid];
- e.p=pf;
- e.dest=0;
- e.w=lhs_;
- maybe_improve(p,pf);
- }
-
-private:
- void destroy_children() {
- assert(adj.size()>=edge_for.size());
- for (int i=0,e=adj.size();i<e;++i) {
- NodeP c=adj[i].dest;
- if (c) { // final state has no end
- delete c;
- }
- }
- }
-public:
- ~PrefixTrieNode() {
- destroy_children();
- }
- void print(std::ostream &o) const {
- o << "Node"<<this<< ": "<<lhs << "->" << p;
- o << ',' << size() << ',';
- print_back_str(o);
- }
- PRINT_SELF(PrefixTrieNode)
-};
-
-inline best_t PrefixTrieEdge::p_dest() const {
- return dest ? dest->p : p; // for final edge, p was set (no sentinel node)
-}
-
-
-//Trie starts with lhs (nonneg index), then continues w/ rhs (mixed >0 word, else NT)
-// trie ends with final edge, which points to a per-lhs prefix node
-struct PrefixTrie {
- void print(std::ostream &o) const {
- o << cfgp << ' ' << root;
- }
- PRINT_SELF(PrefixTrie);
- CFG *cfgp;
- Rules const* rulesp;
- Rules const& rules() const { return *rulesp; }
- CFG const& cfg() const { return *cfgp; }
- PrefixTrieNode root;
- typedef std::vector<NodeP> LhsToTrie; // will have to check lhs2[lhs].p for best cost of some rule with that lhs, then use edge deltas after? they're just caching a very cheap computation, really
- LhsToTrie lhs2; // no reason to use a map or hash table; every NT in the CFG will have some rule rhses. lhs_to_trie[i]=root.edge_for[i], i.e. we still have a root trie node conceptually, we just access through this since it's faster.
- typedef LhsToTrie LhsToComplete;
- LhsToComplete lhs2complete; // the sentinel "we're completing" node (dot at end) for that lhs. special case of suffix-set=same trie minimization (aka right branching binarization) // these will be used to track kbest completions, along with a l state (r state will be in the list)
- PrefixTrie(CFG &cfg) : cfgp(&cfg),rulesp(&cfg.rules),lhs2(cfg.nts.size(),0),lhs2complete(cfg.nts.size()) {
-// cfg.SortLocalBestFirst(); // instead we'll sort in done_building_r
- print_cfg=cfgp;
- SHOWM2(DBUILDTRIE,"PrefixTrie()",rulesp->size(),lhs2.size());
- cfg.VisitRuleIds(*this);
- root.done_root(lhs2);
- SHOWM3(DBUILDTRIE,"done w/ PrefixTrie: ",root,root.adj.size(),lhs2.size());
- DBUILDTRIE(print_by_nt(cerr,lhs2,cfgp));
- SHOWM1(DBUILDTRIE,"lhs2",OSTRF2(print_by_nt,lhs2,cfgp));
- }
-
- void operator()(int ri) {
- Rule const& r=rules()[ri];
- NTHandle lhs=r.lhs;
- best_t p=r.p;
-// NodeP n=const_cast<PrefixTrieNode&>(root).build_lhs(lhs,p);
- NodeP n=root.build_lhs(lhs,p);
- SHOWM4(DBUILDTRIE,"Prefixtrie rule id, root",ri,root,p,*n);
- for (RHS::const_iterator i=r.rhs.begin(),e=r.rhs.end();;++i) {
- SHOWM2(DBUILDTRIE,"PrefixTrie build or final",i-r.rhs.begin(),*n);
- if (i==e) {
- n->set_final(lhs,p);
- break;
- }
- n=n->build(*i,p);
- SHOWM2(DBUILDTRIE,"PrefixTrie built",*i,*n);
- }
-// root.build(lhs,r.p)->build(r.rhs,r.p);
- }
- inline NodeP lhs2_ex(NTHandle n) const {
- NodeP r=lhs2[n];
- if (!r) throw std::runtime_error("PrefixTrie: no CFG rule w/ lhs "+cfgp->nt_name(n));
- return r;
- }
-private:
- PrefixTrie(PrefixTrie const& o);
-};
-
-
-
-typedef std::size_t ItemHash;
-
-
-struct ItemKey {
- explicit ItemKey(NodeP start,Bytes const& start_state) : dot(start),q(start_state),r(start_state) { }
- explicit ItemKey(NodeP dot) : dot(dot) { }
- NodeP dot; // dot is a function of the stuff already recognized, and gives a set of suffixes y to complete to finish a rhs for lhs() -> dot y. for a lhs A -> . *, this will point to lh2[A]
- Bytes q,r; // (q->r are the fsa states; if r is empty it means
- bool operator==(ItemKey const& o) const {
- return dot==o.dot && q==o.q && r==o.r;
- }
- inline ItemHash hash() const {
- ItemHash h=GOLDEN_MEAN_FRACTION*(ItemHash)(dot-NULL); // i.e. lower order bits of ptr are nonrandom
- using namespace boost;
- hash_combine(h,q);
- hash_combine(h,r);
- return h;
- }
- template<class O>
- void print(O &o) const {
- o<<"lhs="<<lhs();
- if (dot)
- dot->print_back_str(o);
- if (print_fsa) {
- o<<'/';
- print_fsa->print_state(o,&q[0]);
- o<<"->";
- print_fsa->print_state(o,&r[0]);
- }
- }
- NTHandle lhs() const { return dot->lhs; }
- PRINT_SELF(ItemKey)
-};
-inline ItemHash hash_value(ItemKey const& x) {
- return x.hash();
-}
-ItemKey null_item((PrefixTrieNode*)0);
-
-struct Item;
-typedef Item *ItemP;
-
-/* we use a single type of item so it can live in a single best-first queue. we hold them by pointer so they can have mutable state, e.g. priority/location, but also lists of predictions and kbest completions (i.e. completions[L,r] = L -> * (r,s), by 1best for each possible s. we may discover more s later. we could use different subtypes since we hold by pointer, but for now everything will be packed as variants of Item */
-#undef INIT_LOCATION
-#if D_ARY_TRACK_OUT_OF_HEAP
-# define INIT_LOCATION , location(D_ARY_HEAP_NULL_INDEX)
-#elif !defined(NDEBUG) || SAFE_VALGRIND
- // avoid spurious valgrind warning - FIXME: still complains???
-# define INIT_LOCATION , location()
-#else
-# define INIT_LOCATION
-#endif
-
-// these should go in a global best-first queue
-struct ItemPrio {
- // NOTE: sum = viterbi (max)
- ItemPrio() : priority(init_0()),inner(init_0()) { }
- explicit ItemPrio(best_t priority) : priority(priority),inner(init_0()) { }
- best_t priority; // includes inner prob. (forward)
- /* The forward probability alpha_i(X[k]->x.y) is the sum of the probabilities of all
- constrained paths of length i that end in state X[k]->x.y*/
- best_t inner;
- /* The inner probability beta_i(X[k]->x.y) is the sum of the probabilities of all
- paths of length i-k that start in state X[k,k]->.xy and end in X[k,i]->x.y, and generate the input symbols x[k,...,i-1] */
- template<class O>
- void print(O &o) const {
- o<<priority; // TODO: show/use inner?
- }
- typedef ItemPrio self_type;
- SELF_TYPE_PRINT
-};
-
-#define ITEM_TYPE(X,t) \
- X(t,ADJ,=-1) \
-
-#define ITEM_TYPE_TYPE ItemType
-
-DECLARE_NAMED_ENUM(ITEM_TYPE)
-DEFINE_NAMED_ENUM(ITEM_TYPE)
-
-struct Item : ItemPrio,ItemKey {
-/* explicit Item(NodeP dot,best_t prio,int next) : ItemPrio(prio),ItemKey(dot),trienext(next),from(0)
- INIT_LOCATION
- { }*/
-// ItemType t;
- // lazy queueing of succesors item:
- bool is_trie_adj() const {
- return trienext>=0;
- }
- explicit Item(FFState const& state,NodeP dot,best_t prio,int next=0) : ItemPrio(prio),ItemKey(dot,state),trienext(next),from(0)
- INIT_LOCATION
- {
-// t=ADJ;
-// if (dot->adj.size())
- dot->p_delta(next,priority);
-// SHOWM1(DFSA,"Item(state,dot,prio)",prio);
- }
- typedef std::queue<ItemP> Predicted;
-// Predicted predicted; // this is empty, unless this is a predicted L -> .asdf item, or a to-complete L -> asdf .
- int trienext; // index of dot->adj to complete (if dest==0), or predict (if NT), or scan (if word). note: we could store pointer inside adj since it and trie are @ fixed addrs. less pointer arith, more space.
- ItemP from; //backpointer - 0 for L -> . asdf for the rest; L -> a .sdf, it's the L -> .asdf item.
- ItemP predicted_from() const {
- ItemP p=(ItemP)this;
- while(p->from) p=p->from;
- return p;
- }
- template<class O>
- void print(O &o) const {
- o<< '[';
- o<<this<<": ";
- ItemKey::print(o);
- o<<' ';
- ItemPrio::print(o);
- o<<" next="<<trienext;
- o<< ']';
- }
- PRINT_SELF(Item)
- unsigned location;
-};
-
-struct GetItemKey {
- typedef Item argument_type;
- typedef ItemKey result_type;
- result_type const& operator()(Item const& i) const { return i; }
- template <class T>
- T const& operator()(T const& t) const { return t; }
-};
-
-/* here's what i imagine (best first):
- all of these are looked up in a chart which includes the fsa states as part of the identity
-
- perhaps some items are ephemeral and never reused (e.g. edge items of a cube, where we delay traversing trie based on probabilities), but in all ohter cases we make entirely new objects derived from the original one (memoizing). let's ignore lazier edge items for now and always push all successors onto heap.
-
- initial item (predicted): GOAL_NT -> . * (trie root for that lhs), start, start (fsa start states). has a list of
-
- completing item ( L -> * . needs to queue all the completions immediately. when predicting before a completion happens, add to prediction list. after it happens, immediately use the completed bests. this is confusing to me: the completions for an original NT w/ a given r state may end up in several different ones. we don't only care about the 1 best cost r item but all the different r.
-
- the prediction's left/right uses the predictor's right
-
- */
-template <class FsaFF=FsaFeatureFunction>
-struct Chart {
- //typedef HASH_MAP<Item,ItemP,boost::hash<Item> > Items;
- //typedef Items::iterator FindItem;
- //typedef std::pair<FindItem,bool> InsertItem;
-// Items items;
- CFG &cfg; // TODO: remove this from Chart
- SentenceMetadata const& smeta;
- FsaFF const& fsa;
- NTHandle goal_nt;
- PrefixTrie trie;
- typedef Agenda<Item,BetterP,GetItemKey> A;
- A a;
-
- /* had to stop working on this for now - it's garbage/useless in this form - see NOTES.earley */
-
- // p_partial is priority*p(rule) - excluding the FSA model score, and predicted
- void succ(Item const& from,int adji,best_t p_partial) {
- PrefixTrieEdge const& te=from.dot->adj[adji];
- NodeP dest=te.dest;
- if (te.is_final()) {
- // complete
- return;
- }
- WordID w=te.w;
- if (w<0) {
- NTHandle lhs=-w;
- } else {
-
- }
- }
-
- void extend1() {
- BetterP better;
- Item &t=*a.top();
- best_t tp=t.priority;
- if (t.is_trie_adj()) {
- best_t pstop=a.second_best(); // remember; best_t a<b means a better than (higher prob) than b
-// NodeP d=t.dot;
- PrefixTrieNode::Adj const& adj=t.dot->adj;
- int n=t.trienext,m=adj.size();
- SHOWM3(DFSA,"popped",t,tp,pstop);
- for (;n<m;++n) { // cube corner
- PrefixTrieEdge const& te=adj[n];
- SHOWM3(DFSA,"maybe try trie next",n,te.p,pstop);
- if (better(te.p,pstop)) { // can get some improvement
- SHOWM2(DFSA,"trying adj ",m,te);
- } else {
- goto done;
- }
- }
- a.pop();
- done:;
- }
- }
-
- void best_first(unsigned kbest=1) {
- assert(kbest==1); //TODO: k-best via best-first requires revisiting best things again and adjusting desc. tricky.
- while(!a.empty()) {
- extend1();
- }
- }
-
- Chart(CFG &cfg,SentenceMetadata const& smeta,FsaFF const& fsa,unsigned reserve=FSA_AGENDA_RESERVE)
- : cfg(cfg),smeta(smeta),fsa(fsa),trie(cfg),a(reserve) {
- assert(fsa.state_bytes());
- print_fsa=&fsa;
- goal_nt=cfg.goal_nt;
- best_t prio=init_1();
- SHOW1(DFSA,prio);
- a.add(a.construct(fsa.start,trie.lhs2_ex(goal_nt),prio));
- }
-};
-
-
-}//anon ns
-
-
-DEFINE_NAMED_ENUM(FSA_BY)
-
-template <class FsaFF=FsaFeatureFunction>
-struct ApplyFsa {
- ApplyFsa(HgCFG &i,
- const SentenceMetadata& smeta,
- const FsaFeatureFunction& fsa,
- DenseWeightVector const& weights,
- ApplyFsaBy const& by,
- Hypergraph* oh
- )
- :hgcfg(i),smeta(smeta),fsa(fsa),weights(weights),by(by),oh(oh)
- {
- stateless=!fsa.state_bytes();
- }
- void Compute() {
- if (by.IsBottomUp() || stateless)
- ApplyBottomUp();
- else
- ApplyEarley();
- }
- void ApplyBottomUp();
- void ApplyEarley();
- CFG const& GetCFG();
-private:
- CFG cfg;
- HgCFG &hgcfg;
- SentenceMetadata const& smeta;
- FsaFF const& fsa;
-// WeightVector weight_vector;
- DenseWeightVector weights;
- ApplyFsaBy by;
- Hypergraph* oh;
- std::string cfg_out;
- bool stateless;
-};
-
-template <class F>
-void ApplyFsa<F>::ApplyBottomUp()
-{
- assert(by.IsBottomUp());
- FeatureFunctionFromFsa<FsaFeatureFunctionFwd> buff(&fsa);
- buff.Init(); // mandatory to call this (normally factory would do it)
- vector<const FeatureFunction*> ffs(1,&buff);
- ModelSet models(weights, ffs);
- IntersectionConfiguration i(stateless ? BU_FULL : by.BottomUpAlgorithm(),by.pop_limit);
- ApplyModelSet(hgcfg.ih,smeta,models,i,oh);
-}
-
-template <class F>
-void ApplyFsa<F>::ApplyEarley()
-{
- hgcfg.GiveCFG(cfg);
- print_cfg=&cfg;
- print_fsa=&fsa;
- Chart<F> chart(cfg,smeta,fsa);
- // don't need to uniq - option to do that already exists in cfg_options
- //TODO:
- chart.best_first();
- *oh=hgcfg.ih;
-}
-
-
-void ApplyFsaModels(HgCFG &i,
- const SentenceMetadata& smeta,
- const FsaFeatureFunction& fsa,
- DenseWeightVector const& weight_vector,
- ApplyFsaBy const& by,
- Hypergraph* oh)
-{
- ApplyFsa<FsaFeatureFunction> a(i,smeta,fsa,weight_vector,by,oh);
- a.Compute();
-}
-
-/*
-namespace {
-char const* anames[]={
- "BU_CUBE",
- "BU_FULL",
- "EARLEY",
- 0
-};
-}
-*/
-
-//TODO: named enum type in boost?
-
-std::string ApplyFsaBy::name() const {
-// return anames[algorithm];
- return GetName(algorithm);
-}
-
-std::string ApplyFsaBy::all_names() {
- return FsaByNames(" ");
- /*
- std::ostringstream o;
- for (int i=0;i<N_ALGORITHMS;++i) {
- assert(anames[i]);
- if (i) o<<' ';
- o<<anames[i];
- }
- return o.str();
- */
-}
-
-ApplyFsaBy::ApplyFsaBy(std::string const& n, int pop_limit) : pop_limit(pop_limit) {
- std::string uname=toupper(n);
- algorithm=GetFsaBy(uname);
-/*anames=0;
- while(anames[algorithm] && anames[algorithm] != uname) ++algorithm;
- if (!anames[algorithm])
- throw std::runtime_error("Unknown ApplyFsaBy type: "+n+" - legal types: "+all_names());
-*/
-}
-
-ApplyFsaBy::ApplyFsaBy(FsaBy i, int pop_limit) : pop_limit(pop_limit) {
-/* if (i<0 || i>=N_ALGORITHMS)
- throw std::runtime_error("Unknown ApplyFsaBy type id: "+itos(i)+" - legal types: "+all_names());
-*/
- GetName(i); // checks validity
- algorithm=i;
-}
-
-int ApplyFsaBy::BottomUpAlgorithm() const {
- assert(IsBottomUp());
- return algorithm==BU_CUBE ?
- IntersectionConfiguration::CUBE
- :IntersectionConfiguration::FULL;
-}
-
-void ApplyFsaModels(Hypergraph const& ih,
- const SentenceMetadata& smeta,
- const FsaFeatureFunction& fsa,
- DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this)
- ApplyFsaBy const& cfg,
- Hypergraph* out)
-{
- HgCFG i(ih);
- ApplyFsaModels(i,smeta,fsa,weights,cfg,out);
-}
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index 9390c809..40fd27e4 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -17,6 +17,10 @@
#include "hg.h"
#include "ff.h"
+#define NORMAL_CP 1
+#define FAST_CP 2
+#define FAST_CP_2 3
+
using namespace std;
using namespace std::tr1;
@@ -164,13 +168,15 @@ public:
const SentenceMetadata& sm,
const Hypergraph& i,
int pop_limit,
- Hypergraph* o) :
+ Hypergraph* o,
+ int s = NORMAL_CP ) :
models(m),
smeta(sm),
in(i),
out(*o),
D(in.nodes_.size()),
- pop_limit_(pop_limit) {
+ pop_limit_(pop_limit),
+ strategy_(s){
if (!SILENT) cerr << " Applying feature functions (cube pruning, pop_limit = " << pop_limit_ << ')' << endl;
node_states_.reserve(kRESERVE_NUM_NODES);
}
@@ -184,9 +190,21 @@ public:
if (num_nodes > 100) every = 10;
assert(in.nodes_[pregoal].out_edges_.size() == 1);
if (!SILENT) cerr << " ";
+ int has = 0;
for (int i = 0; i < in.nodes_.size(); ++i) {
- if (!SILENT && i % every == 0) cerr << '.';
- KBest(i, i == goal_id);
+ if (!SILENT) {
+ int needs = (50 * i / in.nodes_.size());
+ while (has < needs) { cerr << '.'; ++has; }
+ }
+ if (strategy_==NORMAL_CP){
+ KBest(i, i == goal_id);
+ }
+ if (strategy_==FAST_CP){
+ KBestFast(i, i == goal_id);
+ }
+ if (strategy_==FAST_CP_2){
+ KBestFast2(i, i == goal_id);
+ }
}
if (!SILENT) {
cerr << endl;
@@ -258,8 +276,7 @@ public:
make_heap(cand.begin(), cand.end(), HeapCandCompare());
State2Node state2node; // "buf" in Figure 2
int pops = 0;
- int pop_limit_eff=max(1,int(v.promise*pop_limit_));
- while(!cand.empty() && pops < pop_limit_eff) {
+ while(!cand.empty() && pops < pop_limit_) {
pop_heap(cand.begin(), cand.end(), HeapCandCompare());
Candidate* item = cand.back();
cand.pop_back();
@@ -283,6 +300,114 @@ public:
delete freelist[i];
}
+ void KBestFast(const int vert_index, const bool is_goal) {
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ //init with j<0,0> for all rules-edges that lead to node-(NT-span)
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal));
+ }
+ // cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ // cerr << "POPPED: " << *item << endl;
+
+ PushSuccFast(*item, is_goal, &cand);
+ IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){
+ D_v[c++] = i->second;
+ // cerr << "MERGED: " << *i->second << endl;
+ }
+ //cerr <<"Node id: "<< vert_index<< endl;
+ //#ifdef MEASURE_CA
+ // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl;
+ // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl;
+ //#endif
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
+ void KBestFast2(const int vert_index, const bool is_goal) {
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ UniqueCandidateSet unique_accepted;
+ //init with j<0,0> for all rules-edges that lead to node-(NT-span)
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal));
+ }
+ // cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ assert(unique_accepted.insert(item).second); // these should all be unique!
+ // cerr << "POPPED: " << *item << endl;
+
+ PushSuccFast2(*item, is_goal, &cand, &unique_accepted);
+ IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){
+ D_v[c++] = i->second;
+ // cerr << "MERGED: " << *i->second << endl;
+ }
+ //cerr <<"Node id: "<< vert_index<< endl;
+ //#ifdef MEASURE_CA
+ // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl;
+ // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl;
+ //#endif
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) {
CandidateHeap& cand = *pcand;
for (int i = 0; i < item.j_.size(); ++i) {
@@ -300,6 +425,54 @@ public:
}
}
+ //PushSucc following unique ancestor generation function
+ void PushSuccFast(const Candidate& item, const bool is_goal, CandidateHeap* pcand){
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ }
+ if(item.j_[i]!=0){
+ return;
+ }
+ }
+ }
+
+ //PushSucc only if all ancest Cand are added
+ void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate query_unique(*item.in_edge_, j);
+ if (HasAllAncestors(&query_unique,ps)) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ }
+ }
+ }
+ }
+
+ bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){
+ for (int i = 0; i < item->j_.size(); ++i) {
+ JVector j = item->j_;
+ --j[i];
+ if (j[i] >=0) {
+ Candidate query_unique(*item->in_edge_, j);
+ if (cs->count(&query_unique) == 0) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
const ModelSet& models;
const SentenceMetadata& smeta;
const Hypergraph& in;
@@ -311,6 +484,7 @@ public:
FFStates node_states_; // for each node in the out-HG what is
// its q function value?
const int pop_limit_;
+ const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010)
};
struct NoPruningRescorer {
@@ -412,15 +586,28 @@ void ApplyModelSet(const Hypergraph& in,
if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) {
NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state
ma.Apply();
- } else if (config.algorithm == IntersectionConfiguration::CUBE) {
+ } else if (config.algorithm == IntersectionConfiguration::CUBE
+ || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING
+ || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) {
int pl = config.pop_limit;
const int max_pl_for_large=50;
if (pl > max_pl_for_large && in.nodes_.size() > 80000) {
pl = max_pl_for_large;
cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n";
}
- CubePruningRescorer ma(models, smeta, in, pl, out);
- ma.Apply();
+ if (config.algorithm == IntersectionConfiguration::CUBE) {
+ CubePruningRescorer ma(models, smeta, in, pl, out);
+ ma.Apply();
+ }
+ else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){
+ CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP);
+ ma.Apply();
+ }
+ else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){
+ CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2);
+ ma.Apply();
+ }
+
} else {
cerr << "Don't understand intersection algorithm " << config.algorithm << endl;
exit(1);
diff --git a/decoder/apply_models.h b/decoder/apply_models.h
index a85694aa..19a4c7be 100644
--- a/decoder/apply_models.h
+++ b/decoder/apply_models.h
@@ -13,6 +13,8 @@ struct IntersectionConfiguration {
enum {
FULL,
CUBE,
+ FAST_CUBE_PRUNING,
+ FAST_CUBE_PRUNING_2,
N_ALGORITHMS
};
@@ -25,7 +27,9 @@ enum {
inline std::ostream& operator<<(std::ostream& os, const IntersectionConfiguration& c) {
if (c.algorithm == 0) { os << "FULL"; }
else if (c.algorithm == 1) { os << "CUBE:k=" << c.pop_limit; }
- else if (c.algorithm == 2) { os << "N_ALGORITHMS"; }
+ else if (c.algorithm == 2) { os << "FAST_CUBE_PRUNING"; }
+ else if (c.algorithm == 3) { os << "FAST_CUBE_PRUNING_2"; }
+ else if (c.algorithm == 4) { os << "N_ALGORITHMS"; }
else os << "OTHER";
return os;
}
diff --git a/decoder/cdec-fsa.ini b/decoder/cdec-fsa.ini
deleted file mode 100755
index 05aaefd4..00000000
--- a/decoder/cdec-fsa.ini
+++ /dev/null
@@ -1,10 +0,0 @@
-cubepruning_pop_limit=200
-feature_function=WordPenalty
-feature_function=ArityPenalty
-feature_function=WordPenaltyFsa
-#feature_function=LongerThanPrev
-feature_function=ShorterThanPrev debug
-add_pass_through_rules=true
-formalism=scfg
-grammar=mt09.grammar.gz
-weights=weights-fsa
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index 5c40f56e..c671af57 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -19,11 +19,19 @@ int main(int argc, char** argv) {
assert(*in);
string buf;
+#ifdef CP_TIME
+ clock_t time_cp(0);//, end_cp;
+#endif
while(*in) {
getline(*in, buf);
if (buf.empty()) continue;
decoder.Decode(buf);
}
+#ifdef CP_TIME
+ cerr << "Time required for Cube Pruning execution: "
+ << CpTime::Get()
+ << " seconds." << "\n\n";
+#endif
if (show_feature_dictionary) {
int num = FD::NumFeats();
for (int i = 1; i < num; ++i) {
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 588842f1..4ce5749e 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -12,8 +12,7 @@
#include "ff_rules.h"
#include "ff_ruleshape.h"
#include "ff_bleu.h"
-#include "ff_lm_fsa.h"
-#include "ff_sample_fsa.h"
+#include "ff_source_syntax.h"
#include "ff_register.h"
#include "ff_charset.h"
#include "ff_wordset.h"
@@ -30,15 +29,6 @@ void register_feature_functions() {
}
registered = true;
- //TODO: these are worthless example target FSA ffs. remove later
- RegisterFsaImpl<SameFirstLetter>(true);
- RegisterFsaImpl<LongerThanPrev>(true);
- RegisterFsaImpl<ShorterThanPrev>(true);
-// ff_registry.Register("LanguageModelFsaDynamic",new FFFactory<FeatureFunctionFromFsa<FsaFeatureFunctionDynamic<LanguageModelFsa> > >); // to test correctness of FsaFeatureFunctionDynamic erasure
- RegisterFsaDynToFF<LanguageModelFsa>();
- RegisterFsaImpl<LanguageModelFsa>(true); // same as LM but using fsa wrapper
- RegisterFsaDynToFF<SameFirstLetter>();
-
RegisterFF<LanguageModel>();
RegisterFF<WordPenalty>();
@@ -46,8 +36,6 @@ void register_feature_functions() {
RegisterFF<ArityPenalty>();
RegisterFF<BLEUModel>();
- ff_registry.Register(new FFFactory<WordPenaltyFromFsa>); // same as WordPenalty, but implemented using ff_fsa
-
//TODO: use for all features the new Register which requires static FF::usage(false,false) give name
#ifdef HAVE_RANDLM
ff_registry.Register("RandLM", new FFFactory<LanguageModelRandLM>);
@@ -55,6 +43,8 @@ void register_feature_functions() {
ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>());
ff_registry.Register("NgramFeatures", new FFFactory<NgramDetector>());
ff_registry.Register("RuleIdentityFeatures", new FFFactory<RuleIdentityFeatures>());
+ ff_registry.Register("SourceSyntaxFeatures", new FFFactory<SourceSyntaxFeatures>);
+ ff_registry.Register("SourceSpanSizeFeatures", new FFFactory<SourceSpanSizeFeatures>);
ff_registry.Register("RuleNgramFeatures", new FFFactory<RuleNgramFeatures>());
ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory<CMR2008ReorderingFeatures>());
ff_registry.Register("KLanguageModel", new KLanguageModelFactory());
diff --git a/decoder/cfg.cc b/decoder/cfg.cc
index 651978d2..cd7e66e9 100755
--- a/decoder/cfg.cc
+++ b/decoder/cfg.cc
@@ -639,7 +639,7 @@ void CFG::Print(std::ostream &o,CFGFormat const& f) const {
o << '['<<f.goal_nt_name <<']';
WordID rhs=-goal_nt;
f.print_rhs(o,*this,&rhs,&rhs+1);
- if (pushed_inside!=1)
+ if (pushed_inside!=prob_t::One())
f.print_features(o,pushed_inside);
o<<'\n';
}
diff --git a/decoder/cfg_format.h b/decoder/cfg_format.h
index c6a594b8..2f40d483 100755
--- a/decoder/cfg_format.h
+++ b/decoder/cfg_format.h
@@ -101,7 +101,7 @@ struct CFGFormat {
}
void print_features(std::ostream &o,prob_t p,FeatureVector const& fv=FeatureVector()) const {
- bool logp=(logprob_feat && p!=1);
+ bool logp=(logprob_feat && p!=prob_t::One());
if (features || logp) {
o << partsep;
if (logp)
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 55d9f1d7..b93925d1 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -46,6 +46,13 @@
#include "cfg_options.h"
#endif
+#ifdef CP_TIME
+ clock_t CpTime::time_;
+ void CpTime::Add(clock_t x){time_+=x;}
+ void CpTime::Sub(clock_t x){time_-=x;}
+ double CpTime::Get(){return (double)(time_)/CLOCKS_PER_SEC;}
+#endif
+
static const double kMINUS_EPSILON = -1e-6; // don't be too strict
using namespace std;
@@ -152,8 +159,7 @@ struct RescoringPass {
shared_ptr<ModelSet> models;
shared_ptr<IntersectionConfiguration> inter_conf;
vector<const FeatureFunction*> ffs;
- shared_ptr<Weights> w; // null == use previous weights
- vector<double> weight_vector;
+ shared_ptr<vector<weight_t> > weight_vector;
int fid_summary; // 0 == no summary feature
double density_prune; // 0 == don't density prune
double beam_prune; // 0 == don't beam prune
@@ -162,7 +168,7 @@ struct RescoringPass {
ostream& operator<<(ostream& os, const RescoringPass& rp) {
os << "[num_fn=" << rp.ffs.size();
if (rp.inter_conf) { os << " int_alg=" << *rp.inter_conf; }
- if (rp.w) os << " new_weights";
+ //if (rp.weight_vector.size() > 0) os << " new_weights";
if (rp.fid_summary) os << " summary_feature=" << FD::Convert(rp.fid_summary);
if (rp.density_prune) os << " density_prune=" << rp.density_prune;
if (rp.beam_prune) os << " beam_prune=" << rp.beam_prune;
@@ -174,13 +180,8 @@ struct DecoderImpl {
DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg);
~DecoderImpl();
bool Decode(const string& input, DecoderObserver*);
- void SetWeights(const vector<double>& weights) {
- init_weights = weights;
- for (int i = 0; i < rescoring_passes.size(); ++i) {
- if (rescoring_passes[i].models)
- rescoring_passes[i].models->SetWeights(weights);
- rescoring_passes[i].weight_vector = weights;
- }
+ vector<weight_t>& CurrentWeightVector() {
+ return (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector);
}
void SetId(int next_sent_id) { sent_id = next_sent_id - 1; }
@@ -293,8 +294,7 @@ struct DecoderImpl {
OracleBleu oracle;
string formalism;
shared_ptr<Translator> translator;
- Weights w_init_weights; // used with initial parse
- vector<double> init_weights; // weights used with initial parse
+ shared_ptr<vector<weight_t> > init_weights; // weights used with initial parse
vector<shared_ptr<FeatureFunction> > pffs;
#ifdef FSA_RESCORING
CFGOptions cfg_options;
@@ -321,10 +321,11 @@ struct DecoderImpl {
bool write_gradient; // TODO Observer
bool feature_expectations; // TODO Observer
bool output_training_vector; // TODO Observer
+ bool remove_intersected_rule_annotations;
static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {
for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it)
- trg->set_value(it->first, it->second);
+ trg->set_value(it->first, it->second.as_float());
}
};
@@ -354,10 +355,13 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
("per_sentence_grammar_file", po::value<string>(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset")
("list_feature_functions,L","List available feature functions")
+#ifdef HAVE_CMPH
+ ("cmph_perfect_feature_hash,h", po::value<string>(), "Load perfect hash function for features")
+#endif
("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)")
("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)")
- ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full")
+ ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full, Fast_cube_pruning, Fast_cube_pruning_2")
("summary_feature", po::value<string>(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z")
("summary_feature_type", po::value<string>()->default_value("node_risk"), "Summary feature types: node_risk, edge_risk, edge_prob")
("density_prune", po::value<double>(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
@@ -416,6 +420,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("csplit_output_plf", "(Compound splitter) Output lattice in PLF format")
("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file")
+ ("show_derivations", po::value<string>(), "Directory to print the derivation structures to")
("graphviz","Show (constrained) translation forest in GraphViz format")
("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
@@ -425,7 +430,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
- ("forest_output,O",po::value<string>(),"Directory to write forests to");
+ ("forest_output,O",po::value<string>(),"Directory to write forests to")
+ ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)");
+
// ob.AddOptions(&opts);
#ifdef FSA_RESCORING
po::options_description cfgo(cfg_options.description());
@@ -434,7 +441,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
po::options_description clo("Command line options");
clo.add_options()
("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority")
- ("help,h", "Print this help message and exit")
+ ("help,?", "Print this help message and exit")
("usage,u", po::value<string>(), "Describe a feature function type")
("compgen", "Print just option names suitable for bash command line completion builtin 'compgen'")
;
@@ -543,13 +550,18 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
exit(1);
}
- // load initial feature weights (and possibly freeze feature set)
- if (conf.count("weights")) {
- w_init_weights.InitFromFile(str("weights",conf));
- w_init_weights.InitVector(&init_weights);
- init_weights.resize(FD::NumFeats());
+ // load perfect hash function for features
+ if (conf.count("cmph_perfect_feature_hash")) {
+ cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as<string>() << " ...\n";
+ FD::EnableHash(conf["cmph_perfect_feature_hash"].as<string>());
+ cerr << " " << FD::NumFeats() << " features in map\n";
}
+ // load initial feature weights (and possibly freeze feature set)
+ init_weights.reset(new vector<weight_t>);
+ if (conf.count("weights"))
+ Weights::InitFromFile(str("weights",conf), init_weights.get());
+
// cube pruning pop-limit: we may want to configure this on a per-pass basis
pop_limit = conf["cubepruning_pop_limit"].as<int>();
@@ -568,9 +580,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
RescoringPass& rp = rescoring_passes.back();
// only configure new weights if pass > 0, otherwise we reuse the initial chart weights
if (nth_pass_condition && conf.count(ws)) {
- rp.w.reset(new Weights);
- rp.w->InitFromFile(str(ws.c_str(), conf));
- rp.w->InitVector(&rp.weight_vector);
+ rp.weight_vector.reset(new vector<weight_t>());
+ Weights::InitFromFile(str(ws.c_str(), conf), rp.weight_vector.get());
}
bool has_stateful = false;
if (conf.count(ff)) {
@@ -595,6 +606,14 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
if (LowercaseString(str(isn.c_str(),conf)) == "full") {
palg = 0;
}
+ if (LowercaseString(conf["intersection_strategy"].as<string>()) == "fast_cube_pruning") {
+ palg = 2;
+ cerr << "Using Fast Cube Pruning intersection (see Algorithm 2 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n";
+ }
+ if (LowercaseString(conf["intersection_strategy"].as<string>()) == "fast_cube_pruning_2") {
+ palg = 3;
+ cerr << "Using Fast Cube Pruning 2 intersection (see Algorithm 3 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n";
+ }
rp.inter_conf.reset(new IntersectionConfiguration(palg, pop_limit));
} else {
break; // TODO alert user if there are any future configurations
@@ -602,11 +621,15 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
}
// set up weight vectors since later phases may reuse weights from earlier phases
- const vector<double>* prev = &init_weights;
+ shared_ptr<vector<weight_t> > prev_weights = init_weights;
for (int pass = 0; pass < rescoring_passes.size(); ++pass) {
RescoringPass& rp = rescoring_passes[pass];
- if (!rp.w) { rp.weight_vector = *prev; } else { prev = &rp.weight_vector; }
- rp.models.reset(new ModelSet(rp.weight_vector, rp.ffs));
+ if (!rp.weight_vector) {
+ rp.weight_vector = prev_weights;
+ } else {
+ prev_weights = rp.weight_vector;
+ }
+ rp.models.reset(new ModelSet(*rp.weight_vector, rp.ffs));
string ps = "Pass1 "; ps[4] += pass;
if (!SILENT) show_models(conf,*rp.models,ps.c_str());
}
@@ -657,7 +680,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
}
if (!fsa_ffs.empty()) {
cerr<<"FSA: ";
- show_all_features(fsa_ffs,init_weights,cerr,cerr,true,true);
+ show_all_features(fsa_ffs,*init_weights,cerr,cerr,true,true);
}
#endif
@@ -677,6 +700,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
kbest = conf.count("k_best");
unique_kbest = conf.count("unique_k_best");
get_oracle_forest = conf.count("get_oracle_forest");
+ oracle.show_derivation=conf.count("show_derivations");
+ remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");
#ifdef FSA_RESCORING
cfg_options.Validate();
@@ -703,7 +728,8 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) {
if (del) delete o;
return res;
}
-void Decoder::SetWeights(const vector<double>& weights) { pimpl_->SetWeights(weights); }
+vector<weight_t>& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); }
+const vector<weight_t>& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); }
void Decoder::SetSupplementalGrammar(const std::string& grammar_string) {
assert(pimpl_->translator->GetDecoderType() == "SCFG");
static_cast<SCFGTranslator&>(*pimpl_->translator).SetSupplementalGrammar(grammar_string);
@@ -748,7 +774,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
translator->ProcessMarkupHints(smeta.sgml_);
Timer t("Translation");
const bool translation_successful =
- translator->Translate(to_translate, &smeta, init_weights, &forest);
+ translator->Translate(to_translate, &smeta, *init_weights, &forest);
translator->SentenceComplete();
if (!translation_successful) {
@@ -766,10 +792,15 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
const bool show_tree_structure=conf.count("show_tree_structure");
if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,oracle.show_derivation);
if (conf.count("show_expected_length")) {
- const PRPair<double, double> res =
- Inside<PRPair<double, double>,
- PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest);
- cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl;
+ const PRPair<prob_t, prob_t> res =
+ Inside<PRPair<prob_t, prob_t>,
+ PRWeightFunction<prob_t, EdgeProb, prob_t, ELengthWeightFunction> >(forest);
+ cerr << " Expected length (words): " << (res.r / res.p).as_float() << "\t" << res << endl;
+ }
+
+ if (conf.count("show_partition")) {
+ const prob_t z = Inside<prob_t, EdgeProb>(forest);
+ cerr << " Partition log(Z): " << log(z) << endl;
}
SummaryFeature summary_feature_type = kNODE_RISK;
@@ -786,7 +817,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
for (int pass = 0; pass < rescoring_passes.size(); ++pass) {
const RescoringPass& rp = rescoring_passes[pass];
- const vector<double>& cur_weights = rp.weight_vector;
+ const vector<weight_t>& cur_weights = *rp.weight_vector;
if (!SILENT) cerr << endl << " RESCORING PASS #" << (pass+1) << " " << rp << endl;
#ifdef FSA_RESCORING
cfg_options.maybe_output_source(forest);
@@ -799,11 +830,17 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Timer t("Forest rescoring:");
rp.models->PrepareForInput(smeta);
Hypergraph rescored_forest;
+#ifdef CP_TIME
+ CpTime::Sub(clock());
+#endif
ApplyModelSet(forest,
smeta,
*rp.models,
*rp.inter_conf,
&rescored_forest);
+#ifdef CP_TIME
+ CpTime::Add(clock());
+#endif
forest.swap(rescored_forest);
forest.Reweight(cur_weights);
if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation);
@@ -901,7 +938,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
#endif
}
- const vector<double>& last_weights = (rescoring_passes.empty() ? init_weights : rescoring_passes.back().weight_vector);
+ const vector<double>& last_weights = (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector);
// Oracle Rescoring
if(get_oracle_forest) {
@@ -942,7 +979,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
} else {
if (kbest && !has_ref) {
//TODO: does this work properly?
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-");
+ const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
} else if (csplit_output_plf) {
cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {
@@ -989,6 +1027,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
// if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n";
// for (int i = 0; i < forest.edges_.size(); ++i)
// forest.edges_[i].edge_prob_=prob_t::One(); }
+ if (remove_intersected_rule_annotations) {
+ for (unsigned i = 0; i < forest.edges_.size(); ++i)
+ if (forest.edges_[i].rule_ &&
+ forest.edges_[i].rule_->parent_rule_)
+ forest.edges_[i].rule_ = forest.edges_[i].rule_->parent_rule_;
+ }
forest.Reweight(last_weights);
if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,oracle.show_derivation);
if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
@@ -1059,8 +1103,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
}
}
if (conf.count("graphviz")) forest.PrintGraphviz();
- if (kbest)
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-");
+ if (kbest) {
+ const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
+ }
if (conf.count("show_conditional_prob")) {
const prob_t ref_z = Inside<prob_t, EdgeProb>(forest);
cout << (log(ref_z) - log(first_z)) << endl << flush;
diff --git a/decoder/decoder.h b/decoder/decoder.h
index 236eacd5..6b2f7b16 100644
--- a/decoder/decoder.h
+++ b/decoder/decoder.h
@@ -7,7 +7,21 @@
#include <boost/shared_ptr.hpp>
#include <boost/program_options/variables_map.hpp>
-#include "grammar.h"
+#include "weights.h" // weight_t
+
+#undef CP_TIME
+//#define CP_TIME
+#ifdef CP_TIME
+#include <time.h>
+struct CpTime{
+public:
+ static void Add(clock_t x);
+ static void Sub(clock_t x);
+ static double Get();
+private:
+ static clock_t time_;
+};
+#endif
class SentenceMetadata;
struct Hypergraph;
@@ -27,7 +41,12 @@ struct Decoder {
Decoder(int argc, char** argv);
Decoder(std::istream* config_file);
bool Decode(const std::string& input, DecoderObserver* observer = NULL);
- void SetWeights(const std::vector<double>& weights);
+
+ // access this to either *read* or *write* to the decoder's last
+ // weight vector (i.e., the weights of the finest past)
+ std::vector<weight_t>& CurrentWeightVector();
+ const std::vector<weight_t>& CurrentWeightVector() const;
+
void SetId(int id);
~Decoder();
const boost::program_options::variables_map& GetConf() const { return conf; }
diff --git a/decoder/feature_accum.h b/decoder/feature_accum.h
deleted file mode 100755
index 4b8338eb..00000000
--- a/decoder/feature_accum.h
+++ /dev/null
@@ -1,129 +0,0 @@
-#ifndef FEATURE_ACCUM_H
-#define FEATURE_ACCUM_H
-
-#include "ff.h"
-#include "sparse_vector.h"
-#include "value_array.h"
-
-struct SparseFeatureAccumulator : public FeatureVector {
- typedef FeatureVector State;
- SparseFeatureAccumulator() { assert(!"this code is disabled"); }
- template <class FF>
- FeatureVector const& describe(FF const& ) { return *this; }
- void Store(FeatureVector *fv) const {
-//NO fv->set_from(*this);
- }
- template <class FF>
- void Store(FF const& /* ff */,FeatureVector *fv) const {
-//NO fv->set_from(*this);
- }
- template <class FF>
- void Add(FF const& /* ff */,FeatureVector const& fv) {
- (*this)+=fv;
- }
- void Add(FeatureVector const& fv) {
- (*this)+=fv;
- }
- /*
- SparseFeatureAccumulator(FeatureVector const& fv) : State(fv) {}
- FeatureAccumulator(Features const& fids) {}
- FeatureAccumulator(Features const& fids,FeatureVector const& fv) : State(fv) {}
- void Add(Features const& fids,FeatureVector const& fv) {
- *this += fv;
- }
- */
- void Add(int i,Featval v) {
-//NO (*this)[i]+=v;
- }
- void Add(Features const& fids,int i,Featval v) {
-//NO (*this)[i]+=v;
- }
-};
-
-struct SingleFeatureAccumulator {
- typedef Featval State;
- typedef SingleFeatureAccumulator Self;
- State v;
- /*
- void operator +=(State const& o) {
- v+=o;
- }
- */
- void operator +=(Self const& s) {
- v+=s.v;
- }
- SingleFeatureAccumulator() : v() {}
- template <class FF>
- State const& describe(FF const& ) const { return v; }
-
- template <class FF>
- void Store(FF const& ff,FeatureVector *fv) const {
- fv->set_value(ff.fid_,v);
- }
- void Store(Features const& fids,FeatureVector *fv) const {
- assert(fids.size()==1);
- fv->set_value(fids[0],v);
- }
- /*
- SingleFeatureAccumulator(Features const& fids) { assert(fids.size()==1); }
- SingleFeatureAccumulator(Features const& fids,FeatureVector const& fv)
- {
- assert(fids.size()==1);
- v=fv.get_singleton();
- }
- */
-
- template <class FF>
- void Add(FF const& ff,FeatureVector const& fv) {
- v+=fv.get(ff.fid_);
- }
- void Add(FeatureVector const& fv) {
- v+=fv.get_singleton();
- }
-
- void Add(Features const& fids,FeatureVector const& fv) {
- v += fv.get(fids[0]);
- }
- void Add(Featval dv) {
- v+=dv;
- }
- void Add(int,Featval dv) {
- v+=dv;
- }
- void Add(FeatureVector const& fids,int i,Featval dv) {
- assert(fids.size()==1 && i==0);
- v+=dv;
- }
-};
-
-
-#if 0
-// omitting this so we can default construct an accum. might be worth resurrecting in the future
-struct ArrayFeatureAccumulator : public ValueArray<Featval> {
- typedef ValueArray<Featval> State;
- template <class Fsa>
- ArrayFeatureAccumulator(Fsa const& fsa) : State(fsa.features_.size()) { }
- ArrayFeatureAccumulator(Features const& fids) : State(fids.size()) { }
- ArrayFeatureAccumulator(Features const& fids) : State(fids.size()) { }
- ArrayFeatureAccumulator(Features const& fids,FeatureVector const& fv) : State(fids.size()) {
- for (int i=0,e=i<fids.size();i<e;++i)
- (*this)[i]=fv.get(i);
- }
- State const& describe(Features const& fids) const { return *this; }
- void Store(Features const& fids,FeatureVector *fv) const {
- assert(fids.size()==size());
- for (int i=0,e=i<fids.size();i<e;++i)
- fv->set_value(fids[i],(*this)[i]);
- }
- void Add(Features const& fids,FeatureVector const& fv) {
- for (int i=0,e=i<fids.size();i<e;++i)
- (*this)[i]+=fv.get(i);
- }
- void Add(FeatureVector const& fids,int i,Featval v) {
- (*this)[i]+=v;
- }
-};
-#endif
-
-
-#endif
diff --git a/decoder/ff_factory.h b/decoder/ff_factory.h
index 92334396..5eb68c8b 100644
--- a/decoder/ff_factory.h
+++ b/decoder/ff_factory.h
@@ -20,8 +20,6 @@
#include <boost/shared_ptr.hpp>
-#include "ff_fsa_dynamic.h"
-
class FeatureFunction;
class FsaFeatureFunction;
diff --git a/decoder/ff_from_fsa.h b/decoder/ff_from_fsa.h
deleted file mode 100755
index f8d79e03..00000000
--- a/decoder/ff_from_fsa.h
+++ /dev/null
@@ -1,304 +0,0 @@
-#ifndef FF_FROM_FSA_H
-#define FF_FROM_FSA_H
-
-#include "ff_fsa.h"
-
-#ifndef TD__none
-// replacing dependency on SRILM
-#define TD__none -1
-#endif
-
-#ifndef FSA_FF_DEBUG
-# define FSA_FF_DEBUG 0
-#endif
-#if FSA_FF_DEBUG
-# define FSAFFDBG(e,x) FSADBGif(debug(),e,x)
-# define FSAFFDBGnl(e) FSADBGif_nl(debug(),e)
-#else
-# define FSAFFDBG(e,x)
-# define FSAFFDBGnl(e)
-#endif
-
-/* regular bottom up scorer from Fsa feature
- uses guarantee about markov order=N to score ASAP
- encoding of state: if less than N-1 (ctxlen) words
-
- usage:
- typedef FeatureFunctionFromFsa<LanguageModelFsa> LanguageModelFromFsa;
-*/
-
-template <class Impl>
-class FeatureFunctionFromFsa : public FeatureFunction {
- typedef void const* SP;
- typedef WordID *W;
- typedef WordID const* WP;
-public:
- template <class I>
- FeatureFunctionFromFsa(I const& param) : ff(param) {
- debug_=true; // because factory won't set until after we construct.
- }
- template <class I>
- FeatureFunctionFromFsa(I & param) : ff(param) {
- debug_=true; // because factory won't set until after we construct.
- }
-
- static std::string usage(bool args,bool verbose) {
- return Impl::usage(args,verbose);
- }
- void init_name_debug(std::string const& n,bool debug) {
- FeatureFunction::init_name_debug(n,debug);
- ff.init_name_debug(n,debug);
- }
-
- // this should override
- Features features() const {
- DBGINIT("FeatureFunctionFromFsa features() name="<<ff.name()<<" features="<<FD::Convert(ff.features()));
- return ff.features();
- }
-
- // Log because it potentially stores info in edge. otherwise the same as regular TraversalFeatures.
- void TraversalFeaturesLog(const SentenceMetadata& smeta,
- Hypergraph::Edge& edge,
- const std::vector<const void*>& ant_contexts,
- FeatureVector* features,
- FeatureVector* estimated_features,
- void* out_state) const
- {
- TRule const& rule=*edge.rule_;
- Sentence const& e = rule.e(); // items in target side of rule
- typename Impl::Accum accum,h_accum;
- if (!ssz) { // special case for no state - but still build up longer phrases to score in case FSA overrides ScanPhraseAccum
- if (Impl::simple_phrase_score) {
- // save the effort of building up the contiguous rule phrases - probably can just use the else branch, now that phrases aren't copied but are scanned off e directly.
- for (int j=0,ee=e.size();j<ee;++j) {
- if (e[j]>=1) // token
- ff.ScanAccum(smeta,edge,(WordID)e[j],NULL,NULL,&accum);
- FSAFFDBG(edge," "<<TD::Convert(e[j]));
- }
- } else {
-#undef RHS_WORD
-#define RHS_WORD(j) (e[j]>=1)
- for (int j=0,ee=e.size();;++j) { // items in target side of rule
- for(;;++j) {
- if (j>=ee) goto rhs_done; // j may go 1 past ee due to k possibly getting to end
- if (RHS_WORD(j)) break;
- }
- // word @j
- int k=j;
- while(k<ee) if (!RHS_WORD(++k)) break;
- //end or nonword @k - [j,k) is phrase
- FSAFFDBG(edge," ["<<TD::GetString(&e[j],&e[k])<<']');
- ff.ScanPhraseAccum(smeta,edge,&e[j],&e[k],0,0,&accum);
- j=k;
- }
- }
- rhs_done:
- accum.Store(ff,features);
- FSAFFDBG(edge,"="<<accum.describe(ff));
- FSAFFDBGnl(edge);
- return;
- }
-
- // bear with me, because this is hard to understand. reminder: ant_contexts and out_state are left-words first (up to M, TD__none padded). if all M words are present, then FSA state follows. otherwise 0 bytes to keep memcmp/hash happy.
-
-//why do we compute heuristic in so many places? well, because that's how we know what state we should score words in once we're full on our left context (because of markov order bound, we know the score will be the same no matter what came before that left context)
- // these left_* refer to our output (out_state):
- W left_begin=(W)out_state;
- W left_out=left_begin; // [left,fsa_state) = left ctx words. if left words aren't full, then null wordid
- WP left_full=left_end_full(out_state);
- FsaScanner<Impl> fsa(ff,smeta,edge);
- /* fsa holds our current state once we've seen our first M rule or child left-context words. that state scores up the rest of the words at the time, and is replaced by the right state of any full child. at the end, if we've got at least M left words in all, it becomes our right state (otherwise, we don't bother storing the partial state, which might seem useful any time we're built on by a rule that has our variable in the initial position - but without also storing the heuristic for that case, we just end up rescanning from scratch anyway to produce the heuristic. so we just store all 0 bytes if we have less than M left words at the end. */
- for (int j = 0,ee=e.size(); j < ee; ++j) { // items in target side of rule
- s_rhs_next:
- if (!RHS_WORD(j)) { // variable
- // variables a* are referring to this child derivation state.
- SP a = ant_contexts[-e[j]];
- WP al=(WP)a,ale=left_end(a); // the child left words
- int anw=ale-al;
- FSAFFDBG(edge,' '<<describe_state(a));
-// anw left words in child. full if == M. we will use them to fill our left words, and then score the rest fully, knowing what state we're in based on h_state -> our left words -> any number of interior words which are scored then hidden
- if (left_out+anw<left_full) { // still short of M after adding - nothing to score (not even our heuristic)
- wordcpy(left_out,al,anw);
- left_out+=anw;
- } else if (left_out<left_full) { // we had less than M before, and will have a tleast M after adding. so score heuristic and the rest M+1,... score inside.
- int ntofill=left_full-left_out;
- assert(ntofill==M-(left_out-left_begin));
- wordcpy(left_out,al,ntofill);
- left_out=(W)left_full;
- // heuristic known now
- fsa.reset(ff.heuristic_start_state());
- fsa.scan(left_begin,left_full,&h_accum); // save heuristic (happens once only)
- fsa.scan(al+ntofill,ale,&accum); // because of markov order, fully filled left words scored starting at h_start put us in the right state to score the extra words (which are forgotten)
- al+=ntofill; // we used up the first ntofill words of al to end up in some known state via exactly M words total (M-ntofill were there beforehand). now we can scan the remaining al words of this child
- } else { // more to score / state to update (left already full)
- fsa.scan(al,ale,&accum);
- }
- if (anw==M)
- fsa.reset(fsa_state(a));
- // if child had full state already, we must assume there was a gap and use its right state (note: if the child derivation was exactly M words, then we still use its state even though it will be equal to our current; there's no way to distinguish between such an M word item and an e.g. 2*M+k word item, although it's extremely unlikely that you'd have a >M word item that happens to have the same left and right boundary words).
- assert(anw<=M); // of course, we never store more than M left words in an item.
- } else { // single word
- WordID ew=e[j];
- // some redundancy: non-vectorized version of above handling of left words of child item
- if (left_out<left_full) {
- *left_out++=ew;
- if (left_out==left_full) { // handle heuristic, once only, establish state
- fsa.reset(ff.heuristic_start_state());
- fsa.scan(left_begin,left_full,&h_accum); // save heuristic (happens only once)
- }
- } else {
- if (Impl::simple_phrase_score) {
- fsa.scan(ew,&accum); // single word scan isn't optimal if phrase is different
- FSAFFDBG(edge,' '<<TD::Convert(ew));
- } else {
- int k=j;
- while(k<ee) if (!RHS_WORD(++k)) break;
- FSAFFDBG(edge," rule-phrase["<<TD::GetString(&e[j],&e[k])<<']');
- fsa.scan(&e[j],&e[k],&accum);
- if (k==ee) goto s_rhs_done;
- j=k;
- goto s_rhs_next;
- }
- }
- }
- }
-#undef RHS_WORD
- s_rhs_done:
- void *out_fsa_state=fsa_state(out_state);
- if (left_out<left_full) { // finally: partial heuristic for unfilled items
-// fsa.reset(ff.heuristic_start_state()); fsa.scan(left_begin,left_out,&h_accum);
- ff.ScanPhraseAccumOnly(smeta,edge,left_begin,left_out,ff.heuristic_start_state(),&h_accum);
- do { *left_out++=TD__none; } while(left_out<left_full); // none-terminate so left_end(out_state) will know how many words
- ff.state_zero(out_fsa_state); // so we compare / hash correctly. don't know state yet because left context isn't full
- } else // or else store final right-state. heuristic was already assigned
- ff.state_copy(out_fsa_state,fsa.cs);
- accum.Store(ff,features);
- h_accum.Store(ff,estimated_features);
- FSAFFDBG(edge," = " << describe_state(out_state)<<" "<<name<<"="<<accum.describe(ff)<<" h="<<h_accum.describe(ff)<<")");
- FSAFFDBGnl(edge);
- }
-
- void print_state(std::ostream &o,void const*ant) const {
- WP l=(WP)ant,le=left_end(ant),lf=left_end_full(ant);
- o<<'['<<Sentence(l,le);
- if (le==lf) {
- o<<" : ";
- ff.print_state(o,lf);
- }
- o << ']';
- }
-
- std::string describe_state(void const*ant) const {
- std::ostringstream o;
- print_state(o,ant);
- return o.str();
- }
-
- //FIXME: it's assumed that the final rule is just a unary no-target-terminal rewrite (same as ff_lm)
- virtual void FinalTraversalFeatures(const SentenceMetadata& smeta,
- Hypergraph::Edge& edge,
- const void* residual_state,
- FeatureVector* final_features) const
- {
- Sentence const& ends=ff.end_phrase();
- typename Impl::Accum accum;
- if (!ssz) {
- FSAFFDBG(edge," (final,0state,end="<<ends<<")");
- ff.ScanPhraseAccumOnly(smeta,edge,begin(ends),end(ends),0,&accum);
- } else {
- SP ss=ff.start_state();
- WP l=(WP)residual_state,lend=left_end(residual_state);
- SP rst=fsa_state(residual_state);
- FSAFFDBG(edge," (final");// "<<name);//<< " before="<<*final_features);
- if (lend==rst) { // implying we have an fsa state
- ff.ScanPhraseAccumOnly(smeta,edge,l,lend,ss,&accum); // e.g. <s> score(full left unscored phrase)
- FSAFFDBG(edge," start="<<ff.describe_state(ss)<<"->{"<<Sentence(l,lend)<<"}");
- ff.ScanPhraseAccumOnly(smeta,edge,begin(ends),end(ends),rst,&accum); // e.g. [ctx for last M words] score("</s>")
- FSAFFDBG(edge," end="<<ff.describe_state(rst)<<"->{"<<ends<<"}");
- } else { // all we have is a single short phrase < M words before adding ends
- int nl=lend-l;
- Sentence whole(ends.size()+nl);
- WordID *wb=begin(whole);
- wordcpy(wb,l,nl);
- wordcpy(wb+nl,begin(ends),ends.size());
- FSAFFDBG(edge," whole={"<<whole<<"}");
- // whole = left-words + end-phrase
- ff.ScanPhraseAccumOnly(smeta,edge,wb,end(whole),ss,&accum);
- }
- }
- FSAFFDBG(edge,' '<<name<<"="<<accum.describe(ff));
- FSAFFDBGnl(edge);
- accum.Store(ff,final_features);
- }
-
- bool rule_feature() const {
- return StateSize()==0; // Fsa features don't get info about span
- }
-
- static void test() {
- WordID w1[1],w1b[1],w2[2];
- w1[0]=w2[0]=TD::Convert("hi");
- w2[1]=w1b[0]=TD__none;
- assert(left_end(w1,w1+1)==w1+1);
- assert(left_end(w1b,w1b+1)==w1b);
- assert(left_end(w2,w2+2)==w2+1);
- }
-
- // override from FeatureFunction; should be called by factory after constructor. we'll also call in our own ctor
- void Init() {
- ff.Init();
- ff.sync();
- DBGINIT("base (single feature) FsaFeatureFunctionBase::Init name="<<name_<<" features="<<FD::Convert(features()));
-// FeatureFunction::name_=Impl::usage(false,false); // already achieved by ff_factory.cc
- M=ff.markov_order();
- ssz=ff.state_bytes();
- state_offset=sizeof(WordID)*M;
- SetStateSize(ssz+state_offset);
- assert(!ssz == !M); // no fsa state <=> markov order 0
- }
-
-private:
- Impl ff;
- int M; // markov order (ctx len)
- FeatureFunctionFromFsa(); // not allowed.
-
- int state_offset; // NOTE: in bytes (add to char* only). store left-words first, then fsa state
- int ssz; // bytes in fsa state
- /*
- state layout: left WordIds, followed by fsa state
- left words have never been scored. last ones remaining will be scored on FinalTraversalFeatures only.
- right state is unknown until we have all M left words (less than M means TD__none will pad out right end). unk right state will be zeroed out for proper hash/equal recombination.
- */
-
- static inline WordID const* left_end(WordID const* left, WordID const* e) {
- for (;e>left;--e)
- if (e[-1]!=TD__none) break;
- //post: [left,e] are the seen left words
- return e;
- }
- inline WP left_end(SP ant) const {
- return left_end((WP)ant,(WP)fsa_state(ant));
- }
- inline WP left_end_full(SP ant) const {
- return (WP)fsa_state(ant);
- }
- inline SP fsa_state(SP ant) const {
- return ((char const*)ant+state_offset);
- }
- inline void *fsa_state(void * ant) const {
- return ((char *)ant+state_offset);
- }
-};
-
-#ifdef TEST_FSA
-# include "tdict.cc"
-# include "ff_sample_fsa.h"
-int main() {
- std::cerr<<"Testing left_end...\n";
- std::cerr<<"sizeof(FeatureVector)="<<sizeof(FeatureVector)<<"\n";
- WordPenaltyFromFsa::test();
- return 0;
-}
-#endif
-
-#endif
diff --git a/decoder/ff_fsa.h b/decoder/ff_fsa.h
deleted file mode 100755
index 18e90bf1..00000000
--- a/decoder/ff_fsa.h
+++ /dev/null
@@ -1,401 +0,0 @@
-#ifndef FF_FSA_H
-#define FF_FSA_H
-
-/*
- features whose score is just some PFSA over target string. however, PFSA can use edge and smeta info (e.g. spans on edge) - not usually useful.
-
-//SEE ALSO: ff_fsa_dynamic.h, ff_from_fsa.h
-
- state is some fixed width byte array. could actually be a void *, WordID sequence, whatever.
-
- TODO: specify Scan return code or feature value = -inf for failure state (e.g. for hard intersection with desired target lattice?)
-
- TODO: maybe ff that wants to know about SentenceMetadata should store a ref to
- it permanently rather than get passed it for every operation. we're never
- decoding more than 1 sentence at once and it's annoying to pass it. same
- could apply for result edge as well since so far i only use it for logging
- when USE_INFO_EDGE 1 - would make the most sense if the same change happened
- to ff.h at the same time.
-
- TODO: there are a confusing array of default-implemented supposedly slightly more efficient overrides enabled; however, the two key differences are: do you score a phrase, or just word at a time (the latter constraining you to obey markov_order() everywhere. you have to implement the word case no matter what.
-
- TODO: considerable simplification of implementation if Scan implementors are required to update state in place (using temporary copy if they need it), or e.g. using memmove (copy from end to beginning) to rotate state right.
-
- TODO: at what sizes is memcpy/memmove better than looping over 2-3 ints and assigning?
-
- TODO: fsa ff scores phrases not just words
- TODO: fsa feature aggregator that presents itself as a single fsa; benefit: when wrapped in ff_from_fsa, only one set of left words is stored. downside: compared to separate ff, the inside portion of lower-order models is incorporated later. however, the full heuristic is already available and exact for those words. so don't sweat it.
-
- TODO: state (+ possibly span-specific) custom heuristic, e.g. in "longer than previous word" model, you can expect a higher outside if your state is a word of 2 letters. this is on top of the nice heuristic for the unscored words, of course. in ngrams, the avg prob will be about the same, but if the words possible for a source span are summarized, maybe it's possible to predict. probably not worth the effort.
-*/
-
-#define FSA_DEBUG 0
-
-#if USE_INFO_EDGE
-#define FSA_DEBUG_CERR 0
-#else
-#define FSA_DEBUG_CERR 1
-#endif
-
-#define FSA_DEBUG_DEBUG 0
-# define FSADBGif(i,e,x) do { if (i) { if (FSA_DEBUG_CERR){std::cerr<<x;} INFO_EDGE(e,x); if (FSA_DEBUG_DEBUG){std::cerr<<"FSADBGif edge.info "<<&e<<" = "<<e.info()<<std::endl;}} } while(0)
-# define FSADBGif_nl(i,e) do { if (i) { if (FSA_DEBUG_CERR) std::cerr<<std::endl; INFO_EDGE(e,"; "); } } while(0)
-#if FSA_DEBUG
-# include <iostream>
-# define FSADBG(e,x) FSADBGif(d().debug(),e,x)
-# define FSADBGnl(e) FSADBGif_nl(d().debug(),e,x)
-#else
-# define FSADBG(e,x)
-# define FSADBGnl(e)
-#endif
-
-#include "fast_lexical_cast.hpp"
-#include <sstream>
-#include <string>
-#include "ff.h"
-#include "sparse_vector.h"
-#include "tdict.h"
-#include "hg.h"
-#include "ff_fsa_data.h"
-
-/*
-usage: see ff_sample_fsa.h or ff_lm_fsa.h
-
- then, to decode, see ff_from_fsa.h (or TODO: left->right target-earley style rescoring)
-
- */
-
-
-template <class Impl>
-struct FsaFeatureFunctionBase : public FsaFeatureFunctionData {
- Impl const& d() const { return static_cast<Impl const&>(*this); }
- Impl & d() { return static_cast<Impl &>(*this); }
-
- // this will get called by factory - override if you have multiple or dynamically named features. note: may be called repeatedly
- void Init() {
- Init(name());
- DBGINIT("base (single feature) FsaFeatureFunctionBase::Init name="<<name()<<" features="<<FD::Convert(features_));
- }
- void Init(std::string const& fname) {
- fid_=FD::Convert(fname);
- InitHaveFid();
- }
- void InitHaveFid() {
- features_=FeatureFunction::single_feature(fid_);
- }
- Features features() const {
- DBGINIT("FeatureFunctionBase::features() name="<<name()<<" features="<<FD::Convert(features_));
- return features_;
- }
-
-public:
- int fid_; // you can have more than 1 feature of course.
-
- std::string describe() const {
- std::ostringstream o;
- o<<*this;
- return o.str();
- }
-
- // can override to different return type, e.g. just return feats:
- Featval describe_features(FeatureVector const& feats) const {
- return feats.get(fid_);
- }
-
- bool debug() const { return true; }
- int fid() const { return fid_; } // return the one most important feature (for debugging)
- std::string name() const {
- return Impl::usage(false,false);
- }
-
- void print_state(std::ostream &o,void const*state) const {
- char const* i=(char const*)state;
- char const* e=i+ssz;
- for (;i!=e;++i)
- print_hex_byte(o,*i);
- }
-
- std::string describe_state(void const* state) const {
- std::ostringstream o;
- d().print_state(o,state);
- return o.str();
- }
- typedef SingleFeatureAccumulator Accum;
-
- // return m: all strings x with the same final m+1 letters must end in this state
- /* markov chain of order m: P(xn|xn-1...x1)=P(xn|xn-1...xn-m) */
- int markov_order() const { return 0; } // override if you use state. order 0 implies state_bytes()==0 as well, as far as scoring/splitting is concerned (you can still track state, though)
- //TODO: if we wanted, we could mark certain states as maximal-context, but this would lose our fixed amount of left context in ff_from_fsa, and lose also our vector operations (have to scan left words 1 at a time, checking always to see where you change from h to inside - BUT, could detect equivalent LM states, which would be nice).
-
-
-
- // if [i,end) are unscored words of length <= markov_order, score some of them on the right, and return the number scored, i.e. [end-r,end) will have been scored for return r. CAREFUL: for ngram you have to sometimes remember to pay all of the backoff once you see a few more words to the left.
- template <class Accum>
- int early_score_words(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,Accum *accum) const {
- return 0;
- }
-
- // this isn't currently used at all. this left-shortening is not recommended (wasn't worth the computation expense for ngram): specifically for bottom up scoring (ff_from_fsa), you can return a shorter left-words context - but this means e.g. for ngram tracking that a backoff occurred where the final BO cost isn't yet known. you would also have to remember any necessary info in your own state - in the future, ff_from_fsa on a list of fsa features would only shorten it to the max
-
-
- // override this (static)
- static std::string usage(bool param,bool verbose) {
- return FeatureFunction::usage_helper("unnamed_fsa_feature","","",param,verbose);
- }
-
- // move from state to next_state after seeing word x, while emitting features->set_value(fid,val) possibly with duplicates. state and next_state will never be the same memory.
- //TODO: decide if we want to require you to support dest same as src, since that's how we use it most often in ff_from_fsa bottom-up wrapper (in l->r scoring, however, distinct copies will be the rule), and it probably wouldn't be too hard for most people to support. however, it's good to hide the complexity here, once (see overly clever FsaScan loop that swaps src/dest addresses repeatedly to scan a sequence by effectively swapping)
-
-protected:
- // overrides have different name because of inheritance method hiding;
-
- // simple/common case; 1 fid. these need not be overriden if you have multiple feature ids
- Featval Scan1(WordID w,void const* state,void *next_state) const {
- assert(0);
- return 0;
- }
- Featval Scan1Meta(SentenceMetadata const& /* smeta */,Hypergraph::Edge const& /* edge */,
- WordID w,void const* state,void *next_state) const {
- return d().Scan1(w,state,next_state);
- }
-public:
-
- // must override this or Scan1Meta or Scan1
- template <class Accum>
- inline void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,
- WordID w,void const* state,void *next_state,Accum *a) const {
- Add(d().Scan1Meta(smeta,edge,w,state,next_state),a);
- }
-
- // bounce back and forth between two state vars starting at cs, returning end state location. if we required src=dest addr safe state updating, this concept wouldn't need to exist.
- // required that you override this if you score phrases differently than word-by-word, however, you can just use the SCAN_PHRASE_ACCUM_OVERRIDE macro to do that in terms of ScanPhraseAccum
- template <class Accum>
- void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *accum) const {
- // extra code - IT'S FOR EFFICIENCY, MAN! IT'S OK! definitely no bugs here.
- if (!ssz) {
- for (;i<end;++i)
- d().ScanAccum(smeta,edge,*i,0,0,accum);
- return 0;
- }
- void *os,*es;
- if ((end-i)&1) { // odd # of words
- os=cs;
- es=ns;
- goto odd;
- } else {
- i+=1;
- es=cs;
- os=ns;
- }
- for (;i<end;i+=2) {
- d().ScanAccum(smeta,edge,i[-1],es,os,accum); // e->o
- odd:
- d().ScanAccum(smeta,edge,i[0],os,es,accum); // o->e
- }
- return es;
- }
-
-
- static const bool simple_phrase_score=true; // if d().simple_phrase_score_, then you should expect different Phrase scores for phrase length > M. so, set this false if you provide ScanPhraseAccum (SCAN_PHRASE_ACCUM_OVERRIDE macro does this)
-
- // override this (and use SCAN_PHRASE_ACCUM_OVERRIDE ) if you want e.g. maximum possible order ngram scores with markov_order < n-1. in the future SparseFeatureAccumulator will probably be the only option for type-erased FSA ffs.
- // note you'll still have to override ScanAccum
- template <class Accum>
- void ScanPhraseAccum(SentenceMetadata const& smeta,Hypergraph::Edge const & edge,
- WordID const* i, WordID const* end,
- void const* state,void *next_state,Accum *accum) const {
- if (!ssz) {
- for (;i<end;++i)
- d().ScanAccum(smeta,edge,*i,0,0,accum);
- return;
- }
- char tstate[ssz];
- void *tst=tstate;
- bool odd=(end-i)&1;
- void *cs,*ns;
- // we're going to use Bounce (word by word alternating of states) such that the final place is next_state
- if (odd) {
- cs=tst;
- ns=next_state;
- } else {
- cs=next_state;
- ns=tst;
- }
- state_copy(cs,state);
- void *est=d().ScanPhraseAccumBounce(smeta,edge,i,end,cs,ns,accum);
- assert(est==next_state);
- }
-
-
-
- // could replace this with a CRTP subclass providing these impls.
- // the d() subclass dispatch is not needed because these will be defined in the subclass
-#define SCAN_PHRASE_ACCUM_OVERRIDE \
- static const bool simple_phrase_score=false; \
- template <class Accum> \
- void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *accum) const { \
- ScanPhraseAccum(smeta,edge,i,end,cs,ns,accum); \
- return ns; \
- } \
- template <class Accum> \
- void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, \
- WordID const* i, WordID const* end, \
- void const* state,Accum *accum) const { \
- char s2[ssz]; ScanPhraseAccum(smeta,edge,i,end,state,(void*)s2,accum); \
- }
-
- // override this or bounce along with above. note: you can just call ScanPhraseAccum
- // doesn't set state (for heuristic in ff_from_fsa)
- template <class Accum>
- void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,
- WordID const* i, WordID const* end,
- void const* state,Accum *accum) const {
- char s1[ssz];
- char s2[ssz];
- state_copy(s1,state);
- d().ScanPhraseAccumBounce(smeta,edge,i,end,(void*)s1,(void*)s2,accum);
- }
-
- // for single-feat only. but will work for different accums
- template <class Accum>
- inline void Add(Featval v,Accum *a) const {
- a->Add(fid_,v);
- }
- inline void set_feat(FeatureVector *features,Featval v) const {
- features->set_value(fid_,v);
- }
-
- // don't set state-bytes etc. in ctor because it may depend on parsing param string
- FsaFeatureFunctionBase(int statesz=0,Sentence const& end_sentence_phrase=Sentence())
- : FsaFeatureFunctionData(statesz,end_sentence_phrase)
- {
- name_=name(); // should allow FsaDynamic wrapper to get name copied to it with sync
- }
-
-};
-
-template <class Impl>
-struct MultipleFeatureFsa : public FsaFeatureFunctionBase<Impl> {
- typedef SparseFeatureAccumulator Accum;
-};
-
-
-
-
-// if State is pod. sets state size and allocs start, h_start
-// usage:
-// struct ShorterThanPrev : public FsaTypedBase<int,ShorterThanPrev>
-// i.e. Impl is a CRTP
-template <class St,class Impl>
-struct FsaTypedBase : public FsaFeatureFunctionBase<Impl> {
- Impl const& d() const { return static_cast<Impl const&>(*this); }
- Impl & d() { return static_cast<Impl &>(*this); }
-protected:
- typedef FsaFeatureFunctionBase<Impl> Base;
- typedef St State;
- static inline State & state(void *state) {
- return *(State*)state;
- }
- static inline State const& state(void const* state) {
- return *(State const*)state;
- }
- void set_starts(State const& s,State const& heuristic_s) {
- if (0) { // already in ctor
- Base::start.resize(sizeof(State));
- Base::h_start.resize(sizeof(State));
- }
- assert(Base::start.size()==sizeof(State));
- assert(Base::h_start.size()==sizeof(State));
- state(Base::start.begin())=s;
- state(Base::h_start.begin())=heuristic_s;
- }
- FsaTypedBase(St const& start_st=St()
- ,St const& h_start_st=St()
- ,Sentence const& end_sentence_phrase=Sentence())
- : Base(sizeof(State),end_sentence_phrase) {
- set_starts(start_st,h_start_st);
- }
-public:
- void print_state(std::ostream &o,void const*st) const {
- o<<state(st);
- }
- int markov_order() const { return 1; }
-
- // override this
- Featval ScanT1S(WordID w,St const& /* from */ ,St & /* to */) const {
- return 0;
- }
-
- // or this
- Featval ScanT1(SentenceMetadata const& /* smeta */,Hypergraph::Edge const& /* edge */,WordID w,St const& from ,St & to) const {
- return d().ScanT1S(w,from,to);
- }
-
- // or this (most general)
- template <class Accum>
- inline void ScanT(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID w,St const& prev_st,St &new_st,Accum *a) const {
- Add(d().ScanT1(smeta,edge,w,prev_st,new_st),a);
- }
-
- // note: you're on your own when it comes to Phrase overrides. see FsaFeatureFunctionBase. sorry.
-
- template <class Accum>
- inline void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID w,void const* st,void *next_state,Accum *a) const {
- Impl const& im=d();
- FSADBG(edge,"Scan "<<FD::Convert(im.fid_)<<" = "<<a->describe(im)<<" "<<im.state(st)<<"->"<<TD::Convert(w)<<" ");
- im.ScanT(smeta,edge,w,state(st),state(next_state),a);
- FSADBG(edge,state(next_state)<<" = "<<a->describe(im));
- FSADBGnl(edge);
- }
-};
-
-
-// keep a "current state" (bouncing back and forth)
-template <class FF>
-struct FsaScanner {
-// enum {ALIGN=8};
- static const int ALIGN=8;
- FF const& ff;
- SentenceMetadata const& smeta;
- int ssz;
- Bytes states; // first is at begin, second is at (char*)begin+stride
- void *st0; // states
- void *st1; // states+stride
- void *cs; // initially st0, alternates between st0 and st1
- inline void *nexts() const {
- return (cs==st0)?st1:st0;
- }
- Hypergraph::Edge const& edge;
- FsaScanner(FF const& ff,SentenceMetadata const& smeta,Hypergraph::Edge const& edge) : ff(ff),smeta(smeta),edge(edge)
- {
- ssz=ff.state_bytes();
- int stride=((ssz+ALIGN-1)/ALIGN)*ALIGN; // round up to multiple of ALIGN
- states.resize(stride+ssz);
- st0=states.begin();
- st1=(char*)st0+stride;
-// for (int i=0;i<2;++i) st[i]=cs+(i*stride);
- }
- void reset(void const* state) {
- cs=st0;
- std::memcpy(st0,state,ssz);
- }
- template <class Accum>
- void scan(WordID w,Accum *a) {
- void *ns=nexts();
- ff.ScanAccum(smeta,edge,w,cs,ns,a);
- cs=ns;
- }
- template <class Accum>
- void scan(WordID const* i,WordID const* end,Accum *a) {
- // faster. and allows greater-order excursions
- cs=ff.ScanPhraseAccumBounce(smeta,edge,i,end,cs,nexts(),a);
- }
-};
-
-
-//TODO: combine 2 FsaFeatures typelist style (can recurse for more)
-
-
-
-
-#endif
diff --git a/decoder/ff_fsa_data.h b/decoder/ff_fsa_data.h
deleted file mode 100755
index d215e940..00000000
--- a/decoder/ff_fsa_data.h
+++ /dev/null
@@ -1,131 +0,0 @@
-#ifndef FF_FSA_DATA_H
-#define FF_FSA_DATA_H
-
-#include <stdint.h> //C99
-#include <sstream>
-#include "sentences.h"
-#include "feature_accum.h"
-#include "value_array.h"
-#include "ff.h" //debug
-typedef ValueArray<uint8_t> Bytes;
-
-// stuff I see no reason to have virtual. but because it's impossible (w/o virtual inheritance to have dynamic fsa ff know where the impl's data starts, implemented a sync (copy) method that needs to be called. init_name_debug was already necessary to keep state in sync between ff and ff_from_fsa, so no sync should be needed after it. supposing all modifications were through setters, then no explicit sync call would ever be needed; updates could be mirrored.
-struct FsaFeatureFunctionData
-{
- void init_name_debug(std::string const& n,bool debug) {
- name_=n;
- debug_=debug;
- }
- //HACK for diamond inheritance (w/o costing performance)
- FsaFeatureFunctionData *sync_to_;
-
- void sync() const { // call this if you modify any fields after your constructor is done
- if (sync_to_) {
- DBGINIT("sync to "<<*sync_to_);
- *sync_to_=*this;
- DBGINIT("synced result="<<*sync_to_<< " from this="<<*this);
- } else {
- DBGINIT("nobody to sync to - from FeatureFunctionData this="<<*this);
- }
- }
-
- friend std::ostream &operator<<(std::ostream &o,FsaFeatureFunctionData const& d) {
- o << "[FSA "<<d.name_<<" features="<<FD::Convert(d.features_)<<" state_bytes="<<d.state_bytes()<<" end='"<<d.end_phrase()<<"' start=";
- d.print_state(o,d.start_state());
- o<<"]";
- return o;
- }
-
- FsaFeatureFunctionData(int statesz=0,Sentence const& end_sentence_phrase=Sentence()) : start(statesz),h_start(statesz),end_phrase_(end_sentence_phrase),ssz(statesz) {
- debug_=true;
- sync_to_=0;
- }
-
- std::string name_;
- std::string name() const {
- return name_;
- }
- typedef SparseFeatureAccumulator Accum;
- bool debug_;
- bool debug() const { return debug_; }
- void state_copy(void *to,void const*from) const {
- if (ssz)
- std::memcpy(to,from,ssz);
- }
- void state_zero(void *st) const { // you should call this if you don't know the state yet and want it to be hashed/compared properly
- std::memset(st,0,ssz);
- }
- Features features() const {
- return features_;
- }
- int n_features() const {
- return features_.size();
- }
- int state_bytes() const { return ssz; }
- void const* start_state() const {
- return start.begin();
- }
- void const * heuristic_start_state() const {
- return h_start.begin();
- }
- Sentence const& end_phrase() const { return end_phrase_; }
- template <class T>
- static inline T* state_as(void *p) { return (T*)p; }
- template <class T>
- static inline T const* state_as(void const* p) { return (T*)p; }
- std::string describe_features(FeatureVector const& feats) {
- std::ostringstream o;
- o<<feats;
- return o.str();
- }
- void print_state(std::ostream &o,void const*state) const {
- char const* i=(char const*)state;
- char const* e=i+ssz;
- for (;i!=e;++i)
- print_hex_byte(o,*i);
- }
-
- Features features_;
- Bytes start,h_start; // start state and estimated-features (heuristic) start state. set these. default empty.
- Sentence end_phrase_; // words appended for final traversal (final state cost is assessed using Scan) e.g. "</s>" for lm.
-protected:
- int ssz; // don't forget to set this. default 0 (it may depend on params of course)
- // this can be called instead or after constructor (also set bytes and end_phrase_)
- void set_state_bytes(int sb=0) {
- if (start.size()!=sb) start.resize(sb);
- if (h_start.size()!=sb) h_start.resize(sb);
- ssz=sb;
- }
- void set_end_phrase(WordID single) {
- end_phrase_=singleton_sentence(single);
- }
-
- inline void static to_state(void *state,char const* begin,char const* end) {
- std::memcpy(state,begin,end-begin);
- }
- inline void static to_state(void *state,char const* begin,int n) {
- std::memcpy(state,begin,n);
- }
- template <class T>
- inline void static to_state(void *state,T const* begin,int n=1) {
- to_state(state,(char const*)begin,n*sizeof(T));
- }
- template <class T>
- inline void static to_state(void *state,T const* begin,T const* end) {
- to_state(state,(char const*)begin,(char const*)end);
- }
- inline static char hexdigit(int i) {
- int j=i-10;
- return j>=0?'a'+j:'0'+i;
- }
- inline static void print_hex_byte(std::ostream &o,unsigned c) {
- o<<hexdigit(c>>4);
- o<<hexdigit(c&0x0f);
- }
- inline static void Add(Featval v,SingleFeatureAccumulator *a) {
- a->Add(v);
- }
-
-};
-
-#endif
diff --git a/decoder/ff_fsa_dynamic.h b/decoder/ff_fsa_dynamic.h
deleted file mode 100755
index 6f75bbe5..00000000
--- a/decoder/ff_fsa_dynamic.h
+++ /dev/null
@@ -1,208 +0,0 @@
-#ifndef FF_FSA_DYNAMIC_H
-#define FF_FSA_DYNAMIC_H
-
-struct SentenceMetadata;
-
-#include "ff_fsa_data.h"
-#include "hg.h" // can't forward declare nested Hypergraph::Edge class
-#include <sstream>
-
-// the type-erased interface
-
-//FIXME: diamond inheritance problem. make a copy of the fixed data? or else make the dynamic version not wrap but rather be templated CRTP base (yuck)
-struct FsaFeatureFunction : public FsaFeatureFunctionData {
- static const bool simple_phrase_score=false;
- virtual int markov_order() const = 0;
-
- // see ff_fsa.h - FsaFeatureFunctionBase<Impl> gives you reasonable impls of these if you override just ScanAccum
- virtual void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,
- WordID w,void const* state,void *next_state,Accum *a) const = 0;
- virtual void ScanPhraseAccum(SentenceMetadata const& smeta,Hypergraph::Edge const & edge,
- WordID const* i, WordID const* end,
- void const* state,void *next_state,Accum *accum) const = 0;
- virtual void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,
- WordID const* i, WordID const* end,
- void const* state,Accum *accum) const = 0;
- virtual void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *accum) const = 0;
-
- virtual int early_score_words(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,Accum *accum) const { return 0; }
- // called after constructor, before use
- virtual void Init() = 0;
- virtual std::string usage_v(bool param,bool verbose) const {
- return FeatureFunction::usage_helper("unnamed_dynamic_fsa_feature","","",param,verbose);
- }
- virtual void init_name_debug(std::string const& n,bool debug) {
- FsaFeatureFunctionData::init_name_debug(n,debug);
- }
-
- virtual void print_state(std::ostream &o,void const*state) const {
- FsaFeatureFunctionData::print_state(o,state);
- }
- virtual std::string describe() const { return "[FSA unnamed_dynamic_fsa_feature]"; }
-
- //end_phrase()
- virtual ~FsaFeatureFunction() {}
-
- // no need to override:
- std::string describe_state(void const* state) const {
- std::ostringstream o;
- print_state(o,state);
- return o.str();
- }
-};
-
-// conforming to above interface, type erases FsaImpl
-// you might be wondering: why do this? answer: it's cool, and it means that the bottom-up ff over ff_fsa wrapper doesn't go through multiple layers of dynamic dispatch
-// usage: typedef FsaFeatureFunctionDynamic<MyFsa> MyFsaDyn;
-template <class Impl>
-struct FsaFeatureFunctionDynamic : public FsaFeatureFunction {
- static const bool simple_phrase_score=Impl::simple_phrase_score;
- Impl& d() { return impl;//static_cast<Impl&>(*this);
- }
- Impl const& d() const { return impl;
- //static_cast<Impl const&>(*this);
- }
- int markov_order() const { return d().markov_order(); }
-
- std::string describe() const {
- return d().describe();
- }
-
- virtual void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,
- WordID w,void const* state,void *next_state,Accum *a) const {
- return d().ScanAccum(smeta,edge,w,state,next_state,a);
- }
-
- virtual void ScanPhraseAccum(SentenceMetadata const& smeta,Hypergraph::Edge const & edge,
- WordID const* i, WordID const* end,
- void const* state,void *next_state,Accum *a) const {
- return d().ScanPhraseAccum(smeta,edge,i,end,state,next_state,a);
- }
-
- virtual void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,
- WordID const* i, WordID const* end,
- void const* state,Accum *a) const {
- return d().ScanPhraseAccumOnly(smeta,edge,i,end,state,a);
- }
-
- virtual void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *a) const {
- return d().ScanPhraseAccumBounce(smeta,edge,i,end,cs,ns,a);
- }
-
- virtual int early_score_words(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,Accum *accum) const {
- return d().early_score_words(smeta,edge,i,end,accum);
- }
-
- static std::string usage(bool param,bool verbose) {
- return Impl::usage(param,verbose);
- }
-
- std::string usage_v(bool param,bool verbose) const {
- return Impl::usage(param,verbose);
- }
-
- virtual void print_state(std::ostream &o,void const*state) const {
- return d().print_state(o,state);
- }
-
- void init_name_debug(std::string const& n,bool debug) {
- FsaFeatureFunction::init_name_debug(n,debug);
- d().init_name_debug(n,debug);
- }
-
- virtual void Init() {
- d().sync_to_=(FsaFeatureFunctionData*)this;
- d().Init();
- d().sync();
- }
-
- template <class I>
- FsaFeatureFunctionDynamic(I const& param) : impl(param) {
- Init();
- }
-private:
- Impl impl;
-};
-
-// constructor takes ptr or shared_ptr to Impl, otherwise same as above - note: not virtual
-template <class Impl>
-struct FsaFeatureFunctionPimpl : public FsaFeatureFunctionData {
- typedef boost::shared_ptr<Impl const> Pimpl;
- static const bool simple_phrase_score=Impl::simple_phrase_score;
- Impl const& d() const { return *p_; }
- int markov_order() const { return d().markov_order(); }
-
- std::string describe() const {
- return d().describe();
- }
-
- void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,
- WordID w,void const* state,void *next_state,Accum *a) const {
- return d().ScanAccum(smeta,edge,w,state,next_state,a);
- }
-
- void ScanPhraseAccum(SentenceMetadata const& smeta,Hypergraph::Edge const & edge,
- WordID const* i, WordID const* end,
- void const* state,void *next_state,Accum *a) const {
- return d().ScanPhraseAccum(smeta,edge,i,end,state,next_state,a);
- }
-
- void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,
- WordID const* i, WordID const* end,
- void const* state,Accum *a) const {
- return d().ScanPhraseAccumOnly(smeta,edge,i,end,state,a);
- }
-
- void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *a) const {
- return d().ScanPhraseAccumBounce(smeta,edge,i,end,cs,ns,a);
- }
-
- int early_score_words(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,Accum *accum) const {
- return d().early_score_words(smeta,edge,i,end,accum);
- }
-
- static std::string usage(bool param,bool verbose) {
- return Impl::usage(param,verbose);
- }
-
- std::string usage_v(bool param,bool verbose) const {
- return Impl::usage(param,verbose);
- }
-
- void print_state(std::ostream &o,void const*state) const {
- return d().print_state(o,state);
- }
-
-#if 0
- // this and Init() don't touch p_ because we want to leave the original alone.
- void init_name_debug(std::string const& n,bool debug) {
- FsaFeatureFunctionData::init_name_debug(n,debug);
- }
-#endif
- void Init() {
- p_=hold_pimpl_.get();
-#if 0
- d().sync_to_=static_cast<FsaFeatureFunctionData*>(this);
- d().Init();
-#endif
- *static_cast<FsaFeatureFunctionData*>(this)=d();
- }
-
- FsaFeatureFunctionPimpl(Impl const* const p) : hold_pimpl_(p,null_deleter()) {
- Init();
- }
- FsaFeatureFunctionPimpl(Pimpl const& p) : hold_pimpl_(p) {
- Init();
- }
-private:
- Impl const* p_;
- Pimpl hold_pimpl_;
-};
-
-typedef FsaFeatureFunctionPimpl<FsaFeatureFunction> FsaFeatureFunctionFwd; // allow ff_from_fsa for an existing dynamic-type ff (as opposed to usual register a wrapped known-type FSA in ff_register, which is more efficient)
-//typedef FsaFeatureFunctionDynamic<FsaFeatureFunctionFwd> DynamicFsaFeatureFunctionFwd; //if you really need to have a dynamic fsa facade that's also a dynamic fsa
-
-//TODO: combine 2 (or N) FsaFeatureFunction (type erased)
-
-
-#endif
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 24dcb9c3..28bcb6b9 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -12,11 +12,9 @@
#include "lm/model.hh"
#include "lm/enumerate_vocab.hh"
-using namespace std;
+#include "lm/left.hh"
-static const unsigned char HAS_FULL_CONTEXT = 1;
-static const unsigned char HAS_EOS_ON_RIGHT = 2;
-static const unsigned char MASK = 7;
+using namespace std;
// -x : rules include <s> and </s>
// -n NAME : feature id is NAME
@@ -70,6 +68,8 @@ string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) {
return "KLanguageModel";
}
+namespace {
+
struct VMapper : public lm::ngram::EnumerateVocab {
VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); }
void Add(lm::WordIndex index, const StringPiece &str) {
@@ -82,181 +82,122 @@ struct VMapper : public lm::ngram::EnumerateVocab {
const lm::WordIndex kLM_UNKNOWN_TOKEN;
};
-template <class Model>
-class KLanguageModelImpl {
-
- // returns the number of unscored words at the left edge of a span
- inline int UnscoredSize(const void* state) const {
- return *(static_cast<const char*>(state) + unscored_size_offset_);
- }
+#pragma pack(push)
+#pragma pack(1)
- inline void SetUnscoredSize(int size, void* state) const {
- *(static_cast<char*>(state) + unscored_size_offset_) = size;
- }
+struct BoundaryAnnotatedState {
+ lm::ngram::ChartState state;
+ bool seen_bos, seen_eos;
+};
- static inline const lm::ngram::State& RemnantLMState(const void* state) {
- return *static_cast<const lm::ngram::State*>(state);
- }
+#pragma pack(pop)
+
+template <class Model> class BoundaryRuleScore {
+ public:
+ BoundaryRuleScore(const Model &m, BoundaryAnnotatedState &state) :
+ back_(m, state.state),
+ bos_(state.seen_bos),
+ eos_(state.seen_eos),
+ penalty_(0.0),
+ end_sentence_(m.GetVocabulary().EndSentence()) {
+ bos_ = false;
+ eos_ = false;
+ }
- inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const {
- // if we were clever, we could use the memory pointed to by state to do all
- // the work, avoiding this copy
- memcpy(state, &lmstate, ngram_->StateSize());
- }
+ void BeginSentence() {
+ back_.BeginSentence();
+ bos_ = true;
+ }
- lm::WordIndex IthUnscoredWord(int i, const void* state) const {
- const lm::WordIndex* const mem = reinterpret_cast<const lm::WordIndex*>(static_cast<const char*>(state) + unscored_words_offset_);
- return mem[i];
- }
+ void BeginNonTerminal(const BoundaryAnnotatedState &sub) {
+ back_.BeginNonTerminal(sub.state, 0.0f);
+ bos_ = sub.seen_bos;
+ eos_ = sub.seen_eos;
+ }
- void SetIthUnscoredWord(int i, lm::WordIndex index, void *state) const {
- lm::WordIndex* mem = reinterpret_cast<lm::WordIndex*>(static_cast<char*>(state) + unscored_words_offset_);
- mem[i] = index;
- }
+ void NonTerminal(const BoundaryAnnotatedState &sub) {
+ back_.NonTerminal(sub.state, 0.0f);
+ // cdec only calls this if there's content.
+ if (sub.seen_bos) {
+ bos_ = true;
+ penalty_ -= 100.0f;
+ }
+ if (eos_) penalty_ -= 100.0f;
+ eos_ |= sub.seen_eos;
+ }
- inline bool GetFlag(const void *state, unsigned char flag) const {
- return (*(static_cast<const char*>(state) + is_complete_offset_) & flag);
- }
+ void Terminal(lm::WordIndex word) {
+ back_.Terminal(word);
+ if (eos_) penalty_ -= 100.0f;
+ if (word == end_sentence_) eos_ = true;
+ }
- inline void SetFlag(bool on, unsigned char flag, void *state) const {
- if (on) {
- *(static_cast<char*>(state) + is_complete_offset_) |= flag;
- } else {
- *(static_cast<char*>(state) + is_complete_offset_) &= (MASK ^ flag);
+ float Finish() {
+ return penalty_ + back_.Finish();
}
- }
- inline bool HasFullContext(const void *state) const {
- return GetFlag(state, HAS_FULL_CONTEXT);
- }
+ private:
+ lm::ngram::RuleScore<Model> back_;
+ bool &bos_, &eos_;
- inline void SetHasFullContext(bool flag, void *state) const {
- SetFlag(flag, HAS_FULL_CONTEXT, state);
- }
+ float penalty_;
+ lm::WordIndex end_sentence_;
+};
+
+} // namespace
+
+template <class Model>
+class KLanguageModelImpl {
public:
- double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, double* oovs, double* est_oovs, void* remnant) {
- double sum = 0.0;
- double est_sum = 0.0;
- int num_scored = 0;
- int num_estimated = 0;
- if (oovs) *oovs = 0;
- if (est_oovs) *est_oovs = 0;
- bool saw_eos = false;
- bool has_some_history = false;
- lm::ngram::State state = ngram_->NullContextState();
+ double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) {
+ *oovs = 0;
const vector<WordID>& e = rule.e();
- bool context_complete = false;
- for (int j = 0; j < e.size(); ++j) {
- if (e[j] < 1) { // handle non-terminal substitution
- const void* astate = (ant_states[-e[j]]);
- int unscored_ant_len = UnscoredSize(astate);
- for (int k = 0; k < unscored_ant_len; ++k) {
- const lm::WordIndex cur_word = IthUnscoredWord(k, astate);
- const bool is_oov = (cur_word == 0);
- double p = 0;
- if (cur_word == kSOS_) {
- state = ngram_->BeginSentenceState();
- if (has_some_history) { // this is immediately fully scored, and bad
- p = -100;
- context_complete = true;
- } else { // this might be a real <s>
- num_scored = max(0, order_ - 2);
- }
- } else {
- const lm::ngram::State scopy(state);
- p = ngram_->Score(scopy, cur_word, state);
- if (saw_eos) { p = -100; }
- saw_eos = (cur_word == kEOS_);
- }
- has_some_history = true;
- ++num_scored;
- if (!context_complete) {
- if (num_scored >= order_) context_complete = true;
- }
- if (context_complete) {
- sum += p;
- if (oovs && is_oov) (*oovs)++;
- } else {
- if (remnant)
- SetIthUnscoredWord(num_estimated, cur_word, remnant);
- ++num_estimated;
- est_sum += p;
- if (est_oovs && is_oov) (*est_oovs)++;
- }
- }
- saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT);
- if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007
- state = RemnantLMState(astate);
- context_complete = true;
- }
- } else { // handle terminal
- const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[j]); // in future,
+ BoundaryRuleScore<Model> ruleScore(*ngram_, *static_cast<BoundaryAnnotatedState*>(remnant));
+ unsigned i = 0;
+ if (e.size()) {
+ if (e[i] == kCDEC_SOS) {
+ ++i;
+ ruleScore.BeginSentence();
+ } else if (e[i] <= 0) { // special case for left-edge NT
+ ruleScore.BeginNonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[0]]));
+ ++i;
+ }
+ }
+ for (; i < e.size(); ++i) {
+ if (e[i] <= 0) {
+ ruleScore.NonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]]));
+ } else {
+ const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future,
// maybe handle emission
const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id
- double p = 0;
- const bool is_oov = (cur_word == 0);
- if (cur_word == kSOS_) {
- state = ngram_->BeginSentenceState();
- if (has_some_history) { // this is immediately fully scored, and bad
- p = -100;
- context_complete = true;
- } else { // this might be a real <s>
- num_scored = max(0, order_ - 2);
- }
- } else {
- const lm::ngram::State scopy(state);
- p = ngram_->Score(scopy, cur_word, state);
- if (saw_eos) { p = -100; }
- saw_eos = (cur_word == kEOS_);
- }
- has_some_history = true;
- ++num_scored;
- if (!context_complete) {
- if (num_scored >= order_) context_complete = true;
- }
- if (context_complete) {
- sum += p;
- if (oovs && is_oov) (*oovs)++;
- } else {
- if (remnant)
- SetIthUnscoredWord(num_estimated, cur_word, remnant);
- ++num_estimated;
- est_sum += p;
- if (est_oovs && is_oov) (*est_oovs)++;
- }
+ if (cur_word == 0) (*oovs) += 1.0;
+ ruleScore.Terminal(cur_word);
}
}
- if (pest_sum) *pest_sum = est_sum;
- if (remnant) {
- state.ZeroRemaining();
- SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant);
- SetRemnantLMState(state, remnant);
- SetUnscoredSize(num_estimated, remnant);
- SetHasFullContext(context_complete || (num_scored >= order_), remnant);
- }
- return sum;
+ double ret = ruleScore.Finish();
+ static_cast<BoundaryAnnotatedState*>(remnant)->state.ZeroRemaining();
+ return ret;
}
// this assumes no target words on final unary -> goal rule. is that ok?
// for <s> (n-1 left words) and (n-1 right words) </s>
- double FinalTraversalCost(const void* state, double* oovs) {
+ double FinalTraversalCost(const void* state_void, double* oovs) {
+ const BoundaryAnnotatedState &annotated = *static_cast<const BoundaryAnnotatedState*>(state_void);
if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here
- SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_);
- SetHasFullContext(1, dummy_state_);
- SetUnscoredSize(0, dummy_state_);
- dummy_ants_[1] = state;
- *oovs = 0;
- return LookupWords(*dummy_rule_, dummy_ants_, NULL, oovs, NULL, NULL);
+ assert(!annotated.seen_bos);
+ assert(!annotated.seen_eos);
+ lm::ngram::ChartState cstate;
+ lm::ngram::RuleScore<Model> ruleScore(*ngram_, cstate);
+ ruleScore.BeginSentence();
+ ruleScore.NonTerminal(annotated.state, 0.0f);
+ ruleScore.Terminal(kEOS_);
+ return ruleScore.Finish();
} else { // rules DO produce <s> ... </s>
- double p = 0;
- if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; }
- if (UnscoredSize(state) > 0) { // are there unscored words
- if (kSOS_ != IthUnscoredWord(0, state)) {
- p -= 100 * UnscoredSize(state);
- }
- }
- return p;
+ double ret = 0.0;
+ if (!annotated.seen_bos) ret -= 100.0;
+ if (!annotated.seen_eos) ret -= 100.0;
+ return ret;
}
}
@@ -282,6 +223,7 @@ class KLanguageModelImpl {
public:
KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) :
kCDEC_UNK(TD::Convert("<unk>")) ,
+ kCDEC_SOS(TD::Convert("<s>")) ,
add_sos_eos_(!explicit_markers) {
{
VMapper vm(&cdec2klm_map_);
@@ -291,18 +233,9 @@ class KLanguageModelImpl {
}
order_ = ngram_->Order();
cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n";
- state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex);
- unscored_size_offset_ = ngram_->StateSize();
- is_complete_offset_ = unscored_size_offset_ + 1;
- unscored_words_offset_ = is_complete_offset_ + 1;
// special handling of beginning / ending sentence markers
- dummy_state_ = new char[state_size_];
- memset(dummy_state_, 0, state_size_);
- dummy_ants_.push_back(dummy_state_);
- dummy_ants_.push_back(NULL);
- dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0"));
- kSOS_ = MapWord(TD::Convert("<s>"));
+ kSOS_ = MapWord(kCDEC_SOS);
assert(kSOS_ > 0);
kEOS_ = MapWord(TD::Convert("</s>"));
assert(kEOS_ > 0);
@@ -350,13 +283,13 @@ class KLanguageModelImpl {
~KLanguageModelImpl() {
delete ngram_;
- delete[] dummy_state_;
}
- int ReserveStateSize() const { return state_size_; }
+ int ReserveStateSize() const { return sizeof(BoundaryAnnotatedState); }
private:
const WordID kCDEC_UNK;
+ const WordID kCDEC_SOS;
lm::WordIndex kSOS_; // <s> - requires special handling.
lm::WordIndex kEOS_; // </s>
Model* ngram_;
@@ -367,15 +300,8 @@ class KLanguageModelImpl {
// the sentence) with 0, and anything else with -100
int order_;
- int state_size_;
- int unscored_size_offset_;
- int is_complete_offset_;
- int unscored_words_offset_;
- char* dummy_state_;
- vector<const void*> dummy_ants_;
vector<lm::WordIndex> cdec2klm_map_;
vector<WordID> word2class_map_; // if this is a class-based LM, this is the word->class mapping
- TRulePtr dummy_rule_;
};
template <class Model>
@@ -393,7 +319,7 @@ KLanguageModel<Model>::KLanguageModel(const string& param) {
}
fid_ = FD::Convert(featname);
oov_fid_ = FD::Convert(featname+"_OOV");
- cerr << "FID: " << oov_fid_ << endl;
+ // cerr << "FID: " << oov_fid_ << endl;
SetStateSize(pimpl_->ReserveStateSize());
}
@@ -416,13 +342,9 @@ void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* sme
void* state) const {
double est = 0;
double oovs = 0;
- double est_oovs = 0;
- features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, &oovs, &est_oovs, state));
- estimated_features->set_value(fid_, est);
- if (oov_fid_) {
- if (oovs) features->set_value(oov_fid_, oovs);
- if (est_oovs) estimated_features->set_value(oov_fid_, est_oovs);
- }
+ features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state));
+ if (oovs && oov_fid_)
+ features->set_value(oov_fid_, oovs);
}
template <class Model>
@@ -469,3 +391,23 @@ boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string par
std::string KLanguageModelFactory::usage(bool params,bool verbose) const {
return KLanguageModel<lm::ngram::Model>::usage(params, verbose);
}
+
+ switch (m) {
+ case HASH_PROBING:
+ return CreateModel<ProbingModel>(param);
+ case TRIE_SORTED:
+ return CreateModel<TrieModel>(param);
+ case ARRAY_TRIE_SORTED:
+ return CreateModel<ArrayTrieModel>(param);
+ case QUANT_TRIE_SORTED:
+ return CreateModel<QuantTrieModel>(param);
+ case QUANT_ARRAY_TRIE_SORTED:
+ return CreateModel<QuantArrayTrieModel>(param);
+ default:
+ UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m);
+ }
+}
+
+std::string KLanguageModelFactory::usage(bool params,bool verbose) const {
+ return KLanguageModel<lm::ngram::Model>::usage(params, verbose);
+}
diff --git a/decoder/ff_lm.cc b/decoder/ff_lm.cc
index afa36b96..5e16d4e3 100644
--- a/decoder/ff_lm.cc
+++ b/decoder/ff_lm.cc
@@ -46,7 +46,6 @@ char const* usage_verbose="-n determines the name of the feature (and its weight
#endif
#include "ff_lm.h"
-#include "ff_lm_fsa.h"
#include <sstream>
#include <unistd.h>
@@ -69,10 +68,6 @@ char const* usage_verbose="-n determines the name of the feature (and its weight
using namespace std;
-string LanguageModelFsa::usage(bool param,bool verbose) {
- return FeatureFunction::usage_helper("LanguageModelFsa",usage_short,usage_verbose,param,verbose);
-}
-
string LanguageModel::usage(bool param,bool verbose) {
return FeatureFunction::usage_helper(usage_name,usage_short,usage_verbose,param,verbose);
}
@@ -524,49 +519,6 @@ LanguageModel::LanguageModel(const string& param) {
SetStateSize(LanguageModelImpl::OrderToStateSize(order));
}
-//TODO: decide whether to waste a word of space so states are always none-terminated for SRILM. otherwise we have to copy
-void LanguageModelFsa::set_ngram_order(int i) {
- assert(i>0);
- ngram_order_=i;
- ctxlen_=i-1;
- set_state_bytes(ctxlen_*sizeof(WordID));
- WordID *ss=(WordID*)start.begin();
- WordID *hs=(WordID*)h_start.begin();
- if (ctxlen_) { // avoid segfault in case of unigram lm (0 state)
- set_end_phrase(TD::Convert("</s>"));
-// se is pretty boring in unigram case, just adds constant prob. check that this is what we want
- ss[0]=TD::Convert("<s>"); // start-sentence context (length 1)
- hs[0]=0; // empty context
- for (int i=1;i<ctxlen_;++i) {
- ss[i]=hs[i]=0; // need this so storage is initialized for hashing.
- //TODO: reevaluate whether state space comes cleared by allocator or not.
- }
- }
- sync(); // for dynamic markov_order copy etc
-}
-
-LanguageModelFsa::LanguageModelFsa(string const& param) {
- int lmorder;
- pimpl_ = make_lm_impl(param,&lmorder,&fid_);
- Init();
- floor_=pimpl_->floor_;
- set_ngram_order(lmorder);
-}
-
-void LanguageModelFsa::print_state(ostream &o,void const* st) const {
- WordID const *wst=(WordID const*)st;
- o<<'[';
- bool sp=false;
- for (int i=ctxlen_;i>0;sp=true) {
- --i;
- WordID w=wst[i];
- if (w==0) continue;
- if (sp) o<<' ';
- o << TD::Convert(w);
- }
- o<<']';
-}
-
Features LanguageModel::features() const {
return single_feature(fid_);
}
diff --git a/decoder/ff_lm_fsa.h b/decoder/ff_lm_fsa.h
deleted file mode 100755
index 85b7ef44..00000000
--- a/decoder/ff_lm_fsa.h
+++ /dev/null
@@ -1,140 +0,0 @@
-#ifndef FF_LM_FSA_H
-#define FF_LM_FSA_H
-
-//FIXME: when FSA_LM_PHRASE 1, 3gram fsa has differences, especially with unk words, in about the 4th decimal digit (about .05%), compared to regular ff_lm. this is USUALLY a bug (there's way more actual precision in there). this was with #define LM_FSA_SHORTEN_CONTEXT 1 and 0 (so it's not that). also, LM_FSA_SHORTEN_CONTEXT gives identical scores with FSA_LM_PHRASE 0
-
-// enabling for now - retest unigram+ more, solve above puzzle
-
-// some impls in ff_lm.cc
-
-#define FSA_LM_PHRASE 1
-
-#define FSA_LM_DEBUG 0
-#if FSA_LM_DEBUG
-# define FSALMDBG(e,x) FSADBGif(debug(),e,x)
-# define FSALMDBGnl(e) FSADBGif_nl(debug(),e)
-#else
-# define FSALMDBG(e,x)
-# define FSALMDBGnl(e)
-#endif
-
-#include "ff_fsa.h"
-#include "ff_lm.h"
-
-#ifndef TD__none
-// replacing dependency on SRILM
-#define TD__none -1
-#endif
-
-namespace {
-WordID empty_context=TD__none;
-}
-
-struct LanguageModelFsa : public FsaFeatureFunctionBase<LanguageModelFsa> {
- typedef WordID * W;
- typedef WordID const* WP;
-
- // overrides; implementations in ff_lm.cc
- typedef SingleFeatureAccumulator Accum;
- static std::string usage(bool,bool);
- LanguageModelFsa(std::string const& param);
- int markov_order() const { return ctxlen_; }
- void print_state(std::ostream &,void const *) const;
- inline Featval floored(Featval p) const {
- return p<floor_?floor_:p;
- }
- static inline WordID const* left_end(WordID const* left, WordID const* e) {
- for (;e>left;--e)
- if (e[-1]!=TD__none) break;
- //post: [left,e] are the seen left words
- return e;
- }
-
- template <class Accum>
- void ScanAccum(SentenceMetadata const& /* smeta */,Hypergraph::Edge const& edge,WordID w,void const* old_st,void *new_st,Accum *a) const {
-#if USE_INFO_EDGE
- Hypergraph::Edge &de=(Hypergraph::Edge &)edge;
-#endif
- if (!ctxlen_) {
- Add(floored(pimpl_->WordProb(w,&empty_context)),a);
- } else {
- WordID ctx[ngram_order_]; //alloca if you don't have C99
- state_copy(ctx,old_st);
- ctx[ctxlen_]=TD__none;
- Featval p=floored(pimpl_->WordProb(w,ctx));
- FSALMDBG(de,"p("<<TD::Convert(w)<<"|"<<TD::Convert(ctx,ctx+ctxlen_)<<")="<<p);FSALMDBGnl(de);
- // states are srilm contexts so are in reverse order (most recent word is first, then 1-back comes next, etc.).
- WordID *nst=(WordID *)new_st;
- nst[0]=w; // new most recent word
- to_state(nst+1,ctx,ctxlen_-1); // rotate old words right
-#if LM_FSA_SHORTEN_CONTEXT
- p+=pimpl_->ShortenContext(nst,ctxlen_);
-#endif
- Add(p,a);
- }
- }
-
-#if FSA_LM_PHRASE
- //FIXME: there is a bug in here somewhere, or else the 3gram LM we use gives different scores for phrases (impossible? BOW nonzero when shortening context past what LM has?)
- template <class Accum>
- void ScanPhraseAccum(SentenceMetadata const& /* smeta */,const Hypergraph::Edge&edge,WordID const* begin,WordID const* end,void const* old_st,void *new_st,Accum *a) const {
- Hypergraph::Edge &de=(Hypergraph::Edge &)edge;(void)de;
- if (begin==end) return; // otherwise w/ shortening it's possible to end up with no words at all.
- /* // this is forcing unigram prob always. we will instead build the phrase
- if (!ctxlen_) {
- Featval p=0;
- for (;i<end;++i)
- p+=floored(pimpl_->WordProb(*i,e&mpty_context));
- Add(p,a);
- return;
- } */
- int nw=end-begin;
- WP st=(WP)old_st;
- WP st_end=st+ctxlen_; // may include some null already (or none if full)
- int nboth=nw+ctxlen_;
- WordID ctx[nboth+1];
- ctx[nboth]=TD__none;
- // reverse order - state at very end of context, then [i,end) in rev order ending at ctx[0]
- W ctx_score_end=wordcpy_reverse(ctx,begin,end);
- wordcpy(ctx_score_end,st,st_end); // st already reversed.
- assert(ctx_score_end==ctx+nw);
- // we could just copy the filled state words, but it probably doesn't save much time (and might cost some to scan to find the nones. most contexts are full except for the shortest source spans.
- FSALMDBG(de," scan.r->l("<<TD::GetString(ctx,ctx_score_end)<<"|"<<TD::GetString(ctx_score_end,ctx+nboth)<<')');
- Featval p=0;
- FSALMDBGnl(edge);
- for(;ctx_score_end>ctx;--ctx_score_end)
- p+=floored(pimpl_->WordProb(ctx_score_end[-1],ctx_score_end));
- //TODO: look for score discrepancy -
- // i had some idea that maybe shortencontext would return a different prob if the length provided was > ctxlen_; however, since the same disagreement happens with LM_FSA_SHORTEN_CONTEXT 0 anyway, it's not that. perhaps look to SCAN_PHRASE_ACCUM_OVERRIDE - make sure they do the right thing.
-#if LM_FSA_SHORTEN_CONTEXT
- p+=pimpl_->ShortenContext(ctx,nboth<ctxlen_?nboth:ctxlen_);
-#endif
- state_copy(new_st,ctx);
- FSALMDBG(de," lm.Scan("<<TD::GetString(begin,end)<<"|"<<describe_state(old_st)<<")"<<"="<<p<<","<<describe_state(new_st));
- FSALMDBGnl(edge);
- Add(p,a);
- }
-
- SCAN_PHRASE_ACCUM_OVERRIDE
-#endif
-
- // impl details:
- void set_ngram_order(int i); // if you build ff_from_fsa first, then increase this, you will get memory overflows. otherwise, it's the same as a "-o i" argument to constructor
- // note: if you adjust ngram_order, ff_from_fsa won't notice.
-
- double floor_; // log10prob minimum used (e.g. unk words)
-
- // because we might have a custom fid due to lm name option:
- void Init() {
- InitHaveFid();
- }
-
-private:
- int ngram_order_;
- int ctxlen_; // 1 less than above
- LanguageModelInterface *pimpl_;
-
-};
-
-
-#endif
diff --git a/decoder/ff_register.h b/decoder/ff_register.h
index eff23537..80b1457e 100755
--- a/decoder/ff_register.h
+++ b/decoder/ff_register.h
@@ -2,50 +2,12 @@
#define FF_FSA_REGISTER_H
#include "ff_factory.h"
-#include "ff_from_fsa.h"
-#include "ff_fsa_dynamic.h"
-
-inline std::string prefix_fsa(std::string const& name,bool fsa_prefix_ff) {
- return fsa_prefix_ff ? "Fsa"+name : name;
-}
-
-//FIXME: problem with FeatureFunctionFromFsa<FsaFeatureFunction> - need to use factory rather than ctor.
-#if 0
-template <class DynFsa>
-inline void RegisterFsa(bool ff_also=true,bool fsa_prefix_ff=true) {
- assert(!ff_also);
-// global_fsa_ff_registry->RegisterFsa<DynFsa>();
-//if (ff_also) ff_registry.RegisterFF<FeatureFunctionFromFsa<DynFsa> >(prefix_fsa(DynFsa::usage(false,false)),fsa_prefix_ff);
-}
-#endif
-
-//TODO: ff from fsa that uses pointer to fsa impl? e.g. in LanguageModel we share underlying lm file by recognizing same param, but without that effort, otherwise stateful ff may duplicate state if we enable both fsa and ff_from_fsa
-template <class FsaImpl>
-inline void RegisterFsaImpl(bool ff_also=true,bool fsa_prefix_ff=false) {
- typedef FsaFeatureFunctionDynamic<FsaImpl> DynFsa;
- typedef FeatureFunctionFromFsa<FsaImpl> FFFrom;
- std::string name=FsaImpl::usage(false,false);
- fsa_ff_registry.Register(new FsaFactory<DynFsa>);
- if (ff_also)
- ff_registry.Register(prefix_fsa(name,fsa_prefix_ff),new FFFactory<FFFrom>);
-}
template <class Impl>
inline void RegisterFF() {
ff_registry.Register(new FFFactory<Impl>);
}
-template <class FsaImpl>
-inline void RegisterFsaDynToFF(std::string name,bool prefix=true) {
- typedef FsaFeatureFunctionDynamic<FsaImpl> DynFsa;
- ff_registry.Register(prefix?"DynamicFsa"+name:name,new FFFactory<FeatureFunctionFromFsa<DynFsa> >);
-}
-
-template <class FsaImpl>
-inline void RegisterFsaDynToFF(bool prefix=true) {
- RegisterFsaDynToFF<FsaImpl>(FsaImpl::usage(false,false),prefix);
-}
-
void register_feature_functions();
#endif
diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc
new file mode 100644
index 00000000..035132b4
--- /dev/null
+++ b/decoder/ff_source_syntax.cc
@@ -0,0 +1,232 @@
+#include "ff_source_syntax.h"
+
+#include <sstream>
+#include <stack>
+
+#include "sentence_metadata.h"
+#include "array2d.h"
+#include "filelib.h"
+
+using namespace std;
+
+// implements the source side syntax features described in Blunsom et al. (EMNLP 2008)
+// source trees must be represented in Penn Treebank format, e.g.
+// (S (NP John) (VP (V left)))
+
+// log transform to make long spans cluster together
+// but preserve differences
+inline int SpanSizeTransform(unsigned span_size) {
+ if (!span_size) return 0;
+ return static_cast<int>(log(span_size+1) / log(1.39)) - 1;
+}
+
+struct SourceSyntaxFeaturesImpl {
+ SourceSyntaxFeaturesImpl() {}
+
+ void InitializeGrids(const string& tree, unsigned src_len) {
+ assert(tree.size() > 0);
+ //fids_cat.clear();
+ fids_ef.clear();
+ src_tree.clear();
+ //fids_cat.resize(src_len, src_len + 1);
+ fids_ef.resize(src_len, src_len + 1);
+ src_tree.resize(src_len, src_len + 1, TD::Convert("XX"));
+ ParseTreeString(tree, src_len);
+ }
+
+ void ParseTreeString(const string& tree, unsigned src_len) {
+ stack<pair<int, WordID> > stk; // first = i, second = category
+ pair<int, WordID> cur_cat; cur_cat.first = -1;
+ unsigned i = 0;
+ unsigned p = 0;
+ while(p < tree.size()) {
+ const char cur = tree[p];
+ if (cur == '(') {
+ stk.push(cur_cat);
+ ++p;
+ unsigned k = p + 1;
+ while (k < tree.size() && tree[k] != ' ') { ++k; }
+ cur_cat.first = i;
+ cur_cat.second = TD::Convert(tree.substr(p, k - p));
+ // cerr << "NT: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n";
+ p = k + 1;
+ } else if (cur == ')') {
+ unsigned k = p;
+ while (k < tree.size() && tree[k] == ')') { ++k; }
+ const unsigned num_closes = k - p;
+ for (unsigned ci = 0; ci < num_closes; ++ci) {
+ // cur_cat.second spans from cur_cat.first to i
+ // cerr << TD::Convert(cur_cat.second) << " from " << cur_cat.first << " to " << i << endl;
+ // NOTE: unary rule chains end up being labeled with the top-most category
+ src_tree(cur_cat.first, i) = cur_cat.second;
+ cur_cat = stk.top();
+ stk.pop();
+ }
+ p = k;
+ while (p < tree.size() && (tree[p] == ' ' || tree[p] == '\t')) { ++p; }
+ } else if (cur == ' ' || cur == '\t') {
+ cerr << "Unexpected whitespace in: " << tree << endl;
+ abort();
+ } else { // terminal symbol
+ unsigned k = p + 1;
+ do {
+ while (k < tree.size() && tree[k] != ')' && tree[k] != ' ') { ++k; }
+ // cerr << "TERM: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n";
+ ++i;
+ assert(i <= src_len);
+ while (k < tree.size() && tree[k] == ' ') { ++k; }
+ p = k;
+ } while (p < tree.size() && tree[p] != ')');
+ }
+ }
+ // cerr << "i=" << i << " src_len=" << src_len << endl;
+ assert(i == src_len); // make sure tree specified in src_tree is
+ // the same length as the source sentence
+ }
+
+ WordID FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector<double>* feats) {
+ //cerr << "fire features: " << rule.AsString() << " for " << i << "," << j << endl;
+ const WordID lhs = src_tree(i,j);
+ //int& fid_cat = fids_cat(i,j);
+ int& fid_ef = fids_ef(i,j)[&rule];
+ if (fid_ef <= 0) {
+ ostringstream os;
+ //ostringstream os2;
+ os << "SYN:" << TD::Convert(lhs);
+ //os2 << "SYN:" << TD::Convert(lhs) << '_' << SpanSizeTransform(j - i);
+ //fid_cat = FD::Convert(os2.str());
+ os << ':';
+ unsigned ntc = 0;
+ for (unsigned k = 0; k < rule.f_.size(); ++k) {
+ if (k > 0) os << '_';
+ int fj = rule.f_[k];
+ if (fj <= 0) {
+ os << '[' << TD::Convert(ants[ntc++]) << ']';
+ } else {
+ os << TD::Convert(fj);
+ }
+ }
+ os << ':';
+ for (unsigned k = 0; k < rule.e_.size(); ++k) {
+ const int ei = rule.e_[k];
+ if (k > 0) os << '_';
+ if (ei <= 0)
+ os << '[' << (1-ei) << ']';
+ else
+ os << TD::Convert(ei);
+ }
+ fid_ef = FD::Convert(os.str());
+ }
+ //if (fid_cat > 0)
+ // feats->set_value(fid_cat, 1.0);
+ if (fid_ef > 0)
+ feats->set_value(fid_ef, 1.0);
+ return lhs;
+ }
+
+ Array2D<WordID> src_tree; // src_tree(i,j) NT = type
+ // mutable Array2D<int> fids_cat; // this tends to overfit baddly
+ mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
+};
+
+SourceSyntaxFeatures::SourceSyntaxFeatures(const string& param) :
+ FeatureFunction(sizeof(WordID)) {
+ impl = new SourceSyntaxFeaturesImpl;
+}
+
+SourceSyntaxFeatures::~SourceSyntaxFeatures() {
+ delete impl;
+ impl = NULL;
+}
+
+void SourceSyntaxFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ WordID ants[8];
+ for (unsigned i = 0; i < ant_contexts.size(); ++i)
+ ants[i] = *static_cast<const WordID*>(ant_contexts[i]);
+
+ *static_cast<WordID*>(context) =
+ impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features);
+}
+
+void SourceSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+ impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength());
+}
+
+struct SourceSpanSizeFeaturesImpl {
+ SourceSpanSizeFeaturesImpl() {}
+
+ void InitializeGrids(unsigned src_len) {
+ fids.clear();
+ fids.resize(src_len, src_len + 1);
+ }
+
+ int FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector<double>* feats) {
+ if (rule.Arity() > 0) {
+ int& fid = fids(i,j)[&rule];
+ if (fid <= 0) {
+ ostringstream os;
+ os << "SSS:";
+ unsigned ntc = 0;
+ for (unsigned k = 0; k < rule.f_.size(); ++k) {
+ if (k > 0) os << '_';
+ int fj = rule.f_[k];
+ if (fj <= 0) {
+ os << '[' << TD::Convert(-fj) << ants[ntc++] << ']';
+ } else {
+ os << TD::Convert(fj);
+ }
+ }
+ os << ':';
+ for (unsigned k = 0; k < rule.e_.size(); ++k) {
+ const int ei = rule.e_[k];
+ if (k > 0) os << '_';
+ if (ei <= 0)
+ os << '[' << (1-ei) << ']';
+ else
+ os << TD::Convert(ei);
+ }
+ fid = FD::Convert(os.str());
+ }
+ if (fid > 0)
+ feats->set_value(fid, 1.0);
+ }
+ return SpanSizeTransform(j - i);
+ }
+
+ mutable Array2D<map<const TRule*, int> > fids;
+};
+
+SourceSpanSizeFeatures::SourceSpanSizeFeatures(const string& param) :
+ FeatureFunction(sizeof(char)) {
+ impl = new SourceSpanSizeFeaturesImpl;
+}
+
+SourceSpanSizeFeatures::~SourceSpanSizeFeatures() {
+ delete impl;
+ impl = NULL;
+}
+
+void SourceSpanSizeFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ int ants[8];
+ for (unsigned i = 0; i < ant_contexts.size(); ++i)
+ ants[i] = *static_cast<const char*>(ant_contexts[i]);
+
+ *static_cast<char*>(context) =
+ impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features);
+}
+
+void SourceSpanSizeFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+ impl->InitializeGrids(smeta.GetSourceLength());
+}
+
+
diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h
new file mode 100644
index 00000000..279563e1
--- /dev/null
+++ b/decoder/ff_source_syntax.h
@@ -0,0 +1,41 @@
+#ifndef _FF_SOURCE_TOOLS_H_
+#define _FF_SOURCE_TOOLS_H_
+
+#include "ff.h"
+
+struct SourceSyntaxFeaturesImpl;
+
+class SourceSyntaxFeatures : public FeatureFunction {
+ public:
+ SourceSyntaxFeatures(const std::string& param);
+ ~SourceSyntaxFeatures();
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ private:
+ SourceSyntaxFeaturesImpl* impl;
+};
+
+struct SourceSpanSizeFeaturesImpl;
+class SourceSpanSizeFeatures : public FeatureFunction {
+ public:
+ SourceSpanSizeFeatures(const std::string& param);
+ ~SourceSpanSizeFeatures();
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ private:
+ SourceSpanSizeFeaturesImpl* impl;
+};
+
+#endif
diff --git a/decoder/grammar_test.cc b/decoder/grammar_test.cc
index 62b8f958..cde00efa 100644
--- a/decoder/grammar_test.cc
+++ b/decoder/grammar_test.cc
@@ -15,12 +15,12 @@ using namespace std;
class GrammarTest : public testing::Test {
public:
GrammarTest() {
- wts.InitFromFile("test_data/weights.gt");
+ Weights::InitFromFile("test_data/weights.gt", &wts);
}
protected:
virtual void SetUp() { }
virtual void TearDown() { }
- Weights wts;
+ vector<weight_t> wts;
};
TEST_F(GrammarTest,TestTextGrammar) {
diff --git a/decoder/hg.cc b/decoder/hg.cc
index 3ad17f1a..180986d7 100644
--- a/decoder/hg.cc
+++ b/decoder/hg.cc
@@ -157,14 +157,14 @@ prob_t Hypergraph::ComputeEdgePosteriors(double scale, vector<prob_t>* posts) co
const ScaledEdgeProb weight(scale);
const ScaledTransitionEventWeightFunction w2(scale);
SparseVector<prob_t> pv;
- const double inside = InsideOutside<prob_t,
+ const prob_t inside = InsideOutside<prob_t,
ScaledEdgeProb,
SparseVector<prob_t>,
ScaledTransitionEventWeightFunction>(*this, &pv, weight, w2);
posts->resize(edges_.size());
for (int i = 0; i < edges_.size(); ++i)
(*posts)[i] = prob_t(pv.value(i));
- return prob_t(inside);
+ return inside;
}
prob_t Hypergraph::ComputeBestPathThroughEdges(vector<prob_t>* post) const {
diff --git a/decoder/hg.h b/decoder/hg.h
index 70bc4995..52a18601 100644
--- a/decoder/hg.h
+++ b/decoder/hg.h
@@ -49,16 +49,14 @@ public:
// TODO get rid of cat_?
// TODO keep cat_ and add span and/or state? :)
struct Node {
- Node() : id_(), cat_(), promise(1) {}
+ Node() : id_(), cat_() {}
int id_; // equal to this object's position in the nodes_ vector
WordID cat_; // non-terminal category if <0, 0 if not set
WordID NT() const { return -cat_; }
EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_
EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_
- double promise; // set in global pruning; in [0,infty) so that mean is 1. use: e.g. scale cube poplimit. //TODO: appears to be useless, compile without this? on the other hand, pretty cheap.
void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting
cat_=o.cat_;
- promise=o.promise;
}
void copy_reindex(Node const& o,indices_after const& n2,indices_after const& e2) {
copy_fixed(o);
@@ -81,7 +79,7 @@ public:
int head_node_; // refers to a position in nodes_
TailNodeVector tail_nodes_; // contents refer to positions in nodes_
TRulePtr rule_;
- FeatureVector feature_values_;
+ SparseVector<weight_t> feature_values_;
prob_t edge_prob_; // dot product of weights and feat_values
int id_; // equal to this object's position in the edges_ vector
@@ -470,7 +468,7 @@ public:
/// drop edge i if edge_margin[i] < prune_below, unless preserve_mask[i]
void MarginPrune(EdgeProbs const& edge_margin,prob_t prune_below,EdgeMask const* preserve_mask=0,bool safe_inside=false,bool verbose=false);
- //TODO: in my opinion, looking at the ratio of logprobs (features \dot weights) rather than the absolute difference generalizes more nicely across sentence lengths and weight vectors that are constant multiples of one another. at least make that an option. i worked around this a little in cdec by making "beam alpha per source word" but that's not helping with different tuning runs. this would also make me more comfortable about allocating Node.promise
+ //TODO: in my opinion, looking at the ratio of logprobs (features \dot weights) rather than the absolute difference generalizes more nicely across sentence lengths and weight vectors that are constant multiples of one another. at least make that an option. i worked around this a little in cdec by making "beam alpha per source word" but that's not helping with different tuning runs.
// beam_alpha=0 means don't beam prune, otherwise drop things that are e^beam_alpha times worse than best - // prunes any edge whose prob_t on the best path taking that edge is more than e^alpha times
//density=0 means don't density prune: // for density>=1.0, keep this many times the edges needed for the 1best derivation
diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc
index 3be5b82d..5d1910fb 100644
--- a/decoder/hg_test.cc
+++ b/decoder/hg_test.cc
@@ -57,7 +57,7 @@ TEST_F(HGTest,Union) {
c3 = ViterbiESentence(hg1, &t3);
int l3 = ViterbiPathLength(hg1);
cerr << c3 << "\t" << TD::GetString(t3) << endl;
- EXPECT_FLOAT_EQ(c2, c3);
+ EXPECT_FLOAT_EQ(c2.as_float(), c3.as_float());
EXPECT_EQ(TD::GetString(t2), TD::GetString(t3));
EXPECT_EQ(l2, l3);
@@ -117,7 +117,7 @@ TEST_F(HGTest,InsideScore) {
cerr << "cost: " << cost << "\n";
hg.PrintGraphviz();
prob_t inside = Inside<prob_t, EdgeProb>(hg);
- EXPECT_FLOAT_EQ(1.7934048, inside); // computed by hand
+ EXPECT_FLOAT_EQ(1.7934048, inside.as_float()); // computed by hand
vector<prob_t> post;
inside = hg.ComputeBestPathThroughEdges(&post);
EXPECT_FLOAT_EQ(-0.3, log(inside)); // computed by hand
@@ -282,13 +282,13 @@ TEST_F(HGTest, TestGenericInside) {
hg.Reweight(wts);
vector<prob_t> inside;
prob_t ins = Inside<prob_t, EdgeProb>(hg, &inside);
- EXPECT_FLOAT_EQ(1.7934048, ins); // computed by hand
+ EXPECT_FLOAT_EQ(1.7934048, ins.as_float()); // computed by hand
vector<prob_t> outside;
Outside<prob_t, EdgeProb>(hg, inside, &outside);
EXPECT_EQ(3, outside.size());
- EXPECT_FLOAT_EQ(1.7934048, outside[0]);
- EXPECT_FLOAT_EQ(1.3114071, outside[1]);
- EXPECT_FLOAT_EQ(1.0, outside[2]);
+ EXPECT_FLOAT_EQ(1.7934048, outside[0].as_float());
+ EXPECT_FLOAT_EQ(1.3114071, outside[1].as_float());
+ EXPECT_FLOAT_EQ(1.0, outside[2].as_float());
}
TEST_F(HGTest,TestGenericInside2) {
@@ -327,8 +327,8 @@ TEST_F(HGTest,TestAddExpectations) {
SparseVector<prob_t> feat_exps;
prob_t z = InsideOutside<prob_t, EdgeProb,
SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(hg, &feat_exps);
- EXPECT_FLOAT_EQ(-2.5439765, feat_exps.value(FD::Convert("f1")) / z);
- EXPECT_FLOAT_EQ(-2.6357865, feat_exps.value(FD::Convert("f2")) / z);
+ EXPECT_FLOAT_EQ(-2.5439765, (feat_exps.value(FD::Convert("f1")) / z).as_float());
+ EXPECT_FLOAT_EQ(-2.6357865, (feat_exps.value(FD::Convert("f2")) / z).as_float());
cerr << feat_exps << endl;
cerr << "Z=" << z << endl;
}
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 15d48588..b603e27a 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -272,23 +272,31 @@ struct OracleBleu {
}
kbest_out<<endl<<flush;
if (show_derivation) {
- deriv_out<<"\nsent_id="<<sent_id<<"\n";
+ deriv_out<<"\nsent_id="<<sent_id<<"."<<i<<" ||| "; //where i is candidate #/k
+ deriv_out<<log(d->score)<<"\n";
deriv_out<<kbest.derivation_tree(*d,true);
- deriv_out<<flush;
+ deriv_out<<"\n"<<flush;
}
}
}
// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
- void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_) {
+ void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_, std::string const &deriv_out_filename_) {
WriteFile ko(kbest_out_filename_);
- std::cerr << "Output kbest to " << kbest_out_filename_<<std::endl;
+ std::cerr << "Output kbest to " << kbest_out_filename_ <<std::endl;
+ std::ostringstream sderiv;
+ sderiv << deriv_out_filename_;
+ if (show_derivation) {
+ sderiv << "/derivs." << sent_id;
+ std::cerr << "Output derivations to " << deriv_out_filename_ << std::endl;
+ }
+ WriteFile oderiv(sderiv.str());
if (!unique)
- kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),std::cerr);
+ kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),oderiv.get());
else {
- kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),std::cerr);
+ kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),oderiv.get());
}
}
@@ -296,7 +304,7 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo
{
std::ostringstream kbest_string_stream;
kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id;
- DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str());
+ DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), "-");
}
};
diff --git a/decoder/rule_lexer.l b/decoder/rule_lexer.l
index 9331d8ed..083a5bb1 100644
--- a/decoder/rule_lexer.l
+++ b/decoder/rule_lexer.l
@@ -220,6 +220,8 @@ NT [^\t \[\],]+
std::cerr << "Line " << lex_line << ": LHS and RHS arity mismatch!\n";
abort();
}
+ // const bool ignore_grammar_features = false;
+ // if (ignore_grammar_features) scfglex_num_feats = 0;
TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity, scfglex_als, scfglex_num_als));
check_and_update_ctf_stack(rp);
TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top());
diff --git a/decoder/trule.h b/decoder/trule.h
index 4df4ec90..8eb2a059 100644
--- a/decoder/trule.h
+++ b/decoder/trule.h
@@ -5,7 +5,9 @@
#include <vector>
#include <cassert>
#include <iostream>
-#include <boost/shared_ptr.hpp>
+
+#include "boost/shared_ptr.hpp"
+#include "boost/functional/hash.hpp"
#include "sparse_vector.h"
#include "wordid.h"
@@ -162,4 +164,15 @@ class TRule {
bool SanityCheck() const;
};
+inline size_t hash_value(const TRule& r) {
+ size_t h = boost::hash_value(r.e_);
+ boost::hash_combine(h, -r.lhs_);
+ boost::hash_combine(h, boost::hash_value(r.f_));
+ return h;
+}
+
+inline bool operator==(const TRule& a, const TRule& b) {
+ return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_);
+}
+
#endif