summaryrefslogtreecommitdiff
path: root/rst_parser/mst_train.cc
diff options
context:
space:
mode:
Diffstat (limited to 'rst_parser/mst_train.cc')
-rw-r--r--rst_parser/mst_train.cc37
1 files changed, 14 insertions, 23 deletions
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc
index f0403d7e..0709e7c9 100644
--- a/rst_parser/mst_train.cc
+++ b/rst_parser/mst_train.cc
@@ -6,7 +6,6 @@
#include <boost/program_options/variables_map.hpp>
#include "arc_ff.h"
-#include "arc_ff_factory.h"
#include "stringlib.h"
#include "filelib.h"
#include "tdict.h"
@@ -22,7 +21,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
string cfg_file;
opts.add_options()
("training_data,t",po::value<string>()->default_value("-"), "File containing training data (jsent format)")
- ("feature_function,F",po::value<vector<string> >()->composing(), "feature function (multiple permitted)")
("weights,w",po::value<string>(), "Optional starting weights")
("output_every_i_iterations,I",po::value<unsigned>()->default_value(1), "Write weights every I iterations")
("regularization_strength,C",po::value<double>()->default_value(1.0), "Regularization strength")
@@ -74,12 +72,8 @@ int main(int argc, char** argv) {
int size = 1;
po::variables_map conf;
InitCommandLine(argc, argv, &conf);
- ArcFactoredForest af(5);
- ArcFFRegistry reg;
- reg.Register("DistancePenalty", new ArcFFFactory<DistancePenalty>);
+ ArcFeatureFunctions ffs;
vector<TrainingInstance> corpus;
- vector<boost::shared_ptr<ArcFeatureFunction> > ffs;
- ffs.push_back(boost::shared_ptr<ArcFeatureFunction>(new DistancePenalty("")));
TrainingInstance::ReadTraining(conf["training_data"].as<string>(), &corpus, rank, size);
vector<ArcFactoredForest> forests(corpus.size());
SparseVector<double> empirical;
@@ -88,22 +82,19 @@ int main(int argc, char** argv) {
TrainingInstance& cur = corpus[i];
if (rank == 0 && (i+1) % 10 == 0) { cerr << '.' << flush; flag = true; }
if (rank == 0 && (i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; }
- for (int fi = 0; fi < ffs.size(); ++fi) {
- ArcFeatureFunction& ff = *ffs[fi];
- ff.PrepareForInput(cur.ts);
- SparseVector<weight_t> efmap;
- for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) {
- efmap.clear();
- ff.EgdeFeatures(cur.ts, cur.tree.h_m_pairs[j].first,
- cur.tree.h_m_pairs[j].second,
- &efmap);
- cur.features += efmap;
- }
- for (int j = 0; j < cur.tree.roots.size(); ++j) {
- efmap.clear();
- ff.EgdeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap);
- cur.features += efmap;
- }
+ ffs.PrepareForInput(cur.ts);
+ SparseVector<weight_t> efmap;
+ for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) {
+ efmap.clear();
+ ffs.EdgeFeatures(cur.ts, cur.tree.h_m_pairs[j].first,
+ cur.tree.h_m_pairs[j].second,
+ &efmap);
+ cur.features += efmap;
+ }
+ for (int j = 0; j < cur.tree.roots.size(); ++j) {
+ efmap.clear();
+ ffs.EdgeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap);
+ cur.features += efmap;
}
empirical += cur.features;
forests[i].resize(cur.ts.words.size());