summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorgraehl@gmail.com <graehl@gmail.com@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-10 10:02:04 +0000
committergraehl@gmail.com <graehl@gmail.com@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-10 10:02:04 +0000
commit32154b45828f05add1db7c89752ef4220c0fdf16 (patch)
treefa99e4d4847a89d41b464e9ae3c9aacf611e5500 /decoder
parent43db0573b15719d48b89b3a1ad2828036d008560 (diff)
cdec --cfg_output=-
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@499 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder')
-rwxr-xr-xdecoder/apply_fsa_models.cc55
-rwxr-xr-xdecoder/apply_fsa_models.h48
-rw-r--r--decoder/cdec.cc22
-rwxr-xr-xdecoder/cfg.cc37
-rwxr-xr-xdecoder/cfg.h24
-rwxr-xr-xdecoder/cfg_format.h74
-rw-r--r--decoder/filelib.h1
-rwxr-xr-xdecoder/fsa-hiero.ini1
-rw-r--r--decoder/grammar.cc8
-rw-r--r--decoder/hg.h46
-rwxr-xr-xdecoder/nt_span.h30
-rwxr-xr-xdecoder/oracle_bleu.h4
12 files changed, 306 insertions, 44 deletions
diff --git a/decoder/apply_fsa_models.cc b/decoder/apply_fsa_models.cc
index 416b9323..31b2002d 100755
--- a/decoder/apply_fsa_models.cc
+++ b/decoder/apply_fsa_models.cc
@@ -12,57 +12,66 @@
using namespace std;
struct ApplyFsa {
- ApplyFsa(const Hypergraph& ih,
+ ApplyFsa(HgCFG &i,
const SentenceMetadata& smeta,
const FsaFeatureFunction& fsa,
DenseWeightVector const& weights,
ApplyFsaBy const& by,
- Hypergraph* oh)
- :ih(ih),smeta(smeta),fsa(fsa),weights(weights),by(by),oh(oh)
+ Hypergraph* oh
+ )
+ :hgcfg(i),smeta(smeta),fsa(fsa),weights(weights),by(by),oh(oh)
{
-// sparse_to_dense(weight_vector,&weights);
- Init();
}
- void Init() {
+ void Compute() {
if (by.IsBottomUp())
ApplyBottomUp();
else
ApplyEarley();
}
- void ApplyBottomUp() {
- assert(by.IsBottomUp());
- FeatureFunctionFromFsa<FsaFeatureFunctionFwd> buff(&fsa);
- buff.Init(); // mandatory to call this (normally factory would do it)
- vector<const FeatureFunction*> ffs(1,&buff);
- ModelSet models(weights, ffs);
- IntersectionConfiguration i(by.BottomUpAlgorithm(),by.pop_limit);
- ApplyModelSet(ih,smeta,models,i,oh);
- }
- void ApplyEarley() {
- CFG cfg(ih,true,false,true);
- }
+ void ApplyBottomUp();
+ void ApplyEarley();
+ CFG const& GetCFG();
private:
- const Hypergraph& ih;
+ CFG cfg;
+ HgCFG &hgcfg;
const SentenceMetadata& smeta;
const FsaFeatureFunction& fsa;
// WeightVector weight_vector;
DenseWeightVector weights;
ApplyFsaBy by;
Hypergraph* oh;
+ std::string cfg_out;
};
+void ApplyFsa::ApplyBottomUp()
+{
+ assert(by.IsBottomUp());
+ FeatureFunctionFromFsa<FsaFeatureFunctionFwd> buff(&fsa);
+ buff.Init(); // mandatory to call this (normally factory would do it)
+ vector<const FeatureFunction*> ffs(1,&buff);
+ ModelSet models(weights, ffs);
+ IntersectionConfiguration i(by.BottomUpAlgorithm(),by.pop_limit);
+ ApplyModelSet(hgcfg.ih,smeta,models,i,oh);
+}
-void ApplyFsaModels(const Hypergraph& ih,
+void ApplyFsa::ApplyEarley()
+{
+ hgcfg.GiveCFG(cfg);
+ //TODO:
+}
+
+
+void ApplyFsaModels(HgCFG &i,
const SentenceMetadata& smeta,
const FsaFeatureFunction& fsa,
DenseWeightVector const& weight_vector,
ApplyFsaBy const& by,
Hypergraph* oh)
{
- ApplyFsa a(ih,smeta,fsa,weight_vector,by,oh);
+ ApplyFsa a(i,smeta,fsa,weight_vector,by,oh);
+ a.Compute();
}
-
namespace {
char const* anames[]={
"BU_CUBE",
@@ -88,7 +97,7 @@ std::string ApplyFsaBy::all_names() {
return o.str();
}
-ApplyFsaBy::ApplyFsaBy(std::string const& n, int pop_limit) : pop_limit(pop_limit){
+ApplyFsaBy::ApplyFsaBy(std::string const& n, int pop_limit) : pop_limit(pop_limit) {
algorithm=0;
std::string uname=toupper(n);
while(anames[algorithm] && anames[algorithm] != uname) ++algorithm;
diff --git a/decoder/apply_fsa_models.h b/decoder/apply_fsa_models.h
index 3dce5e82..0227664a 100755
--- a/decoder/apply_fsa_models.h
+++ b/decoder/apply_fsa_models.h
@@ -1,8 +1,10 @@
#ifndef _APPLY_FSA_MODELS_H_
#define _APPLY_FSA_MODELS_H_
+#include <string>
#include <iostream>
#include "feature_vector.h"
+#include "cfg.h"
struct FsaFeatureFunction;
struct Hypergraph;
@@ -34,12 +36,56 @@ struct ApplyFsaBy {
static std::string all_names(); // space separated
};
+// in case you might want the CFG whether or not you apply FSA models:
+struct HgCFG {
+ HgCFG(Hypergraph const& ih) : ih(ih) { have_cfg=false; }
+ Hypergraph const& ih;
+ CFG cfg;
+ bool have_cfg;
+ void InitCFG(CFG &to) {
+ to.Init(ih,true,false,true);
+ }
+
+ CFG &GetCFG()
+ {
+ if (!have_cfg) {
+ have_cfg=true;
+ InitCFG(cfg);
+ }
+ return cfg;
+ }
+ void GiveCFG(CFG &to) {
+ if (!have_cfg)
+ InitCFG(to);
+ else {
+ have_cfg=false;
+ to.Clear();
+ to.Swap(cfg);
+ }
+ }
+ CFG const& GetCFG() const {
+ assert(have_cfg);
+ return cfg;
+ }
+};
+
-void ApplyFsaModels(const Hypergraph& in,
+void ApplyFsaModels(HgCFG &hg_or_cfg_in,
const SentenceMetadata& smeta,
const FsaFeatureFunction& fsa,
DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this)
ApplyFsaBy const& cfg,
Hypergraph* out);
+inline void ApplyFsaModels(Hypergraph const& ih,
+ const SentenceMetadata& smeta,
+ const FsaFeatureFunction& fsa,
+ DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this)
+ ApplyFsaBy const& cfg,
+ Hypergraph* out) {
+ HgCFG i(ih);
+ ApplyFsaModels(i,smeta,fsa,weights,cfg,out);
+}
+
+
#endif
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index 29070a69..a179e029 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -36,6 +36,8 @@
#include "sentence_metadata.h"
#include "../vest/scorer.h"
#include "apply_fsa_models.h"
+#include "program_options.h"
+#include "cfg_format.h"
using namespace std;
using namespace std::tr1;
@@ -105,6 +107,8 @@ void print_options(std::ostream &out,po::options_description const& opts) {
}
+CFGFormat cfgf;
+
void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* confp) {
po::variables_map &conf=*confp;
po::options_description opts("Configuration options");
@@ -168,6 +172,10 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c
("forest_output,O",po::value<string>(),"Directory to write forests to")
("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc.");
ob.AddOptions(&opts);
+ po::options_description cfgo("CFG output options");
+ cfgo.add_options()
+ ("cfg_output", po::value<string>(),"write final target CFG (before FSA rescorinn) to this file");
+ cfgf.AddOptions(&cfgo);
po::options_description clo("Command line options");
clo.add_options()
("config,c", po::value<vector<string> >(), "Configuration file(s) - latest has priority")
@@ -177,8 +185,9 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c
;
po::options_description dconfig_options, dcmdline_options;
- dconfig_options.add(opts);
- dcmdline_options.add(opts).add(clo);
+ dconfig_options.add(opts).add(cfgo);
+ //add(opts).add(cfgo)
+ dcmdline_options.add(dconfig_options).add(clo);
po::store(parse_command_line(argc, argv, dcmdline_options), conf);
if (conf.count("compgen")) {
@@ -653,7 +662,11 @@ int main(int argc, char** argv) {
maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen);
-
+ HgCFG hgcfg(forest);
+ if (conf.count("cfg_output")) {
+ WriteFile o(str("cfg_output",conf));
+ hgcfg.GetCFG().Print(o.get(),cfgf);
+ }
if (!fsa_ffs.empty()) {
Timer t("Target FSA rescoring:");
if (!has_late_models)
@@ -662,13 +675,12 @@ int main(int argc, char** argv) {
assert(fsa_ffs.size()==1);
ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit);
cerr << "FSA rescoring with "<<cfg<<" "<<fsa_ffs[0]->describe()<<endl;
- ApplyFsaModels(forest,smeta,*fsa_ffs[0],feature_weights,cfg,&fsa_forest);
+ ApplyFsaModels(hgcfg,smeta,*fsa_ffs[0],feature_weights,cfg,&fsa_forest);
forest.swap(fsa_forest);
forest.Reweight(feature_weights);
forest_stats(forest," +FSA forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
}
-
/*Oracle Rescoring*/
if(get_oracle_forest) {
Oracle o=oracle.ComputeOracle(smeta,&forest,FeatureVector(feature_weights),10,conf["forest_output"].as<std::string>());
diff --git a/decoder/cfg.cc b/decoder/cfg.cc
index b83fc54d..ace0ebb0 100755
--- a/decoder/cfg.cc
+++ b/decoder/cfg.cc
@@ -5,6 +5,7 @@
using namespace std;
void CFG::Init(Hypergraph const& hg,bool target_side,bool copy_features,bool push_weights) {
+ uninit=false;
hg_=&hg;
Hypergraph::NodeProbs np;
goal_inside=hg.ComputeNodeViterbi(&np);
@@ -13,8 +14,10 @@ void CFG::Init(Hypergraph const& hg,bool target_side,bool copy_features,bool pus
nts.resize(nn);
goal_nt=nn-1;
rules.resize(ne);
- for (int i=0;i<nn;++i)
+ for (int i=0;i<nn;++i) {
nts[i].ruleids=hg.nodes_[i].in_edges_;
+ hg.SetNodeOrigin(i,nts[i].from);
+ }
for (int i=0;i<ne;++i) {
Rule &cfgr=rules[i];
Hypergraph::Edge const& e=hg.edges_[i];
@@ -43,12 +46,34 @@ void CFG::Init(Hypergraph const& hg,bool target_side,bool copy_features,bool pus
}
}
-namespace {
+void CFG::Clear() {
+ rules.clear();
+ nts.clear();
+ goal_nt=-1;
+ hg_=0;
+}
+
+void CFG::PrintRule(std::ostream &o,RuleHandle rulei,CFGFormat const& f) const {
+ Rule const& r=rules[rulei];
+ f.print_lhs(o,*this,r.lhs);
+ f.print_rhs(o,*this,r.rhs.begin(),r.rhs.end());
+ f.print_features(o,r.p,r.f);
}
void CFG::Print(std::ostream &o,CFGFormat const& f) const {
- char const* partsep=" ||| ";
- if (!f.goal_nt_name.empty())
- o << '['<<f.goal_nt_name <<']' << partsep; // print rhs
- //TODO:
+ assert(!uninit);
+ if (!f.goal_nt_name.empty()) {
+ o << '['<<f.goal_nt_name <<']';
+ f.print_rhs(o,*this,&goal_nt,&goal_nt+1);
+ if (pushed_inside!=1)
+ f.print_features(o,pushed_inside);
+ o<<'\n';
+ }
+ for (int i=0;i<nts.size();++i) {
+ Ruleids const& ntr=nts[i].ruleids;
+ for (Ruleids::const_iterator j=ntr.begin(),jj=ntr.end();j!=jj;++j) {
+ PrintRule(o,*j,f);
+ o<<'\n';
+ }
+ }
}
diff --git a/decoder/cfg.h b/decoder/cfg.h
index 8d7a5eee..e325c4cd 100755
--- a/decoder/cfg.h
+++ b/decoder/cfg.h
@@ -25,6 +25,7 @@
#include "prob.h"
//#include "int_or_pointer.h"
#include "small_vector.h"
+#include "nt_span.h"
class Hypergraph;
class CFGFormat; // #include "cfg_format.h"
@@ -35,6 +36,10 @@ struct CFG {
typedef SmallVector<WordID> RHS; // same as in trule rhs: >0 means token, <=0 means -node index (not variable index)
typedef std::vector<RuleHandle> Ruleids;
+ void print_nt_name(std::ostream &o,NTHandle n) const {
+ o << nts[n].from;
+ }
+
struct Rule {
int lhs; // index into nts
RHS rhs;
@@ -47,17 +52,33 @@ struct CFG {
struct NT {
Ruleids ruleids; // index into CFG rules with lhs = this NT. aka in_edges_
+ NTSpan from; // optional name - still needs id to disambiguate
};
- CFG() : hg_() { }
+ CFG() : hg_() { uninit=true; }
// provided hg will have weights pushed up to root
CFG(Hypergraph const& hg,bool target_side=true,bool copy_features=false,bool push_weights=true) {
Init(hg,target_side,copy_features,push_weights);
}
+ bool Uninitialized() const { return uninit; }
+ void Clear();
+ bool Empty() const { return nts.empty(); }
void Init(Hypergraph const& hg,bool target_side=true,bool copy_features=false,bool push_weights=true);
void Print(std::ostream &o,CFGFormat const& format) const; // see cfg_format.h
+ void PrintRule(std::ostream &o,RuleHandle rulei,CFGFormat const& format) const;
+ void Swap(CFG &o) { // make sure this includes all fields (easier to see here than in .cc)
+ using namespace std;
+ swap(uninit,o.uninit);
+ swap(hg_,o.hg_);
+ swap(goal_inside,o.goal_inside);
+ swap(pushed_inside,o.pushed_inside);
+ swap(rules,o.rules);
+ swap(nts,o.nts);
+ swap(goal_nt,o.goal_nt);
+ }
protected:
+ bool uninit;
Hypergraph const* hg_; // shouldn't be used for anything, esp. after binarization
prob_t goal_inside,pushed_inside; // when we push viterbi weights to goal, we store the removed probability in pushed_inside
// rules/nts will have same index as hg edges/nodes
@@ -68,4 +89,5 @@ protected:
int goal_nt;
};
+
#endif
diff --git a/decoder/cfg_format.h b/decoder/cfg_format.h
index 1bce3d06..169632a6 100755
--- a/decoder/cfg_format.h
+++ b/decoder/cfg_format.h
@@ -3,9 +3,19 @@
#include <program_options.h>
#include <string>
+#include "wordid.h"
+#include "feature_vector.h"
struct CFGFormat {
- bool identity_scfg;bool features;bool logprob_feat;bool cfg_comma_nt;std::string goal_nt_name;std::string nt_prefix;
+ bool identity_scfg;
+ bool features;
+ bool logprob_feat;
+ bool cfg_comma_nt;
+ bool nt_span;
+ std::string goal_nt_name;
+ std::string nt_prefix;
+ std::string logprob_feat_name;
+ std::string partsep;
template <class Opts> // template to support both printable_opts and boost nonprintable
void AddOptions(Opts *opts) {
using namespace boost::program_options;
@@ -14,19 +24,81 @@ struct CFGFormat {
("identity_scfg",defaulted_value(&identity_scfg),"output an identity SCFG: add an identity target side - '[X12] ||| [X13,1] a ||| [1] a ||| feat= ...' - the redundant target '[1] a |||' is omitted otherwise.")
("features",defaulted_value(&features),"print the CFG feature vector")
("logprob_feat",defaulted_value(&logprob_feat),"print a LogProb=-1.5 feature irrespective of --features.")
+ ("logprob_feat_name",defaulted_value(&logprob_feat_name),"alternate name for the LogProb feature")
("cfg_comma_nt",defaulted_value(&cfg_comma_nt),"if false, omit the usual [NP,1] ',1' variable index in the source side")
("goal_nt_name",defaulted_value(&goal_nt_name),"if nonempty, the first production will be '[goal_nt_name] ||| [x123] ||| LogProb=y' where x123 is the actual goal nt, and y is the pushed prob, if any")
("nt_prefix",defaulted_value(&nt_prefix),"NTs are [<nt_prefix>123] where 123 is the node number starting at 0, and the highest node (last in file) is the goal node in an acyclic hypergraph")
+ ("nt_span",defaulted_value(&nt_span),"prefix A(i,j) for NT coming from hypergraph node with category A on span [i,j). this is after --nt_prefix if any")
;
}
+
+ template<class CFG>
+ void print_source_nt(std::ostream &o,CFG const&cfg,int id,int position=1) const {
+ o<<'[';
+ print_nt_name(o,cfg,id);
+ if (cfg_comma_nt) o<<','<<position;
+ o<<']';
+ }
+
+ template <class CFG>
+ void print_nt_name(std::ostream &o,CFG const& cfg,int id) const {
+ o<<nt_prefix;
+ cfg.print_nt_name(o,id);
+ o<<id;
+ }
+
+ template <class CFG>
+ void print_lhs(std::ostream &o,CFG const& cfg,int id) const {
+ o<<'[';
+ print_nt_name(o,cfg,id);
+ o<<']';
+ }
+
+ template <class CFG,class Iter>
+ void print_rhs(std::ostream &o,CFG const&cfg,Iter begin,Iter end) const {
+ o<<partsep;
+ int pos=0;
+ for (Iter i=begin;i!=end;++i) {
+ WordID w=*i;
+ if (i!=begin) o<<' ';
+ if (w>0) o << TD::Convert(w);
+ else print_source_nt(o,cfg,-w,++pos);
+ }
+ if (identity_scfg) {
+ o<<partsep;
+ int pos=0;
+ for (Iter i=begin;i!=end;++i) {
+ WordID w=*i;
+ if (i!=begin) o<<' ';
+ if (w>0) o << TD::Convert(w);
+ else o << '['<<++pos<<']';
+ }
+ }
+ }
+
+ void print_features(std::ostream &o,prob_t p,FeatureVector const& fv=FeatureVector()) const {
+ bool logp=(logprob_feat && p!=1);
+ if (features || logp) {
+ o << partsep;
+ if (logp)
+ o << logprob_feat_name<<'='<<log(p)<<' ';
+ if (features)
+ o << fv;
+ }
+ }
+
void set_defaults() {
identity_scfg=false;
features=true;
logprob_feat=true;
cfg_comma_nt=true;
goal_nt_name="S";
+ logprob_feat_name="LogProb";
nt_prefix="";
+ partsep=" ||| ";
+ nt_span=true;
}
+
CFGFormat() {
set_defaults();
}
diff --git a/decoder/filelib.h b/decoder/filelib.h
index 4da4bc4f..b9fef9a7 100644
--- a/decoder/filelib.h
+++ b/decoder/filelib.h
@@ -30,6 +30,7 @@ struct BaseFile {
}
S* stream() { return ps_.get(); }
S* operator->() { return ps_.get(); } // compat with old ReadFile * -> new Readfile. remove?
+ S &operator *() const { return get(); }
S &get() const { return *ps_; }
bool is_std() {
return filename_=="-";
diff --git a/decoder/fsa-hiero.ini b/decoder/fsa-hiero.ini
index 3eb8b3d2..7c7d0347 100755
--- a/decoder/fsa-hiero.ini
+++ b/decoder/fsa-hiero.ini
@@ -1,5 +1,4 @@
formalism=scfg
-fsa_feature_function=LanguageModelFsa debug lm.gz -n LM
scfg_extra_glue_grammar=glue-lda.scfg
grammar=grammar.hiero
show_tree_structure=true
diff --git a/decoder/grammar.cc b/decoder/grammar.cc
index 26efaf99..bbe2f01a 100644
--- a/decoder/grammar.cc
+++ b/decoder/grammar.cc
@@ -84,7 +84,7 @@ const GrammarIter* TextGrammar::GetRoot() const {
void TextGrammar::AddRule(const TRulePtr& rule, const unsigned int ctf_level, const TRulePtr& coarse_rule) {
if (ctf_level > 0) {
// assume that coarse_rule is already in tree (would be safer to check)
- if (coarse_rule->fine_rules_ == 0)
+ if (coarse_rule->fine_rules_ == 0)
coarse_rule->fine_rules_.reset(new std::vector<TRulePtr>());
coarse_rule->fine_rules_->push_back(rule);
ctf_levels_ = std::max(ctf_levels_, ctf_level);
@@ -116,7 +116,7 @@ bool TextGrammar::HasRuleForSpan(int /* i */, int /* j */, int distance) const {
GlueGrammar::GlueGrammar(const string& file) : TextGrammar(file) {}
-void RefineRule(TRulePtr pt, const unsigned int ctf_level){
+void RefineRule(TRulePtr pt, const unsigned int ctf_level){
for (unsigned int i=0; i<ctf_level; ++i){
TRulePtr r(new TRule(*pt));
pt->fine_rules_.reset(new vector<TRulePtr>);
@@ -126,10 +126,10 @@ void RefineRule(TRulePtr pt, const unsigned int ctf_level){
}
GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt, const unsigned int ctf_level) {
- TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]"));
+ TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [1]"));
AddRule(stop_glue);
RefineRule(stop_glue, ctf_level);
- TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["+ default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1"));
+ TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["+ default_nt + ",2] ||| [1] [2] ||| Glue=1"));
AddRule(glue);
RefineRule(glue, ctf_level);
}
diff --git a/decoder/hg.h b/decoder/hg.h
index e64db837..59db6cfe 100644
--- a/decoder/hg.h
+++ b/decoder/hg.h
@@ -26,6 +26,7 @@
#include "trule.h"
#include "prob.h"
#include "indices_after.h"
+#include "nt_span.h"
// if you define this, edges_ will be sorted
// (normally, just nodes_ are - root must be nodes_.back()), but this can be quite
@@ -51,6 +52,7 @@ public:
Node() : id_(), cat_(), promise(1) {}
int id_; // equal to this object's position in the nodes_ vector
WordID cat_; // non-terminal category if <0, 0 if not set
+ WordID NT() const { return -cat_; }
EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_
EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_
double promise; // set in global pruning; in [0,infty) so that mean is 1. use: e.g. scale cube poplimit. //TODO: appears to be useless, compile without this? on the other hand, pretty cheap.
@@ -66,6 +68,7 @@ public:
}
};
+
// TODO get rid of edge_prob_? (can be computed on the fly as the dot
// product of the weight vector and the feature values)
struct Edge {
@@ -81,6 +84,8 @@ public:
prob_t edge_prob_; // dot product of weights and feat_values
int id_; // equal to this object's position in the edges_ vector
+ //FIXME: these span ids belong in Node, not Edge, right? every node should have the same spans.
+
// span info. typically, i_ and j_ refer to indices in the source sentence.
// In synchronous parsing, i_ and j_ will refer to target sentence/lattice indices
// while prev_i_ prev_j_ will refer to positions in the source.
@@ -197,6 +202,47 @@ public:
}
};
+ // all this info ought to live in Node, but for some reason it's on Edges.
+ // except for stateful models that have split nt,span, this should identify the node
+ void SetNodeOrigin(int nodeid,NTSpan &r) const {
+ Node const &n=nodes_[nodeid];
+ r.nt=n.NT();
+ if (!n.in_edges_.empty()) {
+ Edge const& e=edges_[n.in_edges_.front()];
+ r.s.l=e.i_;
+ r.s.l=e.j_;
+// if (e.rule_) r.nt=-e.rule_->lhs_;
+ }
+ }
+ NTSpan NodeOrigin(int nodeid) const {
+ NTSpan r;
+ SetNodeOrigin(nodeid,r);
+ return r;
+ }
+ Span NodeSpan(int nodeid) const {
+ Span s;
+ Node const &n=nodes_[nodeid];
+ if (!n.in_edges_.empty()) {
+ Edge const& e=edges_[n.in_edges_.front()];
+ s.l=e.i_;
+ s.l=e.j_;
+ }
+ return s;
+ }
+ // 0 if none, -TD index otherwise (just like in rule)
+ WordID NodeLHS(int nodeid) const {
+ Node const &n=nodes_[nodeid];
+ return n.NT();
+ /*
+ if (!n.in_edges_.empty()) {
+ Edge const& e=edges_[n.in_edges_.front()];
+ if (e.rule_)
+ return -e.rule_->lhs_;
+ }
+ return 0;
+ */
+ }
+
typedef std::vector<prob_t> EdgeProbs;
typedef std::vector<prob_t> NodeProbs;
typedef std::vector<bool> EdgeMask;
diff --git a/decoder/nt_span.h b/decoder/nt_span.h
new file mode 100755
index 00000000..46234b07
--- /dev/null
+++ b/decoder/nt_span.h
@@ -0,0 +1,30 @@
+#ifndef NT_SPAN_H
+#define NT_SPAN_H
+
+#include <iostream>
+#include "wordid.h"
+#include "tdict.h"
+
+struct Span {
+ int l,r;
+ Span() : l(-1) { }
+ friend inline std::ostream &operator<<(std::ostream &o,Span const& s) {
+ if (s.l<0)
+ return o;
+ return o<<'<'<<s.l<<','<<s.r<<'>';
+ }
+};
+
+struct NTSpan {
+ Span s;
+ WordID nt; // awkward: this is a positive index, used in TD. but represented as negative in mixed terminal/NT space in rules/hgs.
+ NTSpan() : nt(0) { }
+ // prints as possibly empty name (whatever you set of nt,s will show)
+ friend inline std::ostream &operator<<(std::ostream &o,NTSpan const& t) {
+ if (t.nt>0)
+ o<<TD::Convert(t.nt);
+ return o << t.s;
+ }
+};
+
+#endif
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 86426ef5..81a584a7 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -91,11 +91,11 @@ struct OracleBleu {
using namespace boost::program_options;
using namespace std;
opts->add_options()
+ ("show_derivation", bool_switch(&show_derivation), "show derivation tree in kbest")
+ ("verbose",bool_switch(&verbose),"detailed logs")
("references,R", value<Refs >(&refs), "Translation reference files")
("oracle_loss", value<string>(&loss_name)->default_value("IBM_BLEU_3"), "IBM_BLEU_3 (default), IBM_BLEU etc")
("bleu_weight", value<double>(&bleu_weight)->default_value(1.), "weight to give the hope/fear loss function vs. model score")
- ("show_derivation", bool_switch(&show_derivation), "show derivation tree in kbest")
- ("verbose",bool_switch(&verbose),"detailed logs")
;
}
int order;