summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-16 20:08:35 +0000
committergraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-16 20:08:35 +0000
commit6bacbfcbe191ec898e43f4f03e570283b156a8ca (patch)
treeb81ddcf798cc7008b09d504687d319f429cd5bb3
parent15a587e247dc0954de27e2627f5511126243943d (diff)
vest: oracle_loss argument bugfix
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@287 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r--decoder/small_vector.h9
-rwxr-xr-xvest/line_mediator.pl3
-rw-r--r--vest/mr_vest_generate_mapper_input.cc29
3 files changed, 25 insertions, 16 deletions
diff --git a/decoder/small_vector.h b/decoder/small_vector.h
index 800c1df1..86d3b0b3 100644
--- a/decoder/small_vector.h
+++ b/decoder/small_vector.h
@@ -3,6 +3,7 @@
#include <streambuf> // std::max - where to get this?
#include <cstring>
#include <cassert>
+#include <limits.h>
#define __SV_MAX_STATIC 2
@@ -77,10 +78,10 @@ class SmallVector {
bool empty() const { return size_ == 0; }
size_t size() const { return size_; }
- inline void ensure_capacity(unsigned char min_size) {
+ inline void ensure_capacity(uint16_t min_size) {
assert(min_size > __SV_MAX_STATIC);
if (min_size < capacity_) return;
- unsigned char new_cap = std::max(static_cast<unsigned char>(capacity_ << 1), min_size);
+ uint16_t new_cap = std::max(static_cast<uint16_t>(capacity_ << 1), min_size);
int* tmp = new int[new_cap];
std::memcpy(tmp, data_.ptr, capacity_ * sizeof(int));
delete[] data_.ptr;
@@ -170,8 +171,8 @@ class SmallVector {
}
private:
- unsigned char capacity_; // only defined when size_ >= __SV_MAX_STATIC
- unsigned char size_;
+ uint16_t capacity_; // only defined when size_ >= __SV_MAX_STATIC
+ uint16_t size_;
union StorageType {
int vals[__SV_MAX_STATIC];
int* ptr;
diff --git a/vest/line_mediator.pl b/vest/line_mediator.pl
index f3c6dbf1..0a9af82e 100755
--- a/vest/line_mediator.pl
+++ b/vest/line_mediator.pl
@@ -47,12 +47,9 @@ if ($ENV{SERIAL}) {
my @rw2=POSIX::pipe();
my $pid=undef;
$SIG{CHLD} = sub { wait };
-# close STDIN;
-# close STDOUT;
while (not defined ($pid=fork())) {
sleep 1;
}
-# info(STDOUT_FILENO,STDIN_FILENO,"\n");
POSIX::close(STDOUT_FILENO);
POSIX::close(STDIN_FILENO);
if ($pid) {
diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc
index 01e93f61..5e208aa0 100644
--- a/vest/mr_vest_generate_mapper_input.cc
+++ b/vest/mr_vest_generate_mapper_input.cc
@@ -78,12 +78,7 @@ struct oracle_directions {
void AddOptions(po::options_description *opts) {
oracle.AddOptions(opts);
- }
-
- void InitCommandLine(int argc, char *argv[], po::variables_map *conf) {
- po::options_description opts("Configuration options");
- OracleBleu::AddOptions(&opts);
- opts.add_options()
+ opts-?add_options()
("dev_set_size,s",po::value<unsigned>(&dev_set_size),"[REQD] Development set size (# of parallel sentences)")
("forest_repository,r",po::value<string>(&forest_repository),"[REQD] Path to forest repository")
("weights,w",po::value<string>(&weights_file),"[REQD] Current feature weights file")
@@ -96,8 +91,13 @@ struct oracle_directions {
("max_similarity,m",po::value<double>(&max_similarity)->default_value(0),"remove directions that are too similar (Tanimoto coeff. less than (1-this)). 0 means don't filter, 1 means only 1 direction allowed?")
("fear_to_hope,f",po::bool_switch(&fear_to_hope),"for each of the oracle_directions, also include a direction from fear to hope (as well as origin to hope)")
("no_old_to_hope,n","don't emit the usual old -> hope oracle")
- ("decoder_translations",po::value<string>(&decoder_translations_file)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU")
- ("help,h", "Help");
+ ("decoder_translations",po::value<string>(&decoder_translations_file)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU");
+ }
+ void InitCommandLine(int argc, char *argv[], po::variables_map *conf) {
+ po::options_description opts("Configuration options");
+ AddOptions(&opts);
+ opts.add_options()("help,h", "Help");
+
po::options_description dcmdline_options;
dcmdline_options.add(opts);
po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
@@ -176,9 +176,13 @@ struct oracle_directions {
oracle_directions() { }
Sentences model_hyps;
+ bool have_doc;
void Init() {
- if (!decoder_translations_file.empty())
+ have_doc=!decoder_translations_file.empty();
+ if (have_doc) {
model_hyps.Load(decoder_translations_file);
+ //TODO: compute doc bleu stats for each sentence, then when getting oracle temporarily exclude stats for that sentence (skip regular score updating)
+ }
start_random=false;
assert(DirectoryExists(forest_repository));
vector<string> features;
@@ -205,6 +209,9 @@ struct oracle_directions {
Oracle const& ComputeOracle(unsigned i) {
Oracle &o=oracles[i];
if (o.is_null()) {
+ if (have_doc) {
+ //TODO:
+ }
ReadFile rf(forest_file(i));
Hypergraph hg;
{
@@ -212,6 +219,10 @@ struct oracle_directions {
HypergraphIO::ReadFromJSON(rf.stream(), &hg);
}
o=oracle.ComputeOracle(oracle.MakeMetadata(hg,i),&hg,origin,&cerr);
+ if (have_doc) {
+ //TODO:
+ } else
+ oracle.IncludeLastScore();
}
return o;
}