diff options
author | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-16 20:08:35 +0000 |
---|---|---|
committer | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-16 20:08:35 +0000 |
commit | 6bacbfcbe191ec898e43f4f03e570283b156a8ca (patch) | |
tree | b81ddcf798cc7008b09d504687d319f429cd5bb3 | |
parent | 15a587e247dc0954de27e2627f5511126243943d (diff) |
vest: oracle_loss argument bugfix
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@287 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r-- | decoder/small_vector.h | 9 | ||||
-rwxr-xr-x | vest/line_mediator.pl | 3 | ||||
-rw-r--r-- | vest/mr_vest_generate_mapper_input.cc | 29 |
3 files changed, 25 insertions, 16 deletions
diff --git a/decoder/small_vector.h b/decoder/small_vector.h index 800c1df1..86d3b0b3 100644 --- a/decoder/small_vector.h +++ b/decoder/small_vector.h @@ -3,6 +3,7 @@ #include <streambuf> // std::max - where to get this? #include <cstring> #include <cassert> +#include <limits.h> #define __SV_MAX_STATIC 2 @@ -77,10 +78,10 @@ class SmallVector { bool empty() const { return size_ == 0; } size_t size() const { return size_; } - inline void ensure_capacity(unsigned char min_size) { + inline void ensure_capacity(uint16_t min_size) { assert(min_size > __SV_MAX_STATIC); if (min_size < capacity_) return; - unsigned char new_cap = std::max(static_cast<unsigned char>(capacity_ << 1), min_size); + uint16_t new_cap = std::max(static_cast<uint16_t>(capacity_ << 1), min_size); int* tmp = new int[new_cap]; std::memcpy(tmp, data_.ptr, capacity_ * sizeof(int)); delete[] data_.ptr; @@ -170,8 +171,8 @@ class SmallVector { } private: - unsigned char capacity_; // only defined when size_ >= __SV_MAX_STATIC - unsigned char size_; + uint16_t capacity_; // only defined when size_ >= __SV_MAX_STATIC + uint16_t size_; union StorageType { int vals[__SV_MAX_STATIC]; int* ptr; diff --git a/vest/line_mediator.pl b/vest/line_mediator.pl index f3c6dbf1..0a9af82e 100755 --- a/vest/line_mediator.pl +++ b/vest/line_mediator.pl @@ -47,12 +47,9 @@ if ($ENV{SERIAL}) { my @rw2=POSIX::pipe(); my $pid=undef; $SIG{CHLD} = sub { wait }; -# close STDIN; -# close STDOUT; while (not defined ($pid=fork())) { sleep 1; } -# info(STDOUT_FILENO,STDIN_FILENO,"\n"); POSIX::close(STDOUT_FILENO); POSIX::close(STDIN_FILENO); if ($pid) { diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc index 01e93f61..5e208aa0 100644 --- a/vest/mr_vest_generate_mapper_input.cc +++ b/vest/mr_vest_generate_mapper_input.cc @@ -78,12 +78,7 @@ struct oracle_directions { void AddOptions(po::options_description *opts) { oracle.AddOptions(opts); - } - - void InitCommandLine(int argc, char *argv[], po::variables_map *conf) { - po::options_description opts("Configuration options"); - OracleBleu::AddOptions(&opts); - opts.add_options() + opts-?add_options() ("dev_set_size,s",po::value<unsigned>(&dev_set_size),"[REQD] Development set size (# of parallel sentences)") ("forest_repository,r",po::value<string>(&forest_repository),"[REQD] Path to forest repository") ("weights,w",po::value<string>(&weights_file),"[REQD] Current feature weights file") @@ -96,8 +91,13 @@ struct oracle_directions { ("max_similarity,m",po::value<double>(&max_similarity)->default_value(0),"remove directions that are too similar (Tanimoto coeff. less than (1-this)). 0 means don't filter, 1 means only 1 direction allowed?") ("fear_to_hope,f",po::bool_switch(&fear_to_hope),"for each of the oracle_directions, also include a direction from fear to hope (as well as origin to hope)") ("no_old_to_hope,n","don't emit the usual old -> hope oracle") - ("decoder_translations",po::value<string>(&decoder_translations_file)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU") - ("help,h", "Help"); + ("decoder_translations",po::value<string>(&decoder_translations_file)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU"); + } + void InitCommandLine(int argc, char *argv[], po::variables_map *conf) { + po::options_description opts("Configuration options"); + AddOptions(&opts); + opts.add_options()("help,h", "Help"); + po::options_description dcmdline_options; dcmdline_options.add(opts); po::store(parse_command_line(argc, argv, dcmdline_options), *conf); @@ -176,9 +176,13 @@ struct oracle_directions { oracle_directions() { } Sentences model_hyps; + bool have_doc; void Init() { - if (!decoder_translations_file.empty()) + have_doc=!decoder_translations_file.empty(); + if (have_doc) { model_hyps.Load(decoder_translations_file); + //TODO: compute doc bleu stats for each sentence, then when getting oracle temporarily exclude stats for that sentence (skip regular score updating) + } start_random=false; assert(DirectoryExists(forest_repository)); vector<string> features; @@ -205,6 +209,9 @@ struct oracle_directions { Oracle const& ComputeOracle(unsigned i) { Oracle &o=oracles[i]; if (o.is_null()) { + if (have_doc) { + //TODO: + } ReadFile rf(forest_file(i)); Hypergraph hg; { @@ -212,6 +219,10 @@ struct oracle_directions { HypergraphIO::ReadFromJSON(rf.stream(), &hg); } o=oracle.ComputeOracle(oracle.MakeMetadata(hg,i),&hg,origin,&cerr); + if (have_doc) { + //TODO: + } else + oracle.IncludeLastScore(); } return o; } |