summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <prguest11@taipan.cs>2012-04-23 03:11:26 +0100
committerChris Dyer <prguest11@taipan.cs>2012-04-23 03:11:26 +0100
commit217c4aaeba1c9f19b3420b526235bffd86c7a92b (patch)
treea551d277413faf1feb64c23786aa4dc9bf3827eb
parent06718177056fe5262262e00d98dc89f67cefb193 (diff)
mst train
-rw-r--r--rst_parser/mst_train.cc15
1 files changed, 12 insertions, 3 deletions
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc
index b3711aba..6332693e 100644
--- a/rst_parser/mst_train.cc
+++ b/rst_parser/mst_train.cc
@@ -28,6 +28,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("weights,w",po::value<string>(), "Optional starting weights")
("output_every_i_iterations,I",po::value<unsigned>()->default_value(1), "Write weights every I iterations")
("regularization_strength,C",po::value<double>()->default_value(1.0), "Regularization strength")
+#ifdef HAVE_CMPH
+ ("cmph_perfect_feature_hash,h", po::value<string>(), "Load perfect hash function for features")
+#endif
#if HAVE_THREAD
("threads,T",po::value<unsigned>()->default_value(1), "Number of threads")
#endif
@@ -119,11 +122,19 @@ int main(int argc, char** argv) {
int size = 1;
po::variables_map conf;
InitCommandLine(argc, argv, &conf);
+ if (conf.count("cmph_perfect_feature_hash")) {
+ cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as<string>() << " ...\n";
+ FD::EnableHash(conf["cmph_perfect_feature_hash"].as<string>());
+ cerr << " " << FD::NumFeats() << " features in map\n";
+ }
ArcFeatureFunctions ffs;
vector<TrainingInstance> corpus;
TrainingInstance::ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus, rank, size);
+ vector<weight_t> weights;
+ Weights::InitFromFile(conf["weights"].as<string>(), &weights);
vector<ArcFactoredForest> forests(corpus.size());
SparseVector<double> empirical;
+ cerr << "Extracting features...\n";
bool flag = false;
for (int i = 0; i < corpus.size(); ++i) {
TrainingInstance& cur = corpus[i];
@@ -149,9 +160,7 @@ int main(int argc, char** argv) {
}
if (flag) cerr << endl;
//cerr << "EMP: " << empirical << endl; //DE
- vector<weight_t> weights(FD::NumFeats(), 0.0);
- if (conf.count("weights"))
- Weights::InitFromFile(conf["weights"].as<string>(), &weights);
+ weights.resize(FD::NumFeats(), 0.0);
vector<weight_t> g(FD::NumFeats(), 0.0);
cerr << "features initialized\noptimizing...\n";
boost::shared_ptr<BatchOptimizer> o;