diff options
Diffstat (limited to 'rst_parser/mst_train.cc')
-rw-r--r-- | rst_parser/mst_train.cc | 37 |
1 files changed, 14 insertions, 23 deletions
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index f0403d7e..0709e7c9 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -6,7 +6,6 @@ #include <boost/program_options/variables_map.hpp> #include "arc_ff.h" -#include "arc_ff_factory.h" #include "stringlib.h" #include "filelib.h" #include "tdict.h" @@ -22,7 +21,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { string cfg_file; opts.add_options() ("training_data,t",po::value<string>()->default_value("-"), "File containing training data (jsent format)") - ("feature_function,F",po::value<vector<string> >()->composing(), "feature function (multiple permitted)") ("weights,w",po::value<string>(), "Optional starting weights") ("output_every_i_iterations,I",po::value<unsigned>()->default_value(1), "Write weights every I iterations") ("regularization_strength,C",po::value<double>()->default_value(1.0), "Regularization strength") @@ -74,12 +72,8 @@ int main(int argc, char** argv) { int size = 1; po::variables_map conf; InitCommandLine(argc, argv, &conf); - ArcFactoredForest af(5); - ArcFFRegistry reg; - reg.Register("DistancePenalty", new ArcFFFactory<DistancePenalty>); + ArcFeatureFunctions ffs; vector<TrainingInstance> corpus; - vector<boost::shared_ptr<ArcFeatureFunction> > ffs; - ffs.push_back(boost::shared_ptr<ArcFeatureFunction>(new DistancePenalty(""))); TrainingInstance::ReadTraining(conf["training_data"].as<string>(), &corpus, rank, size); vector<ArcFactoredForest> forests(corpus.size()); SparseVector<double> empirical; @@ -88,22 +82,19 @@ int main(int argc, char** argv) { TrainingInstance& cur = corpus[i]; if (rank == 0 && (i+1) % 10 == 0) { cerr << '.' << flush; flag = true; } if (rank == 0 && (i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; } - for (int fi = 0; fi < ffs.size(); ++fi) { - ArcFeatureFunction& ff = *ffs[fi]; - ff.PrepareForInput(cur.ts); - SparseVector<weight_t> efmap; - for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) { - efmap.clear(); - ff.EgdeFeatures(cur.ts, cur.tree.h_m_pairs[j].first, - cur.tree.h_m_pairs[j].second, - &efmap); - cur.features += efmap; - } - for (int j = 0; j < cur.tree.roots.size(); ++j) { - efmap.clear(); - ff.EgdeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap); - cur.features += efmap; - } + ffs.PrepareForInput(cur.ts); + SparseVector<weight_t> efmap; + for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(cur.ts, cur.tree.h_m_pairs[j].first, + cur.tree.h_m_pairs[j].second, + &efmap); + cur.features += efmap; + } + for (int j = 0; j < cur.tree.roots.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap); + cur.features += efmap; } empirical += cur.features; forests[i].resize(cur.ts.words.size()); |