diff options
Diffstat (limited to 'rst_parser/mst_train.cc')
-rw-r--r-- | rst_parser/mst_train.cc | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index def23edb..b5114726 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -13,6 +13,7 @@ #include "picojson.h" #include "optimize.h" #include "weights.h" +#include "rst.h" using namespace std; namespace po = boost::program_options; @@ -173,12 +174,13 @@ int main(int argc, char** argv) { double obj = -empirical.dot(weights); // SparseVector<double> mfm; //DE for (int i = 0; i < corpus.size(); ++i) { + const int num_words = corpus[i].ts.words.size(); forests[i].Reweight(weights); - double logz; - forests[i].EdgeMarginals(&logz); - //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -logz << " OO= " << (-corpus[i].features.dot(weights) - logz) << endl; - obj -= logz; - int num_words = corpus[i].ts.words.size(); + double lz; + forests[i].EdgeMarginals(&lz); + obj -= lz; + //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -lz << " OO= " << (-corpus[i].features.dot(weights) - lz) << endl; + //cerr << " ZZ = " << zz << endl; for (int h = -1; h < num_words; ++h) { for (int m = 0; m < num_words; ++m) { if (h == m) continue; @@ -198,13 +200,20 @@ int main(int argc, char** argv) { double gnorm = 0; for (int i = 0; i < g.size(); ++i) gnorm += g[i]*g[i]; - cerr << "OBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm) << endl; + ostringstream ll; + ll << "ITER=" << (iter+1) << "\tOBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm); + cerr << endl << ll.str() << endl; obj += r; assert(obj >= 0); o->Optimize(obj, g, &weights); Weights::ShowLargestFeatures(weights); + string sl = ll.str(); + Weights::WriteToFile(o->HasConverged() ? "weights.final.gz" : "weights.cur.gz", weights, true, &sl); if (o->HasConverged()) { cerr << "CONVERGED\n"; break; } } + forests[0].Reweight(weights); + TreeSampler ts(forests[0]); + EdgeSubset tt; ts.SampleRandomSpanningTree(&tt); return 0; } |