summaryrefslogtreecommitdiff
path: root/rst_parser/mst_train.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-04-16 00:18:20 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-04-16 00:18:20 -0400
commitfa47b549e5ac7c16dce9e40d52328ffd51b60dc6 (patch)
tree037edacd471b3a91427db2708af1533bb6116a65 /rst_parser/mst_train.cc
parentdaa182defda1a97cb66b45b4ebf2a223948d950b (diff)
rst algorithm
Diffstat (limited to 'rst_parser/mst_train.cc')
-rw-r--r--rst_parser/mst_train.cc21
1 files changed, 15 insertions, 6 deletions
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc
index def23edb..b5114726 100644
--- a/rst_parser/mst_train.cc
+++ b/rst_parser/mst_train.cc
@@ -13,6 +13,7 @@
#include "picojson.h"
#include "optimize.h"
#include "weights.h"
+#include "rst.h"
using namespace std;
namespace po = boost::program_options;
@@ -173,12 +174,13 @@ int main(int argc, char** argv) {
double obj = -empirical.dot(weights);
// SparseVector<double> mfm; //DE
for (int i = 0; i < corpus.size(); ++i) {
+ const int num_words = corpus[i].ts.words.size();
forests[i].Reweight(weights);
- double logz;
- forests[i].EdgeMarginals(&logz);
- //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -logz << " OO= " << (-corpus[i].features.dot(weights) - logz) << endl;
- obj -= logz;
- int num_words = corpus[i].ts.words.size();
+ double lz;
+ forests[i].EdgeMarginals(&lz);
+ obj -= lz;
+ //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -lz << " OO= " << (-corpus[i].features.dot(weights) - lz) << endl;
+ //cerr << " ZZ = " << zz << endl;
for (int h = -1; h < num_words; ++h) {
for (int m = 0; m < num_words; ++m) {
if (h == m) continue;
@@ -198,13 +200,20 @@ int main(int argc, char** argv) {
double gnorm = 0;
for (int i = 0; i < g.size(); ++i)
gnorm += g[i]*g[i];
- cerr << "OBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm) << endl;
+ ostringstream ll;
+ ll << "ITER=" << (iter+1) << "\tOBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm);
+ cerr << endl << ll.str() << endl;
obj += r;
assert(obj >= 0);
o->Optimize(obj, g, &weights);
Weights::ShowLargestFeatures(weights);
+ string sl = ll.str();
+ Weights::WriteToFile(o->HasConverged() ? "weights.final.gz" : "weights.cur.gz", weights, true, &sl);
if (o->HasConverged()) { cerr << "CONVERGED\n"; break; }
}
+ forests[0].Reweight(weights);
+ TreeSampler ts(forests[0]);
+ EdgeSubset tt; ts.SampleRandomSpanningTree(&tt);
return 0;
}