diff options
-rwxr-xr-x | decoder/cfg.cc | 30 | ||||
-rwxr-xr-x | decoder/cfg.h | 6 | ||||
-rwxr-xr-x | decoder/cfg_format.h | 1 | ||||
-rwxr-xr-x | decoder/cfg_test.cc | 35 | ||||
-rwxr-xr-x | decoder/hg_test.h | 2 | ||||
-rwxr-xr-x | utils/batched_append.h | 2 | ||||
-rw-r--r-- | utils/sparse_vector.h | 6 |
7 files changed, 64 insertions, 18 deletions
diff --git a/decoder/cfg.cc b/decoder/cfg.cc index c2d96b33..c0598f16 100755 --- a/decoder/cfg.cc +++ b/decoder/cfg.cc @@ -182,7 +182,7 @@ void CFG::Binarize(CFGBinarize const& b) { NTs new_nts; // these will be appended at the end, so we don't have to worry about iterator invalidation Rules new_rules; //TODO: this could be factored easily into in-place (append to new_* like below) and functional (nondestructive copy) versions (copy orig to target and append to target) - int newnt=nts.size(); + int newnt=-nts.size(); int newruleid=rules.size(); BinRhs bin; for (NTs::const_iterator n=nts.begin(),nn=nts.end();n!=nn;++n) { @@ -192,21 +192,29 @@ void CFG::Binarize(CFGBinarize const& b) { if (rhs.empty()) continue; bin.second=rhs.back(); for (int r=rhs.size()-2;r>=rhsmin;--r) { // pairs from right to left (normally we leave the last pair alone) - rhs.pop_back(); bin.first=rhs[r]; if (newnt==(bin.second=(get_default(bin2lhs,bin,newnt)))) { - new_nts.push_back(NT()); - new_nts.back().ruleids.push_back(newruleid); - new_rules.push_back(Rule(newnt,bin)); + new_nts.push_back(NT(newruleid)); + new_rules.push_back(Rule(-newnt,bin)); + ++newruleid; if (b.bin_name_nts) new_nts.back().from.nt=BinName(bin,nts,new_nts); - ++newnt;++newruleid; + --newnt; } } + if (rhsmin<rhs.size()) { + rhs[rhsmin]=bin.second; + rhs.resize(rhsmin+1); + } } } +#if 0 batched_append_swap(nts,new_nts); batched_append_swap(rules,new_rules); +#else + batched_append(nts,new_nts); + batched_append(rules,new_rules); +#endif if (b.bin_topo) //TODO: more efficient (at least for l2r) maintenance of order OrderNTsTopo(); } @@ -302,3 +310,13 @@ void CFG::Print(std::ostream &o,CFGFormat const& f) const { } } } + +void CFG::Print(std::ostream &o) const { + Print(o,CFGFormat()); +} + + +std::ostream &operator<<(std::ostream &o,CFG const &x) { + x.Print(o); + return o; +} diff --git a/decoder/cfg.h b/decoder/cfg.h index 21e03e2c..b6dd6d99 100755 --- a/decoder/cfg.h +++ b/decoder/cfg.h @@ -53,7 +53,7 @@ struct CFG { o << nts[n].from << n; } - typedef std::pair<int,int> BinRhs; + typedef std::pair<WordID,WordID> BinRhs; struct Rule { std::size_t hash_impl() const { @@ -144,6 +144,8 @@ struct CFG { }; struct NT { + NT() { } + explicit NT(RuleHandle r) : ruleids(1,r) { } std::size_t hash_impl() const { using namespace boost; return hash_value(ruleids); } bool operator ==(NT const &o) const { return ruleids==o.ruleids; // don't care about from @@ -181,6 +183,7 @@ struct CFG { void Init(Hypergraph const& hg,bool target_side=true,bool copy_features=false,bool push_weights=true); void Print(std::ostream &o,CFGFormat const& format) const; // see cfg_format.h void PrintRule(std::ostream &o,RuleHandle rulei,CFGFormat const& format) const; + void Print(std::ostream &o) const; // default format void Swap(CFG &o) { // make sure this includes all fields (easier to see here than in .cc) using namespace std; swap(uninit,o.uninit); @@ -302,5 +305,6 @@ inline void swap(CFG &a,CFG &b) { a.Swap(b); } +std::ostream &operator<<(std::ostream &o,CFG const &x); #endif diff --git a/decoder/cfg_format.h b/decoder/cfg_format.h index d56d42f2..a9b3fd9f 100755 --- a/decoder/cfg_format.h +++ b/decoder/cfg_format.h @@ -111,6 +111,7 @@ struct CFGFormat { } } + //TODO: default to no nt names (nt_span=0) void set_defaults() { identity_scfg=false; features=true; diff --git a/decoder/cfg_test.cc b/decoder/cfg_test.cc index c4c37a2c..81efa768 100755 --- a/decoder/cfg_test.cc +++ b/decoder/cfg_test.cc @@ -3,18 +3,27 @@ #include "hg_test.h" #include "cfg_options.h" +#define CSHOW_V 1 +#if CSHOW_V +# define CSHOWDO(x) x +#else +# define CSHOWDO(x) +#endif +#define CSHOW(x) CSHOWDO(cerr<<#x<<'='<<x<<endl;) + struct CFGTest : public HGSetup { CFGTest() { } ~CFGTest() { } static void JsonFN(Hypergraph hg,CFG &cfg,std::string file ,std::string const& wts="Model_0 1 EgivenF 1 f1 1") { - FeatureVector v; + FeatureVector featw; istringstream ws(wts); -// ASSERT_TRUE(ws>>v); + EXPECT_TRUE(ws>>featw); + CSHOW(featw) HGSetup::JsonTestFile(&hg,file); -// hg.Reweight(v); - cfg.Init(hg,true,false,false); + hg.Reweight(featw); + cfg.Init(hg,true,true,false); } static void SetUpTestCase() { @@ -27,10 +36,22 @@ TEST_F(CFGTest,Binarize) { Hypergraph hg; CFG cfg; JsonFN(hg,cfg,perro_json,perro_wts); + CSHOW("\nCFG Test.\n"); + CFGBinarize b; CFGFormat form; - form.features=true; - cerr<<"\nCFG Test.\n\n"; - cfg.Print(cerr,form); + form.nt_span=true; + for (int i=-1;i<16;++i) { + b.bin_l2r=i>=0; + b.bin_unary=i&1; + b.bin_name_nts=i&2; + b.bin_uniq=i&4; + b.bin_topo=i&8; + CFG cc=cfg; + EXPECT_EQ(cc,cfg); + CSHOW("\nBinarizing: "<<b); + cc.Binarize(b); + CSHOWDO(cc.Print(cerr,form);cerr<<"\n\n";); + } } int main(int argc, char **argv) { diff --git a/decoder/hg_test.h b/decoder/hg_test.h index 4f3ae251..c1bc05bd 100755 --- a/decoder/hg_test.h +++ b/decoder/hg_test.h @@ -9,7 +9,7 @@ using namespace std; -#pragma GCC diagnostic ignore "-Wunused" +#pragma GCC diagnostic ignored "-Wunused-variable" namespace { diff --git a/utils/batched_append.h b/utils/batched_append.h index 14a6d576..842f3209 100755 --- a/utils/batched_append.h +++ b/utils/batched_append.h @@ -13,7 +13,7 @@ void batched_append(Vector &v,SRange const& s) { template <class SRange,class Vector> void batched_append_swap(Vector &v,SRange & s) { - using namespace std; // to find the right swap + using namespace std; // to find the right swap via ADL size_t i=v.size(); size_t news=i+s.size(); v.resize(news); diff --git a/utils/sparse_vector.h b/utils/sparse_vector.h index 7ac85d1d..e3904403 100644 --- a/utils/sparse_vector.h +++ b/utils/sparse_vector.h @@ -145,7 +145,7 @@ public: if (!(s>>v)) error("reading value failed"); } std::pair<iterator,bool> vi=values_.insert(value_type(k,v)); - if (vi.second) { + if (!vi.second) { T &oldv=vi.first->second; switch(dp) { case NO_DUPS: error("read duplicate key with NO_DUPS. key=" @@ -157,9 +157,11 @@ public: } } } - return; + goto good; eof: if (!s.eof()) error("reading key failed (before EOF)"); + good: + s.clear(); // we may have reached eof, but that's no error. } friend inline std::ostream & operator<<(std::ostream &o,Self const& s) { |