summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/Makefile.am3
-rw-r--r--decoder/cdec.cc17
-rw-r--r--decoder/dict_test.cc17
-rw-r--r--decoder/fdict.cc124
-rw-r--r--decoder/fdict.h3
-rw-r--r--decoder/ff_wordalign.cc79
-rw-r--r--decoder/ff_wordalign.h6
-rw-r--r--decoder/lexalign.cc34
-rw-r--r--decoder/lextrans.cc12
-rw-r--r--decoder/lextrans.h12
-rw-r--r--decoder/stringlib.cc1
-rw-r--r--tests/system_tests/unsup-align/cdec.ini4
-rw-r--r--training/Makefile.am14
-rwxr-xr-xtraining/cluster-em.pl64
-rwxr-xr-xword-aligner/support/supplement_weights_file.pl2
15 files changed, 279 insertions, 113 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index d4e2a77c..81cd43e7 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -65,7 +65,8 @@ libcdec_a_SOURCES = \
ff_csplit.cc \
ff_tagger.cc \
freqdict.cc \
- lexcrf.cc \
+ lexalign.cc \
+ lextrans.cc \
tagger.cc \
bottom_up_parser.cc \
phrasebased_translator.cc \
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index b130e7fd..811a0d04 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -18,7 +18,8 @@
#include "sampler.h"
#include "sparse_vector.h"
#include "tagger.h"
-#include "lexcrf.h"
+#include "lextrans.h"
+#include "lexalign.h"
#include "csplit.h"
#include "weights.h"
#include "tdict.h"
@@ -50,7 +51,7 @@ void ShowBanner() {
void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
- ("formalism,f",po::value<string>(),"Decoding formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSplit (compound splitting), Tagger (sequence labeling)")
+ ("formalism,f",po::value<string>(),"Decoding formalism; values include SCFG, FST, PB, LexTrans (lexical translation model, also disc training), CSplit (compound splitting), Tagger (sequence labeling), LexAlign (alignment only, or EM training)")
("input,i",po::value<string>()->default_value("-"),"Source file")
("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
("weights,w",po::value<string>(),"Feature weights file")
@@ -72,7 +73,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("show_expected_length", "Show the expected translation length under the model")
("show_partition,z", "Compute and show the partition (inside score)")
("beam_prune", po::value<double>(), "Prune paths from +LM forest")
- ("lexcrf_use_null", "Support source-side null words in lexical translation")
+ ("lexalign_use_null", "Support source-side null words in lexical translation")
("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set")
("csplit_output_plf", "(Compound splitter) Output lattice in PLF format")
("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
@@ -117,8 +118,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
}
const string formalism = LowercaseString((*conf)["formalism"].as<string>());
- if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb" && formalism != "csplit" && formalism != "tagger") {
- cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lexcrf', or 'tagger'\n";
+ if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") {
+ cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n";
cerr << dcmdline_options << endl;
exit(1);
}
@@ -273,8 +274,10 @@ int main(int argc, char** argv) {
translator.reset(new PhraseBasedTranslator(conf));
else if (formalism == "csplit")
translator.reset(new CompoundSplit(conf));
- else if (formalism == "lexcrf")
- translator.reset(new LexicalCRF(conf));
+ else if (formalism == "lextrans")
+ translator.reset(new LexicalTrans(conf));
+ else if (formalism == "lexalign")
+ translator.reset(new LexicalAlign(conf));
else if (formalism == "tagger")
translator.reset(new Tagger(conf));
else
diff --git a/decoder/dict_test.cc b/decoder/dict_test.cc
index 5c5d84f0..2049ec27 100644
--- a/decoder/dict_test.cc
+++ b/decoder/dict_test.cc
@@ -1,8 +1,13 @@
#include "dict.h"
+#include "fdict.h"
+
+#include <iostream>
#include <gtest/gtest.h>
#include <cassert>
+using namespace std;
+
class DTest : public testing::Test {
public:
DTest() {}
@@ -23,6 +28,18 @@ TEST_F(DTest, Convert) {
EXPECT_EQ(d.Convert(b), "bar");
}
+TEST_F(DTest, FDictTest) {
+ int fid = FD::Convert("First");
+ EXPECT_GT(fid, 0);
+ EXPECT_EQ(FD::Convert(fid), "First");
+ string x = FD::Escape("=");
+ cerr << x << endl;
+ EXPECT_NE(x, "=");
+ x = FD::Escape(";");
+ cerr << x << endl;
+ EXPECT_NE(x, ";");
+}
+
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
diff --git a/decoder/fdict.cc b/decoder/fdict.cc
index 8218a5d3..7e1b0e1f 100644
--- a/decoder/fdict.cc
+++ b/decoder/fdict.cc
@@ -1,5 +1,129 @@
#include "fdict.h"
+#include <string>
+
+using namespace std;
+
Dict FD::dict_;
bool FD::frozen_ = false;
+static int HexPairValue(const char * code) {
+ int value = 0;
+ const char * pch = code;
+ for (;;) {
+ int digit = *pch++;
+ if (digit >= '0' && digit <= '9') {
+ value += digit - '0';
+ }
+ else if (digit >= 'A' && digit <= 'F') {
+ value += digit - 'A' + 10;
+ }
+ else if (digit >= 'a' && digit <= 'f') {
+ value += digit - 'a' + 10;
+ }
+ else {
+ return -1;
+ }
+ if (pch == code + 2)
+ return value;
+ value <<= 4;
+ }
+}
+
+int UrlDecode(const char *source, char *dest)
+{
+ char * start = dest;
+
+ while (*source) {
+ switch (*source) {
+ case '+':
+ *(dest++) = ' ';
+ break;
+ case '%':
+ if (source[1] && source[2]) {
+ int value = HexPairValue(source + 1);
+ if (value >= 0) {
+ *(dest++) = value;
+ source += 2;
+ }
+ else {
+ *dest++ = '?';
+ }
+ }
+ else {
+ *dest++ = '?';
+ }
+ break;
+ default:
+ *dest++ = *source;
+ }
+ source++;
+ }
+
+ *dest = 0;
+ return dest - start;
+}
+
+int UrlEncode(const char *source, char *dest, unsigned max) {
+ static const char *digits = "0123456789ABCDEF";
+ unsigned char ch;
+ unsigned len = 0;
+ char *start = dest;
+
+ while (len < max - 4 && *source)
+ {
+ ch = (unsigned char)*source;
+ if (*source == ' ') {
+ *dest++ = '+';
+ }
+ else if (strchr("=:;,_| %", ch)) {
+ *dest++ = '%';
+ *dest++ = digits[(ch >> 4) & 0x0F];
+ *dest++ = digits[ ch & 0x0F];
+ }
+ else {
+ *dest++ = *source;
+ }
+ source++;
+ }
+ *dest = 0;
+ return start - dest;
+}
+
+std::string UrlDecodeString(const std::string & encoded) {
+ const char * sz_encoded = encoded.c_str();
+ size_t needed_length = encoded.length();
+ for (const char * pch = sz_encoded; *pch; pch++) {
+ if (*pch == '%')
+ needed_length += 2;
+ }
+ needed_length += 10;
+ char stackalloc[64];
+ char * buf = needed_length > sizeof(stackalloc)/sizeof(*stackalloc) ?
+ (char *)malloc(needed_length) : stackalloc;
+ UrlDecode(encoded.c_str(), buf);
+ std::string result(buf);
+ if (buf != stackalloc) {
+ free(buf);
+ }
+ return result;
+}
+
+std::string UrlEncodeString(const std::string & decoded) {
+ const char * sz_decoded = decoded.c_str();
+ size_t needed_length = decoded.length() * 3 + 3;
+ char stackalloc[64];
+ char * buf = needed_length > sizeof(stackalloc)/sizeof(*stackalloc) ?
+ (char *)malloc(needed_length) : stackalloc;
+ UrlEncode(decoded.c_str(), buf, needed_length);
+ std::string result(buf);
+ if (buf != stackalloc) {
+ free(buf);
+ }
+ return result;
+}
+
+string FD::Escape(const string& s) {
+ return UrlEncodeString(s);
+}
+
diff --git a/decoder/fdict.h b/decoder/fdict.h
index d05f1706..c4236580 100644
--- a/decoder/fdict.h
+++ b/decoder/fdict.h
@@ -20,6 +20,9 @@ struct FD {
static inline const std::string& Convert(const WordID& w) {
return dict_.Convert(w);
}
+ // Escape any string to a form that can be used as the name
+ // of a weight in a weights file
+ static std::string Escape(const std::string& s);
static Dict dict_;
private:
static bool frozen_;
diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc
index fb90df62..669aa530 100644
--- a/decoder/ff_wordalign.cc
+++ b/decoder/ff_wordalign.cc
@@ -26,7 +26,7 @@ Model2BinaryFeatures::Model2BinaryFeatures(const string& param) :
val = -1;
if (j < i) {
ostringstream os;
- os << "M2_FL:" << i << "_SI:" << j << "_TI:" << k;
+ os << "M2FL:" << i << ":TI:" << k << "_SI:" << j;
val = FD::Convert(os.str());
}
}
@@ -181,32 +181,27 @@ void MarkovJumpFClass::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
}
+// std::vector<std::map<int, int> > flen2jump2fid_;
MarkovJump::MarkovJump(const string& param) :
FeatureFunction(1),
fid_(FD::Convert("MarkovJump")),
- individual_params_per_jumpsize_(false),
- condition_on_flen_(false) {
+ binary_params_(false) {
cerr << " MarkovJump";
vector<string> argv;
int argc = SplitOnWhitespace(param, &argv);
- if (argc > 0) {
- if (argv[0] == "--fclasses") {
- argc--;
- assert(argc > 0);
- const string f_class_file = argv[1];
- }
- if (argc != 1 || !(argv[0] == "-f" || argv[0] == "-i" || argv[0] == "-if")) {
- cerr << "MarkovJump: expected parameters to be -f, -i, or -if\n";
- exit(1);
- }
- individual_params_per_jumpsize_ = (argv[0][1] == 'i');
- condition_on_flen_ = (argv[0][argv[0].size() - 1] == 'f');
- if (individual_params_per_jumpsize_) {
- template_ = "Jump:000";
- cerr << ", individual jump parameters";
- if (condition_on_flen_) {
- template_ += ":F00";
- cerr << " (split by f-length)";
+ if (argc != 1 || !(argv[0] == "-b" || argv[0] == "+b")) {
+ cerr << "MarkovJump: expected parameters to be -b or +b\n";
+ exit(1);
+ }
+ binary_params_ = argv[0] == "+b";
+ if (binary_params_) {
+ flen2jump2fid_.resize(MAX_SENTENCE_SIZE);
+ for (int i = 1; i < MAX_SENTENCE_SIZE; ++i) {
+ map<int, int>& jump2fid = flen2jump2fid_[i];
+ for (int jump = -i; jump <= i; ++jump) {
+ ostringstream os;
+ os << "Jump:FLen:" << i << "_J:" << jump;
+ jump2fid[jump] = FD::Convert(os.str());
}
}
} else {
@@ -215,6 +210,7 @@ MarkovJump::MarkovJump(const string& param) :
cerr << endl;
}
+// TODO handle NULLs according to Och 2000
void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
const vector<const void*>& ant_states,
@@ -222,8 +218,24 @@ void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta,
SparseVector<double>* estimated_features,
void* state) const {
unsigned char& dpstate = *((unsigned char*)state);
+ const int flen = smeta.GetSourceLength();
if (edge.Arity() == 0) {
dpstate = static_cast<unsigned int>(edge.i_);
+ if (edge.prev_i_ == 0) {
+ if (binary_params_) {
+ // NULL will be tricky
+ // TODO initial state distribution, not normal jumps
+ const int fid = flen2jump2fid_[flen].find(edge.i_ + 1)->second;
+ features->set_value(fid, 1.0);
+ }
+ } else if (edge.prev_i_ == smeta.GetTargetLength() - 1) {
+ // NULL will be tricky
+ if (binary_params_) {
+ int jumpsize = flen - edge.i_;
+ const int fid = flen2jump2fid_[flen].find(jumpsize)->second;
+ features->set_value(fid, 1.0);
+ }
+ }
} else if (edge.Arity() == 1) {
dpstate = *((unsigned char*)ant_states[0]);
} else if (edge.Arity() == 2) {
@@ -234,27 +246,12 @@ void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta,
else
dpstate = static_cast<unsigned int>(right_index);
const int jumpsize = right_index - left_index;
- features->set_value(fid_, fabs(jumpsize - 1)); // Blunsom and Cohn def
- if (individual_params_per_jumpsize_) {
- string fname = template_;
- int param = jumpsize;
- if (jumpsize < 0) {
- param *= -1;
- fname[5]='L';
- } else if (jumpsize > 0) {
- fname[5]='R';
- }
- if (param) {
- fname[6] = '0' + (param / 10);
- fname[7] = '0' + (param % 10);
- }
- if (condition_on_flen_) {
- const int flen = smeta.GetSourceLength();
- fname[10] = '0' + (flen / 10);
- fname[11] = '0' + (flen % 10);
- }
- features->set_value(FD::Convert(fname), 1.0);
+ if (binary_params_) {
+ const int fid = flen2jump2fid_[flen].find(jumpsize)->second;
+ features->set_value(fid, 1.0);
+ } else {
+ features->set_value(fid_, fabs(jumpsize - 1)); // Blunsom and Cohn def
}
} else {
assert(!"something really unexpected is happening");
diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h
index 688750de..c44ad26b 100644
--- a/decoder/ff_wordalign.h
+++ b/decoder/ff_wordalign.h
@@ -49,10 +49,8 @@ class MarkovJump : public FeatureFunction {
void* out_context) const;
private:
const int fid_;
- bool individual_params_per_jumpsize_;
- bool condition_on_flen_;
- bool condition_on_fclass_;
- std::string template_;
+ bool binary_params_;
+ std::vector<std::map<int, int> > flen2jump2fid_;
};
class MarkovJumpFClass : public FeatureFunction {
diff --git a/decoder/lexalign.cc b/decoder/lexalign.cc
index ee3b5fe0..8dd77c53 100644
--- a/decoder/lexalign.cc
+++ b/decoder/lexalign.cc
@@ -31,17 +31,24 @@ struct LexicalAlignImpl {
const WordID& e_i = target[i][0].label;
Hypergraph::Node* node = forest->AddNode(kXCAT);
const int new_node_id = node->id_;
+ int num_srcs = 0;
for (int j = f_start; j < f_len; ++j) { // for each word in the source
const WordID src_sym = (j < 0 ? kNULL : lattice[j][0].label);
- TRulePtr& rule = LexRule(src_sym, e_i);
- Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());
- edge->i_ = j;
- edge->j_ = j+1;
- edge->prev_i_ = i;
- edge->prev_j_ = i+1;
- edge->feature_values_ += edge->rule_->GetFeatureValues();
- forest->ConnectEdgeToHeadNode(edge->id_, new_node_id);
+ const TRulePtr& rule = LexRule(src_sym, e_i);
+ if (rule) {
+ Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());
+ edge->i_ = j;
+ edge->j_ = j+1;
+ edge->prev_i_ = i;
+ edge->prev_j_ = i+1;
+ edge->feature_values_ += edge->rule_->GetFeatureValues();
+ ++num_srcs;
+ forest->ConnectEdgeToHeadNode(edge->id_, new_node_id);
+ } else {
+ cerr << TD::Convert(src_sym) << " does not translate to " << TD::Convert(e_i) << endl;
+ }
}
+ assert(num_srcs > 0);
if (prev_node_id >= 0) {
const int comb_node_id = forest->AddNode(kXCAT)->id_;
Hypergraph::TailNodeVector tail(2, prev_node_id);
@@ -66,21 +73,23 @@ struct LexicalAlignImpl {
return it->second;
int& fid = e2fid[e];
if (f == 0) {
- fid = FD::Convert("Lx_<eps>_" + FD::Escape(TD::Convert(e)));
+ fid = FD::Convert("Lx:<eps>_" + FD::Escape(TD::Convert(e)));
} else {
- fid = FD::Convert("Lx_" + FD::Escape(TD::Convert(f)) + "_" + FD::Escape(TD::Convert(e)));
+ fid = FD::Convert("Lx:" + FD::Escape(TD::Convert(f)) + "_" + FD::Escape(TD::Convert(e)));
}
return fid;
}
- inline TRulePtr& LexRule(const WordID& f, const WordID& e) {
+ inline const TRulePtr& LexRule(const WordID& f, const WordID& e) {
+ const int fid = LexFeatureId(f, e);
+ if (!fid) { return kNULL_PTR; }
map<int, TRulePtr>& e2rule = f2e2rule[f];
map<int, TRulePtr>::iterator it = e2rule.find(e);
if (it != e2rule.end())
return it->second;
TRulePtr& tr = e2rule[e];
tr.reset(TRule::CreateLexicalRule(f, e));
- tr->scores_.set_value(LexFeatureId(f, e), 1.0);
+ tr->scores_.set_value(fid, 1.0);
return tr;
}
@@ -90,6 +99,7 @@ struct LexicalAlignImpl {
const WordID kNULL;
const TRulePtr kBINARY;
const TRulePtr kGOAL_RULE;
+ const TRulePtr kNULL_PTR;
map<int, map<int, TRulePtr> > f2e2rule;
map<int, map<int, int> > f2e2fid;
GrammarPtr grammar;
diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc
index b0e03c69..e7fa1aa1 100644
--- a/decoder/lextrans.cc
+++ b/decoder/lextrans.cc
@@ -1,4 +1,4 @@
-#include "lexcrf.h"
+#include "lextrans.h"
#include <iostream>
@@ -10,8 +10,8 @@
using namespace std;
-struct LexicalCRFImpl {
- LexicalCRFImpl(const boost::program_options::variables_map& conf) :
+struct LexicalTransImpl {
+ LexicalTransImpl(const boost::program_options::variables_map& conf) :
use_null(conf.count("lexcrf_use_null") > 0),
kXCAT(TD::Convert("X")*-1),
kNULL(TD::Convert("<eps>")),
@@ -95,10 +95,10 @@ struct LexicalCRFImpl {
GrammarPtr grammar;
};
-LexicalCRF::LexicalCRF(const boost::program_options::variables_map& conf) :
- pimpl_(new LexicalCRFImpl(conf)) {}
+LexicalTrans::LexicalTrans(const boost::program_options::variables_map& conf) :
+ pimpl_(new LexicalTransImpl(conf)) {}
-bool LexicalCRF::Translate(const string& input,
+bool LexicalTrans::Translate(const string& input,
SentenceMetadata* smeta,
const vector<double>& weights,
Hypergraph* forest) {
diff --git a/decoder/lextrans.h b/decoder/lextrans.h
index 99362c81..9920f79c 100644
--- a/decoder/lextrans.h
+++ b/decoder/lextrans.h
@@ -1,18 +1,18 @@
-#ifndef _LEXCRF_H_
-#define _LEXCRF_H_
+#ifndef _LEXTrans_H_
+#define _LEXTrans_H_
#include "translator.h"
#include "lattice.h"
-struct LexicalCRFImpl;
-struct LexicalCRF : public Translator {
- LexicalCRF(const boost::program_options::variables_map& conf);
+struct LexicalTransImpl;
+struct LexicalTrans : public Translator {
+ LexicalTrans(const boost::program_options::variables_map& conf);
bool Translate(const std::string& input,
SentenceMetadata* smeta,
const std::vector<double>& weights,
Hypergraph* forest);
private:
- boost::shared_ptr<LexicalCRFImpl> pimpl_;
+ boost::shared_ptr<LexicalTransImpl> pimpl_;
};
#endif
diff --git a/decoder/stringlib.cc b/decoder/stringlib.cc
index 3ed74bef..3e52ae87 100644
--- a/decoder/stringlib.cc
+++ b/decoder/stringlib.cc
@@ -1,5 +1,6 @@
#include "stringlib.h"
+#include <cstring>
#include <cstdlib>
#include <cassert>
#include <iostream>
diff --git a/tests/system_tests/unsup-align/cdec.ini b/tests/system_tests/unsup-align/cdec.ini
index 37a37214..885338a6 100644
--- a/tests/system_tests/unsup-align/cdec.ini
+++ b/tests/system_tests/unsup-align/cdec.ini
@@ -1,6 +1,6 @@
aligner=true
grammar=unsup-align.lex-grammar
intersection_strategy=full
-formalism=lexcrf
+formalism=lextrans
feature_function=RelativeSentencePosition
-feature_function=MarkovJump
+feature_function=MarkovJump -b
diff --git a/training/Makefile.am b/training/Makefile.am
index 6427fcba..490de774 100644
--- a/training/Makefile.am
+++ b/training/Makefile.am
@@ -1,10 +1,12 @@
bin_PROGRAMS = \
model1 \
+ mr_em_map_adapter \
+ mr_em_adapted_reduce \
+ mr_reduce_to_weights \
mr_optimize_reduce \
grammar_convert \
atools \
plftools \
- mr_em_train \
collapse_weights
noinst_PROGRAMS = \
@@ -32,8 +34,14 @@ lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a -lz
mr_optimize_reduce_SOURCES = mr_optimize_reduce.cc optimize.cc
mr_optimize_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a -lz
-mr_em_train_SOURCES = mr_em_train.cc
-mr_em_train_LDADD = $(top_srcdir)/decoder/libcdec.a -lz
+mr_em_map_adapter_SOURCES = mr_em_map_adapter.cc
+mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a -lz
+
+mr_reduce_to_weights_SOURCES = mr_reduce_to_weights.cc
+mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a -lz
+
+mr_em_adapted_reduce_SOURCES = mr_em_adapted_reduce.cc
+mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a -lz
plftools_SOURCES = plftools.cc
plftools_LDADD = $(top_srcdir)/decoder/libcdec.a -lz
diff --git a/training/cluster-em.pl b/training/cluster-em.pl
index 175870da..267ab642 100755
--- a/training/cluster-em.pl
+++ b/training/cluster-em.pl
@@ -3,44 +3,46 @@
use strict;
my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR; }
use Getopt::Long;
-my $parallel = 1;
+my $parallel = 0;
my $CWD=`pwd`; chomp $CWD;
-my $BIN_DIR = "/chomes/redpony/cdyer-svn-repo/cdec/src";
-my $OPTIMIZER = "$BIN_DIR/mr_em_train";
-my $DECODER = "$BIN_DIR/cdec";
-my $COMBINER_CACHE_SIZE = 150;
+my $BIN_DIR = "$CWD/..";
+my $REDUCER = "$BIN_DIR/training/mr_em_adapted_reduce";
+my $REDUCE2WEIGHTS = "$BIN_DIR/training/mr_reduce_to_weights";
+my $ADAPTER = "$BIN_DIR/training/mr_em_map_adapter";
+my $DECODER = "$BIN_DIR/decoder/cdec";
+my $COMBINER_CACHE_SIZE = 10000000;
my $PARALLEL = "/chomes/redpony/svn-trunk/sa-utils/parallelize.pl";
-die "Can't find $OPTIMIZER" unless -f $OPTIMIZER;
-die "Can't execute $OPTIMIZER" unless -x $OPTIMIZER;
+die "Can't find $REDUCER" unless -f $REDUCER;
+die "Can't execute $REDUCER" unless -x $REDUCER;
+die "Can't find $REDUCE2WEIGHTS" unless -f $REDUCE2WEIGHTS;
+die "Can't execute $REDUCE2WEIGHTS" unless -x $REDUCE2WEIGHTS;
+die "Can't find $ADAPTER" unless -f $ADAPTER;
+die "Can't execute $ADAPTER" unless -x $ADAPTER;
die "Can't find $DECODER" unless -f $DECODER;
die "Can't execute $DECODER" unless -x $DECODER;
-die "Can't find $PARALLEL" unless -f $PARALLEL;
-die "Can't execute $PARALLEL" unless -x $PARALLEL;
my $restart = '';
if ($ARGV[0] && $ARGV[0] eq '--restart') { shift @ARGV; $restart = 1; }
-die "Usage: $0 [--restart] training.corpus weights.init grammar.file [grammar2.file] ...\n" unless (scalar @ARGV >= 3);
+die "Usage: $0 [--restart] training.corpus cdec.ini\n" unless (scalar @ARGV == 2);
my $training_corpus = shift @ARGV;
-my $initial_weights = shift @ARGV;
-my @in_grammar_files = @ARGV;
+my $config = shift @ARGV;
my $pmem="2500mb";
my $nodes = 40;
my $max_iteration = 1000;
my $CFLAG = "-C 1";
-unless ($parallel) { $CFLAG = "-C 500"; }
-my @grammar_files;
-for my $g (@in_grammar_files) {
- unless ($g =~ /^\//) { $g = $CWD . '/' . $g; }
- die "Can't find $g" unless -f $g;
- push @grammar_files, $g;
-}
+if ($parallel) {
+ die "Can't find $PARALLEL" unless -f $PARALLEL;
+ die "Can't execute $PARALLEL" unless -x $PARALLEL;
+} else { $CFLAG = "-C 500"; }
+
+my $initial_weights = '';
print STDERR <<EOT;
EM TRAIN CONFIGURATION INFORMATION
- Grammar file(s): @grammar_files
+ Config file: $config
Training corpus: $training_corpus
Initial weights: $initial_weights
Decoder memory: $pmem
@@ -68,11 +70,13 @@ if ($restart) {
die "$dir already exists!\n" if -e $dir;
mkdir $dir or die "Can't create $dir: $!";
- unless ($initial_weights =~ /\.gz$/) {
- `cp $initial_weights $dir/weights.1`;
- `gzip -9 $dir/weights.1`;
- } else {
- `cp $initial_weights $dir/weights.1.gz`;
+ if ($initial_weights) {
+ unless ($initial_weights =~ /\.gz$/) {
+ `cp $initial_weights $dir/weights.1`;
+ `gzip -9 $dir/weights.1`;
+ } else {
+ `cp $initial_weights $dir/weights.1.gz`;
+ }
}
}
@@ -82,14 +86,14 @@ while ($iter < $max_iteration) {
print STDERR " time: $cur_time\n";
my $start = time;
my $next_iter = $iter + 1;
- my $gfile = '-g' . (join ' -g ', @grammar_files);
- my $dec_cmd="$DECODER --feature_expectations -S 999 $CFLAG $gfile -n -w $dir/weights.$iter.gz < $training_corpus 2> $dir/deco.log.$iter";
- my $opt_cmd = "$OPTIMIZER $gfile -o $dir/weights.$next_iter.gz";
+ my $WSTR = "-w $dir/weights.$iter.gz";
+ if ($iter == 1) { $WSTR = ''; }
+ my $dec_cmd="$DECODER --feature_expectations -c $config $WSTR $CFLAG < $training_corpus 2> $dir/deco.log.$iter";
my $pcmd = "$PARALLEL -e $dir/err -p $pmem --nodelist \"$nodelist\" -- ";
my $cmd = "";
if ($parallel) { $cmd = $pcmd; }
- $cmd .= "$dec_cmd | $opt_cmd";
-
+ $cmd .= "$dec_cmd";
+ $cmd .= "| $ADAPTER | sort -k1 | $REDUCER | $REDUCE2WEIGHTS -o $dir/weights.$next_iter.gz";
print STDERR "EXECUTING: $cmd\n";
my $result = `$cmd`;
if ($? != 0) {
diff --git a/word-aligner/support/supplement_weights_file.pl b/word-aligner/support/supplement_weights_file.pl
index 7f804b90..06876043 100755
--- a/word-aligner/support/supplement_weights_file.pl
+++ b/word-aligner/support/supplement_weights_file.pl
@@ -55,7 +55,7 @@ for (my $ss=1; $ss < 100; $ss++) {
# M2_FL:8_SI:3_TI:2=1
for (my $i = 0; $i < $ss; $i++) {
for (my $j = 0; $j < 100; $j++) {
- print "M2_FL:${ss}_SI:${i}_TI:${j} 0\n";
+ print "M2FL:${ss}:TI:${j}_SI:${i} 0\n";
$added++;
}
}