summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc28
1 files changed, 17 insertions, 11 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 149f87d4..83e4e440 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -1,4 +1,10 @@
#include "dtrain.h"
+#include "score.h"
+#include "kbestget.h"
+#include "ksampler.h"
+#include "pairsampling.h"
+
+using namespace dtrain;
bool
@@ -138,23 +144,23 @@ main(int argc, char** argv)
string scorer_str = cfg["scorer"].as<string>();
LocalScorer* scorer;
if (scorer_str == "bleu") {
- scorer = dynamic_cast<BleuScorer*>(new BleuScorer);
+ scorer = static_cast<BleuScorer*>(new BleuScorer);
} else if (scorer_str == "stupid_bleu") {
- scorer = dynamic_cast<StupidBleuScorer*>(new StupidBleuScorer);
+ scorer = static_cast<StupidBleuScorer*>(new StupidBleuScorer);
} else if (scorer_str == "fixed_stupid_bleu") {
- scorer = dynamic_cast<FixedStupidBleuScorer*>(new FixedStupidBleuScorer);
+ scorer = static_cast<FixedStupidBleuScorer*>(new FixedStupidBleuScorer);
} else if (scorer_str == "smooth_bleu") {
- scorer = dynamic_cast<SmoothBleuScorer*>(new SmoothBleuScorer);
+ scorer = static_cast<SmoothBleuScorer*>(new SmoothBleuScorer);
} else if (scorer_str == "sum_bleu") {
- scorer = dynamic_cast<SumBleuScorer*>(new SumBleuScorer);
+ scorer = static_cast<SumBleuScorer*>(new SumBleuScorer);
} else if (scorer_str == "sumexp_bleu") {
- scorer = dynamic_cast<SumExpBleuScorer*>(new SumExpBleuScorer);
+ scorer = static_cast<SumExpBleuScorer*>(new SumExpBleuScorer);
} else if (scorer_str == "sumwhatever_bleu") {
- scorer = dynamic_cast<SumWhateverBleuScorer*>(new SumWhateverBleuScorer);
+ scorer = static_cast<SumWhateverBleuScorer*>(new SumWhateverBleuScorer);
} else if (scorer_str == "approx_bleu") {
- scorer = dynamic_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N, approx_bleu_d));
+ scorer = static_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N, approx_bleu_d));
} else if (scorer_str == "lc_bleu") {
- scorer = dynamic_cast<LinearBleuScorer*>(new LinearBleuScorer(N));
+ scorer = static_cast<LinearBleuScorer*>(new LinearBleuScorer(N));
} else {
cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl;
exit(1);
@@ -166,9 +172,9 @@ main(int argc, char** argv)
MT19937 rng; // random number generator, only for forest sampling
HypSampler* observer;
if (sample_from == "kbest")
- observer = dynamic_cast<KBestGetter*>(new KBestGetter(k, filter_type));
+ observer = static_cast<KBestGetter*>(new KBestGetter(k, filter_type));
else
- observer = dynamic_cast<KSampler*>(new KSampler(k, &rng));
+ observer = static_cast<KSampler*>(new KSampler(k, &rng));
observer->SetScorer(scorer);
// init weights