summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rst_parser/arc_factored.h4
-rw-r--r--rst_parser/mst_train.cc21
-rw-r--r--rst_parser/rst.cc45
-rw-r--r--rst_parser/rst.h9
4 files changed, 69 insertions, 10 deletions
diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h
index a95f8230..d9a0bb24 100644
--- a/rst_parser/arc_factored.h
+++ b/rst_parser/arc_factored.h
@@ -28,10 +28,12 @@ struct ArcFeatureFunction;
class ArcFactoredForest {
public:
ArcFactoredForest() : num_words_() {}
- explicit ArcFactoredForest(short num_words) {
+ explicit ArcFactoredForest(short num_words) : num_words_(num_words) {
resize(num_words);
}
+ unsigned size() const { return num_words_; }
+
void resize(unsigned num_words) {
num_words_ = num_words;
root_edges_.clear();
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc
index def23edb..b5114726 100644
--- a/rst_parser/mst_train.cc
+++ b/rst_parser/mst_train.cc
@@ -13,6 +13,7 @@
#include "picojson.h"
#include "optimize.h"
#include "weights.h"
+#include "rst.h"
using namespace std;
namespace po = boost::program_options;
@@ -173,12 +174,13 @@ int main(int argc, char** argv) {
double obj = -empirical.dot(weights);
// SparseVector<double> mfm; //DE
for (int i = 0; i < corpus.size(); ++i) {
+ const int num_words = corpus[i].ts.words.size();
forests[i].Reweight(weights);
- double logz;
- forests[i].EdgeMarginals(&logz);
- //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -logz << " OO= " << (-corpus[i].features.dot(weights) - logz) << endl;
- obj -= logz;
- int num_words = corpus[i].ts.words.size();
+ double lz;
+ forests[i].EdgeMarginals(&lz);
+ obj -= lz;
+ //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -lz << " OO= " << (-corpus[i].features.dot(weights) - lz) << endl;
+ //cerr << " ZZ = " << zz << endl;
for (int h = -1; h < num_words; ++h) {
for (int m = 0; m < num_words; ++m) {
if (h == m) continue;
@@ -198,13 +200,20 @@ int main(int argc, char** argv) {
double gnorm = 0;
for (int i = 0; i < g.size(); ++i)
gnorm += g[i]*g[i];
- cerr << "OBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm) << endl;
+ ostringstream ll;
+ ll << "ITER=" << (iter+1) << "\tOBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm);
+ cerr << endl << ll.str() << endl;
obj += r;
assert(obj >= 0);
o->Optimize(obj, g, &weights);
Weights::ShowLargestFeatures(weights);
+ string sl = ll.str();
+ Weights::WriteToFile(o->HasConverged() ? "weights.final.gz" : "weights.cur.gz", weights, true, &sl);
if (o->HasConverged()) { cerr << "CONVERGED\n"; break; }
}
+ forests[0].Reweight(weights);
+ TreeSampler ts(forests[0]);
+ EdgeSubset tt; ts.SampleRandomSpanningTree(&tt);
return 0;
}
diff --git a/rst_parser/rst.cc b/rst_parser/rst.cc
index f6b295b3..c4ce898e 100644
--- a/rst_parser/rst.cc
+++ b/rst_parser/rst.cc
@@ -2,6 +2,49 @@
using namespace std;
-StochasticForest::StochasticForest(const ArcFactoredForest& af) {
+// David B. Wilson. Generating Random Spanning Trees More Quickly than the Cover Time.
+
+TreeSampler::TreeSampler(const ArcFactoredForest& af) : forest(af), usucc(af.size() + 1) {
+ // edges are directed from modifiers to heads, to the root
+ for (int m = 1; m <= forest.size(); ++m) {
+ SampleSet<double>& ss = usucc[m];
+ for (int h = 0; h <= forest.size(); ++h)
+ ss.add(forest(h-1,m-1).edge_prob.as_float());
+ }
}
+void TreeSampler::SampleRandomSpanningTree(EdgeSubset* tree) {
+ MT19937 rng;
+ const int r = 0;
+ bool success = false;
+ while (!success) {
+ int roots = 0;
+ vector<int> next(forest.size() + 1, -1);
+ vector<char> in_tree(forest.size() + 1, 0);
+ in_tree[r] = 1;
+ for (int i = 0; i < forest.size(); ++i) {
+ int u = i;
+ if (in_tree[u]) continue;
+ while(!in_tree[u]) {
+ next[u] = rng.SelectSample(usucc[u]);
+ u = next[u];
+ }
+ u = i;
+ cerr << (u-1);
+ while(!in_tree[u]) {
+ in_tree[u] = true;
+ u = next[u];
+ cerr << " > " << (u-1);
+ if (u == r) { ++roots; }
+ }
+ cerr << endl;
+ }
+ assert(roots > 0);
+ if (roots > 1) {
+ cerr << "FAILURE\n";
+ } else {
+ success = true;
+ }
+ }
+};
+
diff --git a/rst_parser/rst.h b/rst_parser/rst.h
index 865871eb..a269ff9b 100644
--- a/rst_parser/rst.h
+++ b/rst_parser/rst.h
@@ -1,10 +1,15 @@
#ifndef _RST_H_
#define _RST_H_
+#include <vector>
+#include "sampler.h"
#include "arc_factored.h"
-struct StochasticForest {
- explicit StochasticForest(const ArcFactoredForest& af);
+struct TreeSampler {
+ explicit TreeSampler(const ArcFactoredForest& af);
+ void SampleRandomSpanningTree(EdgeSubset* tree);
+ const ArcFactoredForest& forest;
+ std::vector<SampleSet<double> > usucc;
};
#endif