From b9d7da0413403805f035479a0a426c27102032f6 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 3 Jun 2011 20:49:52 -0400 Subject: Add exception catcher around constructor --- decoder/ff_klm.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 35b35d36..71ba9f30 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -385,7 +385,12 @@ KLanguageModel::KLanguageModel(const string& param) { if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { abort(); } - pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); + try { + pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); + } fid_ = FD::Convert(featname); oov_fid_ = FD::Convert(featname+"_OOV"); cerr << "FID: " << oov_fid_ << endl; -- cgit v1.2.3 From dcf4447590277887d65b0bdec7e6818081869a9a Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 3 Jun 2011 20:56:37 -0400 Subject: Code cleanup for vocabulary mapping --- decoder/ff_klm.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 71ba9f30..a3bd0c5f 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -282,11 +282,10 @@ class KLanguageModelImpl { KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : kCDEC_UNK(TD::Convert("")) , add_sos_eos_(!explicit_markers) { - if (true) { - boost::scoped_ptr vm; - vm.reset(new VMapper(&cdec2klm_map_)); + { + VMapper vm(&cdec2klm_map_); lm::ngram::Config conf; - conf.enumerate_vocab = vm.get(); + conf.enumerate_vocab = &vm; ngram_ = new Model(filename.c_str(), conf); } order_ = ngram_->Order(); -- cgit v1.2.3 From 74d3ac177d70b77646f6a0b3b4095d725f893a36 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 8 Jun 2011 00:39:09 -0400 Subject: external MT evaluator client code. most logic in place, needs to be integrated. actually, the whole evaluation architecture needs to be trashed and rewritten from scratch. what a disaster it is --- mteval/Makefile.am | 2 +- mteval/external_scorer.cc | 150 ++++++++++++++++++++++++++++++++++++++++++++++ mteval/external_scorer.h | 35 +++++++++++ 3 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 mteval/external_scorer.cc create mode 100644 mteval/external_scorer.h diff --git a/mteval/Makefile.am b/mteval/Makefile.am index f9277779..95845090 100644 --- a/mteval/Makefile.am +++ b/mteval/Makefile.am @@ -10,7 +10,7 @@ endif noinst_LIBRARIES = libmteval.a -libmteval_a_SOURCES = ter.cc comb_scorer.cc aer_scorer.cc scorer.cc +libmteval_a_SOURCES = ter.cc comb_scorer.cc aer_scorer.cc scorer.cc external_scorer.cc fast_score_SOURCES = fast_score.cc fast_score_LDADD = libmteval.a $(top_srcdir)/utils/libutils.a -lz diff --git a/mteval/external_scorer.cc b/mteval/external_scorer.cc new file mode 100644 index 00000000..4327ce9b --- /dev/null +++ b/mteval/external_scorer.cc @@ -0,0 +1,150 @@ +#include "external_scorer.h" + +#include // popen +#include +#include +#include +#include + +#include "tdict.h" + +using namespace std; + +ScoreServer::ScoreServer(const string& cmd) : pipe_() { + cerr << "Invoking " << cmd << " ..." << endl; + pipe_ = popen(cmd.c_str(), "r+"); + assert(pipe_); + string dummy; + RequestResponse("EVAL ||| Reference initialization string . ||| Testing initialization string .\n", &dummy); + assert(dummy.size() > 0); + cerr << "Connection established.\n"; +} + +ScoreServer::~ScoreServer() { + pclose(pipe_); +} + +double ScoreServer::ComputeScore(const vector& fields) { + ostringstream os; + os << "EVAL"; + for (unsigned i = 0; i < fields.size(); ++i) + os << ' ' << fields[i]; + os << endl; + string sres; + RequestResponse(os.str(), &sres); + return strtod(sres.c_str(), NULL); +} + +void ScoreServer::Evaluate(const vector >& refs, const vector& hyp, vector* fields) { + ostringstream os; + os << "SCORE"; + for (unsigned i = 0; i < refs.size(); ++i) { + os << " |||"; + for (unsigned j = 0; j < refs[i].size(); ++j) { + os << ' ' << TD::Convert(refs[i][j]); + } + } + os << " |||"; + for (unsigned i = 0; i < hyp.size(); ++i) { + os << ' ' << TD::Convert(hyp[i]); + } + os << endl; + string sres; + RequestResponse(os.str(), &sres); + istringstream is(sres); + double val; + fields->clear(); + while(is >> val) { + fields->push_back(val); + } +} + +#define MAX_BUF 16000 + +void ScoreServer::RequestResponse(const string& request, string* response) { + fprintf(pipe_, "%s", request.c_str()); + fflush(pipe_); + char buf[MAX_BUF]; + size_t cr = fread(buf, 1, MAX_BUF, pipe_); + if (cr == 0) { + cerr << "Read error. Request: " << request << endl; + abort(); + } + while (buf[cr-1] != '\n') { + size_t n = fread(&buf[cr], 1, MAX_BUF-cr, pipe_); + assert(n > 0); + cr += n; + assert(cr < MAX_BUF); + } + buf[cr - 1] = 0; + *response = buf; +} + +struct ExternalScore : public ScoreBase { + ExternalScore() : score_server() {} + explicit ExternalScore(const ScoreServer* s) : score_server(s), fields() {} + ExternalScore(const ScoreServer* s, const vector& f) : score_server(s), fields(f) {} + float ComputePartialScore() const { return 0.0;} + float ComputeScore() const { + // TODO make EVAL call + assert(!"not implemented"); + } + void ScoreDetails(string* details) const { + ostringstream os; + os << "EXT=" << ComputeScore() << " <"; + for (unsigned i = 0; i < fields.size(); ++i) + os << (i ? " " : "") << fields[i]; + os << '>'; + *details = os.str(); + } + void PlusPartialEquals(const Score&, int, int, int){ + assert(!"not implemented"); // no idea + } + void PlusEquals(const Score& delta, const float scale) { + assert(!"not implemented"); // don't even know what this is + } + void PlusEquals(const Score& delta) { + if (static_cast(delta).score_server) score_server = static_cast(delta).score_server; + if (fields.size() != static_cast(delta).fields.size()) + fields.resize(max(fields.size(), static_cast(delta).fields.size())); + for (unsigned i = 0; i < static_cast(delta).fields.size(); ++i) + fields[i] += static_cast(delta).fields[i]; + } + ScoreP GetZero() const { + return ScoreP(new ExternalScore(score_server)); + } + ScoreP GetOne() const { + return ScoreP(new ExternalScore(score_server)); + } + void Subtract(const Score& rhs, Score* res) const { + static_cast(res)->score_server = score_server; + vector& rf = static_cast(res)->fields; + rf.resize(max(fields.size(), static_cast(rhs).fields.size())); + for (unsigned i = 0; i < rf.size(); ++i) { + rf[i] = (i < fields.size() ? fields[i] : 0.0f) - + (i < static_cast(rhs).fields.size() ? static_cast(rhs).fields[i] : 0.0f); + } + } + void Encode(string* out) const { + ostringstream os; + } + bool IsAdditiveIdentity() const { + for (int i = 0; i < fields.size(); ++i) + if (fields[i]) return false; + return true; + } + + const ScoreServer* score_server; + vector fields; +}; + +ScoreP ExternalSentenceScorer::ScoreCandidate(const Sentence& hyp) const { + ExternalScore* res = new ExternalScore(eval_server); + eval_server->Evaluate(refs, hyp, &res->fields); + return ScoreP(res); +} + +ScoreP ExternalSentenceScorer::ScoreCCandidate(const Sentence& hyp) const { + assert(!"not implemented"); +} + diff --git a/mteval/external_scorer.h b/mteval/external_scorer.h new file mode 100644 index 00000000..a2c91960 --- /dev/null +++ b/mteval/external_scorer.h @@ -0,0 +1,35 @@ +#ifndef _EXTERNAL_SCORER_H_ +#define _EXTERNAL_SCORER_H_ + +#include +#include + +#include "scorer.h" + +class ScoreServer { + public: + explicit ScoreServer(const std::string& cmd); + virtual ~ScoreServer(); + + double ComputeScore(const std::vector& fields); + void Evaluate(const std::vector >& refs, const std::vector& hyp, std::vector* fields); + + private: + void RequestResponse(const std::string& request, std::string* response); + FILE* pipe_; +}; + +class ExternalSentenceScorer : public SentenceScorer { + public: + virtual ScoreP ScoreCandidate(const Sentence& hyp) const = 0; + virtual ScoreP ScoreCCandidate(const Sentence& hyp) const =0; + protected: + ScoreServer* eval_server; +}; + +class METEORServer : public ScoreServer { + public: + METEORServer() : ScoreServer("java -Xmx1024m -jar meteor-1.3.jar - - -mira -lower") {} +}; + +#endif -- cgit v1.2.3 From c456e5b4470a244de811bf8c070532f8012f5731 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 8 Jun 2011 22:35:06 -0400 Subject: rudimentary support for meteor via an external process. still needs configuration of path, but it should work --- mteval/external_scorer.cc | 79 ++++++++++++++++++++++++++++++++--------------- mteval/external_scorer.h | 28 +++++++++++------ mteval/scorer.cc | 10 ++++-- mteval/scorer.h | 4 +-- vest/dist-vest.pl | 2 ++ 5 files changed, 85 insertions(+), 38 deletions(-) diff --git a/mteval/external_scorer.cc b/mteval/external_scorer.cc index 4327ce9b..3757064b 100644 --- a/mteval/external_scorer.cc +++ b/mteval/external_scorer.cc @@ -2,20 +2,42 @@ #include // popen #include +#include #include #include #include +#include "stringlib.h" #include "tdict.h" using namespace std; +map > ScoreServerManager::servers_; + +class METEORServer : public ScoreServer { + public: + METEORServer() : ScoreServer("java -Xmx1024m -jar /Users/cdyer/software/meteor/meteor-1.3.jar - - -mira -lower -t tune -l en") {} +}; + +ScoreServer* ScoreServerManager::Instance(const string& score_type) { + boost::shared_ptr& s = servers_[score_type]; + if (!s) { + if (score_type == "meteor") { + s.reset(new METEORServer); + } else { + cerr << "Don't know how to create score server for type '" << score_type << "'\n"; + abort(); + } + } + return s.get(); +} + ScoreServer::ScoreServer(const string& cmd) : pipe_() { cerr << "Invoking " << cmd << " ..." << endl; pipe_ = popen(cmd.c_str(), "r+"); - assert(pipe_); + if (!pipe_) { perror("popen"); abort(); } string dummy; - RequestResponse("EVAL ||| Reference initialization string . ||| Testing initialization string .\n", &dummy); + RequestResponse("SCORE ||| Reference initialization string . ||| Testing initialization string .", &dummy); assert(dummy.size() > 0); cerr << "Connection established.\n"; } @@ -24,12 +46,11 @@ ScoreServer::~ScoreServer() { pclose(pipe_); } -double ScoreServer::ComputeScore(const vector& fields) { +float ScoreServer::ComputeScore(const vector& fields) { ostringstream os; - os << "EVAL"; + os << "EVAL |||"; for (unsigned i = 0; i < fields.size(); ++i) os << ' ' << fields[i]; - os << endl; string sres; RequestResponse(os.str(), &sres); return strtod(sres.c_str(), NULL); @@ -48,46 +69,42 @@ void ScoreServer::Evaluate(const vector >& refs, const vectorclear(); - while(is >> val) { + while(is >> val) fields->push_back(val); - } } #define MAX_BUF 16000 void ScoreServer::RequestResponse(const string& request, string* response) { - fprintf(pipe_, "%s", request.c_str()); + //cerr << "@SERVER: " << request << endl; + fputs(request.c_str(), pipe_); + fputc('\n', pipe_); fflush(pipe_); char buf[MAX_BUF]; - size_t cr = fread(buf, 1, MAX_BUF, pipe_); - if (cr == 0) { + if (NULL == fgets(buf, MAX_BUF, pipe_)) { cerr << "Read error. Request: " << request << endl; abort(); } - while (buf[cr-1] != '\n') { - size_t n = fread(&buf[cr], 1, MAX_BUF-cr, pipe_); - assert(n > 0); - cr += n; - assert(cr < MAX_BUF); + size_t len = strlen(buf); + if (len < 2) { + cerr << "Malformed response: " << buf << endl; } - buf[cr - 1] = 0; - *response = buf; + *response = Trim(buf, " \t\n"); + //cerr << "@RESPONSE: '" << *response << "'\n"; } struct ExternalScore : public ScoreBase { ExternalScore() : score_server() {} - explicit ExternalScore(const ScoreServer* s) : score_server(s), fields() {} - ExternalScore(const ScoreServer* s, const vector& f) : score_server(s), fields(f) {} + explicit ExternalScore(ScoreServer* s) : score_server(s), fields() {} + ExternalScore(ScoreServer* s, const vector& f) : score_server(s), fields(f) {} float ComputePartialScore() const { return 0.0;} float ComputeScore() const { - // TODO make EVAL call - assert(!"not implemented"); + return score_server->ComputeScore(fields); } void ScoreDetails(string* details) const { ostringstream os; @@ -127,14 +144,17 @@ struct ExternalScore : public ScoreBase { } void Encode(string* out) const { ostringstream os; + for (unsigned i = 0; i < fields.size(); ++i) + os << (i == 0 ? "" : " ") << fields[i]; + *out = os.str(); } bool IsAdditiveIdentity() const { - for (int i = 0; i < fields.size(); ++i) + for (unsigned i = 0; i < fields.size(); ++i) if (fields[i]) return false; return true; } - const ScoreServer* score_server; + mutable ScoreServer* score_server; vector fields; }; @@ -148,3 +168,12 @@ ScoreP ExternalSentenceScorer::ScoreCCandidate(const Sentence& hyp) const { assert(!"not implemented"); } +ScoreP ExternalSentenceScorer::ScoreFromString(ScoreServer* s, const string& data) { + istringstream is(data); + vector fields; + float val; + while(is >> val) + fields.push_back(val); + return ScoreP(new ExternalScore(s, fields)); +} + diff --git a/mteval/external_scorer.h b/mteval/external_scorer.h index a2c91960..59ece269 100644 --- a/mteval/external_scorer.h +++ b/mteval/external_scorer.h @@ -3,15 +3,20 @@ #include #include +#include +#include +#include #include "scorer.h" class ScoreServer { - public: + friend class ScoreServerManager; + protected: explicit ScoreServer(const std::string& cmd); virtual ~ScoreServer(); - double ComputeScore(const std::vector& fields); + public: + float ComputeScore(const std::vector& fields); void Evaluate(const std::vector >& refs, const std::vector& hyp, std::vector* fields); private: @@ -19,17 +24,22 @@ class ScoreServer { FILE* pipe_; }; +struct ScoreServerManager { + static ScoreServer* Instance(const std::string& score_type); + private: + static std::map > servers_; +}; + class ExternalSentenceScorer : public SentenceScorer { public: - virtual ScoreP ScoreCandidate(const Sentence& hyp) const = 0; - virtual ScoreP ScoreCCandidate(const Sentence& hyp) const =0; + ExternalSentenceScorer(ScoreServer* server, const std::vector >& r) : + SentenceScorer("External", r), eval_server(server) {} + virtual ScoreP ScoreCandidate(const Sentence& hyp) const; + virtual ScoreP ScoreCCandidate(const Sentence& hyp) const; + static ScoreP ScoreFromString(ScoreServer* s, const std::string& data); + protected: ScoreServer* eval_server; }; -class METEORServer : public ScoreServer { - public: - METEORServer() : ScoreServer("java -Xmx1024m -jar meteor-1.3.jar - - -mira -lower") {} -}; - #endif diff --git a/mteval/scorer.cc b/mteval/scorer.cc index 64ce63af..2daa0daa 100644 --- a/mteval/scorer.cc +++ b/mteval/scorer.cc @@ -17,11 +17,12 @@ #include "comb_scorer.h" #include "tdict.h" #include "stringlib.h" +#include "external_scorer.h" using boost::shared_ptr; using namespace std; -void Score::TimesEquals(float scale) { +void Score::TimesEquals(float /*scale*/) { cerr<<"UNIMPLEMENTED except for BLEU (for MIRA): Score::TimesEquals"< Sentences; std::string desc; Sentences refs; - SentenceScorer(std::string desc="SentenceScorer_unknown", Sentences const& refs=Sentences()) : desc(desc),refs(refs) { } + explicit SentenceScorer(std::string desc="SentenceScorer_unknown", Sentences const& refs=Sentences()) : desc(desc),refs(refs) { } std::string verbose_desc() const; virtual float ComputeRefLength(const Sentence& hyp) const; // default: avg of refs.length virtual ~SentenceScorer(); diff --git a/vest/dist-vest.pl b/vest/dist-vest.pl index 789b5b14..b7a862c4 100755 --- a/vest/dist-vest.pl +++ b/vest/dist-vest.pl @@ -118,6 +118,8 @@ if ($usefork) { $usefork = "--use-fork"; } else { $usefork = ''; } if ($metric =~ /^(combi|ter)$/i) { $lines_per_mapper = 40; +} elsif ($metric =~ /^meteor$/i) { + $lines_per_mapper = 2000; # start up time is really high } ($iniFile) = @ARGV; -- cgit v1.2.3 From 9366fc1ce04385290722bd703933bf0c1c166671 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 10 Jun 2011 16:20:17 -0400 Subject: proper use of pipes --- mteval/external_scorer.cc | 52 ++++++++++++++++++++++++++++++++--------------- mteval/external_scorer.h | 4 ++-- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/mteval/external_scorer.cc b/mteval/external_scorer.cc index 3757064b..1c09c2a1 100644 --- a/mteval/external_scorer.cc +++ b/mteval/external_scorer.cc @@ -3,6 +3,7 @@ #include // popen #include #include +#include #include #include #include @@ -16,7 +17,7 @@ map > ScoreServerManager::servers_; class METEORServer : public ScoreServer { public: - METEORServer() : ScoreServer("java -Xmx1024m -jar /Users/cdyer/software/meteor/meteor-1.3.jar - - -mira -lower -t tune -l en") {} + METEORServer() : ScoreServer("java -Xmx1024m -jar /usr0/cdyer/meteor/meteor-1.3.jar - - -mira -lower -t tune -l en") {} }; ScoreServer* ScoreServerManager::Instance(const string& score_type) { @@ -32,10 +33,30 @@ ScoreServer* ScoreServerManager::Instance(const string& score_type) { return s.get(); } -ScoreServer::ScoreServer(const string& cmd) : pipe_() { +ScoreServer::ScoreServer(const string& cmd) { cerr << "Invoking " << cmd << " ..." << endl; - pipe_ = popen(cmd.c_str(), "r+"); - if (!pipe_) { perror("popen"); abort(); } + if (pipe(p2c) < 0) { perror("pipe"); exit(1); } + if (pipe(c2p) < 0) { perror("pipe"); exit(1); } + pid_t cpid = fork(); + if (cpid < 0) { perror("fork"); exit(1); } + if (cpid == 0) { // child + close(p2c[1]); + close(c2p[0]); + dup2(p2c[0], 0); + close(p2c[0]); + dup2(c2p[1], 1); + close(c2p[1]); + cerr << "Exec'ing from child " << cmd << endl; + vector vargs; + SplitOnWhitespace(cmd, &vargs); + const char** cargv = static_cast(malloc(sizeof(const char*) * vargs.size())); + for (unsigned i = 1; i < vargs.size(); ++i) cargv[i-1] = vargs[i].c_str(); + cargv[vargs.size() - 1] = NULL; + execvp(vargs[0].c_str(), (char* const*)cargv); + } else { // parent + close(c2p[1]); + close(p2c[0]); + } string dummy; RequestResponse("SCORE ||| Reference initialization string . ||| Testing initialization string .", &dummy); assert(dummy.size() > 0); @@ -43,7 +64,7 @@ ScoreServer::ScoreServer(const string& cmd) : pipe_() { } ScoreServer::~ScoreServer() { - pclose(pipe_); + // TODO close stuff, join stuff } float ScoreServer::ComputeScore(const vector& fields) { @@ -81,21 +102,20 @@ void ScoreServer::Evaluate(const vector >& refs, const vector { diff --git a/mteval/external_scorer.h b/mteval/external_scorer.h index 59ece269..a28fb920 100644 --- a/mteval/external_scorer.h +++ b/mteval/external_scorer.h @@ -2,7 +2,6 @@ #define _EXTERNAL_SCORER_H_ #include -#include #include #include #include @@ -21,7 +20,8 @@ class ScoreServer { private: void RequestResponse(const std::string& request, std::string* response); - FILE* pipe_; + int p2c[2]; + int c2p[2]; }; struct ScoreServerManager { -- cgit v1.2.3 From 205893513c8343fdc55789e427fab4c8b536dc12 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sun, 26 Jun 2011 18:40:15 -0400 Subject: Quantization --- BUILDING | 5 +- decoder/cdec_ff.cc | 1 + decoder/ff_klm.cc | 1 + klm/compile.sh | 4 +- klm/lm/Makefile.am | 1 + klm/lm/binary_format.cc | 17 +++- klm/lm/binary_format.hh | 10 ++- klm/lm/blank.hh | 4 + klm/lm/build_binary.cc | 83 +++++++++++++---- klm/lm/config.cc | 2 + klm/lm/config.hh | 5 ++ klm/lm/model.cc | 25 +++--- klm/lm/model.hh | 21 +++-- klm/lm/model_test.cc | 13 +-- klm/lm/quantize.cc | 84 ++++++++++++++++++ klm/lm/quantize.hh | 207 +++++++++++++++++++++++++++++++++++++++++++ klm/lm/search_hashed.cc | 38 ++++++-- klm/lm/search_hashed.hh | 122 ++++++++++++------------- klm/lm/search_trie.cc | 121 ++++++++++++++++++++----- klm/lm/search_trie.hh | 132 +++++++++++++++------------ klm/lm/trie.cc | 123 ++++++++++--------------- klm/lm/trie.hh | 33 +++---- klm/lm/vocab.cc | 18 ++-- klm/lm/vocab.hh | 35 +++++--- klm/util/bit_packing.cc | 4 +- klm/util/bit_packing.hh | 39 +++++--- klm/util/bit_packing_test.cc | 25 ++++-- klm/util/sorted_uniform.hh | 120 +++++++++++++++++-------- 28 files changed, 921 insertions(+), 372 deletions(-) create mode 100644 klm/lm/quantize.cc create mode 100644 klm/lm/quantize.hh diff --git a/BUILDING b/BUILDING index dcb3d45b..b7535d70 100644 --- a/BUILDING +++ b/BUILDING @@ -1,6 +1,5 @@ To build cdec, you'll need: - * SRILM (register and download from http://www.speech.sri.com/projects/srilm/) * Google c++ testing framework (http://code.google.com/p/googletest/) * boost headers & boost program_options (you may need to install a package like boost-devel) @@ -9,7 +8,7 @@ To build cdec, you'll need: Instructions for building ----------------------------------- - 1) Download and build SRILM + 1) Optional: Download and build SRILM 2) Download, build, and install Google Test (optional, this is necessary to build unit tests that may be useful in development; system tests @@ -22,7 +21,7 @@ Instructions for building 4) Configure and build. Your command will look something like this. - ./configure --with-srilm=/home/me/software/srilm-1.5.9 --disable-gtest + ./configure --disable-gtest make If you get errors during configure about missing BOOST macros, then step 3 diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 37aa655b..31f88a4f 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -55,6 +55,7 @@ void register_feature_functions() { ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); ff_registry.Register("KLanguageModel", new FFFactory >()); ff_registry.Register("KLanguageModel_Trie", new FFFactory >()); + ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory >()); ff_registry.Register("KLanguageModel_Probing", new FFFactory >()); ff_registry.Register("NonLatinCount", new FFFactory); ff_registry.Register("RuleShape", new FFFactory); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index a3bd0c5f..9b7fe2d3 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -437,4 +437,5 @@ void KLanguageModel::FinalTraversalFeatures(const void* ant_state, // instantiate templates template class KLanguageModel; template class KLanguageModel; +template class KLanguageModel; diff --git a/klm/compile.sh b/klm/compile.sh index 49e04db8..6ca85e1f 100755 --- a/klm/compile.sh +++ b/klm/compile.sh @@ -3,11 +3,9 @@ #If your code uses ICU, edit util/string_piece.hh and uncomment #define USE_ICU #I use zlib by default. If you don't want to depend on zlib, remove #define USE_ZLIB from util/file_piece.hh -#don't need to use if compiling with moses Makefiles already - set -e -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do +for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o done g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 61d98d97..395494bc 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -15,6 +15,7 @@ libklm_a_SOURCES = \ binary_format.cc \ config.cc \ lm_exception.cc \ + quantize.cc \ model.cc \ ngram_query.cc \ read_arpa.cc \ diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 34d9ffca..92b1008b 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -80,6 +80,14 @@ void WriteHeader(void *to, const Parameters ¶ms) { } // namespace +void SeekOrThrow(int fd, off_t off) { + if ((off_t)-1 == lseek(fd, off, SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed"); +} + +void AdvanceOrThrow(int fd, off_t off) { + if ((off_t)-1 == lseek(fd, off, SEEK_CUR)) UTIL_THROW(util::ErrnoException, "Seek failed"); +} + uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { if (config.write_mmap) { std::size_t total = TotalHeaderSize(order) + memory_size; @@ -156,7 +164,7 @@ bool IsBinaryFormat(int fd) { } void ReadHeader(int fd, Parameters &out) { - if ((off_t)-1 == lseek(fd, sizeof(Sanity), SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed in binary file"); + SeekOrThrow(fd, sizeof(Sanity)); ReadLoop(fd, &out.fixed, sizeof(out.fixed)); if (out.fixed.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0."); @@ -173,6 +181,10 @@ void MatchCheck(ModelType model_type, const Parameters ¶ms) { } } +void SeekPastHeader(int fd, const Parameters ¶ms) { + SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); +} + uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { const off_t file_size = util::SizeFile(backing.file.get()); // The header is smaller than a page, so we have to map the whole header as well. @@ -186,8 +198,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); if (config.enumerate_vocab) { - if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) - UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words"); + SeekOrThrow(backing.file.get(), total_map); } return reinterpret_cast(backing.search.get()) + TotalHeaderSize(params.counts.size()); } diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 1fc71be4..2b32b450 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -16,7 +16,7 @@ namespace lm { namespace ngram { -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2} ModelType; +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType; /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in @@ -48,6 +48,10 @@ struct Backing { util::scoped_memory search; }; +void SeekOrThrow(int fd, off_t off); +// Seek forward +void AdvanceOrThrow(int fd, off_t off); + // Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. @@ -65,6 +69,8 @@ void ReadHeader(int fd, Parameters ¶ms); void MatchCheck(ModelType model_type, const Parameters ¶ms); +void SeekPastHeader(int fd, const Parameters ¶ms); + uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing); void ComplainAboutARPA(const Config &config, ModelType model_type); @@ -83,6 +89,8 @@ template void LoadLM(const char *file, const Config &config, To &to) // Replace the run-time configured probing_multiplier with the one in the file. Config new_config(config); new_config.probing_multiplier = params.fixed.probing_multiplier; + detail::SeekPastHeader(backing.file.get(), params); + To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config); std::size_t memory_size = To::Size(params.counts, new_config); uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); to.InitializeFromBinary(start, params, new_config, backing.file.get()); diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 4615a09e..162411a9 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -22,6 +22,8 @@ namespace ngram { */ const float kNoExtensionBackoff = -0.0; const float kExtensionBackoff = 0.0; +const uint64_t kNoExtensionQuant = 0; +const uint64_t kExtensionQuant = 1; inline void SetExtension(float &backoff) { if (backoff == kNoExtensionBackoff) backoff = kExtensionBackoff; @@ -47,6 +49,8 @@ inline bool HasExtension(const float &backoff) { */ const float kBlankProb = -std::numeric_limits::infinity(); const float kBlankBackoff = kNoExtensionBackoff; +const uint32_t kBlankProbQuant = 0; +const uint32_t kBlankBackoffQuant = 0; } // namespace ngram } // namespace lm diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 91ad2fb9..4552c419 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,22 +15,21 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n" "-u sets the default log10 probability for if the ARPA file does not have\n" "one.\n" "-s allows models to be built even if they do not have and .\n" "-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" -"type is one of probing, trie, or sorted:\n\n" +"type is either probing or trie:\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" "trie is a straightforward trie with bit-level packing. It uses the least\n" "memory and is still faster than SRI or IRST. Building the trie format uses an\n" "on-disk sort to save memory.\n" "-t is the temporary directory prefix. Default is the output file name.\n" -"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n\n" -/*"sorted is like probing but uses a sorted uniform map instead of a hash table.\n" -"It uses more memory than trie and is also slower, so there's no real reason to\n" -"use it.\n\n"*/ +"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n" +"-q turns quantization on and sets the number of bits (e.g. -q 8).\n" +"-b sets backoff quantization bits. Requires -q and defaults to that value.\n\n" "See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" "Passing only an input file will print memory usage of each data structure.\n" "If the ARPA file does not have , -u sets 's probability; default 0.0.\n"; @@ -51,19 +50,53 @@ unsigned long int ParseUInt(const char *from) { return ret; } +uint8_t ParseBitCount(const char *from) { + unsigned long val = ParseUInt(from); + if (val > 25) { + util::ParseNumberException e(from); + e << " bit counts are limited to 256."; + } + return val; +} + void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t probing_size = ProbingModel::Size(counts, config); - // probing is always largest so use it to determine number of columns. - long int length = std::max(5, lrint(ceil(log10(probing_size)))); + std::size_t sizes[3]; + sizes[0] = ProbingModel::Size(counts, config); + sizes[1] = TrieModel::Size(counts, config); + sizes[2] = QuantTrieModel::Size(counts, config); + std::size_t max_length = *std::max_element(sizes, sizes + 3); + std::size_t min_length = *std::max_element(sizes, sizes + 3); + std::size_t divide; + char prefix; + if (min_length < (1 << 10) * 10) { + prefix = ' '; + divide = 1; + } else if (min_length < (1 << 20) * 10) { + prefix = 'k'; + divide = 1 << 10; + } else if (min_length < (1ULL << 30) * 10) { + prefix = 'M'; + divide = 1 << 20; + } else { + prefix = 'G'; + divide = 1 << 30; + } + long int length = std::max(2, lrint(ceil(log10(max_length / divide)))); std::cout << "Memory estimate:\ntype "; // right align bytes. - for (long int i = 0; i < length - 5; ++i) std::cout << ' '; - std::cout << "bytes\n" - "probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n" - "trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n"; + for (long int i = 0; i < length - 2; ++i) std::cout << ' '; + std::cout << prefix << "B\n" + "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" + "trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n" + "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"; +} + +void ProbingQuantizationUnsupported() { + std::cerr << "Quantization is only implemented in the trie data structure." << std::endl; + exit(1); } } // namespace ngram @@ -73,11 +106,21 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { int main(int argc, char *argv[]) { using namespace lm::ngram; + bool quantize = false, set_backoff_bits = false; try { lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "siu:p:t:m:")) != -1) { + while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) { switch(opt) { + case 'q': + config.prob_bits = ParseBitCount(optarg); + if (!set_backoff_bits) config.backoff_bits = config.prob_bits; + quantize = true; + break; + case 'b': + config.backoff_bits = ParseBitCount(optarg); + set_backoff_bits = true; + break; case 'u': config.unknown_missing_logprob = ParseFloat(optarg); break; @@ -100,19 +143,29 @@ int main(int argc, char *argv[]) { Usage(argv[0]); } } + if (!quantize && set_backoff_bits) { + std::cerr << "You specified backoff quantization (-b) but not probability quantization (-q)" << std::endl; + abort(); + } if (optind + 1 == argc) { ShowSizes(argv[optind], config); } else if (optind + 2 == argc) { config.write_mmap = argv[optind + 1]; + if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); ProbingModel(argv[optind], config); } else if (optind + 3 == argc) { const char *model_type = argv[optind]; const char *from_file = argv[optind + 1]; config.write_mmap = argv[optind + 2]; if (!strcmp(model_type, "probing")) { + if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); ProbingModel(from_file, config); } else if (!strcmp(model_type, "trie")) { - TrieModel(from_file, config); + if (quantize) { + QuantTrieModel(from_file, config); + } else { + TrieModel(from_file, config); + } } else { Usage(argv[0]); } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index cee8fce2..08e1af5c 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -18,6 +18,8 @@ Config::Config() : arpa_complain(ALL), write_mmap(NULL), include_vocab(true), + prob_bits(8), + backoff_bits(8), load_method(util::POPULATE_OR_READ) {} } // namespace ngram diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 6c7fe39b..dcc7cf35 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -71,6 +71,11 @@ struct Config { // Include the vocab in the binary file? Only effective if write_mmap != NULL. bool include_vocab; + // Quantization options. Only effective for QuantTrieModel. One value is + // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used + // to quantize. + uint8_t prob_bits, backoff_bits; + // ONLY EFFECTIVE WHEN READING BINARY diff --git a/klm/lm/model.cc b/klm/lm/model.cc index f0579c0c..a1d10b3d 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -44,17 +44,13 @@ template GenericModel::Ge begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; State null_context = State(); null_context.valid_length_ = 0; - P::Init(begin_sentence, null_context, vocab_, search_.middle.size() + 2); + P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); } template void GenericModel::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { SetupMemory(start, params.counts, config); vocab_.LoadedBinary(fd, config.enumerate_vocab); - search_.unigram.LoadedBinary(); - for (typename std::vector::iterator i = search_.middle.begin(); i != search_.middle.end(); ++i) { - i->LoadedBinary(); - } - search_.longest.LoadedBinary(); + search_.LoadedBinary(); } template void GenericModel::InitializeFromARPA(const char *file, const Config &config) { @@ -116,8 +112,9 @@ template FullScoreReturn GenericModel void GenericModel FullScoreReturn GenericModel::const_iterator mid_iter = search_.middle.begin(); + const typename Search::Middle *mid_iter = search_.MiddleBegin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { // Ran out of history. Typically no backoff, but this could be a blank. @@ -192,7 +189,7 @@ template FullScoreReturn GenericModel FullScoreReturn GenericModel; -template class GenericModel; -template class GenericModel; +template class GenericModel; // HASH_PROBING +template class GenericModel, SortedVocabulary>; // TRIE_SORTED +template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT } // namespace detail } // namespace ngram diff --git a/klm/lm/model.hh b/klm/lm/model.hh index b85ccdcc..1f49a382 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -5,6 +5,7 @@ #include "lm/config.hh" #include "lm/facade.hh" #include "lm/max_order.hh" +#include "lm/quantize.hh" #include "lm/search_hashed.hh" #include "lm/search_trie.hh" #include "lm/vocab.hh" @@ -70,9 +71,10 @@ template class GenericModel : public base::Mod private: typedef base::ModelFacade, State, VocabularyT> P; public: - // Get the size of memory that will be mapped given ngram counts. This - // does not include small non-mapped control structures, such as this class - // itself. + /* Get the size of memory that will be mapped given ngram counts. This + * does not include small non-mapped control structures, such as this class + * itself. + */ static size_t Size(const std::vector &counts, const Config &config = Config()); /* Load the model from a file. It may be an ARPA or binary file. Binary @@ -111,6 +113,11 @@ template class GenericModel : public base::Mod private: friend void LoadLM<>(const char *file, const Config &config, GenericModel &to); + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { + AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); + Search::UpdateConfigFromBinary(fd, counts, config); + } + float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const; FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; @@ -130,9 +137,7 @@ template class GenericModel : public base::Mod VocabularyT vocab_; - typedef typename Search::Unigram Unigram; typedef typename Search::Middle Middle; - typedef typename Search::Longest Longest; Search search_; }; @@ -141,13 +146,15 @@ template class GenericModel : public base::Mod // These must also be instantiated in the cc file. typedef ::lm::ngram::ProbingVocabulary Vocabulary; -typedef detail::GenericModel ProbingModel; +typedef detail::GenericModel ProbingModel; // HASH_PROBING // Default implementation. No real reason for it to be the default. typedef ProbingModel Model; // Smaller implementation. typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel TrieModel; +typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED + +typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED } // namespace ngram } // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 548c098d..8bf040ff 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -243,13 +243,14 @@ BOOST_AUTO_TEST_CASE(probing) { LoadingTest(); } -/*BOOST_AUTO_TEST_CASE(sorted) { - LoadingTest(); -}*/ BOOST_AUTO_TEST_CASE(trie) { LoadingTest(); } +BOOST_AUTO_TEST_CASE(quant) { + LoadingTest(); +} + template void BinaryTest() { Config config; config.write_mmap = "test.binary"; @@ -275,12 +276,12 @@ template void BinaryTest() { BOOST_AUTO_TEST_CASE(write_and_read_probing) { BinaryTest(); } -/*BOOST_AUTO_TEST_CASE(write_and_read_sorted) { - BinaryTest(); -}*/ BOOST_AUTO_TEST_CASE(write_and_read_trie) { BinaryTest(); } +BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) { + BinaryTest(); +} } // namespace } // namespace ngram diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc new file mode 100644 index 00000000..b4d76893 --- /dev/null +++ b/klm/lm/quantize.cc @@ -0,0 +1,84 @@ +#include "lm/quantize.hh" + +#include "lm/lm_exception.hh" + +#include +#include + +#include + +namespace lm { +namespace ngram { + +/* Quantize into bins of equal size as described in + * M. Federico and N. Bertoldi. 2006. How many bits are needed + * to store probabilities for phrase-based translation? In Proc. + * of the Workshop on Statistical Machine Translation, pages + * 94–101, New York City, June. Association for Computa- + * tional Linguistics. + */ + +namespace { + +void MakeBins(float *values, float *values_end, float *centers, uint32_t bins) { + std::sort(values, values_end); + const float *start = values, *finish; + for (uint32_t i = 0; i < bins; ++i, ++centers, start = finish) { + finish = values + (((values_end - values) * static_cast(i + 1)) / bins); + if (finish == start) { + // zero length bucket. + *centers = i ? *(centers - 1) : -std::numeric_limits::infinity(); + } else { + *centers = std::accumulate(start, finish, 0.0) / static_cast(finish - start); + } + } +} + +const char kSeparatelyQuantizeVersion = 1; + +} // namespace + +void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector &/*counts*/, Config &config) { + char version; + if (read(fd, &version, 1) != 1 || read(fd, &config.prob_bits, 1) != 1 || read(fd, &config.backoff_bits, 1) != 1) + UTIL_THROW(util::ErrnoException, "Failed to read header for quantization."); + if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion); +} + +void SeparatelyQuantize::SetupMemory(void *start, const Config &config) { + // Reserve 8 byte header for bit counts. + start_ = reinterpret_cast(static_cast(start) + 8); + prob_bits_ = config.prob_bits; + backoff_bits_ = config.backoff_bits; + // We need the reserved values. + if (config.prob_bits == 0) UTIL_THROW(ConfigException, "You can't quantize probability to zero"); + if (config.backoff_bits == 0) UTIL_THROW(ConfigException, "You can't quantize backoff to zero"); + if (config.prob_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing probability supports at most 25 bits. Currently you have requested " << static_cast(config.prob_bits) << " bits."); + if (config.backoff_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing backoff supports at most 25 bits. Currently you have requested " << static_cast(config.backoff_bits) << " bits."); +} + +void SeparatelyQuantize::Train(uint8_t order, std::vector &prob, std::vector &backoff) { + TrainProb(order, prob); + + // Backoff + float *centers = start_ + TableStart(order) + ProbTableLength(); + *(centers++) = kNoExtensionBackoff; + *(centers++) = kExtensionBackoff; + MakeBins(&*backoff.begin(), &*backoff.end(), centers, (1ULL << backoff_bits_) - 2); +} + +void SeparatelyQuantize::TrainProb(uint8_t order, std::vector &prob) { + float *centers = start_ + TableStart(order); + *(centers++) = kBlankProb; + MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_) - 1); +} + +void SeparatelyQuantize::FinishedLoading(const Config &config) { + uint8_t *actual_base = reinterpret_cast(start_) - 8; + *(actual_base++) = kSeparatelyQuantizeVersion; // version + *(actual_base++) = config.prob_bits; + *(actual_base++) = config.backoff_bits; +} + +} // namespace ngram +} // namespace lm diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh new file mode 100644 index 00000000..aae72b34 --- /dev/null +++ b/klm/lm/quantize.hh @@ -0,0 +1,207 @@ +#ifndef LM_QUANTIZE_H__ +#define LM_QUANTIZE_H__ + +#include "lm/binary_format.hh" // for ModelType +#include "lm/blank.hh" +#include "lm/config.hh" +#include "util/bit_packing.hh" + +#include +#include + +#include + +#include + +namespace lm { +namespace ngram { + +class Config; + +/* Store values directly and don't quantize. */ +class DontQuantize { + public: + static const ModelType kModelType = TRIE_SORTED; + static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} + static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } + static uint8_t MiddleBits(const Config &/*config*/) { return 63; } + static uint8_t LongestBits(const Config &/*config*/) { return 31; } + + struct Middle { + void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { + util::WriteNonPositiveFloat31(base, bit_offset, prob); + util::WriteFloat32(base, bit_offset + 31, backoff); + } + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + backoff = util::ReadFloat32(base, bit_offset + 31); + } + void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { + backoff = util::ReadFloat32(base, bit_offset + 31); + } + uint8_t TotalBits() const { return 63; } + }; + + struct Longest { + void Write(void *base, uint64_t bit_offset, float prob) const { + util::WriteNonPositiveFloat31(base, bit_offset, prob); + } + void Read(const void *base, uint64_t bit_offset, float &prob) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + } + uint8_t TotalBits() const { return 31; } + }; + + DontQuantize() {} + + void SetupMemory(void * /*start*/, const Config & /*config*/) {} + + static const bool kTrain = false; + // These should never be called because kTrain is false. + void Train(uint8_t /*order*/, std::vector &/*prob*/, std::vector &/*backoff*/) {} + void TrainProb(uint8_t, std::vector &/*prob*/) {} + + void FinishedLoading(const Config &) {} + + Middle Mid(uint8_t /*order*/) const { return Middle(); } + Longest Long(uint8_t /*order*/) const { return Longest(); } +}; + +class SeparatelyQuantize { + private: + class Bins { + public: + // Sigh C++ default constructor + Bins() {} + + Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} + + uint64_t EncodeProb(float value) const { + return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1)); + } + + uint64_t EncodeBackoff(float value) const { + if (value == 0.0) { + return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant; + } + return Encode(value, 2); + } + + float Decode(std::size_t off) const { return begin_[off]; } + + uint8_t Bits() const { return bits_; } + + uint64_t Mask() const { return mask_; } + + private: + uint64_t Encode(float value, size_t reserved) const { + const float *above = std::lower_bound(begin_ + reserved, end_, value); + if (above == begin_ + reserved) return reserved; + if (above == end_) return end_ - begin_ - 1; + return above - begin_ - (value - *(above - 1) < *above - value); + } + + const float *begin_; + const float *end_; + uint8_t bits_; + uint64_t mask_; + }; + + public: + static const ModelType kModelType = QUANT_TRIE_SORTED; + + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config); + + static std::size_t Size(uint8_t order, const Config &config) { + size_t longest_table = (static_cast(1) << static_cast(config.prob_bits)) * sizeof(float); + size_t middle_table = (static_cast(1) << static_cast(config.backoff_bits)) * sizeof(float) + longest_table; + // unigrams are currently not quantized so no need for a table. + return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8; + } + + static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; } + static uint8_t LongestBits(const Config &config) { return config.prob_bits; } + + class Middle { + public: + Middle(uint8_t prob_bits, const float *prob_begin, uint8_t backoff_bits, const float *backoff_begin) : + total_bits_(prob_bits + backoff_bits), total_mask_((1ULL << total_bits_) - 1), prob_(prob_bits, prob_begin), backoff_(backoff_bits, backoff_begin) {} + + void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { + util::WriteInt57(base, bit_offset, total_bits_, + (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); + } + + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { + uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); + prob = prob_.Decode(both >> backoff_.Bits()); + backoff = backoff_.Decode(both & backoff_.Mask()); + } + + void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { + backoff = backoff_.Decode(util::ReadInt25(base, bit_offset, backoff_.Bits(), backoff_.Mask())); + } + + uint8_t TotalBits() const { + return total_bits_; + } + + private: + const uint8_t total_bits_; + const uint64_t total_mask_; + const Bins prob_; + const Bins backoff_; + }; + + class Longest { + public: + // Sigh C++ default constructor + Longest() {} + + Longest(uint8_t prob_bits, const float *prob_begin) : prob_(prob_bits, prob_begin) {} + + void Write(void *base, uint64_t bit_offset, float prob) const { + util::WriteInt25(base, bit_offset, prob_.Bits(), prob_.EncodeProb(prob)); + } + + void Read(const void *base, uint64_t bit_offset, float &prob) const { + prob = prob_.Decode(util::ReadInt25(base, bit_offset, prob_.Bits(), prob_.Mask())); + } + + uint8_t TotalBits() const { return prob_.Bits(); } + + private: + Bins prob_; + }; + + SeparatelyQuantize() {} + + void SetupMemory(void *start, const Config &config); + + static const bool kTrain = true; + // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff. + void Train(uint8_t order, std::vector &prob, std::vector &backoff); + // Train just probabilities (for longest order). + void TrainProb(uint8_t order, std::vector &prob); + + void FinishedLoading(const Config &config); + + Middle Mid(uint8_t order) const { + const float *table = start_ + TableStart(order); + return Middle(prob_bits_, table, backoff_bits_, table + ProbTableLength()); + } + + Longest Long(uint8_t order) const { return Longest(prob_bits_, start_ + TableStart(order)); } + + private: + size_t TableStart(uint8_t order) const { return ((1ULL << prob_bits_) + (1ULL << backoff_bits_)) * static_cast(order - 2); } + size_t ProbTableLength() const { return (1ULL << prob_bits_); } + + float *start_; + uint8_t prob_bits_, backoff_bits_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_QUANTIZE_H__ diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index eaad59ab..c56ba7b8 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -80,6 +80,21 @@ template void ReadNGrams( } // namespace namespace detail { + +template uint8_t *TemplateHashedSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { + std::size_t allocated = Unigram::Size(counts[0]); + unigram = Unigram(start, allocated); + start += allocated; + for (unsigned int n = 2; n < counts.size(); ++n) { + allocated = Middle::Size(counts[n - 1], config.probing_multiplier); + middle_.push_back(Middle(start, allocated)); + start += allocated; + } + allocated = Longest::Size(counts.back(), config.probing_multiplier); + longest = Longest(start, allocated); + start += allocated; + return start; +} template template void TemplateHashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. @@ -92,15 +107,15 @@ template template void TemplateHashe try { if (counts.size() > 2) { - ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0], warn); + ReadNGrams(f, 2, counts[1], vocab, middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); } for (unsigned int n = 3; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle(middle[n-3]), middle[n-2], warn); + ReadNGrams(f, n, counts[n-1], vocab, middle_, ActivateLowerMiddle(middle_[n-3]), middle_[n-2], warn); } if (counts.size() > 2) { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle(middle.back()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateLowerMiddle(middle_.back()), longest, warn); } else { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateUnigram(unigram.Raw()), longest, warn); } } catch (util::ProbingSizeException &e) { UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); @@ -108,13 +123,18 @@ template template void TemplateHashe ReadEnd(f); } -template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); -template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, SortedVocabulary &vocab, Backing &backing); - -SortedHashedSearch::SortedHashedSearch() { - UTIL_THROW(util::Exception, "Sorted is broken at the moment, sorry"); +template void TemplateHashedSearch::LoadedBinary() { + unigram.LoadedBinary(); + for (typename std::vector::iterator i = middle_.begin(); i != middle_.end(); ++i) { + i->LoadedBinary(); + } + longest.LoadedBinary(); } +template class TemplateHashedSearch; + +template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); + } // namespace detail } // namespace ngram } // namespace lm diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 6dc11fb3..f3acdefc 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -8,7 +8,6 @@ #include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" -#include "util/sorted_uniform.hh" #include #include @@ -62,73 +61,71 @@ struct HashedSearch { } }; -template struct TemplateHashedSearch : public HashedSearch { - typedef MiddleT Middle; - std::vector middle; +template class TemplateHashedSearch : public HashedSearch { + public: + typedef MiddleT Middle; - typedef LongestT Longest; - Longest longest; + typedef LongestT Longest; + Longest longest; - static std::size_t Size(const std::vector &counts, const Config &config) { - std::size_t ret = Unigram::Size(counts[0]); - for (unsigned char n = 1; n < counts.size() - 1; ++n) { - ret += Middle::Size(counts[n], config.probing_multiplier); - } - return ret + Longest::Size(counts.back(), config.probing_multiplier); - } + // TODO: move probing_multiplier here with next binary file format update. + static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} - uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { - std::size_t allocated = Unigram::Size(counts[0]); - unigram = Unigram(start, allocated); - start += allocated; - for (unsigned int n = 2; n < counts.size(); ++n) { - allocated = Middle::Size(counts[n - 1], config.probing_multiplier); - middle.push_back(Middle(start, allocated)); - start += allocated; + static std::size_t Size(const std::vector &counts, const Config &config) { + std::size_t ret = Unigram::Size(counts[0]); + for (unsigned char n = 1; n < counts.size() - 1; ++n) { + ret += Middle::Size(counts[n], config.probing_multiplier); + } + return ret + Longest::Size(counts.back(), config.probing_multiplier); } - allocated = Longest::Size(counts.back(), config.probing_multiplier); - longest = Longest(start, allocated); - start += allocated; - return start; - } - template void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing); + uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &config); - bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { - node = CombineWordHash(node, word); - typename Middle::ConstIterator found; - if (!middle.Find(node, found)) return false; - prob = found->GetValue().prob; - backoff = found->GetValue().backoff; - return true; - } + template void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing); - bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { - node = CombineWordHash(node, word); - typename Middle::ConstIterator found; - if (!middle.Find(node, found)) return false; - backoff = found->GetValue().backoff; - return true; - } + const Middle *MiddleBegin() const { return &*middle_.begin(); } + const Middle *MiddleEnd() const { return &*middle_.end(); } - bool LookupLongest(WordIndex word, float &prob, Node &node) const { - node = CombineWordHash(node, word); - typename Longest::ConstIterator found; - if (!longest.Find(node, found)) return false; - prob = found->GetValue().prob; - return true; - } + bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + prob = found->GetValue().prob; + backoff = found->GetValue().backoff; + return true; + } + + void LoadedBinary(); - // Geenrate a node without necessarily checking that it actually exists. - // Optionally return false if it's know to not exist. - bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - assert(begin != end); - node = static_cast(*begin); - for (const WordIndex *i = begin + 1; i < end; ++i) { - node = CombineWordHash(node, *i); + bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + backoff = found->GetValue().backoff; + return true; } - return true; - } + + bool LookupLongest(WordIndex word, float &prob, Node &node) const { + node = CombineWordHash(node, word); + typename Longest::ConstIterator found; + if (!longest.Find(node, found)) return false; + prob = found->GetValue().prob; + return true; + } + + // Geenrate a node without necessarily checking that it actually exists. + // Optionally return false if it's know to not exist. + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + assert(begin != end); + node = static_cast(*begin); + for (const WordIndex *i = begin + 1; i < end; ++i) { + node = CombineWordHash(node, *i); + } + return true; + } + + private: + std::vector middle_; }; // std::identity is an SGI extension :-( @@ -143,15 +140,6 @@ struct ProbingHashedSearch : public TemplateHashedSearch< static const ModelType kModelType = HASH_PROBING; }; -struct SortedHashedSearch : public TemplateHashedSearch< - util::SortedUniformMap >, - util::SortedUniformMap > > { - - SortedHashedSearch(); - - static const ModelType kModelType = HASH_SORTED; -}; - } // namespace detail } // namespace ngram } // namespace lm diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 7c57072b..1ce4d278 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -4,6 +4,7 @@ #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" +#include "lm/quantize.hh" #include "lm/read_arpa.hh" #include "lm/trie.hh" #include "lm/vocab.hh" @@ -21,6 +22,7 @@ #include #include #include +#include #include #include @@ -579,7 +581,7 @@ bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const W // Phase to count n-grams, including blanks inserted because they were pruned but have extensions class JustCount { public: - JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) + template JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order) : counts_(counts), longest_counts_(counts + order - 1) {} void Unigrams(WordIndex begin, WordIndex end) { @@ -608,9 +610,9 @@ class JustCount { }; // Phase to actually write n-grams to the trie. -class WriteEntries { +template class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -647,14 +649,14 @@ class WriteEntries { private: ContextReader *contexts_; UnigramValue *const unigrams_; - BitPackedMiddle *const middle_; - BitPackedLongest &longest_; + BitPackedMiddle *const middle_; + BitPackedLongest &longest_; BitPacked &bigram_pack_; }; template class RecursiveInsert { public: - RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : + template RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) : doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { } @@ -775,7 +777,51 @@ void SanityCheckCounts(const std::vector &initial, const std::vector &counts, const Config &config, TrieSearch &out, Backing &backing) { +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} + +template void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + ProbBackoff weights; + std::vector probs, backoffs; + probs.reserve(count); + backoffs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); + ++progress; + } + } + quant.Train(order, probs, backoffs); +} + +template void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + Prob weights; + std::vector probs, backoffs; + probs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + ++progress; + } + } + quant.TrainProb(order, probs); +} + +} // namespace + +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing) { std::vector inputs(counts.size() - 1); std::vector contexts(counts.size() - 1); @@ -791,7 +837,7 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co std::vector fixed_counts(counts.size()); { - RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } for (std::vector::const_iterator i = inputs.begin(); i != inputs.end(); ++i) { @@ -800,7 +846,16 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + + if (Quant::kTrain) { + util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); + for (unsigned char i = 2; i < counts.size(); ++i) { + TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant); + } + TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant); + quant.FinishedLoading(config); + } for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -808,7 +863,7 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert inserter(&*inputs.begin(), &*contexts.begin(), unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -845,23 +900,44 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. - if (!out.middle.empty()) { - for (size_t i = 0; i < out.middle.size() - 1; ++i) { - out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); + if (out.middle_begin_ != out.middle_end_) { + for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { + i->FinishedLoading((i+1)->InsertIndex()); } - out.middle.back().FinishedLoading(out.longest.InsertIndex()); + (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex()); } } -bool IsDirectory(const char *path) { - struct stat info; - if (0 != stat(path, &info)) return false; - return S_ISDIR(info.st_mode); +template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { + quant_.SetupMemory(start, config); + start += Quant::Size(counts.size(), config); + unigram.Init(start); + start += Unigram::Size(counts[0]); + FreeMiddles(); + middle_begin_ = static_cast(malloc(sizeof(Middle) * (counts.size() - 2))); + middle_end_ = middle_begin_ + (counts.size() - 2); + for (unsigned char i = counts.size() - 1; i >= 2; --i) { + new (middle_begin_ + i - 2) Middle( + start, + quant_.Mid(i), + counts[0], + counts[i], + (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1])); + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + } + longest.Init(start, quant_.Long(counts.size()), counts[0]); + return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -} // namespace +template void TrieSearch::LoadedBinary() { + unigram.LoadedBinary(); + for (Middle *i = middle_begin_; i != middle_end_; ++i) { + i->LoadedBinary(); + } + longest.LoadedBinary(); +} -void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; @@ -885,12 +961,15 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v // At least 1MB sorting memory. ARPAToSortedFiles(config, f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory, counts, config, *this, backing); + BuildTrie(temporary_directory, counts, config, *this, quant_, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } } +template class TrieSearch; +template class TrieSearch; + } // namespace trie } // namespace ngram } // namespace lm diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0f720217..0a52acb5 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,72 +13,88 @@ struct Backing; class SortedVocabulary; namespace trie { -struct TrieSearch { - typedef NodeRange Node; +template class TrieSearch; +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); - typedef ::lm::ngram::trie::Unigram Unigram; - Unigram unigram; +template class TrieSearch { + public: + typedef NodeRange Node; - typedef trie::BitPackedMiddle Middle; - std::vector middle; + typedef ::lm::ngram::trie::Unigram Unigram; + Unigram unigram; - typedef trie::BitPackedLongest Longest; - Longest longest; + typedef trie::BitPackedMiddle Middle; - static const ModelType kModelType = TRIE_SORTED; + typedef trie::BitPackedLongest Longest; + Longest longest; - static std::size_t Size(const std::vector &counts, const Config &/*config*/) { - std::size_t ret = Unigram::Size(counts[0]); - for (unsigned char i = 1; i < counts.size() - 1; ++i) { - ret += Middle::Size(counts[i], counts[0], counts[i+1]); + static const ModelType kModelType = Quant::kModelType; + + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { + Quant::UpdateConfigFromBinary(fd, counts, config); } - return ret + Longest::Size(counts.back(), counts[0]); - } - - uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &/*config*/) { - unigram.Init(start); - start += Unigram::Size(counts[0]); - middle.resize(counts.size() - 2); - for (unsigned char i = 1; i < counts.size() - 1; ++i) { - middle[i-1].Init( - start, - counts[0], - counts[i+1], - (i == counts.size() - 2) ? static_cast(longest) : static_cast(middle[i])); - start += Middle::Size(counts[i], counts[0], counts[i+1]); + + static std::size_t Size(const std::vector &counts, const Config &config) { + std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); + for (unsigned char i = 1; i < counts.size() - 1; ++i) { + ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); + } + return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } - longest.Init(start, counts[0]); - return start + Longest::Size(counts.back(), counts[0]); - } - - void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - return unigram.Find(word, prob, backoff, node); - } - - bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { - return mid.Find(word, prob, backoff, node); - } - - bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { - return mid.FindNoProb(word, backoff, node); - } - - bool LookupLongest(WordIndex word, float &prob, const Node &node) const { - return longest.Find(word, prob, node); - } - - bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - // TODO: don't decode backoff. - assert(begin != end); - float ignored_prob, ignored_backoff; - LookupUnigram(*begin, ignored_prob, ignored_backoff, node); - for (const WordIndex *i = begin + 1; i < end; ++i) { - if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false; + + TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {} + + ~TrieSearch() { FreeMiddles(); } + + uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &config); + + void LoadedBinary(); + + const Middle *MiddleBegin() const { return middle_begin_; } + const Middle *MiddleEnd() const { return middle_end_; } + + void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); + + bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + return unigram.Find(word, prob, backoff, node); + } + + bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { + return mid.Find(word, prob, backoff, node); } - return true; - } + + bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { + return mid.FindNoProb(word, backoff, node); + } + + bool LookupLongest(WordIndex word, float &prob, const Node &node) const { + return longest.Find(word, prob, node); + } + + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + // TODO: don't decode backoff. + assert(begin != end); + float ignored_prob, ignored_backoff; + LookupUnigram(*begin, ignored_prob, ignored_backoff, node); + for (const WordIndex *i = begin + 1; i < end; ++i) { + if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; + } + return true; + } + + private: + friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); + + // Middles are managed manually so we can delay construction and they don't have to be copyable. + void FreeMiddles() { + for (const Middle *i = middle_begin_; i != middle_end_; ++i) { + i->~Middle(); + } + free(middle_begin_); + } + + Middle *middle_begin_, *middle_end_; + Quant quant_; }; } // namespace trie diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 2c633613..63c2a612 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,8 +1,8 @@ #include "lm/trie.hh" +#include "lm/quantize.hh" #include "util/bit_packing.hh" #include "util/exception.hh" -#include "util/proxy_iterator.hh" #include "util/sorted_uniform.hh" #include @@ -12,53 +12,32 @@ namespace ngram { namespace trie { namespace { -// Assumes key is first. -class JustKeyProxy { +class KeyAccessor { public: - JustKeyProxy() : inner_(), base_(), key_mask_(), key_bits_(), total_bits_() {} + KeyAccessor(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) + : base_(reinterpret_cast(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {} - operator uint64_t() const { return GetKey(); } + typedef uint64_t Key; - uint64_t GetKey() const { - uint64_t bit_off = inner_ * static_cast(total_bits_); - return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_bits_, key_mask_); + Key operator()(uint64_t index) const { + return util::ReadInt57(base_, index * static_cast(total_bits_), key_bits_, key_mask_); } private: - friend class util::ProxyIterator; - friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); - - JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) - : inner_(index), base_(static_cast(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {} - - // This is a read-only iterator. - JustKeyProxy &operator=(const JustKeyProxy &other); - - typedef uint64_t value_type; - - typedef uint64_t InnerIterator; - uint64_t &Inner() { return inner_; } - const uint64_t &Inner() const { return inner_; } - - // The address in bits is base_ * 8 + inner_ * total_bits_. - uint64_t inner_; const uint8_t *const base_; - const uint64_t key_mask_; + const WordIndex key_mask_; const uint8_t key_bits_, total_bits_; }; -bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { - util::ProxyIterator begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits)); - util::ProxyIterator end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits)); - util::ProxyIterator out; - if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false; - at_index = out.Inner(); +bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, const uint64_t max_vocab, const uint64_t key, uint64_t &at_index) { + KeyAccessor accessor(base, key_mask, key_bits, total_bits); + if (!util::BoundedSortedUniformFind::T>(accessor, begin_index - 1, (uint64_t)0, end_index, max_vocab, key, at_index)) return false; return true; } } // namespace std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { - uint8_t total_bits = util::RequiredBits(max_vocab) + 31 + remaining_bits; + uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits; // Extra entry for next pointer at the end. // +7 then / 8 to round up bits and convert to bytes // +sizeof(uint64_t) so that ReadInt57 etc don't go segfault. @@ -71,100 +50,96 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) word_bits_ = util::RequiredBits(max_vocab); word_mask_ = (1ULL << word_bits_) - 1ULL; if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions."); - prob_bits_ = 31; - total_bits_ = word_bits_ + prob_bits_ + remaining_bits; + total_bits_ = word_bits_ + remaining_bits; base_ = static_cast(base); insert_index_ = 0; + max_vocab_ = max_vocab; } -std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { - return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr)); +template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { + return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr)); } -void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) { - next_source_ = &next_source; - backoff_bits_ = 32; - next_bits_ = util::RequiredBits(max_next); +template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) { if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); - next_mask_ = (1ULL << next_bits_) - 1; - - BaseInit(base, max_vocab, backoff_bits_ + next_bits_); + BaseInit(base, max_vocab, quant.TotalBits() + next_bits_); } -void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { +template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { assert(word <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word); + util::WriteInt57(base_, at_pointer, word_bits_, word); at_pointer += word_bits_; - util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); - at_pointer += prob_bits_; - util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff); - at_pointer += backoff_bits_; + quant_.Write(base_, at_pointer, prob, backoff); + at_pointer += quant_.TotalBits(); uint64_t next = next_source_->InsertIndex(); assert(next <= next_mask_); - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next); + util::WriteInt57(base_, at_pointer, next_bits_, next); ++insert_index_; } -bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) { + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } at_pointer *= total_bits_; at_pointer += word_bits_; - prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); - at_pointer += prob_bits_; - backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); - at_pointer += backoff_bits_; - range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + quant_.Read(base_, at_pointer, prob, backoff); + at_pointer += quant_.TotalBits(); + + range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); // Read the next entry's pointer. at_pointer += total_bits_; - range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); return true; } -bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; at_pointer *= total_bits_; at_pointer += word_bits_; - at_pointer += prob_bits_; - backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); - at_pointer += backoff_bits_; - range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + quant_.ReadBackoff(base_, at_pointer, backoff); + at_pointer += quant_.TotalBits(); + range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); // Read the next entry's pointer. at_pointer += total_bits_; - range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); return true; } -void BitPackedMiddle::FinishedLoading(uint64_t next_end) { +template void BitPackedMiddle::FinishedLoading(uint64_t next_end) { assert(next_end <= next_mask_); uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; - util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end); + util::WriteInt57(base_, last_next_write, next_bits_, next_end); } -void BitPackedLongest::Insert(WordIndex index, float prob) { +template void BitPackedLongest::Insert(WordIndex index, float prob) { assert(index <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, index); + util::WriteInt57(base_, at_pointer, word_bits_, index); at_pointer += word_bits_; - util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); + quant_.Write(base_, at_pointer, prob); ++insert_index_; } -bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const { +template bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; at_pointer = at_pointer * total_bits_ + word_bits_; - prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); + quant_.Read(base_, at_pointer, prob); return true; } +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedLongest; +template class BitPackedLongest; + } // namespace trie } // namespace ngram } // namespace lm diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 6aef050c..8fa21aaf 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -74,23 +74,21 @@ class BitPacked { void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); - uint8_t word_bits_, prob_bits_; + uint8_t word_bits_; uint8_t total_bits_; uint64_t word_mask_; uint8_t *base_; - uint64_t insert_index_; + uint64_t insert_index_, max_vocab_; }; -class BitPackedMiddle : public BitPacked { +template class BitPackedMiddle : public BitPacked { public: - BitPackedMiddle() {} - - static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); // next_source need not be initialized. - void Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); + BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); void Insert(WordIndex word, float prob, float backoff); @@ -101,28 +99,33 @@ class BitPackedMiddle : public BitPacked { void FinishedLoading(uint64_t next_end); private: - uint8_t backoff_bits_, next_bits_; + Quant quant_; + uint8_t next_bits_; uint64_t next_mask_; const BitPacked *next_source_; }; -class BitPackedLongest : public BitPacked { +template class BitPackedLongest : public BitPacked { public: - BitPackedLongest() {} - - static std::size_t Size(uint64_t entries, uint64_t max_vocab) { - return BaseSize(entries, max_vocab, 0); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { + return BaseSize(entries, max_vocab, quant_bits); } - void Init(void *base, uint64_t max_vocab) { - return BaseInit(base, max_vocab, 0); + BitPackedLongest() {} + + void Init(void *base, const Quant &quant, uint64_t max_vocab) { + quant_ = quant; + BaseInit(base, max_vocab, quant_.TotalBits()); } void Insert(WordIndex word, float prob); bool Find(WordIndex word, float &prob, const NodeRange &node) const; + + private: + Quant quant_; }; } // namespace trie diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 515af5db..7defd5c1 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -28,8 +28,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("", 5); // Sadly some LMs have . const uint64_t kUnknownCapHash = detail::HashForVocab("", 5); -void ReadWords(int fd, EnumerateVocab *enumerate) { - if (!enumerate) return; +WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { + if (!enumerate) return std::numeric_limits::max(); const std::size_t kInitialRead = 16384; std::string buf; buf.reserve(kInitialRead + 100); @@ -38,7 +38,7 @@ void ReadWords(int fd, EnumerateVocab *enumerate) { while (true) { ssize_t got = read(fd, &buf[0], kInitialRead); if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); - if (got == 0) return; + if (got == 0) return index; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; @@ -87,13 +87,13 @@ SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) { // Lead with the number of entries. - return sizeof(uint64_t) + sizeof(Entry) * entries; + return sizeof(uint64_t) + sizeof(uint64_t) * entries; } void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) { assert(allocated >= Size(entries, config)); // Leave space for number of entries. - begin_ = reinterpret_cast(reinterpret_cast(start) + 1); + begin_ = reinterpret_cast(start) + 1; end_ = begin_; saw_unk_ = false; } @@ -112,7 +112,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { saw_unk_ = true; return 0; } - end_->key = hashed; + *end_ = hashed; if (enumerate_) { strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); } @@ -134,8 +134,10 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { util::JointSort(begin_, end_, reorder_vocab + 1); } SetSpecial(Index(""), Index(""), 0); - // Save size. + // Save size. Excludes UNK. *(reinterpret_cast(begin_) - 1) = end_ - begin_; + // Includes UNK. + bound_ = end_ - begin_ + 1; } void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { @@ -183,7 +185,7 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { lookup_.LoadedBinary(); - ReadWords(fd, to); + available_ = ReadWords(fd, to); SetSpecial(Index(""), Index(""), 0); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 546c1649..c92518e4 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -9,6 +9,7 @@ #include "util/sorted_uniform.hh" #include "util/string_piece.hh" +#include #include #include @@ -44,22 +45,16 @@ class WriteWordsWrapper : public EnumerateVocab { // Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. class SortedVocabulary : public base::Vocabulary { - private: - // Sorted uniform requires a GetKey function. - struct Entry { - uint64_t GetKey() const { return key; } - uint64_t key; - bool operator<(const Entry &other) const { - return key < other.key; - } - }; - public: SortedVocabulary(); WordIndex Index(const StringPiece &str) const { - const Entry *found; - if (util::SortedUniformFind(begin_, end_, detail::HashForVocab(str), found)) { + const uint64_t *found; + if (util::BoundedSortedUniformFind, util::Pivot64>( + util::IdentityAccessor(), + begin_ - 1, 0, + end_, std::numeric_limits::max(), + detail::HashForVocab(str), found)) { return found - begin_ + 1; // +1 because is 0 and does not appear in the lookup table. } else { return 0; @@ -68,6 +63,10 @@ class SortedVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); + // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. + // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases. + WordIndex Bound() const { return bound_; } + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); @@ -83,7 +82,11 @@ class SortedVocabulary : public base::Vocabulary { void LoadedBinary(int fd, EnumerateVocab *to); private: - Entry *begin_, *end_; + uint64_t *begin_, *end_; + + WordIndex bound_; + + WordIndex highest_value_; bool saw_unk_; @@ -105,6 +108,12 @@ class ProbingVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); + // Vocab words are [0, Bound()). + // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary. + // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update. + // Specifically, the binary file format does not currently indicate whether is in count or not. + WordIndex Bound() const { return available_; } + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc index 681da5f2..41999b72 100644 --- a/klm/util/bit_packing.cc +++ b/klm/util/bit_packing.cc @@ -28,10 +28,10 @@ void BitPackingSanity() { memset(mem, 0, sizeof(mem)); const uint64_t test57 = 0x123456789abcdefULL; for (uint64_t b = 0; b < 57 * 8; b += 57) { - WriteInt57(mem + b / 8, b % 8, 57, test57); + WriteInt57(mem, b, 57, test57); } for (uint64_t b = 0; b < 57 * 8; b += 57) { - if (test57 != ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)) + if (test57 != ReadInt57(mem, b, 57, (1ULL << 57) - 1)) UTIL_THROW(Exception, "The bit packing routines are failing for your architecture. Please send a bug report with your architecture, operating system, and compiler."); } // TODO: more checks. diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 5c71c792..b35d80c8 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -42,47 +42,62 @@ inline uint8_t BitPackShift(uint8_t bit, uint8_t length) { #error "Bit packing code isn't written for your byte order." #endif +inline uint64_t ReadOff(const void *base, uint64_t bit_off) { + return *reinterpret_cast(reinterpret_cast(base) + (bit_off >> 3)); +} + /* Pack integers up to 57 bits using their least significant digits. * The length is specified using mask: * Assumes mask == (1 << length) - 1 where length <= 57. */ -inline uint64_t ReadInt57(const void *base, uint8_t bit, uint8_t length, uint64_t mask) { - return (*reinterpret_cast(base) >> BitPackShift(bit, length)) & mask; +inline uint64_t ReadInt57(const void *base, uint64_t bit_off, uint8_t length, uint64_t mask) { + return (ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, length)) & mask; } /* Assumes value < (1 << length) and length <= 57. * Assumes the memory is zero initially. */ -inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) { - *reinterpret_cast(base) |= (value << BitPackShift(bit, length)); +inline void WriteInt57(void *base, uint64_t bit_off, uint8_t length, uint64_t value) { + *reinterpret_cast(reinterpret_cast(base) + (bit_off >> 3)) |= + (value << BitPackShift(bit_off & 7, length)); +} + +/* Same caveats as above, but for a 25 bit limit. */ +inline uint32_t ReadInt25(const void *base, uint64_t bit_off, uint8_t length, uint32_t mask) { + return (*reinterpret_cast(reinterpret_cast(base) + (bit_off >> 3)) >> BitPackShift(bit_off & 7, length)) & mask; +} + +inline void WriteInt25(void *base, uint64_t bit_off, uint8_t length, uint32_t value) { + *reinterpret_cast(reinterpret_cast(base) + (bit_off >> 3)) |= + (value << BitPackShift(bit_off & 7, length)); } typedef union { float f; uint32_t i; } FloatEnc; -inline float ReadFloat32(const void *base, uint8_t bit) { +inline float ReadFloat32(const void *base, uint64_t bit_off) { FloatEnc encoded; - encoded.i = *reinterpret_cast(base) >> BitPackShift(bit, 32); + encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32); return encoded.f; } -inline void WriteFloat32(void *base, uint8_t bit, float value) { +inline void WriteFloat32(void *base, uint64_t bit_off, float value) { FloatEnc encoded; encoded.f = value; - WriteInt57(base, bit, 32, encoded.i); + WriteInt57(base, bit_off, 32, encoded.i); } const uint32_t kSignBit = 0x80000000; -inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) { +inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) { FloatEnc encoded; - encoded.i = *reinterpret_cast(base) >> BitPackShift(bit, 31); + encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31); // Sign bit set means negative. encoded.i |= kSignBit; return encoded.f; } -inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) { +inline void WriteNonPositiveFloat31(void *base, uint64_t bit_off, float value) { FloatEnc encoded; encoded.f = value; encoded.i &= ~kSignBit; - WriteInt57(base, bit, 31, encoded.i); + WriteInt57(base, bit_off, 31, encoded.i); } void BitPackingSanity(); diff --git a/klm/util/bit_packing_test.cc b/klm/util/bit_packing_test.cc index c578ddd1..4edc2004 100644 --- a/klm/util/bit_packing_test.cc +++ b/klm/util/bit_packing_test.cc @@ -9,15 +9,16 @@ namespace util { namespace { const uint64_t test57 = 0x123456789abcdefULL; +const uint32_t test25 = 0x1234567; -BOOST_AUTO_TEST_CASE(ZeroBit) { +BOOST_AUTO_TEST_CASE(ZeroBit57) { char mem[16]; memset(mem, 0, sizeof(mem)); WriteInt57(mem, 0, 57, test57); BOOST_CHECK_EQUAL(test57, ReadInt57(mem, 0, 57, (1ULL << 57) - 1)); } -BOOST_AUTO_TEST_CASE(EachBit) { +BOOST_AUTO_TEST_CASE(EachBit57) { char mem[16]; for (uint8_t b = 0; b < 8; ++b) { memset(mem, 0, sizeof(mem)); @@ -26,15 +27,27 @@ BOOST_AUTO_TEST_CASE(EachBit) { } } -BOOST_AUTO_TEST_CASE(Consecutive) { +BOOST_AUTO_TEST_CASE(Consecutive57) { char mem[57+8]; memset(mem, 0, sizeof(mem)); for (uint64_t b = 0; b < 57 * 8; b += 57) { - WriteInt57(mem + (b / 8), b % 8, 57, test57); - BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); + WriteInt57(mem, b, 57, test57); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); } for (uint64_t b = 0; b < 57 * 8; b += 57) { - BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); + } +} + +BOOST_AUTO_TEST_CASE(Consecutive25) { + char mem[25+8]; + memset(mem, 0, sizeof(mem)); + for (uint64_t b = 0; b < 25 * 8; b += 25) { + WriteInt25(mem, b, 25, test25); + BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); + } + for (uint64_t b = 0; b < 25 * 8; b += 25) { + BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); } } diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index 05826b51..84d7aa02 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -9,52 +9,96 @@ namespace util { -inline std::size_t Pivot(uint64_t off, uint64_t range, std::size_t width) { - std::size_t ret = static_cast(static_cast(off) / static_cast(range) * static_cast(width)); - // Cap for floating point rounding - return (ret < width) ? ret : width - 1; -} -/*inline std::size_t Pivot(uint32_t off, uint32_t range, std::size_t width) { - return static_cast(static_cast(off) * static_cast(width) / static_cast(range)); +template class IdentityAccessor { + public: + typedef T Key; + T operator()(const uint64_t *in) const { return *in; } +}; + +struct Pivot64 { + static inline std::size_t Calc(uint64_t off, uint64_t range, std::size_t width) { + std::size_t ret = static_cast(static_cast(off) / static_cast(range) * static_cast(width)); + // Cap for floating point rounding + return (ret < width) ? ret : width - 1; + } +}; + +// Use when off * width is <2^64. This is guaranteed when each of them is actually a 32-bit value. +struct Pivot32 { + static inline std::size_t Calc(uint64_t off, uint64_t range, uint64_t width) { + return static_cast((off * width) / (range + 1)); + } +}; + +// Usage: PivotSelect::T +template struct PivotSelect; +template <> struct PivotSelect<8> { typedef Pivot64 T; }; +template <> struct PivotSelect<4> { typedef Pivot32 T; }; +template <> struct PivotSelect<2> { typedef Pivot32 T; }; + +/* Binary search. */ +template bool BinaryFind( + const Accessor &accessor, + Iterator begin, + Iterator end, + const typename Accessor::Key key, Iterator &out) { + while (end > begin) { + Iterator pivot(begin + (end - begin) / 2); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + begin = pivot + 1; + } else if (mid > key) { + end = pivot; + } else { + out = pivot; + return true; + } + } + return false; } -inline std::size_t Pivot(uint16_t off, uint16_t range, std::size_t width) { - return static_cast(static_cast(off) * width / static_cast(range)); + +// Search the range [before_it + 1, after_it - 1] for key. +// Preconditions: +// before_v <= key <= after_v +// before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v +// range is sorted. +template bool BoundedSortedUniformFind( + const Accessor &accessor, + Iterator before_it, typename Accessor::Key before_v, + Iterator after_it, typename Accessor::Key after_v, + const typename Accessor::Key key, Iterator &out) { + while (after_it - before_it > 1) { + Iterator pivot(before_it + (1 + Pivot::Calc(key - before_v, after_v - before_v, after_it - before_it - 1))); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + before_it = pivot; + before_v = mid; + } else if (mid > key) { + after_it = pivot; + after_v = mid; + } else { + out = pivot; + return true; + } + } + return false; } -inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t width) { - return static_cast(static_cast(off) * width / static_cast(range)); -}*/ -template bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) { +template bool SortedUniformFind(const Accessor &accessor, Iterator begin, Iterator end, const typename Accessor::Key key, Iterator &out) { if (begin == end) return false; - Key below(begin->GetKey()); + typename Accessor::Key below(accessor(begin)); if (key <= below) { if (key == below) { out = begin; return true; } return false; } // Make the range [begin, end]. --end; - Key above(end->GetKey()); + typename Accessor::Key above(accessor(end)); if (key >= above) { if (key == above) { out = end; return true; } return false; } - - // Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above. - while (end - begin > 1) { - Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast(end - begin - 1)))); - Key mid(pivot->GetKey()); - if (mid < key) { - begin = pivot; - below = mid; - } else if (mid > key) { - end = pivot; - above = mid; - } else { - out = pivot; - return true; - } - } - return false; + return BoundedSortedUniformFind(accessor, begin, below, end, above, key, out); } // To use this template, you need to define a Pivot function to match Key. @@ -64,7 +108,13 @@ template class SortedUniformMap { typedef typename Packing::ConstIterator ConstIterator; typedef typename Packing::MutableIterator MutableIterator; - public: + struct Accessor { + public: + typedef typename Packing::Key Key; + const Key &operator()(const ConstIterator &i) const { return i->GetKey(); } + Key &operator()(const MutableIterator &i) const { return i->GetKey(); } + }; + // Offer consistent API with probing hash. static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) { return sizeof(uint64_t) + entries * Packing::kBytes; @@ -120,7 +170,7 @@ template class SortedUniformMap { assert(initialized_); assert(loaded_); #endif - return SortedUniformFind(begin_, end_, key, out); + return SortedUniformFind(begin_, end_, key, out); } // Do not call before FinishedInserting. @@ -129,7 +179,7 @@ template class SortedUniformMap { assert(initialized_); assert(loaded_); #endif - return SortedUniformFind(ConstIterator(begin_), ConstIterator(end_), key, out); + return SortedUniformFind(Accessor(), ConstIterator(begin_), ConstIterator(end_), key, out); } ConstIterator begin() const { return begin_; } -- cgit v1.2.3 From 59932be2de387ecfcaa81a8387e8f21d5123c050 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 27 Jun 2011 17:50:41 -0400 Subject: Fix binary format for trie --- klm/lm/quantize.cc | 2 +- klm/lm/search_trie.cc | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index b4d76893..4bb6b1b8 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -34,7 +34,7 @@ void MakeBins(float *values, float *values_end, float *centers, uint32_t bins) { } } -const char kSeparatelyQuantizeVersion = 1; +const char kSeparatelyQuantizeVersion = 2; } // namespace diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 1ce4d278..91f87f1c 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -916,14 +916,19 @@ template uint8_t *TrieSearch::SetupMemory(uint8_t *start, c FreeMiddles(); middle_begin_ = static_cast(malloc(sizeof(Middle) * (counts.size() - 2))); middle_end_ = middle_begin_ + (counts.size() - 2); + std::vector middle_starts(counts.size() - 2); + for (unsigned char i = 2; i < counts.size(); ++i) { + middle_starts[i-2] = start; + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + } + // Crazy backwards thing so we initialize in the correct order. for (unsigned char i = counts.size() - 1; i >= 2; --i) { new (middle_begin_ + i - 2) Middle( - start, + middle_starts[i-2], quant_.Mid(i), counts[0], counts[i], (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1])); - start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); } longest.Init(start, quant_.Long(counts.size()), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); -- cgit v1.2.3 From f91319978f6e74e5c4e5701da8fbbacb96a3161e Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 5 Jul 2011 23:19:43 -0400 Subject: fast phrasinator that uses DPs instead of PYPs --- phrasinator/Makefile.am | 8 +- phrasinator/ccrp_nt.h | 154 +++++++++++++++ phrasinator/gibbs_train_plm.notables.cc | 335 ++++++++++++++++++++++++++++++++ phrasinator/train-phrasinator.pl | 2 +- utils/sampler.h | 2 +- 5 files changed, 498 insertions(+), 3 deletions(-) create mode 100644 phrasinator/ccrp_nt.h create mode 100644 phrasinator/gibbs_train_plm.notables.cc diff --git a/phrasinator/Makefile.am b/phrasinator/Makefile.am index 0b15a250..95a603df 100644 --- a/phrasinator/Makefile.am +++ b/phrasinator/Makefile.am @@ -1,6 +1,12 @@ -bin_PROGRAMS = gibbs_train_plm +bin_PROGRAMS = gibbs_train_plm head_bigram_model gibbs_train_plm_notables + +gibbs_train_plm_notables_SOURCES = gibbs_train_plm.notables.cc +gibbs_train_plm_notables_LDADD = $(top_srcdir)/utils/libutils.a -lz gibbs_train_plm_SOURCES = gibbs_train_plm.cc gibbs_train_plm_LDADD = $(top_srcdir)/utils/libutils.a -lz +head_bigram_model_SOURCES = head_bigram_model.cc +head_bigram_model_LDADD = $(top_srcdir)/utils/libutils.a -lz + AM_CPPFLAGS = -funroll-loops -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/phrasinator/ccrp_nt.h b/phrasinator/ccrp_nt.h new file mode 100644 index 00000000..163b643a --- /dev/null +++ b/phrasinator/ccrp_nt.h @@ -0,0 +1,154 @@ +#ifndef _CCRP_NT_H_ +#define _CCRP_NT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "sampler.h" +#include "slice_sampler.h" + +// Chinese restaurant process (Pitman-Yor parameters) with table tracking. + +template > +class CCRP_NoTable { + public: + explicit CCRP_NoTable(double conc) : + num_customers_(), + concentration_(conc), + concentration_prior_shape_(std::numeric_limits::quiet_NaN()), + concentration_prior_rate_(std::numeric_limits::quiet_NaN()) {} + + CCRP_NoTable(double c_shape, double c_rate, double c = 10.0) : + num_customers_(), + concentration_(c), + concentration_prior_shape_(c_shape), + concentration_prior_rate_(c_rate) {} + + double concentration() const { return concentration_; } + + bool has_concentration_prior() const { + return !std::isnan(concentration_prior_shape_); + } + + void clear() { + num_customers_ = 0; + custs_.clear(); + } + + unsigned num_customers() const { + return num_customers_; + } + + unsigned num_customers(const Dish& dish) const { + const typename std::tr1::unordered_map::const_iterator it = custs_.find(dish); + if (it == custs_.end()) return 0; + return it->second; + } + + void increment(const Dish& dish) { + ++custs_[dish]; + ++num_customers_; + } + + void decrement(const Dish& dish) { + if ((--custs_[dish]) == 0) + custs_.erase(dish); + --num_customers_; + } + + double prob(const Dish& dish, const double& p0) const { + const unsigned at_table = num_customers(dish); + return (at_table + p0 * concentration_) / (num_customers_ + concentration_); + } + + double log_crp_prob() const { + return log_crp_prob(concentration_); + } + + static double log_gamma_density(const double& x, const double& shape, const double& rate) { + assert(x >= 0.0); + assert(shape > 0.0); + assert(rate > 0.0); + const double lp = (shape-1)*log(x) - shape*log(rate) - x/rate - lgamma(shape); + return lp; + } + + // taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process + // does not include P_0's + double log_crp_prob(const double& concentration) const { + double lp = 0.0; + if (has_concentration_prior()) + lp += log_gamma_density(concentration, concentration_prior_shape_, concentration_prior_rate_); + assert(lp <= 0.0); + if (num_customers_) { + lp += lgamma(concentration) - lgamma(concentration + num_customers_) + + custs_.size() * log(concentration); + assert(std::isfinite(lp)); + for (typename std::tr1::unordered_map::const_iterator it = custs_.begin(); + it != custs_.end(); ++it) { + lp += lgamma(it->second); + } + } + assert(std::isfinite(lp)); + return lp; + } + + void resample_hyperparameters(MT19937* rng, const unsigned nloop = 5, const unsigned niterations = 10) { + assert(has_concentration_prior()); + ConcentrationResampler cr(*this); + for (int iter = 0; iter < nloop; ++iter) { + concentration_ = slice_sampler1d(cr, concentration_, *rng, 0.0, + std::numeric_limits::infinity(), 0.0, niterations, 100*niterations); + } + } + + struct ConcentrationResampler { + ConcentrationResampler(const CCRP_NoTable& crp) : crp_(crp) {} + const CCRP_NoTable& crp_; + double operator()(const double& proposed_concentration) const { + return crp_.log_crp_prob(proposed_concentration); + } + }; + + void Print(std::ostream* out) const { + (*out) << "DP(alpha=" << concentration_ << ") customers=" << num_customers_ << std::endl; + int cc = 0; + for (typename std::tr1::unordered_map::const_iterator it = custs_.begin(); + it != custs_.end(); ++it) { + (*out) << " " << it->first << "(" << it->second << " eating)"; + ++cc; + if (cc > 10) { (*out) << " ..."; break; } + } + (*out) << std::endl; + } + + unsigned num_customers_; + std::tr1::unordered_map custs_; + + typedef typename std::tr1::unordered_map::const_iterator const_iterator; + const_iterator begin() const { + return custs_.begin(); + } + const_iterator end() const { + return custs_.end(); + } + + double concentration_; + + // optional gamma prior on concentration_ (NaN if no prior) + double concentration_prior_shape_; + double concentration_prior_rate_; +}; + +template +std::ostream& operator<<(std::ostream& o, const CCRP_NoTable& c) { + c.Print(&o); + return o; +} + +#endif diff --git a/phrasinator/gibbs_train_plm.notables.cc b/phrasinator/gibbs_train_plm.notables.cc new file mode 100644 index 00000000..4b431b90 --- /dev/null +++ b/phrasinator/gibbs_train_plm.notables.cc @@ -0,0 +1,335 @@ +#include +#include + +#include +#include + +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "ccrp.h" +#include "ccrp_nt.h" + +using namespace std; +using namespace std::tr1; +namespace po = boost::program_options; + +Dict d; // global dictionary + +string Join(char joiner, const vector& phrase) { + ostringstream os; + for (int i = 0; i < phrase.size(); ++i) { + if (i > 0) os << joiner; + os << d.Convert(phrase[i]); + } + return os.str(); +} + +template +void WriteSeg(const vector& line, const vector& label, const Dict& d) { + assert(line.size() == label.size()); + assert(label.back()); + int prev = 0; + int cur = 0; + while (cur < line.size()) { + if (label[cur]) { + if (prev) cout << ' '; + cout << "{{"; + for (int i = prev; i <= cur; ++i) + cout << (i == prev ? "" : " ") << d.Convert(line[i]); + cout << "}}:" << label[cur]; + prev = cur + 1; + } + ++cur; + } + cout << endl; +} + +ostream& operator<<(ostream& os, const vector& phrase) { + for (int i = 0; i < phrase.size(); ++i) + os << (i == 0 ? "" : " ") << d.Convert(phrase[i]); + return os; +} + +struct UnigramLM { + explicit UnigramLM(const string& fname) { + ifstream in(fname.c_str()); + assert(in); + } + + double logprob(int word) const { + assert(word < freqs_.size()); + return freqs_[word]; + } + + vector freqs_; +}; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value()->default_value(1000),"Number of samples") + ("input,i",po::value(),"Read file from") + ("random_seed,S",po::value(), "Random seed") + ("write_cdec_grammar,g", po::value(), "Write cdec grammar to this file") + ("write_cdec_weights,w", po::value(), "Write cdec weights to this file") + ("poisson_length,p", "Use a Poisson distribution as the length of a phrase in the base distribuion") + ("no_hyperparameter_inference,N", "Disable hyperparameter inference"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void ReadCorpus(const string& filename, vector >* c, set* vocab) { + c->clear(); + istream* in; + if (filename == "-") + in = &cin; + else + in = new ifstream(filename.c_str()); + assert(*in); + string line; + while(*in) { + getline(*in, line); + if (line.empty() && !*in) break; + c->push_back(vector()); + vector& v = c->back(); + d.ConvertWhitespaceDelimitedLine(line, &v); + for (int i = 0; i < v.size(); ++i) vocab->insert(v[i]); + } + if (in != &cin) delete in; +} + +double log_poisson(unsigned x, const double& lambda) { + assert(lambda > 0.0); + return log(lambda) * x - lgamma(x + 1) - lambda; +} + +struct UniphraseLM { + UniphraseLM(const vector >& corpus, + const set& vocab, + const po::variables_map& conf) : + phrases_(1,1), + gen_(1,1), + corpus_(corpus), + uniform_word_(1.0 / vocab.size()), + gen_p0_(0.5), + p_end_(0.5), + use_poisson_(conf.count("poisson_length") > 0) {} + + double p0(const vector& phrase) const { + static vector p0s(10000, 0.0); + assert(phrase.size() < 10000); + double& p = p0s[phrase.size()]; + if (p) return p; + p = exp(log_p0(phrase)); + if (!p) { + cerr << "0 prob phrase: " << phrase << "\nAssigning std::numeric_limits::min()\n"; + p = std::numeric_limits::min(); + } + return p; + } + + double log_p0(const vector& phrase) const { + double len_logprob; + if (use_poisson_) + len_logprob = log_poisson(phrase.size(), 1.0); + else + len_logprob = log(1 - p_end_) * (phrase.size() -1) + log(p_end_); + return log(uniform_word_) * phrase.size() + len_logprob; + } + + double llh() const { + double llh = gen_.log_crp_prob(); + llh += log(gen_p0_) + log(1 - gen_p0_); + double llhr = phrases_.log_crp_prob(); + for (CCRP_NoTable >::const_iterator it = phrases_.begin(); it != phrases_.end(); ++it) { + llhr += log_p0(it->first); + //llhr += log_p0(it->first); + if (!isfinite(llh)) { + cerr << it->first << endl; + cerr << log_p0(it->first) << endl; + abort(); + } + } + return llh + llhr; + } + + void Sample(unsigned int samples, bool hyp_inf, MT19937* rng) { + cerr << "Initializing...\n"; + z_.resize(corpus_.size()); + int tc = 0; + for (int i = 0; i < corpus_.size(); ++i) { + const vector& line = corpus_[i]; + const int ls = line.size(); + const int last_pos = ls - 1; + vector& z = z_[i]; + z.resize(ls); + int prev = 0; + for (int j = 0; j < ls; ++j) { + z[j] = rng->next() < 0.5; + if (j == last_pos) z[j] = true; // break phrase at the end of the sentence + if (z[j]) { + const vector p(line.begin() + prev, line.begin() + j + 1); + phrases_.increment(p); + //cerr << p << ": " << p0(p) << endl; + prev = j + 1; + gen_.increment(false); + ++tc; // remove + } + } + ++tc; + gen_.increment(true); // end of utterance + } + cerr << "TC: " << tc << endl; + cerr << "Initial LLH: " << llh() << endl; + cerr << "Sampling...\n"; + cerr << gen_ << endl; + for (int s = 1; s < samples; ++s) { + cerr << '.'; + if (s % 10 == 0) { + cerr << " [" << s; + if (hyp_inf) ResampleHyperparameters(rng); + cerr << " LLH=" << llh() << "]\n"; + vector z(z_[0].size(), 0); + //for (int j = 0; j < z.size(); ++j) z[j] = z_[0][j]; + //SegCorpus::Write(corpus_[0], z, d); + } + for (int i = 0; i < corpus_.size(); ++i) { + const vector& line = corpus_[i]; + const int ls = line.size(); + const int last_pos = ls - 1; + vector& z = z_[i]; + int prev = 0; + for (int j = 0; j < last_pos; ++j) { // don't resample last position + int next = j+1; while(!z[next]) { ++next; } + const vector p1p2(line.begin() + prev, line.begin() + next + 1); + const vector p1(line.begin() + prev, line.begin() + j + 1); + const vector p2(line.begin() + j + 1, line.begin() + next + 1); + + if (z[j]) { + phrases_.decrement(p1); + phrases_.decrement(p2); + gen_.decrement(false); + gen_.decrement(false); + } else { + phrases_.decrement(p1p2); + gen_.decrement(false); + } + + const double d1 = phrases_.prob(p1p2, p0(p1p2)) * gen_.prob(false, gen_p0_); + double d2 = phrases_.prob(p1, p0(p1)) * gen_.prob(false, gen_p0_); + phrases_.increment(p1); + gen_.increment(false); + d2 *= phrases_.prob(p2, p0(p2)) * gen_.prob(false, gen_p0_); + phrases_.decrement(p1); + gen_.decrement(false); + z[j] = rng->SelectSample(d1, d2); + + if (z[j]) { + phrases_.increment(p1); + phrases_.increment(p2); + gen_.increment(false); + gen_.increment(false); + prev = j + 1; + } else { + phrases_.increment(p1p2); + gen_.increment(false); + } + } + } + } +// cerr << endl << endl << gen_ << endl << phrases_ << endl; + cerr << gen_.prob(false, gen_p0_) << " " << gen_.prob(true, 1 - gen_p0_) << endl; + } + + void WriteCdecGrammarForCurrentSample(ostream* os) const { + CCRP_NoTable >::const_iterator it = phrases_.begin(); + for (; it != phrases_.end(); ++it) { + (*os) << "[X] ||| " << Join(' ', it->first) << " ||| " + << Join('_', it->first) << " ||| C=1 P=" + << log(phrases_.prob(it->first, p0(it->first))) << endl; + } + } + + double OOVUnigramLogProb() const { + vector x(1,99999999); + return log(phrases_.prob(x, p0(x))); + } + + void ResampleHyperparameters(MT19937* rng) { + phrases_.resample_hyperparameters(rng); + gen_.resample_hyperparameters(rng); + cerr << " " << phrases_.concentration(); + } + + CCRP_NoTable > phrases_; + CCRP_NoTable gen_; + vector > z_; // z_[i] is there a phrase boundary after the ith word + const vector >& corpus_; + const double uniform_word_; + const double gen_p0_; + const double p_end_; // in base length distribution, p of the end of a phrase + const bool use_poisson_; +}; + + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + shared_ptr prng; + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as())); + else + prng.reset(new MT19937); + MT19937& rng = *prng; + + vector > corpus; + set vocab; + ReadCorpus(conf["input"].as(), &corpus, &vocab); + cerr << "Corpus size: " << corpus.size() << " sentences\n"; + cerr << "Vocabulary size: " << vocab.size() << " types\n"; + + UniphraseLM ulm(corpus, vocab, conf); + ulm.Sample(conf["samples"].as(), conf.count("no_hyperparameter_inference") == 0, &rng); + cerr << "OOV unigram prob: " << ulm.OOVUnigramLogProb() << endl; + + for (int i = 0; i < corpus.size(); ++i) + WriteSeg(corpus[i], ulm.z_[i], d); + + if (conf.count("write_cdec_grammar")) { + string fname = conf["write_cdec_grammar"].as(); + cerr << "Writing model to " << fname << " ...\n"; + WriteFile wf(fname); + ulm.WriteCdecGrammarForCurrentSample(wf.stream()); + } + + if (conf.count("write_cdec_weights")) { + string fname = conf["write_cdec_weights"].as(); + cerr << "Writing weights to " << fname << " .\n"; + WriteFile wf(fname); + ostream& os = *wf.stream(); + os << "# make C smaller to use more phrases\nP 1\nPassThrough " << ulm.OOVUnigramLogProb() << "\nC -3\n"; + } + + + + return 0; +} + diff --git a/phrasinator/train-phrasinator.pl b/phrasinator/train-phrasinator.pl index de258caf..c50b8e68 100755 --- a/phrasinator/train-phrasinator.pl +++ b/phrasinator/train-phrasinator.pl @@ -5,7 +5,7 @@ use Getopt::Long; use File::Spec qw (rel2abs); my $DECODER = "$script_dir/../decoder/cdec"; -my $TRAINER = "$script_dir/gibbs_train_plm"; +my $TRAINER = "$script_dir/gibbs_train_plm_notables"; die "Can't find $TRAINER" unless -f $TRAINER; die "Can't execute $TRAINER" unless -x $TRAINER; diff --git a/utils/sampler.h b/utils/sampler.h index a14f6e2f..153e7ef1 100644 --- a/utils/sampler.h +++ b/utils/sampler.h @@ -105,7 +105,7 @@ class SampleSet { const F& operator[](int i) const { return m_scores[i]; } F& operator[](int i) { return m_scores[i]; } bool empty() const { return m_scores.empty(); } - void add(const prob_t& s) { m_scores.push_back(s); } + void add(const F& s) { m_scores.push_back(s); } void clear() { m_scores.clear(); } size_t size() const { return m_scores.size(); } void resize(int size) { m_scores.resize(size); } -- cgit v1.2.3 From 164d32f02604ee5bff5de94ad669fb2b4d12d34a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 5 Jul 2011 23:35:48 -0400 Subject: build bug --- phrasinator/Makefile.am | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/phrasinator/Makefile.am b/phrasinator/Makefile.am index 95a603df..aba98601 100644 --- a/phrasinator/Makefile.am +++ b/phrasinator/Makefile.am @@ -1,4 +1,6 @@ -bin_PROGRAMS = gibbs_train_plm head_bigram_model gibbs_train_plm_notables +bin_PROGRAMS = gibbs_train_plm gibbs_train_plm_notables + +#head_bigram_model gibbs_train_plm_notables_SOURCES = gibbs_train_plm.notables.cc gibbs_train_plm_notables_LDADD = $(top_srcdir)/utils/libutils.a -lz @@ -6,7 +8,7 @@ gibbs_train_plm_notables_LDADD = $(top_srcdir)/utils/libutils.a -lz gibbs_train_plm_SOURCES = gibbs_train_plm.cc gibbs_train_plm_LDADD = $(top_srcdir)/utils/libutils.a -lz -head_bigram_model_SOURCES = head_bigram_model.cc -head_bigram_model_LDADD = $(top_srcdir)/utils/libutils.a -lz +#head_bigram_model_SOURCES = head_bigram_model.cc +#head_bigram_model_LDADD = $(top_srcdir)/utils/libutils.a -lz AM_CPPFLAGS = -funroll-loops -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -- cgit v1.2.3 From fe4b60f8669f0bdfcc67832e5487b33bd4b28938 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 6 Jul 2011 19:54:58 -0400 Subject: ngram count features --- decoder/Makefile.am | 1 + decoder/cdec_ff.cc | 2 + decoder/ff_ngrams.cc | 319 +++++++++++++++++++++++++++++++++++++++++++++++++++ decoder/ff_ngrams.h | 29 +++++ 4 files changed, 351 insertions(+) create mode 100644 decoder/ff_ngrams.cc create mode 100644 decoder/ff_ngrams.h diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 244da2de..d884c431 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -65,6 +65,7 @@ libcdec_a_SOURCES = \ ff_charset.cc \ ff_lm.cc \ ff_klm.cc \ + ff_ngrams.cc \ ff_spans.cc \ ff_ruleshape.cc \ ff_wordalign.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 31f88a4f..3451c9fb 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -4,6 +4,7 @@ #include "ff_spans.h" #include "ff_lm.h" #include "ff_klm.h" +#include "ff_ngrams.h" #include "ff_csplit.h" #include "ff_wordalign.h" #include "ff_tagger.h" @@ -51,6 +52,7 @@ void register_feature_functions() { ff_registry.Register("RandLM", new FFFactory); #endif ff_registry.Register("SpanFeatures", new FFFactory()); + ff_registry.Register("NgramFeatures", new FFFactory()); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); ff_registry.Register("KLanguageModel", new FFFactory >()); diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc new file mode 100644 index 00000000..54b394ae --- /dev/null +++ b/decoder/ff_ngrams.cc @@ -0,0 +1,319 @@ +#include "ff_ngrams.h" + +#include +#include + +#include + +#include "filelib.h" +#include "stringlib.h" +#include "hg.h" +#include "tdict.h" + +using namespace std; + +static const unsigned char HAS_FULL_CONTEXT = 1; +static const unsigned char HAS_EOS_ON_RIGHT = 2; +static const unsigned char MASK = 7; + +namespace { +template +struct State { + explicit State() { + memset(state, 0, sizeof(state)); + } + explicit State(int order) { + memset(state, 0, (order - 1) * sizeof(WordID)); + } + State(char order, const WordID* mem) { + memcpy(state, mem, (order - 1) * sizeof(WordID)); + } + State(const State& other) { + memcpy(state, other.state, sizeof(state)); + } + const State& operator=(const State& other) { + memcpy(state, other.state, sizeof(state)); + } + explicit State(const State& other, unsigned order, WordID extend) { + char om1 = order - 1; + assert(om1 > 0); + for (char i = 1; i < om1; ++i) state[i - 1]= other.state[i]; + state[om1 - 1] = extend; + } + const WordID& operator[](size_t i) const { return state[i]; } + WordID& operator[](size_t i) { return state[i]; } + WordID state[MAX_ORDER]; +}; +} + +class NgramDetectorImpl { + + // returns the number of unscored words at the left edge of a span + inline int UnscoredSize(const void* state) const { + return *(static_cast(state) + unscored_size_offset_); + } + + inline void SetUnscoredSize(int size, void* state) const { + *(static_cast(state) + unscored_size_offset_) = size; + } + + inline State<5> RemnantLMState(const void* cstate) const { + return State<5>(order_, static_cast(cstate)); + } + + inline const State<5> BeginSentenceState() const { + State<5> state(order_); + state.state[0] = kSOS_; + return state; + } + + inline void SetRemnantLMState(const State<5>& lmstate, void* state) const { + // if we were clever, we could use the memory pointed to by state to do all + // the work, avoiding this copy + memcpy(state, lmstate.state, (order_-1) * sizeof(WordID)); + } + + WordID IthUnscoredWord(int i, const void* state) const { + const WordID* const mem = reinterpret_cast(static_cast(state) + unscored_words_offset_); + return mem[i]; + } + + void SetIthUnscoredWord(int i, const WordID index, void *state) const { + WordID* mem = reinterpret_cast(static_cast(state) + unscored_words_offset_); + mem[i] = index; + } + + inline bool GetFlag(const void *state, unsigned char flag) const { + return (*(static_cast(state) + is_complete_offset_) & flag); + } + + inline void SetFlag(bool on, unsigned char flag, void *state) const { + if (on) { + *(static_cast(state) + is_complete_offset_) |= flag; + } else { + *(static_cast(state) + is_complete_offset_) &= (MASK ^ flag); + } + } + + inline bool HasFullContext(const void *state) const { + return GetFlag(state, HAS_FULL_CONTEXT); + } + + inline void SetHasFullContext(bool flag, void *state) const { + SetFlag(flag, HAS_FULL_CONTEXT, state); + } + + void FireFeatures(const State<5>& state, const WordID cur, SparseVector* feats) { + assert(order_ == 2); + if (cur >= unimap_.size()) + unimap_.resize(cur + 10, 0); + int& uf = unimap_[cur]; + if (!uf) { + ostringstream os; + os << "U:" << TD::Convert(cur); + uf = FD::Convert(os.str()); + } + feats->set_value(uf, 1.0); + if (state.state[0]) { + if (state.state[0] >= bimap_.size()) + bimap_.resize(state.state[0] + 10); + int& bf = bimap_[state.state[0]][cur]; + if (!bf) { + ostringstream os; + os << "B:" << TD::Convert(state[0]) << '_' << TD::Convert(cur); + bf = FD::Convert(os.str()); + } + feats->set_value(bf, 1.0); + } + } + + public: + void LookupWords(const TRule& rule, const vector& ant_states, SparseVector* feats, SparseVector* est_feats, void* remnant) { + double sum = 0.0; + double est_sum = 0.0; + int num_scored = 0; + int num_estimated = 0; + bool saw_eos = false; + bool has_some_history = false; + State<5> state; + const vector& e = rule.e(); + bool context_complete = false; + for (int j = 0; j < e.size(); ++j) { + if (e[j] < 1) { // handle non-terminal substitution + const void* astate = (ant_states[-e[j]]); + int unscored_ant_len = UnscoredSize(astate); + for (int k = 0; k < unscored_ant_len; ++k) { + const WordID cur_word = IthUnscoredWord(k, astate); + const bool is_oov = (cur_word == 0); + SparseVector p; + if (cur_word == kSOS_) { + state = BeginSentenceState(); + if (has_some_history) { // this is immediately fully scored, and bad + p.set_value(FD::Convert("Malformed"), 1.0); + context_complete = true; + } else { // this might be a real + num_scored = max(0, order_ - 2); + } + } else { + FireFeatures(state, cur_word, &p); + const State<5> scopy = State<5>(state, order_, cur_word); + state = scopy; + if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); } + saw_eos = (cur_word == kEOS_); + } + has_some_history = true; + ++num_scored; + if (!context_complete) { + if (num_scored >= order_) context_complete = true; + } + if (context_complete) { + (*feats) += p; + } else { + if (remnant) + SetIthUnscoredWord(num_estimated, cur_word, remnant); + ++num_estimated; + (*est_feats) += p; + } + } + saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); + if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 + state = RemnantLMState(astate); + context_complete = true; + } + } else { // handle terminal + const WordID cur_word = e[j]; + SparseVector p; + if (cur_word == kSOS_) { + state = BeginSentenceState(); + if (has_some_history) { // this is immediately fully scored, and bad + p.set_value(FD::Convert("Malformed"), -100); + context_complete = true; + } else { // this might be a real + num_scored = max(0, order_ - 2); + } + } else { + FireFeatures(state, cur_word, &p); + const State<5> scopy = State<5>(state, order_, cur_word); + state = scopy; + if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); } + saw_eos = (cur_word == kEOS_); + } + has_some_history = true; + ++num_scored; + if (!context_complete) { + if (num_scored >= order_) context_complete = true; + } + if (context_complete) { + (*feats) += p; + } else { + if (remnant) + SetIthUnscoredWord(num_estimated, cur_word, remnant); + ++num_estimated; + (*est_feats) += p; + } + } + } + if (remnant) { + SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); + SetRemnantLMState(state, remnant); + SetUnscoredSize(num_estimated, remnant); + SetHasFullContext(context_complete || (num_scored >= order_), remnant); + } + } + + // this assumes no target words on final unary -> goal rule. is that ok? + // for (n-1 left words) and (n-1 right words) + void FinalTraversal(const void* state, SparseVector* feats) { + if (add_sos_eos_) { // rules do not produce , so do it here + SetRemnantLMState(BeginSentenceState(), dummy_state_); + SetHasFullContext(1, dummy_state_); + SetUnscoredSize(0, dummy_state_); + dummy_ants_[1] = state; + LookupWords(*dummy_rule_, dummy_ants_, feats, NULL, NULL); + } else { // rules DO produce ... +#if 0 + double p = 0; + if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } + if (UnscoredSize(state) > 0) { // are there unscored words + if (kSOS_ != IthUnscoredWord(0, state)) { + p -= 100 * UnscoredSize(state); + } + } + return p; +#endif + } + } + + public: + explicit NgramDetectorImpl(bool explicit_markers) : + kCDEC_UNK(TD::Convert("")) , + add_sos_eos_(!explicit_markers) { + order_ = 2; + state_size_ = (order_ - 1) * sizeof(WordID) + 2 + (order_ - 1) * sizeof(WordID); + unscored_size_offset_ = (order_ - 1) * sizeof(WordID); + is_complete_offset_ = unscored_size_offset_ + 1; + unscored_words_offset_ = is_complete_offset_ + 1; + + // special handling of beginning / ending sentence markers + dummy_state_ = new char[state_size_]; + memset(dummy_state_, 0, state_size_); + dummy_ants_.push_back(dummy_state_); + dummy_ants_.push_back(NULL); + dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); + kSOS_ = TD::Convert(""); + kEOS_ = TD::Convert(""); + } + + ~NgramDetectorImpl() { + delete[] dummy_state_; + } + + int ReserveStateSize() const { return state_size_; } + + private: + const WordID kCDEC_UNK; + WordID kSOS_; // - requires special handling. + WordID kEOS_; // + const bool add_sos_eos_; // flag indicating whether the hypergraph produces and + // if this is true, FinalTransitionFeatures will "add" and + // if false, FinalTransitionFeatures will score anything with the + // markers in the right place (i.e., the beginning and end of + // the sentence) with 0, and anything else with -100 + + int order_; + int state_size_; + int unscored_size_offset_; + int is_complete_offset_; + int unscored_words_offset_; + char* dummy_state_; + vector dummy_ants_; + TRulePtr dummy_rule_; + mutable std::vector unimap_; // [left][right] + mutable std::vector > bimap_; // [left][right] +}; + +NgramDetector::NgramDetector(const string& param) { + string filename, mapfile, featname; + bool explicit_markers = (param == "-x"); + pimpl_ = new NgramDetectorImpl(explicit_markers); + SetStateSize(pimpl_->ReserveStateSize()); +} + +NgramDetector::~NgramDetector() { + delete pimpl_; +} + +void NgramDetector::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + pimpl_->LookupWords(*edge.rule_, ant_states, features, estimated_features, state); +} + +void NgramDetector::FinalTraversalFeatures(const void* ant_state, + SparseVector* features) const { + pimpl_->FinalTraversal(ant_state, features); +} + diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h new file mode 100644 index 00000000..82f61b33 --- /dev/null +++ b/decoder/ff_ngrams.h @@ -0,0 +1,29 @@ +#ifndef _NGRAMS_FF_H_ +#define _NGRAMS_FF_H_ + +#include +#include +#include + +#include "ff.h" + +struct NgramDetectorImpl; +class NgramDetector : public FeatureFunction { + public: + // param = "filename.lm [-o n]" + NgramDetector(const std::string& param); + ~NgramDetector(); + virtual void FinalTraversalFeatures(const void* context, + SparseVector* features) const; + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + NgramDetectorImpl* pimpl_; +}; + +#endif -- cgit v1.2.3 From 3b004be48979da652cc64e7a01e685190eb79498 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 6 Jul 2011 20:41:52 -0400 Subject: tool to compute feature expectations in translation charts --- training/Makefile.am | 4 + training/feature_expectations.cc | 232 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 236 insertions(+) create mode 100644 training/feature_expectations.cc diff --git a/training/Makefile.am b/training/Makefile.am index 0d9085e4..e075e417 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -14,6 +14,7 @@ bin_PROGRAMS = \ mpi_batch_optimize \ mpi_em_optimize \ compute_cllh \ + feature_expectations \ augment_grammar noinst_PROGRAMS = \ @@ -28,6 +29,9 @@ mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +feature_expectations_SOURCES = feature_expectations.cc +feature_expectations_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + mpi_em_optimize_SOURCES = mpi_em_optimize.cc optimize.cc mpi_em_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz diff --git a/training/feature_expectations.cc b/training/feature_expectations.cc new file mode 100644 index 00000000..f1a85495 --- /dev/null +++ b/training/feature_expectations.cc @@ -0,0 +1,232 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "online_optimizer.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" +#include "sampler.h" + +#ifdef HAVE_MPI +#include +#include +namespace mpi = boost::mpi; +#endif + +using namespace std; +namespace po = boost::program_options; + +struct FComp { + const vector& w_; + FComp(const vector& w) : w_(w) {} + bool operator()(int a, int b) const { + return fabs(w_[a]) > fabs(w_[b]); + } +}; + +void ShowFeatures(const vector& w) { + vector fnums(w.size()); + for (int i = 0; i < w.size(); ++i) + fnums[i] = i; + sort(fnums.begin(), fnums.end(), FComp(w)); + for (vector::iterator i = fnums.begin(); i != fnums.end(); ++i) { + if (w[*i]) cout << FD::Convert(*i) << ' ' << w[*i] << endl; + } +} + +void ReadConfig(const string& ini, vector* out) { + ReadFile rf(ini); + istream& in = *rf.stream(); + while(in) { + string line; + getline(in, line); + if (!in) continue; + out->push_back(line); + } +} + +void StoreConfig(const vector& cfg, istringstream* o) { + ostringstream os; + for (int i = 0; i < cfg.size(); ++i) { os << cfg[i] << endl; } + o->str(os.str()); +} + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input,i",po::value(),"Corpus of source language sentences") + ("weights,w",po::value(),"Input feature weights file") + ("decoder_config,c",po::value(), "cdec.ini file"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("input") || !conf->count("decoder_config")) { + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* order) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int id = 0; + while(in) { + getline(in, line); + if (!in) break; + if (id % size == rank) { + c->push_back(line); + order->push_back(id); + } + ++id; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct TrainingObserver : public DecoderObserver { + void Reset() { + acc_exp.clear(); + total_complete = 0; + } + + virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { + cur_model_exp.clear(); + state = 1; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 1); + state = 2; + const prob_t z = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); + cur_model_exp /= z; + acc_exp += cur_model_exp; + } + + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + cerr << "IGNORING ALIGNMENT FOREST!\n"; + } + + virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { + if (state == 2) { + ++total_complete; + } + } + + void GetExpectations(SparseVector* g) const { + g->clear(); + for (SparseVector::const_iterator it = acc_exp.begin(); it != acc_exp.end(); ++it) + g->set_value(it->first, it->second); + } + + int total_complete; + SparseVector cur_model_exp; + SparseVector acc_exp; + int state; +}; + +#ifdef HAVE_MPI +namespace boost { namespace mpi { + template<> + struct is_commutative >, SparseVector > + : mpl::true_ { }; +} } // end namespace boost::mpi +#endif + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return 1; + + // load initial weights + Weights weights; + if (conf.count("weights")) + weights.InitFromFile(conf["weights"].as()); + + vector corpus; + vector ids; + ReadTrainingCorpus(conf["input"].as(), rank, size, &corpus, &ids); + assert(corpus.size() > 0); + + vector cdec_ini; + ReadConfig(conf["decoder_config"].as(), &cdec_ini); + istringstream ini; + StoreConfig(cdec_ini, &ini); + Decoder decoder(&ini); + if (decoder.GetConf()["input"].as() != "-") { + cerr << "cdec.ini must not set an input file\n"; + return 1; + } + + SparseVector x; + weights.InitSparseVector(&x); + TrainingObserver observer; + + weights.InitFromVector(x); + vector lambdas; + weights.InitVector(&lambdas); + decoder.SetWeights(lambdas); + observer.Reset(); + for (unsigned i = 0; i < corpus.size(); ++i) { + int id = ids[i]; + decoder.SetId(id); + decoder.Decode(corpus[i], &observer); + } + SparseVector local_exps, exps; + observer.GetExpectations(&local_exps); +#ifdef HAVE_MPI + reduce(world, local_exps, exps, std::plus >(), 0); +#else + exps.swap(local_exps); +#endif + + weights.InitFromVector(exps); + weights.InitVector(&lambdas); + ShowFeatures(lambdas); + + return 0; +} -- cgit v1.2.3 From 75b814cb246052746134f32c723cf6d278b148df Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 6 Jul 2011 23:32:53 -0400 Subject: better handling of ngram features --- decoder/ff_ngrams.cc | 49 +++++++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc index 54b394ae..d52667cd 100644 --- a/decoder/ff_ngrams.cc +++ b/decoder/ff_ngrams.cc @@ -103,27 +103,29 @@ class NgramDetectorImpl { SetFlag(flag, HAS_FULL_CONTEXT, state); } - void FireFeatures(const State<5>& state, const WordID cur, SparseVector* feats) { - assert(order_ == 2); - if (cur >= unimap_.size()) - unimap_.resize(cur + 10, 0); - int& uf = unimap_[cur]; - if (!uf) { - ostringstream os; - os << "U:" << TD::Convert(cur); - uf = FD::Convert(os.str()); - } - feats->set_value(uf, 1.0); - if (state.state[0]) { - if (state.state[0] >= bimap_.size()) - bimap_.resize(state.state[0] + 10); - int& bf = bimap_[state.state[0]][cur]; - if (!bf) { + void FireFeatures(const State<5>& state, WordID cur, SparseVector* feats) { + FidTree* ft = &fidroot_; + int n = 0; + WordID buf[10]; + int ci = order_ - 1; + WordID curword = cur; + while(curword) { + buf[n] = curword; + int& fid = ft->fids[curword]; + ++n; + if (!fid) { + const char* code="_UBT456789"; ostringstream os; - os << "B:" << TD::Convert(state[0]) << '_' << TD::Convert(cur); - bf = FD::Convert(os.str()); + os << code[n] << ':'; + for (int i = n-1; i >= 0; --i) + os << (i != n-1 ? "_" : "") << TD::Convert(buf[i]); + fid = FD::Convert(os.str()); } - feats->set_value(bf, 1.0); + feats->set_value(fid, 1); + ft = &ft->levels[curword]; + --ci; + if (ci < 0) break; + curword = state[ci]; } } @@ -248,7 +250,7 @@ class NgramDetectorImpl { explicit NgramDetectorImpl(bool explicit_markers) : kCDEC_UNK(TD::Convert("")) , add_sos_eos_(!explicit_markers) { - order_ = 2; + order_ = 3; state_size_ = (order_ - 1) * sizeof(WordID) + 2 + (order_ - 1) * sizeof(WordID); unscored_size_offset_ = (order_ - 1) * sizeof(WordID); is_complete_offset_ = unscored_size_offset_ + 1; @@ -288,8 +290,11 @@ class NgramDetectorImpl { char* dummy_state_; vector dummy_ants_; TRulePtr dummy_rule_; - mutable std::vector unimap_; // [left][right] - mutable std::vector > bimap_; // [left][right] + struct FidTree { + map fids; + map levels; + }; + mutable FidTree fidroot_; }; NgramDetector::NgramDetector(const string& param) { -- cgit v1.2.3 From 71daf4bf0b91a247d0d1663ae7850a3db85a378d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 7 Jul 2011 18:39:38 -0400 Subject: support for extracting k-best derivation trees --- decoder/decoder.cc | 12 +++++++++--- decoder/oracle_bleu.h | 22 +++++++++++++++------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/decoder/decoder.cc b/decoder/decoder.cc index ff068be9..2c3a06de 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -416,6 +416,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") ("extract_rules", po::value(), "Extract the rules used in translation (de-duped) to this file") + ("show_derivations", po::value(), "Directory to print the derivation structures to") ("graphviz","Show (constrained) translation forest in GraphViz format") ("max_translation_beam,x", po::value(), "Beam approximation to get max translation from the chart") ("max_translation_sample,X", po::value(), "Sample the max translation from the chart") @@ -426,6 +427,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("vector_format",po::value()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") ("forest_output,O",po::value(),"Directory to write forests to"); + // ob.AddOptions(&opts); #ifdef FSA_RESCORING po::options_description cfgo(cfg_options.description()); @@ -677,6 +679,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream kbest = conf.count("k_best"); unique_kbest = conf.count("unique_k_best"); get_oracle_forest = conf.count("get_oracle_forest"); + oracle.show_derivation=conf.count("show_derivations"); #ifdef FSA_RESCORING cfg_options.Validate(); @@ -938,7 +941,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } else { if (kbest && !has_ref) { //TODO: does this work properly? - oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"-"); + const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; + oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"-", deriv_fname); } else if (csplit_output_plf) { cout << HypergraphIO::AsPLF(forest, false) << endl; } else { @@ -1055,8 +1059,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } } if (conf.count("graphviz")) forest.PrintGraphviz(); - if (kbest) - oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"-"); + if (kbest) { + const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; + oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"-", deriv_fname); + } if (conf.count("show_conditional_prob")) { const prob_t ref_z = Inside(forest); cout << (log(ref_z) - log(first_z)) << endl << flush; diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 15d48588..b603e27a 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -272,23 +272,31 @@ struct OracleBleu { } kbest_out<score)<<"\n"; deriv_out< > >(sent_id,forest,k,ko.get(),std::cerr); + kbest > >(sent_id,forest,k,ko.get(),oderiv.get()); else { - kbest(sent_id,forest,k,ko.get(),std::cerr); + kbest(sent_id,forest,k,ko.get(),oderiv.get()); } } @@ -296,7 +304,7 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo { std::ostringstream kbest_string_stream; kbest_string_stream << forest_output << "/kbest_"< Date: Thu, 7 Jul 2011 21:01:38 -0400 Subject: Exception update --- klm/util/exception.cc | 28 +++++++++++++++++++++++++ klm/util/exception.hh | 56 ++++++++++++++++++++++++++++++++++++++++++++++---- klm/util/file_piece.cc | 42 ++++++++++++++++++------------------- klm/util/file_piece.hh | 34 +++++++++++++++--------------- 4 files changed, 117 insertions(+), 43 deletions(-) diff --git a/klm/util/exception.cc b/klm/util/exception.cc index 84f9fe7c..62280970 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -1,5 +1,9 @@ #include "util/exception.hh" +#ifdef __GXX_RTTI +#include +#endif + #include #include @@ -22,6 +26,30 @@ const char *Exception::what() const throw() { return text_.c_str(); } +void Exception::SetLocation(const char *file, unsigned int line, const char *func, const char *child_name, const char *condition) { + /* The child class might have set some text, but we want this to come first. + * Another option would be passing this information to the constructor, but + * then child classes would have to accept constructor arguments and pass + * them down. + */ + text_ = stream_.str(); + stream_.str(""); + stream_ << file << ':' << line; + if (func) stream_ << " in " << func << " threw "; + if (child_name) { + stream_ << child_name; + } else { +#ifdef __GXX_RTTI + stream_ << typeid(this).name(); +#else + stream_ << "an exception"; +#endif + } + if (condition) stream_ << " because `" << condition; + stream_ << "'.\n"; + stream_ << text_; +} + namespace { // The XOPEN version. const char *HandleStrerror(int ret, const char *buf) { diff --git a/klm/util/exception.hh b/klm/util/exception.hh index c6936914..81675a57 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -1,8 +1,6 @@ #ifndef UTIL_EXCEPTION__ #define UTIL_EXCEPTION__ -#include "util/string_piece.hh" - #include #include #include @@ -22,6 +20,14 @@ class Exception : public std::exception { // Not threadsafe, but probably doesn't matter. FWIW, Boost's exception guidance implies that what() isn't threadsafe. const char *what() const throw(); + // For use by the UTIL_THROW macros. + void SetLocation( + const char *file, + unsigned int line, + const char *func, + const char *child_name, + const char *condition); + private: template friend typename Except::template ExceptionTag::Identity operator<<(Except &e, const Data &data); @@ -43,7 +49,49 @@ template typename Except::template ExceptionTag(); } -double FilePiece::ReadDouble() throw(GZException, EndOfFileException, ParseNumberException) { +double FilePiece::ReadDouble() { return ReadNumber(); } -long int FilePiece::ReadLong() throw(GZException, EndOfFileException, ParseNumberException) { +long int FilePiece::ReadLong() { return ReadNumber(); } -unsigned long int FilePiece::ReadULong() throw(GZException, EndOfFileException, ParseNumberException) { +unsigned long int FilePiece::ReadULong() { return ReadNumber(); } -void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) { +void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) { #ifdef HAVE_ZLIB gz_file_ = NULL; #endif @@ -163,7 +163,7 @@ void ParseNumber(const char *begin, char *&end, unsigned long int &out) { } } // namespace -template T FilePiece::ReadNumber() throw(GZException, EndOfFileException, ParseNumberException) { +template T FilePiece::ReadNumber() { SkipSpaces(); while (last_space_ < position_) { if (at_end_) { @@ -186,7 +186,7 @@ template T FilePiece::ReadNumber() throw(GZException, EndOfFileExcepti return ret; } -const char *FilePiece::FindDelimiterOrEOF(const bool *delim) throw (GZException, EndOfFileException) { +const char *FilePiece::FindDelimiterOrEOF(const bool *delim) { size_t skip = 0; while (true) { for (const char *i = position_ + skip; i < position_end_; ++i) { @@ -201,7 +201,7 @@ const char *FilePiece::FindDelimiterOrEOF(const bool *delim) throw (GZException, } } -void FilePiece::Shift() throw(GZException, EndOfFileException) { +void FilePiece::Shift() { if (at_end_) { progress_.Finished(); throw EndOfFileException(); @@ -217,7 +217,7 @@ void FilePiece::Shift() throw(GZException, EndOfFileException) { } } -void FilePiece::MMapShift(off_t desired_begin) throw() { +void FilePiece::MMapShift(off_t desired_begin) { // Use mmap. off_t ignore = desired_begin % page_; // Duplicate request for Shift means give more data. @@ -259,25 +259,23 @@ void FilePiece::MMapShift(off_t desired_begin) throw() { progress_.Set(desired_begin); } -void FilePiece::TransitionToRead() throw (GZException) { +void FilePiece::TransitionToRead() { assert(!fallback_to_read_); fallback_to_read_ = true; data_.reset(); data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED); - if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_); + UTIL_THROW_IF(!data_.get(), ErrnoException, "malloc failed for " << default_map_size_); position_ = data_.begin(); position_end_ = position_; #ifdef HAVE_ZLIB assert(!gz_file_); gz_file_ = gzdopen(file_.get(), "r"); - if (!gz_file_) { - UTIL_THROW(GZException, "zlib failed to open " << file_name_); - } + UTIL_THROW_IF(!gz_file_, GZException, "zlib failed to open " << file_name_); #endif } -void FilePiece::ReadShift() throw(GZException, EndOfFileException) { +void FilePiece::ReadShift() { assert(fallback_to_read_); // Bytes [data_.begin(), position_) have been consumed. // Bytes [position_, position_end_) have been read into the buffer. @@ -297,7 +295,7 @@ void FilePiece::ReadShift() throw(GZException, EndOfFileException) { std::size_t valid_length = position_end_ - position_; default_map_size_ *= 2; data_.call_realloc(default_map_size_); - if (!data_.get()) UTIL_THROW(ErrnoException, "realloc failed for " << default_map_size_); + UTIL_THROW_IF(!data_.get(), ErrnoException, "realloc failed for " << default_map_size_); position_ = data_.begin(); position_end_ = position_ + valid_length; } else { @@ -320,7 +318,7 @@ void FilePiece::ReadShift() throw(GZException, EndOfFileException) { } #else read_return = read(file_.get(), static_cast(data_.get()) + already_read, default_map_size_ - already_read); - if (read_return == -1) UTIL_THROW(ErrnoException, "read failed"); + UTIL_THROW_IF(read_return == -1, ErrnoException, "read failed"); progress_.Set(mapped_offset_); #endif if (read_return == 0) { diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index 870ae5a3..a5c00910 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -45,13 +45,13 @@ off_t SizeFile(int fd); class FilePiece { public: // 32 MB default. - explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException); + explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432); // Takes ownership of fd. name is used for messages. - explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException); + explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432); ~FilePiece(); - char get() throw(GZException, EndOfFileException) { + char get() { if (position_ == position_end_) { Shift(); if (at_end_) throw EndOfFileException(); @@ -60,22 +60,22 @@ class FilePiece { } // Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace(). - StringPiece ReadDelimited(const bool *delim = kSpaces) throw(GZException, EndOfFileException) { + StringPiece ReadDelimited(const bool *delim = kSpaces) { SkipSpaces(delim); return Consume(FindDelimiterOrEOF(delim)); } // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. // It is similar to getline in that way. - StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException); + StringPiece ReadLine(char delim = '\n'); - float ReadFloat() throw(GZException, EndOfFileException, ParseNumberException); - double ReadDouble() throw(GZException, EndOfFileException, ParseNumberException); - long int ReadLong() throw(GZException, EndOfFileException, ParseNumberException); - unsigned long int ReadULong() throw(GZException, EndOfFileException, ParseNumberException); + float ReadFloat(); + double ReadDouble(); + long int ReadLong(); + unsigned long int ReadULong(); // Skip spaces defined by isspace. - void SkipSpaces(const bool *delim = kSpaces) throw (GZException, EndOfFileException) { + void SkipSpaces(const bool *delim = kSpaces) { for (; ; ++position_) { if (position_ == position_end_) Shift(); if (!delim[static_cast(*position_)]) return; @@ -89,9 +89,9 @@ class FilePiece { const std::string &FileName() const { return file_name_; } private: - void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw(GZException); + void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer); - template T ReadNumber() throw(GZException, EndOfFileException, ParseNumberException); + template T ReadNumber(); StringPiece Consume(const char *to) { StringPiece ret(position_, to - position_); @@ -99,14 +99,14 @@ class FilePiece { return ret; } - const char *FindDelimiterOrEOF(const bool *delim = kSpaces) throw (GZException, EndOfFileException); + const char *FindDelimiterOrEOF(const bool *delim = kSpaces); - void Shift() throw (EndOfFileException, GZException); + void Shift(); // Backends to Shift(). - void MMapShift(off_t desired_begin) throw (); + void MMapShift(off_t desired_begin); - void TransitionToRead() throw (GZException); - void ReadShift() throw (GZException, EndOfFileException); + void TransitionToRead(); + void ReadShift(); const char *position_, *last_space_, *position_end_; -- cgit v1.2.3 From 3396d8de52872e47ec61be942e4b50170a789950 Mon Sep 17 00:00:00 2001 From: andrea gesmundo Date: Fri, 8 Jul 2011 13:56:42 +0200 Subject: add Fast Cube Pruning --- decoder/apply_models.cc | 196 ++++++++++++++++++++++++++++++++++++++++++++++-- decoder/apply_models.h | 6 +- decoder/decoder.cc | 10 ++- 3 files changed, 204 insertions(+), 8 deletions(-) diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9390c809..62eff262 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -17,6 +17,10 @@ #include "hg.h" #include "ff.h" +#define NORMAL_CP 1 +#define FAST_CP 2 +#define FAST_CP_2 3 + using namespace std; using namespace std::tr1; @@ -164,13 +168,15 @@ public: const SentenceMetadata& sm, const Hypergraph& i, int pop_limit, - Hypergraph* o) : + Hypergraph* o, + int s = NORMAL_CP ) : models(m), smeta(sm), in(i), out(*o), D(in.nodes_.size()), - pop_limit_(pop_limit) { + pop_limit_(pop_limit), + strategy_(s){ if (!SILENT) cerr << " Applying feature functions (cube pruning, pop_limit = " << pop_limit_ << ')' << endl; node_states_.reserve(kRESERVE_NUM_NODES); } @@ -186,7 +192,15 @@ public: if (!SILENT) cerr << " "; for (int i = 0; i < in.nodes_.size(); ++i) { if (!SILENT && i % every == 0) cerr << '.'; - KBest(i, i == goal_id); + if (strategy_==NORMAL_CP){ + KBest(i, i == goal_id); + } + if (strategy_==FAST_CP){ + KBestFast(i, i == goal_id); + } + if (strategy_==FAST_CP_2){ + KBestFast2(i, i == goal_id); + } } if (!SILENT) { cerr << endl; @@ -283,6 +297,114 @@ public: delete freelist[i]; } + void KBestFast(const int vert_index, const bool is_goal) { + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + + PushSuccFast(*item, is_goal, &cand); + IncorporateIntoPlusLMForest(item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ + D_v[c++] = i->second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_accepted; + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + assert(unique_accepted.insert(item).second); // these should all be unique! + // cerr << "POPPED: " << *item << endl; + + PushSuccFast2(*item, is_goal, &cand, &unique_accepted); + IncorporateIntoPlusLMForest(item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ + D_v[c++] = i->second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<tail_nodes_[i]].size()) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + if(item.j_[i]!=0){ + return; + } + } + } + + //PushSucc only if all ancest Cand are added + void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){ + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (HasAllAncestors(&query_unique,ps)) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + } + } + } + + bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){ + for (int i = 0; i < item->j_.size(); ++i) { + JVector j = item->j_; + --j[i]; + if (j[i] >=0) { + Candidate query_unique(*item->in_edge_, j); + if (cs->count(&query_unique) == 0) { + return false; + } + } + } + return true; + } + const ModelSet& models; const SentenceMetadata& smeta; const Hypergraph& in; @@ -311,6 +481,7 @@ public: FFStates node_states_; // for each node in the out-HG what is // its q function value? const int pop_limit_; + const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) }; struct NoPruningRescorer { @@ -412,15 +583,28 @@ void ApplyModelSet(const Hypergraph& in, if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) { NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state ma.Apply(); - } else if (config.algorithm == IntersectionConfiguration::CUBE) { + } else if (config.algorithm == IntersectionConfiguration::CUBE + || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING + || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) { int pl = config.pop_limit; const int max_pl_for_large=50; if (pl > max_pl_for_large && in.nodes_.size() > 80000) { pl = max_pl_for_large; cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; } - CubePruningRescorer ma(models, smeta, in, pl, out); - ma.Apply(); + if (config.algorithm == IntersectionConfiguration::CUBE) { + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); + } + else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){ + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); + ma.Apply(); + } + else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){ + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); + ma.Apply(); + } + } else { cerr << "Don't understand intersection algorithm " << config.algorithm << endl; exit(1); diff --git a/decoder/apply_models.h b/decoder/apply_models.h index a85694aa..19a4c7be 100644 --- a/decoder/apply_models.h +++ b/decoder/apply_models.h @@ -13,6 +13,8 @@ struct IntersectionConfiguration { enum { FULL, CUBE, + FAST_CUBE_PRUNING, + FAST_CUBE_PRUNING_2, N_ALGORITHMS }; @@ -25,7 +27,9 @@ enum { inline std::ostream& operator<<(std::ostream& os, const IntersectionConfiguration& c) { if (c.algorithm == 0) { os << "FULL"; } else if (c.algorithm == 1) { os << "CUBE:k=" << c.pop_limit; } - else if (c.algorithm == 2) { os << "N_ALGORITHMS"; } + else if (c.algorithm == 2) { os << "FAST_CUBE_PRUNING"; } + else if (c.algorithm == 3) { os << "FAST_CUBE_PRUNING_2"; } + else if (c.algorithm == 4) { os << "N_ALGORITHMS"; } else os << "OTHER"; return os; } diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 2c3a06de..8a4a1485 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -357,7 +357,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("weights,w",po::value(),"Feature weights file (initial forest / pass 1)") ("feature_function,F",po::value >()->composing(), "Pass 1 additional feature function(s) (-L for list)") - ("intersection_strategy,I",po::value()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full") + ("intersection_strategy,I",po::value()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full, Fast_cube_pruning, Fast_cube_pruning_2") ("summary_feature", po::value(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z") ("summary_feature_type", po::value()->default_value("node_risk"), "Summary feature types: node_risk, edge_risk, edge_prob") ("density_prune", po::value(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") @@ -597,6 +597,14 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream if (LowercaseString(str(isn.c_str(),conf)) == "full") { palg = 0; } + if (LowercaseString(conf["intersection_strategy"].as()) == "fast_cube_pruning") { + palg = 2; + cerr << "Using Fast Cube Pruning intersection (see Algorithm 2 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n"; + } + if (LowercaseString(conf["intersection_strategy"].as()) == "fast_cube_pruning_2") { + palg = 3; + cerr << "Using Fast Cube Pruning 2 intersection (see Algorithm 3 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n"; + } rp.inter_conf.reset(new IntersectionConfiguration(palg, pop_limit)); } else { break; // TODO alert user if there are any future configurations -- cgit v1.2.3 From ed8a6e81d87f6e917ecffc290cde0a340b6aa03b Mon Sep 17 00:00:00 2001 From: andrea gesmundo Date: Fri, 8 Jul 2011 15:33:47 +0200 Subject: add cp time measure (def macro) --- decoder/cdec.cc | 8 ++++++++ decoder/decoder.cc | 13 +++++++++++++ decoder/decoder.h | 14 ++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 5c40f56e..c671af57 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -19,11 +19,19 @@ int main(int argc, char** argv) { assert(*in); string buf; +#ifdef CP_TIME + clock_t time_cp(0);//, end_cp; +#endif while(*in) { getline(*in, buf); if (buf.empty()) continue; decoder.Decode(buf); } +#ifdef CP_TIME + cerr << "Time required for Cube Pruning execution: " + << CpTime::Get() + << " seconds." << "\n\n"; +#endif if (show_feature_dictionary) { int num = FD::NumFeats(); for (int i = 1; i < num; ++i) { diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 8a4a1485..76f31352 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -46,6 +46,13 @@ #include "cfg_options.h" #endif +#ifdef CP_TIME + clock_t CpTime::time_; + void CpTime::Add(clock_t x){time_+=x;} + void CpTime::Sub(clock_t x){time_-=x;} + double CpTime::Get(){return (double)(time_)/CLOCKS_PER_SEC;} +#endif + static const double kMINUS_EPSILON = -1e-6; // don't be too strict using namespace std; @@ -806,11 +813,17 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { Timer t("Forest rescoring:"); rp.models->PrepareForInput(smeta); Hypergraph rescored_forest; +#ifdef CP_TIME + CpTime::Sub(clock()); +#endif ApplyModelSet(forest, smeta, *rp.models, *rp.inter_conf, &rescored_forest); +#ifdef CP_TIME + CpTime::Add(clock()); +#endif forest.swap(rescored_forest); forest.Reweight(cur_weights); if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation); diff --git a/decoder/decoder.h b/decoder/decoder.h index 813400e3..5491369f 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -7,6 +7,20 @@ #include #include +#undef CP_TIME +//#define CP_TIME +#ifdef CP_TIME +#include +struct CpTime{ +public: + static void Add(clock_t x); + static void Sub(clock_t x); + static double Get(); +private: + static clock_t time_; +}; +#endif + class SentenceMetadata; struct Hypergraph; struct DecoderImpl; -- cgit v1.2.3 From f80150140b5273fd1eb0dfb34bdd789c4cbd35e6 Mon Sep 17 00:00:00 2001 From: andrea gesmundo Date: Fri, 8 Jul 2011 15:38:52 +0200 Subject: add exp log file with time measures --- expLog | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 expLog diff --git a/expLog b/expLog new file mode 100644 index 00000000..2070ac98 --- /dev/null +++ b/expLog @@ -0,0 +1,60 @@ +TIME MEASURES AFTER MERGE WITH cdec: +8/July/2011 +commit ed8a6e81d87f6e917ecf + +./runEval +Fri Jul 8 13:28:23 CEST 2011 +Fri Jul 8 13:30:24 CEST 2011 +Loading references (4 files) +Loaded reference translations for 919 sentences. +Loaded 919 references for scoring with ibm_bleu +BLEU = 32.25, 76.5|43.1|24.3|13.9 (brev=0.993) +0.322487 +Fri Jul 8 13:30:24 CEST 2011 +------------ +Fri Jul 8 15:04:00 CEST 2011 +Fri Jul 8 15:05:58 CEST 2011 +Time required for Cube Pruning execution: 77.61 seconds. +------------ +Fri Jul 8 15:24:39 CEST 2011 +Fri Jul 8 15:26:36 CEST 2011 +Time required for Cube Pruning execution: 79.01 seconds. +------------ + +./runEvalFCP +Fri Jul 8 13:33:17 CEST 2011 +Fri Jul 8 13:35:06 CEST 2011 +Loading references (4 files) +Loaded reference translations for 919 sentences. +Loaded 919 references for scoring with ibm_bleu +BLEU = 32.39, 76.5|43.1|24.5|14.0 (brev=0.994) +0.323857 +Fri Jul 8 13:35:07 CEST 2011 +------------ +Fri Jul 8 15:08:17 CEST 2011 +Fri Jul 8 15:10:05 CEST 2011 +Time required for Cube Pruning execution: 69.36 seconds. +------------ +Fri Jul 8 15:21:48 CEST 2011 +Fri Jul 8 15:23:35 CEST 2011 +Time required for Cube Pruning execution: 69.71 seconds. +------------ + +./runEvalFCP2 +Fri Jul 8 13:53:38 CEST 2011 +Fri Jul 8 13:55:29 CEST 2011 +Loading references (4 files) +Loaded reference translations for 919 sentences. +Loaded 919 references for scoring with ibm_bleu +BLEU = 32.49, 76.6|43.2|24.5|14.1 (brev=0.994) +0.324901 +Fri Jul 8 13:55:29 CEST 2011 +------------ +Fri Jul 8 15:12:52 CEST 2011 +Fri Jul 8 15:14:42 CEST 2011 +Time required for Cube Pruning execution: 72.66 seconds. +------------ +Fri Jul 8 15:19:13 CEST 2011 +Fri Jul 8 15:21:03 CEST 2011 +Time required for Cube Pruning execution: 72.06 seconds. +------------ -- cgit v1.2.3 From 95deb840699f9b6f8fe499b374bd726bce97365c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 10 Jul 2011 23:00:21 -0400 Subject: starting implementation of Hopkins&May (2011) optimizer --- Makefile.am | 2 +- configure.ac | 2 +- pro-train/Makefile.am | 13 + pro-train/README.shared-mem | 9 + pro-train/dist-pro.pl | 735 ++++++++++++++++++++++++++++++++++++++++++++ pro-train/mr_pro_map.cc | 111 +++++++ pro-train/mr_pro_reduce.cc | 81 +++++ 7 files changed, 951 insertions(+), 2 deletions(-) create mode 100644 pro-train/Makefile.am create mode 100644 pro-train/README.shared-mem create mode 100755 pro-train/dist-pro.pl create mode 100644 pro-train/mr_pro_map.cc create mode 100644 pro-train/mr_pro_reduce.cc diff --git a/Makefile.am b/Makefile.am index f5397d0b..98b4bac7 100644 --- a/Makefile.am +++ b/Makefile.am @@ -1,7 +1,7 @@ # warning - the subdirectories in the following list should # be kept in topologically sorted order. Also, DO NOT introduce # cyclic dependencies between these directories! -SUBDIRS = utils mteval klm/util klm/lm decoder phrasinator training mira vest extools +SUBDIRS = utils mteval klm/util klm/lm decoder phrasinator training mira vest pro-train extools #gi/pyp-topics/src gi/clda/src gi/posterior-regularisation/prjava diff --git a/configure.ac b/configure.ac index db2b4afc..4e708073 100644 --- a/configure.ac +++ b/configure.ac @@ -92,4 +92,4 @@ then AM_CONDITIONAL([GLC], true) fi -AC_OUTPUT(Makefile utils/Makefile mteval/Makefile extools/Makefile decoder/Makefile phrasinator/Makefile training/Makefile vest/Makefile klm/util/Makefile klm/lm/Makefile mira/Makefile gi/pyp-topics/src/Makefile gi/clda/src/Makefile) +AC_OUTPUT(Makefile utils/Makefile mteval/Makefile extools/Makefile decoder/Makefile phrasinator/Makefile training/Makefile vest/Makefile pro-train/Makefile klm/util/Makefile klm/lm/Makefile mira/Makefile gi/pyp-topics/src/Makefile gi/clda/src/Makefile) diff --git a/pro-train/Makefile.am b/pro-train/Makefile.am new file mode 100644 index 00000000..945ed5c3 --- /dev/null +++ b/pro-train/Makefile.am @@ -0,0 +1,13 @@ +bin_PROGRAMS = \ + mr_pro_map \ + mr_pro_reduce + +TESTS = lo_test + +mr_pro_map_SOURCES = mr_pro_map.cc +mr_pro_map_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz + +mr_pro_reduce_SOURCES = mr_pro_reduce.cc +mr_pro_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/pro-train/README.shared-mem b/pro-train/README.shared-mem new file mode 100644 index 00000000..7728efc0 --- /dev/null +++ b/pro-train/README.shared-mem @@ -0,0 +1,9 @@ +If you want to run dist-vest.pl on a very large shared memory machine, do the +following: + + ./dist-vest.pl --use-make I --decode-nodes J --weights weights.init --source-file=dev.src --ref-files=dev.ref.* cdec.ini + +This will use I jobs for doing the line search and J jobs to run the decoder. Typically, since the +decoder must load grammars, language models, etc., J should be smaller than I, but this will depend +on the system you are running on and the complexity of the models used for decoding. + diff --git a/pro-train/dist-pro.pl b/pro-train/dist-pro.pl new file mode 100755 index 00000000..35bccea4 --- /dev/null +++ b/pro-train/dist-pro.pl @@ -0,0 +1,735 @@ +#!/usr/bin/env perl +use strict; +my @ORIG_ARGV=@ARGV; +use Cwd qw(getcwd); +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../environment"; } + +# Skip local config (used for distributing jobs) if we're running in local-only mode +use LocalConfig; +use Getopt::Long; +use IPC::Open2; +use POSIX ":sys_wait_h"; +my $QSUB_CMD = qsub_args(mert_memory()); + +my $VEST_DIR="$SCRIPT_DIR/../vest"; +require "$VEST_DIR/libcall.pl"; + +# Default settings +my $srcFile; +my $refFiles; +my $bin_dir = $SCRIPT_DIR; +die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; +my $FAST_SCORE="$bin_dir/../mteval/fast_score"; +die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; +my $MAPINPUT = "$bin_dir/mr_pro_generate_mapper_input"; +my $MAPPER = "$bin_dir/mr_pro_map"; +my $REDUCER = "$bin_dir/mr_pro_reduce"; +my $parallelize = "$VEST_DIR/parallelize.pl"; +my $libcall = "$VEST_DIR/libcall.pl"; +my $sentserver = "$VEST_DIR/sentserver"; +my $sentclient = "$VEST_DIR/sentclient"; +my $LocalConfig = "$SCRIPT_DIR/../environment/LocalConfig.pm"; + +my $SCORER = $FAST_SCORE; +die "Can't find $MAPPER" unless -x $MAPPER; +my $cdec = "$bin_dir/../decoder/cdec"; +die "Can't find decoder in $cdec" unless -x $cdec; +die "Can't find $parallelize" unless -x $parallelize; +die "Can't find $libcall" unless -e $libcall; +my $decoder = $cdec; +my $lines_per_mapper = 400; +my $rand_directions = 15; +my $iteration = 1; +my $run_local = 0; +my $best_weights; +my $max_iterations = 15; +my $optimization_iters = 6; +my $decode_nodes = 15; # number of decode nodes +my $pmem = "9g"; +my $disable_clean = 0; +my %seen_weights; +my $normalize; +my $help = 0; +my $epsilon = 0.0001; +my $interval = 5; +my $dryrun = 0; +my $last_score = -10000000; +my $metric = "ibm_bleu"; +my $dir; +my $iniFile; +my $weights; +my $initialWeights; +my $decoderOpt; +my $noprimary; +my $maxsim=0; +my $oraclen=0; +my $oracleb=20; +my $bleu_weight=1; +my $use_make; # use make to parallelize line search +my $dirargs=''; +my $density_prune; +my $usefork; +my $pass_suffix = ''; +my $cpbin=1; +# Process command-line options +Getopt::Long::Configure("no_auto_abbrev"); +if (GetOptions( + "decoder=s" => \$decoderOpt, + "decode-nodes=i" => \$decode_nodes, + "density-prune=f" => \$density_prune, + "dont-clean" => \$disable_clean, + "pass-suffix=s" => \$pass_suffix, + "use-fork" => \$usefork, + "dry-run" => \$dryrun, + "epsilon=s" => \$epsilon, + "help" => \$help, + "interval" => \$interval, + "iteration=i" => \$iteration, + "local" => \$run_local, + "use-make=i" => \$use_make, + "max-iterations=i" => \$max_iterations, + "normalize=s" => \$normalize, + "pmem=s" => \$pmem, + "cpbin!" => \$cpbin, + "rand-directions=i" => \$rand_directions, + "random_directions=i" => \$rand_directions, + "bleu_weight=s" => \$bleu_weight, + "no-primary!" => \$noprimary, + "max-similarity=s" => \$maxsim, + "oracle-directions=i" => \$oraclen, + "n-oracle=i" => \$oraclen, + "oracle-batch=i" => \$oracleb, + "directions-args=s" => \$dirargs, + "ref-files=s" => \$refFiles, + "metric=s" => \$metric, + "source-file=s" => \$srcFile, + "weights=s" => \$initialWeights, + "workdir=s" => \$dir, + "opt-iterations=i" => \$optimization_iters, +) == 0 || @ARGV!=1 || $help) { + print_help(); + exit; +} + +if (defined $density_prune) { + die "--density_prune n: n must be greater than 1.0\n" unless $density_prune > 1.0; +} + +if ($usefork) { $usefork = "--use-fork"; } else { $usefork = ''; } + +if ($metric =~ /^(combi|ter)$/i) { + $lines_per_mapper = 40; +} elsif ($metric =~ /^meteor$/i) { + $lines_per_mapper = 2000; # start up time is really high +} + +($iniFile) = @ARGV; + + +sub write_config; +sub enseg; +sub print_help; + +my $nodelist; +my $host =check_output("hostname"); chomp $host; +my $bleu; +my $interval_count = 0; +my $logfile; +my $projected_score; + +# used in sorting scores +my $DIR_FLAG = '-r'; +if ($metric =~ /^ter$|^aer$/i) { + $DIR_FLAG = ''; +} + +my $refs_comma_sep = get_comma_sep_refs('r',$refFiles); + +unless ($dir){ + $dir = "vest"; +} +unless ($dir =~ /^\//){ # convert relative path to absolute path + my $basedir = check_output("pwd"); + chomp $basedir; + $dir = "$basedir/$dir"; +} + +if ($decoderOpt){ $decoder = $decoderOpt; } + + +# Initializations and helper functions +srand; + +my @childpids = (); +my @cleanupcmds = (); + +sub cleanup { + print STDERR "Cleanup...\n"; + for my $pid (@childpids){ unchecked_call("kill $pid"); } + for my $cmd (@cleanupcmds){ unchecked_call("$cmd"); } + exit 1; +}; +# Always call cleanup, no matter how we exit +*CORE::GLOBAL::exit = + sub{ cleanup(); }; +$SIG{INT} = "cleanup"; +$SIG{TERM} = "cleanup"; +$SIG{HUP} = "cleanup"; + +my $decoderBase = check_output("basename $decoder"); chomp $decoderBase; +my $newIniFile = "$dir/$decoderBase.ini"; +my $inputFileName = "$dir/input"; +my $user = $ENV{"USER"}; + + +# process ini file +-e $iniFile || die "Error: could not open $iniFile for reading\n"; +open(INI, $iniFile); + +use File::Basename qw(basename); +#pass bindir, refs to vars holding bin +sub modbin { + local $_; + my $bindir=shift; + check_call("mkdir -p $bindir"); + -d $bindir || die "couldn't make bindir $bindir"; + for (@_) { + my $src=$$_; + $$_="$bindir/".basename($src); + check_call("cp -p $src $$_"); + } +} +sub dirsize { + opendir ISEMPTY,$_[0]; + return scalar(readdir(ISEMPTY))-1; +} +if ($dryrun){ + write_config(*STDERR); + exit 0; +} else { + if (-e $dir && dirsize($dir)>1 && -e "$dir/hgs" ){ # allow preexisting logfile, binaries, but not dist-vest.pl outputs + die "ERROR: working dir $dir already exists\n\n"; + } else { + -e $dir || mkdir $dir; + mkdir "$dir/hgs"; + modbin("$dir/bin",\$LocalConfig,\$cdec,\$SCORER,\$MAPINPUT,\$MAPPER,\$REDUCER,\$parallelize,\$sentserver,\$sentclient,\$libcall) if $cpbin; + mkdir "$dir/scripts"; + my $cmdfile="$dir/rerun-vest.sh"; + open CMD,'>',$cmdfile; + print CMD "cd ",&getcwd,"\n"; +# print CMD &escaped_cmdline,"\n"; #buggy - last arg is quoted. + my $cline=&cmdline."\n"; + print CMD $cline; + close CMD; + print STDERR $cline; + chmod(0755,$cmdfile); + unless (-e $initialWeights) { + print STDERR "Please specify an initial weights file with --initial-weights\n"; + print_help(); + exit; + } + check_call("cp $initialWeights $dir/weights.0"); + die "Can't find weights.0" unless (-e "$dir/weights.0"); + } + write_config(*STDERR); +} + + +# Generate initial files and values +check_call("cp $iniFile $newIniFile"); +$iniFile = $newIniFile; + +my $newsrc = "$dir/dev.input"; +enseg($srcFile, $newsrc); +$srcFile = $newsrc; +my $devSize = 0; +open F, "<$srcFile" or die "Can't read $srcFile: $!"; +while() { $devSize++; } +close F; + +unless($best_weights){ $best_weights = $weights; } +unless($projected_score){ $projected_score = 0.0; } +$seen_weights{$weights} = 1; + +my $random_seed = int(time / 1000); +my $lastWeightsFile; +my $lastPScore = 0; +# main optimization loop +while (1){ + print STDERR "\n\nITERATION $iteration\n==========\n"; + + if ($iteration > $max_iterations){ + print STDERR "\nREACHED STOPPING CRITERION: Maximum iterations\n"; + last; + } + # iteration-specific files + my $runFile="$dir/run.raw.$iteration"; + my $onebestFile="$dir/1best.$iteration"; + my $logdir="$dir/logs.$iteration"; + my $decoderLog="$logdir/decoder.sentserver.log.$iteration"; + my $scorerLog="$logdir/scorer.log.$iteration"; + check_call("mkdir -p $logdir"); + + + #decode + print STDERR "RUNNING DECODER AT "; + print STDERR unchecked_output("date"); + my $im1 = $iteration - 1; + my $weightsFile="$dir/weights.$im1"; + my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; + if ($density_prune) { + $decoder_cmd .= " --density_prune $density_prune"; + } + my $pcmd; + if ($run_local) { + $pcmd = "cat $srcFile |"; + } elsif ($use_make) { + # TODO: Throw error when decode_nodes is specified along with use_make + $pcmd = "cat $srcFile | $parallelize --use-fork -p $pmem -e $logdir -j $use_make --"; + } else { + $pcmd = "cat $srcFile | $parallelize $usefork -p $pmem -e $logdir -j $decode_nodes --"; + } + my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $runFile"; + print STDERR "COMMAND:\n$cmd\n"; + check_bash_call($cmd); + my $num_hgs; + my $num_topbest; + my $retries = 0; + while($retries < 5) { + $num_hgs = check_output("ls $dir/hgs/*.gz | wc -l"); + $num_topbest = check_output("wc -l < $runFile"); + print STDERR "NUMBER OF HGs: $num_hgs\n"; + print STDERR "NUMBER OF TOP-BEST HYPs: $num_topbest\n"; + if($devSize == $num_hgs && $devSize == $num_topbest) { + last; + } else { + print STDERR "Incorrect number of hypergraphs or topbest. Waiting for distributed filesystem and retrying...\n"; + sleep(3); + } + $retries++; + } + die "Dev set contains $devSize sentences, but we don't have topbest and hypergraphs for all these! Decoder failure? Check $decoderLog\n" if ($devSize != $num_hgs || $devSize != $num_topbest); + my $dec_score = check_output("cat $runFile | $SCORER $refs_comma_sep -l $metric"); + chomp $dec_score; + print STDERR "DECODER SCORE: $dec_score\n"; + + # save space + check_call("gzip -f $runFile"); + check_call("gzip -f $decoderLog"); + + # run optimizer + print STDERR "RUNNING OPTIMIZER AT "; + print STDERR unchecked_output("date"); + my $mergeLog="$logdir/prune-merge.log.$iteration"; + + my $score = 0; + my $icc = 0; + my $inweights="$dir/weights.$im1"; + for (my $opt_iter=1; $opt_iter<$optimization_iters; $opt_iter++) { + print STDERR "\nGENERATE OPTIMIZATION STRATEGY (OPT-ITERATION $opt_iter/$optimization_iters)\n"; + print STDERR unchecked_output("date"); + $icc++; + my $nop=$noprimary?"--no_primary":""; + my $targs=$oraclen ? "--decoder_translations='$runFile.gz' ".get_comma_sep_refs('-references',$refFiles):""; + my $bwargs=$bleu_weight!=1 ? "--bleu_weight=$bleu_weight":""; + $cmd="$MAPINPUT -w $inweights -r $dir/hgs $bwargs -s $devSize -d $rand_directions --max_similarity=$maxsim --oracle_directions=$oraclen --oracle_batch=$oracleb $targs $dirargs > $dir/agenda.$im1-$opt_iter"; + print STDERR "COMMAND:\n$cmd\n"; + check_call($cmd); + check_call("mkdir -p $dir/splag.$im1"); + $cmd="split -a 3 -l $lines_per_mapper $dir/agenda.$im1-$opt_iter $dir/splag.$im1/mapinput."; + print STDERR "COMMAND:\n$cmd\n"; + check_call($cmd); + opendir(DIR, "$dir/splag.$im1") or die "Can't open directory: $!"; + my @shards = grep { /^mapinput\./ } readdir(DIR); + closedir DIR; + die "No shards!" unless scalar @shards > 0; + my $joblist = ""; + my $nmappers = 0; + my @mapoutputs = (); + @cleanupcmds = (); + my %o2i = (); + my $first_shard = 1; + my $mkfile; # only used with makefiles + my $mkfilename; + if ($use_make) { + $mkfilename = "$dir/splag.$im1/domap.mk"; + open $mkfile, ">$mkfilename" or die "Couldn't write $mkfilename: $!"; + print $mkfile "all: $dir/splag.$im1/map.done\n\n"; + } + my @mkouts = (); # only used with makefiles + for my $shard (@shards) { + my $mapoutput = $shard; + my $client_name = $shard; + $client_name =~ s/mapinput.//; + $client_name = "vest.$client_name"; + $mapoutput =~ s/mapinput/mapoutput/; + push @mapoutputs, "$dir/splag.$im1/$mapoutput"; + $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard"; + my $script = "$MAPPER -s $srcFile -l $metric $refs_comma_sep < $dir/splag.$im1/$shard | sort -t \$'\\t' -k 1 > $dir/splag.$im1/$mapoutput"; + if ($run_local) { + print STDERR "COMMAND:\n$script\n"; + check_bash_call($script); + } elsif ($use_make) { + my $script_file = "$dir/scripts/map.$shard"; + open F, ">$script_file" or die "Can't write $script_file: $!"; + print F "#!/bin/bash\n"; + print F "$script\n"; + close F; + my $output = "$dir/splag.$im1/$mapoutput"; + push @mkouts, $output; + chmod(0755, $script_file) or die "Can't chmod $script_file: $!"; + if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } + print $mkfile "$output: $dir/splag.$im1/$shard\n\t$script_file\n\n"; + } else { + my $script_file = "$dir/scripts/map.$shard"; + open F, ">$script_file" or die "Can't write $script_file: $!"; + print F "$script\n"; + close F; + if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } + + $nmappers++; + my $qcmd = "$QSUB_CMD -N $client_name -o /dev/null -e $logdir/$client_name.ER $script_file"; + my $jobid = check_output("$qcmd"); + chomp $jobid; + $jobid =~ s/^(\d+)(.*?)$/\1/g; + $jobid =~ s/^Your job (\d+) .*$/\1/; + push(@cleanupcmds, "qdel $jobid 2> /dev/null"); + print STDERR " $jobid"; + if ($joblist == "") { $joblist = $jobid; } + else {$joblist = $joblist . "\|" . $jobid; } + } + } + if ($run_local) { + print STDERR "\nProcessing line search complete.\n"; + } elsif ($use_make) { + print $mkfile "$dir/splag.$im1/map.done: @mkouts\n\ttouch $dir/splag.$im1/map.done\n\n"; + close $mkfile; + my $mcmd = "make -j $use_make -f $mkfilename"; + print STDERR "\nExecuting: $mcmd\n"; + check_call($mcmd); + } else { + print STDERR "\nLaunched $nmappers mappers.\n"; + sleep 8; + print STDERR "Waiting for mappers to complete...\n"; + while ($nmappers > 0) { + sleep 5; + my @livejobs = grep(/$joblist/, split(/\n/, unchecked_output("qstat | grep -v ' C '"))); + $nmappers = scalar @livejobs; + } + print STDERR "All mappers complete.\n"; + } + my $tol = 0; + my $til = 0; + for my $mo (@mapoutputs) { + my $olines = get_lines($mo); + my $ilines = get_lines($o2i{$mo}); + $tol += $olines; + $til += $ilines; + die "$mo: output lines ($olines) doesn't match input lines ($ilines)" unless $olines==$ilines; + } + print STDERR "Results for $tol/$til lines\n"; + print STDERR "\nSORTING AND RUNNING VEST REDUCER\n"; + print STDERR unchecked_output("date"); + $cmd="sort -t \$'\\t' -k 1 @mapoutputs | $REDUCER -l $metric > $dir/redoutput.$im1"; + print STDERR "COMMAND:\n$cmd\n"; + check_bash_call($cmd); + $cmd="sort -nk3 $DIR_FLAG '-t|' $dir/redoutput.$im1 | head -1"; + # sort returns failure even when it doesn't fail for some reason + my $best=unchecked_output("$cmd"); chomp $best; + print STDERR "$best\n"; + my ($oa, $x, $xscore) = split /\|/, $best; + $score = $xscore; + print STDERR "PROJECTED SCORE: $score\n"; + if (abs($x) < $epsilon) { + print STDERR "\nOPTIMIZER: no score improvement: abs($x) < $epsilon\n"; + last; + } + my $psd = $score - $last_score; + $last_score = $score; + if (abs($psd) < $epsilon) { + print STDERR "\nOPTIMIZER: no score improvement: abs($psd) < $epsilon\n"; + last; + } + my ($origin, $axis) = split /\s+/, $oa; + + my %ori = convert($origin); + my %axi = convert($axis); + + my $finalFile="$dir/weights.$im1-$opt_iter"; + open W, ">$finalFile" or die "Can't write: $finalFile: $!"; + my $norm = 0; + for my $k (sort keys %ori) { + my $dd = $ori{$k} + $axi{$k} * $x; + $norm += $dd * $dd; + } + $norm = sqrt($norm); + $norm = 1; + for my $k (sort keys %ori) { + my $v = ($ori{$k} + $axi{$k} * $x) / $norm; + print W "$k $v\n"; + } + check_call("rm $dir/splag.$im1/*"); + $inweights = $finalFile; + } + $lastWeightsFile = "$dir/weights.$iteration"; + check_call("cp $inweights $lastWeightsFile"); + if ($icc < 2) { + print STDERR "\nREACHED STOPPING CRITERION: score change too little\n"; + last; + } + $lastPScore = $score; + $iteration++; + print STDERR "\n==========\n"; +} + +print STDERR "\nFINAL WEIGHTS: $lastWeightsFile\n(Use -w with the decoder)\n\n"; + +print STDOUT "$lastWeightsFile\n"; + +exit 0; + +sub normalize_weights { + my ($rfn, $rpts, $feat) = @_; + my @feat_names = @$rfn; + my @pts = @$rpts; + my $z = 1.0; + for (my $i=0; $i < scalar @feat_names; $i++) { + if ($feat_names[$i] eq $feat) { + $z = $pts[$i]; + last; + } + } + for (my $i=0; $i < scalar @feat_names; $i++) { + $pts[$i] /= $z; + } + print STDERR " NORM WEIGHTS: @pts\n"; + return @pts; +} + +sub get_lines { + my $fn = shift @_; + open FL, "<$fn" or die "Couldn't read $fn: $!"; + my $lc = 0; + while() { $lc++; } + return $lc; +} + +sub get_comma_sep_refs { + my ($r,$p) = @_; + my $o = check_output("echo $p"); + chomp $o; + my @files = split /\s+/, $o; + return "-$r " . join(" -$r ", @files); +} + +sub read_weights_file { + my ($file) = @_; + open F, "<$file" or die "Couldn't read $file: $!"; + my @r = (); + my $pm = -1; + while() { + next if /^#/; + next if /^\s*$/; + chomp; + if (/^(.+)\s+(.+)$/) { + my $m = $1; + my $w = $2; + die "Weights out of order: $m <= $pm" unless $m > $pm; + push @r, $w; + } else { + warn "Unexpected feature name in weight file: $_"; + } + } + close F; + return join ' ', @r; +} + +# subs +sub write_config { + my $fh = shift; + my $cleanup = "yes"; + if ($disable_clean) {$cleanup = "no";} + + print $fh "\n"; + print $fh "DECODER: $decoder\n"; + print $fh "INI FILE: $iniFile\n"; + print $fh "WORKING DIR: $dir\n"; + print $fh "SOURCE (DEV): $srcFile\n"; + print $fh "REFS (DEV): $refFiles\n"; + print $fh "EVAL METRIC: $metric\n"; + print $fh "START ITERATION: $iteration\n"; + print $fh "MAX ITERATIONS: $max_iterations\n"; + print $fh "DECODE NODES: $decode_nodes\n"; + print $fh "HEAD NODE: $host\n"; + print $fh "PMEM (DECODING): $pmem\n"; + print $fh "CLEANUP: $cleanup\n"; + print $fh "INITIAL WEIGHTS: $initialWeights\n"; +} + +sub update_weights_file { + my ($neww, $rfn, $rpts) = @_; + my @feats = @$rfn; + my @pts = @$rpts; + my $num_feats = scalar @feats; + my $num_pts = scalar @pts; + die "$num_feats (num_feats) != $num_pts (num_pts)" unless $num_feats == $num_pts; + open G, ">$neww" or die; + for (my $i = 0; $i < $num_feats; $i++) { + my $f = $feats[$i]; + my $lambda = $pts[$i]; + print G "$f $lambda\n"; + } + close G; +} + +sub enseg { + my $src = shift; + my $newsrc = shift; + open(SRC, $src); + open(NEWSRC, ">$newsrc"); + my $i=0; + while (my $line=){ + chomp $line; + if ($line =~ /^\s* tags, you must include a zero-based id attribute"; + } + } else { + print NEWSRC "$line\n"; + } + $i++; + } + close SRC; + close NEWSRC; +} + +sub print_help { + + my $executable = check_output("basename $0"); chomp $executable; + print << "Help"; + +Usage: $executable [options] + + $executable [options] + Runs a complete MERT optimization and test set decoding, using + the decoder configuration in ini file. Note that many of the + options have default values that are inferred automatically + based on certain conventions. For details, refer to descriptions + of the options --decoder, --weights, and --workdir. + +Options: + + --local + Run the decoder and optimizer locally with a single thread. + + --use-make + Use make -j to run the optimizer commands (useful on large + shared-memory machines where qsub is unavailable). + + --decode-nodes + Number of decoder processes to run in parallel. [default=15] + + --decoder + Decoder binary to use. + + --density-prune + Limit the density of the hypergraph on each iteration to N times + the number of edges on the Viterbi path. + + --help + Print this message and exit. + + --iteration + Starting iteration number. If not specified, defaults to 1. + + --max-iterations + Maximum number of iterations to run. If not specified, defaults + to 10. + + --pass-suffix + If the decoder is doing multi-pass decoding, the pass suffix "2", + "3", etc., is used to control what iteration of weights is set. + + --pmem + Amount of physical memory requested for parallel decoding jobs. + + --ref-files + Dev set ref files. This option takes only a single string argument. + To use multiple files (including file globbing), this argument should + be quoted. + + --metric + Metric to optimize. + Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi + + --normalize + After each iteration, rescale all feature weights such that feature- + name has a weight of 1.0. + + --rand-directions + MERT will attempt to optimize along all of the principle directions, + set this parameter to explore other directions. Defaults to 5. + + --source-file + Dev set source file. + + --weights + A file specifying initial feature weights. The format is + FeatureName_1 value1 + FeatureName_2 value2 + + --workdir + Directory for intermediate and output files. If not specified, the + name is derived from the ini filename. Assuming that the ini + filename begins with the decoder name and ends with ini, the default + name of the working directory is inferred from the middle part of + the filename. E.g. an ini file named decoder.foo.ini would have + a default working directory name foo. + +Help +} + +sub convert { + my ($str) = @_; + my @ps = split /;/, $str; + my %dict = (); + for my $p (@ps) { + my ($k, $v) = split /=/, $p; + $dict{$k} = $v; + } + return %dict; +} + + + +sub cmdline { + return join ' ',($0,@ORIG_ARGV); +} + +#buggy: last arg gets quoted sometimes? +my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; +my $shell_escape_in_quote=qr{[\\"\$`!]}; + +sub escape_shell { + my ($arg)=@_; + return undef unless defined $arg; + if ($arg =~ /$is_shell_special/) { + $arg =~ s/($shell_escape_in_quote)/\\$1/g; + return "\"$arg\""; + } + return $arg; +} + +sub escaped_shell_args { + return map {local $_=$_;chomp;escape_shell($_)} @_; +} + +sub escaped_shell_args_str { + return join ' ',&escaped_shell_args(@_); +} + +sub escaped_cmdline { + return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); +} diff --git a/pro-train/mr_pro_map.cc b/pro-train/mr_pro_map.cc new file mode 100644 index 00000000..b046cdea --- /dev/null +++ b/pro-train/mr_pro_map.cc @@ -0,0 +1,111 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "sampler.h" +#include "filelib.h" +#include "stringlib.h" +#include "scorer.h" +#include "inside_outside.h" +#include "hg_io.h" +#include "kbest.h" +#include "viterbi.h" + +// This is Figure 4 (Algorithm Sampler) from Hopkins&May (2011) + +using namespace std; +namespace po = boost::program_options; + +boost::shared_ptr rng; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") + ("source,s",po::value(), "Source file (ignored, except for AER)") + ("loss_function,l",po::value()->default_value("ibm_bleu"), "Loss function being optimized") + ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") + ("weights,w",po::value(), "[REQD] Current weights file") + ("kbest_size,k",po::value()->default_value(1500u), "Top k-hypotheses to extract") + ("candidate_pairs,G", po::value()->default_value(5000u), "Number of pairs to sample per hypothesis (Gamma)") + ("best_pairs,X", po::value()->default_value(50u), "Number of pairs, ranked by magnitude of objective delta, to retain (Xi)") + ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (!conf->count("reference")) { + cerr << "Please specify one or more references using -r \n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +struct HypInfo { + HypInfo(const vector& h, const SparseVector& feats) : hyp(h), g_(-1), x(feats) {} + double g() { + return g_; + } + private: + int sent_id; + vector hyp; + double g_; + public: + SparseVector x; +}; + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + if (conf.count("random_seed")) + rng.reset(new MT19937(conf["random_seed"].as())); + else + rng.reset(new MT19937); + const string loss_function = conf["loss_function"].as(); + ScoreType type = ScoreTypeFromString(loss_function); + DocScorer ds(type, conf["reference"].as >(), conf["source"].as()); + cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function << endl; + Hypergraph hg; + string last_file; + ReadFile in_read(conf["input"].as()); + istream &in=*in_read.stream(); + const unsigned kbest_size = conf["kbest_size"].as(); + const unsigned gamma = conf["candidate_pairs"].as(); + const unsigned xi = conf["best_pairs"].as(); + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + istringstream is(line); + int sent_id; + string file; + // path-to-file (JSON) sent_id + is >> file >> sent_id; + ReadFile rf(file); + HypergraphIO::ReadFromJSON(rf.stream(), &hg); + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, kbest_size); + + vector J_i; + for (int i = 0; i < kbest_size; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); + // if (invert_score) sentscore *= -1.0; + // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; + d->feature_values; + sentscore; + } + } + return 0; +} + diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc new file mode 100644 index 00000000..3df52020 --- /dev/null +++ b/pro-train/mr_pro_reduce.cc @@ -0,0 +1,81 @@ +#include +#include +#include +#include + +#include +#include + +#include "sparse_vector.h" +#include "error_surface.h" +#include "line_optimizer.h" +#include "b64tools.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("loss_function,l",po::value(), "Loss function being optimized") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = conf->count("loss_function") == 0; + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string loss_function = conf["loss_function"].as(); + ScoreType type = ScoreTypeFromString(loss_function); + LineOptimizer::ScoreType opt_type = LineOptimizer::MAXIMIZE_SCORE; + if (type == TER || type == AER) { + opt_type = LineOptimizer::MINIMIZE_SCORE; + } + string last_key; + vector esv; + while(cin) { + string line; + getline(cin, line); + if (line.empty()) continue; + size_t ks = line.find("\t"); + assert(string::npos != ks); + assert(ks > 2); + string key = line.substr(2, ks - 2); + string val = line.substr(ks + 1); + if (key != last_key) { + if (!last_key.empty()) { + float score; + double x = LineOptimizer::LineOptimize(esv, opt_type, &score); + cout << last_key << "|" << x << "|" << score << endl; + } + last_key = key; + esv.clear(); + } + if (val.size() % 4 != 0) { + cerr << "B64 encoding error 1! Skipping.\n"; + continue; + } + string encoded(val.size() / 4 * 3, '\0'); + if (!B64::b64decode(reinterpret_cast(&val[0]), val.size(), &encoded[0], encoded.size())) { + cerr << "B64 encoding error 2! Skipping.\n"; + continue; + } + esv.push_back(ErrorSurface()); + esv.back().Deserialize(type, encoded); + } + if (!esv.empty()) { + // cerr << "ESV=" << esv.size() << endl; + // for (int i = 0; i < esv.size(); ++i) { cerr << esv[i].size() << endl; } + float score; + double x = LineOptimizer::LineOptimize(esv, opt_type, &score); + cout << last_key << "|" << x << "|" << score << endl; + } + return 0; +} -- cgit v1.2.3 From bde4a34bab96052570c248f7d9ccc299a9a3f097 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 11 Jul 2011 20:39:45 -0400 Subject: sort of working hopkins&may optimizer --- pro-train/Makefile.am | 4 +- pro-train/dist-pro.pl | 308 ++++++++++-------------------- pro-train/mr_pro_generate_mapper_input.pl | 18 ++ pro-train/mr_pro_map.cc | 118 ++++++++++-- pro-train/mr_pro_reduce.cc | 167 ++++++++++++---- 5 files changed, 349 insertions(+), 266 deletions(-) create mode 100755 pro-train/mr_pro_generate_mapper_input.pl diff --git a/pro-train/Makefile.am b/pro-train/Makefile.am index 945ed5c3..fdaf43e2 100644 --- a/pro-train/Makefile.am +++ b/pro-train/Makefile.am @@ -8,6 +8,6 @@ mr_pro_map_SOURCES = mr_pro_map.cc mr_pro_map_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz mr_pro_reduce_SOURCES = mr_pro_reduce.cc -mr_pro_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz +mr_pro_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/training/optimize.o $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz -AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training diff --git a/pro-train/dist-pro.pl b/pro-train/dist-pro.pl index 35bccea4..55d7f1fa 100755 --- a/pro-train/dist-pro.pl +++ b/pro-train/dist-pro.pl @@ -21,7 +21,7 @@ my $bin_dir = $SCRIPT_DIR; die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; my $FAST_SCORE="$bin_dir/../mteval/fast_score"; die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; -my $MAPINPUT = "$bin_dir/mr_pro_generate_mapper_input"; +my $MAPINPUT = "$bin_dir/mr_pro_generate_mapper_input.pl"; my $MAPPER = "$bin_dir/mr_pro_map"; my $REDUCER = "$bin_dir/mr_pro_reduce"; my $parallelize = "$VEST_DIR/parallelize.pl"; @@ -37,8 +37,7 @@ die "Can't find decoder in $cdec" unless -x $cdec; die "Can't find $parallelize" unless -x $parallelize; die "Can't find $libcall" unless -e $libcall; my $decoder = $cdec; -my $lines_per_mapper = 400; -my $rand_directions = 15; +my $lines_per_mapper = 100; my $iteration = 1; my $run_local = 0; my $best_weights; @@ -58,7 +57,6 @@ my $metric = "ibm_bleu"; my $dir; my $iniFile; my $weights; -my $initialWeights; my $decoderOpt; my $noprimary; my $maxsim=0; @@ -67,7 +65,6 @@ my $oracleb=20; my $bleu_weight=1; my $use_make; # use make to parallelize line search my $dirargs=''; -my $density_prune; my $usefork; my $pass_suffix = ''; my $cpbin=1; @@ -76,7 +73,6 @@ Getopt::Long::Configure("no_auto_abbrev"); if (GetOptions( "decoder=s" => \$decoderOpt, "decode-nodes=i" => \$decode_nodes, - "density-prune=f" => \$density_prune, "dont-clean" => \$disable_clean, "pass-suffix=s" => \$pass_suffix, "use-fork" => \$usefork, @@ -91,8 +87,6 @@ if (GetOptions( "normalize=s" => \$normalize, "pmem=s" => \$pmem, "cpbin!" => \$cpbin, - "rand-directions=i" => \$rand_directions, - "random_directions=i" => \$rand_directions, "bleu_weight=s" => \$bleu_weight, "no-primary!" => \$noprimary, "max-similarity=s" => \$maxsim, @@ -103,18 +97,12 @@ if (GetOptions( "ref-files=s" => \$refFiles, "metric=s" => \$metric, "source-file=s" => \$srcFile, - "weights=s" => \$initialWeights, "workdir=s" => \$dir, - "opt-iterations=i" => \$optimization_iters, ) == 0 || @ARGV!=1 || $help) { print_help(); exit; } -if (defined $density_prune) { - die "--density_prune n: n must be greater than 1.0\n" unless $density_prune > 1.0; -} - if ($usefork) { $usefork = "--use-fork"; } else { $usefork = ''; } if ($metric =~ /^(combi|ter)$/i) { @@ -146,7 +134,7 @@ if ($metric =~ /^ter$|^aer$/i) { my $refs_comma_sep = get_comma_sep_refs('r',$refFiles); unless ($dir){ - $dir = "vest"; + $dir = "protrain"; } unless ($dir =~ /^\//){ # convert relative path to absolute path my $basedir = check_output("pwd"); @@ -203,18 +191,19 @@ sub dirsize { opendir ISEMPTY,$_[0]; return scalar(readdir(ISEMPTY))-1; } +my @allweights; if ($dryrun){ write_config(*STDERR); exit 0; } else { - if (-e $dir && dirsize($dir)>1 && -e "$dir/hgs" ){ # allow preexisting logfile, binaries, but not dist-vest.pl outputs + if (-e $dir && dirsize($dir)>1 && -e "$dir/hgs" ){ # allow preexisting logfile, binaries, but not dist-pro.pl outputs die "ERROR: working dir $dir already exists\n\n"; } else { -e $dir || mkdir $dir; mkdir "$dir/hgs"; modbin("$dir/bin",\$LocalConfig,\$cdec,\$SCORER,\$MAPINPUT,\$MAPPER,\$REDUCER,\$parallelize,\$sentserver,\$sentclient,\$libcall) if $cpbin; mkdir "$dir/scripts"; - my $cmdfile="$dir/rerun-vest.sh"; + my $cmdfile="$dir/rerun-pro.sh"; open CMD,'>',$cmdfile; print CMD "cd ",&getcwd,"\n"; # print CMD &escaped_cmdline,"\n"; #buggy - last arg is quoted. @@ -223,13 +212,8 @@ if ($dryrun){ close CMD; print STDERR $cline; chmod(0755,$cmdfile); - unless (-e $initialWeights) { - print STDERR "Please specify an initial weights file with --initial-weights\n"; - print_help(); - exit; - } - check_call("cp $initialWeights $dir/weights.0"); - die "Can't find weights.0" unless (-e "$dir/weights.0"); + check_call("touch $dir/weights.0"); + die "Can't find weights.0" unless (-e "$dir/weights.0"); } write_config(*STDERR); } @@ -255,6 +239,7 @@ my $random_seed = int(time / 1000); my $lastWeightsFile; my $lastPScore = 0; # main optimization loop +my @mapoutputs = (); # aggregate map outputs over all iters while (1){ print STDERR "\n\nITERATION $iteration\n==========\n"; @@ -276,10 +261,8 @@ while (1){ print STDERR unchecked_output("date"); my $im1 = $iteration - 1; my $weightsFile="$dir/weights.$im1"; + push @allweights, "-w $dir/weights.$im1"; my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; - if ($density_prune) { - $decoder_cmd .= " --density_prune $density_prune"; - } my $pcmd; if ($run_local) { $pcmd = "cat $srcFile |"; @@ -320,163 +303,111 @@ while (1){ # run optimizer print STDERR "RUNNING OPTIMIZER AT "; print STDERR unchecked_output("date"); + print STDERR " - GENERATE TRAINING EXEMPLARS\n"; my $mergeLog="$logdir/prune-merge.log.$iteration"; my $score = 0; my $icc = 0; my $inweights="$dir/weights.$im1"; - for (my $opt_iter=1; $opt_iter<$optimization_iters; $opt_iter++) { - print STDERR "\nGENERATE OPTIMIZATION STRATEGY (OPT-ITERATION $opt_iter/$optimization_iters)\n"; - print STDERR unchecked_output("date"); - $icc++; - my $nop=$noprimary?"--no_primary":""; - my $targs=$oraclen ? "--decoder_translations='$runFile.gz' ".get_comma_sep_refs('-references',$refFiles):""; - my $bwargs=$bleu_weight!=1 ? "--bleu_weight=$bleu_weight":""; - $cmd="$MAPINPUT -w $inweights -r $dir/hgs $bwargs -s $devSize -d $rand_directions --max_similarity=$maxsim --oracle_directions=$oraclen --oracle_batch=$oracleb $targs $dirargs > $dir/agenda.$im1-$opt_iter"; - print STDERR "COMMAND:\n$cmd\n"; - check_call($cmd); - check_call("mkdir -p $dir/splag.$im1"); - $cmd="split -a 3 -l $lines_per_mapper $dir/agenda.$im1-$opt_iter $dir/splag.$im1/mapinput."; - print STDERR "COMMAND:\n$cmd\n"; - check_call($cmd); - opendir(DIR, "$dir/splag.$im1") or die "Can't open directory: $!"; - my @shards = grep { /^mapinput\./ } readdir(DIR); - closedir DIR; - die "No shards!" unless scalar @shards > 0; - my $joblist = ""; - my $nmappers = 0; - my @mapoutputs = (); - @cleanupcmds = (); - my %o2i = (); - my $first_shard = 1; - my $mkfile; # only used with makefiles - my $mkfilename; - if ($use_make) { - $mkfilename = "$dir/splag.$im1/domap.mk"; - open $mkfile, ">$mkfilename" or die "Couldn't write $mkfilename: $!"; - print $mkfile "all: $dir/splag.$im1/map.done\n\n"; - } - my @mkouts = (); # only used with makefiles - for my $shard (@shards) { - my $mapoutput = $shard; - my $client_name = $shard; - $client_name =~ s/mapinput.//; - $client_name = "vest.$client_name"; - $mapoutput =~ s/mapinput/mapoutput/; - push @mapoutputs, "$dir/splag.$im1/$mapoutput"; - $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard"; - my $script = "$MAPPER -s $srcFile -l $metric $refs_comma_sep < $dir/splag.$im1/$shard | sort -t \$'\\t' -k 1 > $dir/splag.$im1/$mapoutput"; - if ($run_local) { - print STDERR "COMMAND:\n$script\n"; - check_bash_call($script); - } elsif ($use_make) { - my $script_file = "$dir/scripts/map.$shard"; - open F, ">$script_file" or die "Can't write $script_file: $!"; - print F "#!/bin/bash\n"; - print F "$script\n"; - close F; - my $output = "$dir/splag.$im1/$mapoutput"; - push @mkouts, $output; - chmod(0755, $script_file) or die "Can't chmod $script_file: $!"; - if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } - print $mkfile "$output: $dir/splag.$im1/$shard\n\t$script_file\n\n"; - } else { - my $script_file = "$dir/scripts/map.$shard"; - open F, ">$script_file" or die "Can't write $script_file: $!"; - print F "$script\n"; - close F; - if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } - - $nmappers++; - my $qcmd = "$QSUB_CMD -N $client_name -o /dev/null -e $logdir/$client_name.ER $script_file"; - my $jobid = check_output("$qcmd"); - chomp $jobid; - $jobid =~ s/^(\d+)(.*?)$/\1/g; - $jobid =~ s/^Your job (\d+) .*$/\1/; - push(@cleanupcmds, "qdel $jobid 2> /dev/null"); - print STDERR " $jobid"; - if ($joblist == "") { $joblist = $jobid; } - else {$joblist = $joblist . "\|" . $jobid; } - } - } + $cmd="$MAPINPUT $dir/hgs > $dir/agenda.$im1"; + print STDERR "COMMAND:\n$cmd\n"; + check_call($cmd); + check_call("mkdir -p $dir/splag.$im1"); + $cmd="split -a 3 -l $lines_per_mapper $dir/agenda.$im1 $dir/splag.$im1/mapinput."; + print STDERR "COMMAND:\n$cmd\n"; + check_call($cmd); + opendir(DIR, "$dir/splag.$im1") or die "Can't open directory: $!"; + my @shards = grep { /^mapinput\./ } readdir(DIR); + closedir DIR; + die "No shards!" unless scalar @shards > 0; + my $joblist = ""; + my $nmappers = 0; + @cleanupcmds = (); + my %o2i = (); + my $first_shard = 1; + my $mkfile; # only used with makefiles + my $mkfilename; + if ($use_make) { + $mkfilename = "$dir/splag.$im1/domap.mk"; + open $mkfile, ">$mkfilename" or die "Couldn't write $mkfilename: $!"; + print $mkfile "all: $dir/splag.$im1/map.done\n\n"; + } + my @mkouts = (); # only used with makefiles + for my $shard (@shards) { + my $mapoutput = $shard; + my $client_name = $shard; + $client_name =~ s/mapinput.//; + $client_name = "pro.$client_name"; + $mapoutput =~ s/mapinput/mapoutput/; + push @mapoutputs, "$dir/splag.$im1/$mapoutput"; + $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard"; + my $script = "$MAPPER -s $srcFile -l $metric $refs_comma_sep @allweights < $dir/splag.$im1/$shard > $dir/splag.$im1/$mapoutput"; if ($run_local) { - print STDERR "\nProcessing line search complete.\n"; + print STDERR "COMMAND:\n$script\n"; + check_bash_call($script); } elsif ($use_make) { - print $mkfile "$dir/splag.$im1/map.done: @mkouts\n\ttouch $dir/splag.$im1/map.done\n\n"; - close $mkfile; - my $mcmd = "make -j $use_make -f $mkfilename"; - print STDERR "\nExecuting: $mcmd\n"; - check_call($mcmd); + my $script_file = "$dir/scripts/map.$shard"; + open F, ">$script_file" or die "Can't write $script_file: $!"; + print F "#!/bin/bash\n"; + print F "$script\n"; + close F; + my $output = "$dir/splag.$im1/$mapoutput"; + push @mkouts, $output; + chmod(0755, $script_file) or die "Can't chmod $script_file: $!"; + if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } + print $mkfile "$output: $dir/splag.$im1/$shard\n\t$script_file\n\n"; } else { - print STDERR "\nLaunched $nmappers mappers.\n"; - sleep 8; - print STDERR "Waiting for mappers to complete...\n"; - while ($nmappers > 0) { - sleep 5; - my @livejobs = grep(/$joblist/, split(/\n/, unchecked_output("qstat | grep -v ' C '"))); - $nmappers = scalar @livejobs; - } - print STDERR "All mappers complete.\n"; + my $script_file = "$dir/scripts/map.$shard"; + open F, ">$script_file" or die "Can't write $script_file: $!"; + print F "$script\n"; + close F; + if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } + + $nmappers++; + my $qcmd = "$QSUB_CMD -N $client_name -o /dev/null -e $logdir/$client_name.ER $script_file"; + my $jobid = check_output("$qcmd"); + chomp $jobid; + $jobid =~ s/^(\d+)(.*?)$/\1/g; + $jobid =~ s/^Your job (\d+) .*$/\1/; + push(@cleanupcmds, "qdel $jobid 2> /dev/null"); + print STDERR " $jobid"; + if ($joblist == "") { $joblist = $jobid; } + else {$joblist = $joblist . "\|" . $jobid; } } - my $tol = 0; - my $til = 0; - for my $mo (@mapoutputs) { - my $olines = get_lines($mo); - my $ilines = get_lines($o2i{$mo}); - $tol += $olines; - $til += $ilines; - die "$mo: output lines ($olines) doesn't match input lines ($ilines)" unless $olines==$ilines; - } - print STDERR "Results for $tol/$til lines\n"; - print STDERR "\nSORTING AND RUNNING VEST REDUCER\n"; - print STDERR unchecked_output("date"); - $cmd="sort -t \$'\\t' -k 1 @mapoutputs | $REDUCER -l $metric > $dir/redoutput.$im1"; - print STDERR "COMMAND:\n$cmd\n"; - check_bash_call($cmd); - $cmd="sort -nk3 $DIR_FLAG '-t|' $dir/redoutput.$im1 | head -1"; - # sort returns failure even when it doesn't fail for some reason - my $best=unchecked_output("$cmd"); chomp $best; - print STDERR "$best\n"; - my ($oa, $x, $xscore) = split /\|/, $best; - $score = $xscore; - print STDERR "PROJECTED SCORE: $score\n"; - if (abs($x) < $epsilon) { - print STDERR "\nOPTIMIZER: no score improvement: abs($x) < $epsilon\n"; - last; - } - my $psd = $score - $last_score; - $last_score = $score; - if (abs($psd) < $epsilon) { - print STDERR "\nOPTIMIZER: no score improvement: abs($psd) < $epsilon\n"; - last; - } - my ($origin, $axis) = split /\s+/, $oa; - - my %ori = convert($origin); - my %axi = convert($axis); - - my $finalFile="$dir/weights.$im1-$opt_iter"; - open W, ">$finalFile" or die "Can't write: $finalFile: $!"; - my $norm = 0; - for my $k (sort keys %ori) { - my $dd = $ori{$k} + $axi{$k} * $x; - $norm += $dd * $dd; - } - $norm = sqrt($norm); - $norm = 1; - for my $k (sort keys %ori) { - my $v = ($ori{$k} + $axi{$k} * $x) / $norm; - print W "$k $v\n"; + } + if ($run_local) { + print STDERR "\nCompleted extraction of training exemplars.\n"; + } elsif ($use_make) { + print $mkfile "$dir/splag.$im1/map.done: @mkouts\n\ttouch $dir/splag.$im1/map.done\n\n"; + close $mkfile; + my $mcmd = "make -j $use_make -f $mkfilename"; + print STDERR "\nExecuting: $mcmd\n"; + check_call($mcmd); + } else { + print STDERR "\nLaunched $nmappers mappers.\n"; + sleep 8; + print STDERR "Waiting for mappers to complete...\n"; + while ($nmappers > 0) { + sleep 5; + my @livejobs = grep(/$joblist/, split(/\n/, unchecked_output("qstat | grep -v ' C '"))); + $nmappers = scalar @livejobs; } - check_call("rm $dir/splag.$im1/*"); - $inweights = $finalFile; + print STDERR "All mappers complete.\n"; } - $lastWeightsFile = "$dir/weights.$iteration"; - check_call("cp $inweights $lastWeightsFile"); - if ($icc < 2) { - print STDERR "\nREACHED STOPPING CRITERION: score change too little\n"; - last; + my $tol = 0; + my $til = 0; + print STDERR "MO: @mapoutputs\n"; + for my $mo (@mapoutputs) { + #my $olines = get_lines($mo); + #my $ilines = get_lines($o2i{$mo}); + #die "$mo: no training instances generated!" if $olines == 0; } + print STDERR "\nRUNNING CLASSIFIER (REDUCER)\n"; + print STDERR unchecked_output("date"); + $cmd="cat @mapoutputs | $REDUCER -w $dir/weights.$im1 > $dir/weights.$iteration"; + print STDERR "COMMAND:\n$cmd\n"; + check_bash_call($cmd); + $lastWeightsFile = "$dir/weights.$iteration"; $lastPScore = $score; $iteration++; print STDERR "\n==========\n"; @@ -488,24 +419,6 @@ print STDOUT "$lastWeightsFile\n"; exit 0; -sub normalize_weights { - my ($rfn, $rpts, $feat) = @_; - my @feat_names = @$rfn; - my @pts = @$rpts; - my $z = 1.0; - for (my $i=0; $i < scalar @feat_names; $i++) { - if ($feat_names[$i] eq $feat) { - $z = $pts[$i]; - last; - } - } - for (my $i=0; $i < scalar @feat_names; $i++) { - $pts[$i] /= $z; - } - print STDERR " NORM WEIGHTS: @pts\n"; - return @pts; -} - sub get_lines { my $fn = shift @_; open FL, "<$fn" or die "Couldn't read $fn: $!"; @@ -563,7 +476,6 @@ sub write_config { print $fh "HEAD NODE: $host\n"; print $fh "PMEM (DECODING): $pmem\n"; print $fh "CLEANUP: $cleanup\n"; - print $fh "INITIAL WEIGHTS: $initialWeights\n"; } sub update_weights_file { @@ -603,6 +515,7 @@ sub enseg { } close SRC; close NEWSRC; + die "Empty dev set!" if ($i == 0); } sub print_help { @@ -634,10 +547,6 @@ Options: --decoder Decoder binary to use. - --density-prune - Limit the density of the hypergraph on each iteration to N times - the number of edges on the Viterbi path. - --help Print this message and exit. @@ -668,18 +577,9 @@ Options: After each iteration, rescale all feature weights such that feature- name has a weight of 1.0. - --rand-directions - MERT will attempt to optimize along all of the principle directions, - set this parameter to explore other directions. Defaults to 5. - --source-file Dev set source file. - --weights - A file specifying initial feature weights. The format is - FeatureName_1 value1 - FeatureName_2 value2 - --workdir Directory for intermediate and output files. If not specified, the name is derived from the ini filename. Assuming that the ini diff --git a/pro-train/mr_pro_generate_mapper_input.pl b/pro-train/mr_pro_generate_mapper_input.pl new file mode 100755 index 00000000..b30fc4fd --- /dev/null +++ b/pro-train/mr_pro_generate_mapper_input.pl @@ -0,0 +1,18 @@ +#!/usr/bin/perl -w +use strict; + +die "Usage: $0 HG_DIR\n" unless scalar @ARGV == 1; +my $d = shift @ARGV; +die "Can't find directory $d" unless -d $d; + +opendir(DIR, $d) or die "Can't read $d: $!"; +my @hgs = grep { /\.gz$/ } readdir(DIR); +closedir DIR; + +for my $hg (@hgs) { + my $file = $hg; + my $id = $hg; + $id =~ s/(\.json)?\.gz//; + print "$d/$file $id\n"; +} + diff --git a/pro-train/mr_pro_map.cc b/pro-train/mr_pro_map.cc index b046cdea..128d93ce 100644 --- a/pro-train/mr_pro_map.cc +++ b/pro-train/mr_pro_map.cc @@ -10,6 +10,7 @@ #include "sampler.h" #include "filelib.h" #include "stringlib.h" +#include "weights.h" #include "scorer.h" #include "inside_outside.h" #include "hg_io.h" @@ -27,10 +28,10 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") - ("source,s",po::value(), "Source file (ignored, except for AER)") + ("source,s",po::value()->default_value(""), "Source file (ignored, except for AER)") ("loss_function,l",po::value()->default_value("ibm_bleu"), "Loss function being optimized") ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") - ("weights,w",po::value(), "[REQD] Current weights file") + ("weights,w",po::value >(), "[REQD] Weights files from previous and current iterations") ("kbest_size,k",po::value()->default_value(1500u), "Top k-hypotheses to extract") ("candidate_pairs,G", po::value()->default_value(5000u), "Number of pairs to sample per hypothesis (Gamma)") ("best_pairs,X", po::value()->default_value(50u), "Number of pairs, ranked by magnitude of objective delta, to retain (Xi)") @@ -44,6 +45,10 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { cerr << "Please specify one or more references using -r \n"; flag = true; } + if (!conf->count("weights")) { + cerr << "Please specify one or more weights using -w \n"; + flag = true; + } if (flag || conf->count("help")) { cerr << dcmdline_options << endl; exit(1); @@ -51,18 +56,78 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } struct HypInfo { - HypInfo(const vector& h, const SparseVector& feats) : hyp(h), g_(-1), x(feats) {} - double g() { + HypInfo(const vector& h, const SparseVector& feats) : hyp(h), g_(-100.0), x(feats) {} + + // lazy evaluation + double g(const SentenceScorer& scorer) const { + if (g_ == -100.0) + g_ = scorer.ScoreCandidate(hyp)->ComputeScore(); return g_; } - private: - int sent_id; vector hyp; - double g_; + mutable double g_; public: SparseVector x; }; +struct ThresholdAlpha { + explicit ThresholdAlpha(double t = 0.05) : threshold(t) {} + double operator()(double mag) const { + if (mag < threshold) return 0.0; else return 1.0; + } + const double threshold; +}; + +struct TrainingInstance { + TrainingInstance(const SparseVector& feats, bool positive, double diff) : x(feats), y(positive), gdiff(diff) {} + SparseVector x; +#ifdef DEBUGGING_PRO + vector a; + vector b; +#endif + bool y; + double gdiff; +}; + +struct DiffOrder { + bool operator()(const TrainingInstance& a, const TrainingInstance& b) const { + return a.gdiff > b.gdiff; + } +}; + +template +void Sample(const unsigned gamma, const unsigned xi, const vector& J_i, const SentenceScorer& scorer, const Alpha& alpha_i, bool invert_score, vector* pv) { + vector v; + for (unsigned i = 0; i < gamma; ++i) { + size_t a = rng->inclusive(0, J_i.size() - 1)(); + size_t b = rng->inclusive(0, J_i.size() - 1)(); + if (a == b) continue; + double ga = J_i[a].g(scorer); + double gb = J_i[b].g(scorer); + bool positive = ga < gb; + if (invert_score) positive = !positive; + double gdiff = fabs(ga - gb); + if (!gdiff) continue; + if (rng->next() < alpha_i(gdiff)) { + v.push_back(TrainingInstance((J_i[a].x - J_i[b].x).erase_zeros(), positive, gdiff)); +#ifdef DEBUGGING_PRO + v.back().a = J_i[a].hyp; + v.back().b = J_i[b].hyp; +#endif + } + } + vector::iterator mid = v.begin() + xi; + if (xi > v.size()) mid = v.end(); + partial_sort(v.begin(), mid, v.end(), DiffOrder()); + copy(v.begin(), mid, back_inserter(*pv)); +#ifdef DEBUGGING_PRO + if (v.size() >= 5) + for (int i =0; i < 5; ++i) { + cerr << v[i].gdiff << " y=" << v[i].y << "\tA:" << TD::GetString(v[i].a) << "\n\tB: " << TD::GetString(v[i].b) << endl; + } +#endif +} + int main(int argc, char** argv) { po::variables_map conf; InitCommandLine(argc, argv, &conf); @@ -81,7 +146,15 @@ int main(int argc, char** argv) { const unsigned kbest_size = conf["kbest_size"].as(); const unsigned gamma = conf["candidate_pairs"].as(); const unsigned xi = conf["best_pairs"].as(); + vector weights_files = conf["weights"].as >(); + vector > weights(weights_files.size()); + for (int i = 0; i < weights.size(); ++i) { + Weights w; + w.InitFromFile(weights_files[i]); + w.InitVector(&weights[i]); + } while(in) { + vector v; string line; getline(in, line); if (line.empty()) continue; @@ -92,18 +165,27 @@ int main(int argc, char** argv) { is >> file >> sent_id; ReadFile rf(file); HypergraphIO::ReadFromJSON(rf.stream(), &hg); - KBest::KBestDerivations, ESentenceTraversal> kbest(hg, kbest_size); - vector J_i; - for (int i = 0; i < kbest_size; ++i) { - const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(hg.nodes_.size() - 1, i); - if (!d) break; - float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); - // if (invert_score) sentscore *= -1.0; - // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; - d->feature_values; - sentscore; + int start = weights.size(); + start -= 4; + if (start < 0) start = 0; + for (int i = start; i < weights.size(); ++i) { + hg.Reweight(weights[i]); + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, kbest_size); + + for (int i = 0; i < kbest_size; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + J_i.push_back(HypInfo(d->yield, d->feature_values)); + } + } + + Sample(gamma, xi, J_i, *ds[sent_id], ThresholdAlpha(0.05), (type == TER), &v); + for (unsigned i = 0; i < v.size(); ++i) { + const TrainingInstance& vi = v[i]; + cout << vi.y << "\t" << vi.x << endl; + cout << (!vi.y) << "\t" << (vi.x * -1.0) << endl; } } return 0; diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 3df52020..2b9c5ce7 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -1,3 +1,4 @@ +#include #include #include #include @@ -6,24 +7,29 @@ #include #include +#include "weights.h" #include "sparse_vector.h" -#include "error_surface.h" -#include "line_optimizer.h" -#include "b64tools.h" +#include "optimize.h" using namespace std; namespace po = boost::program_options; +// since this is a ranking model, there should be equal numbers of +// positive and negative examples so the bias should be 0 +static const double MAX_BIAS = 1e-10; + void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() - ("loss_function,l",po::value(), "Loss function being optimized") + ("weights,w", po::value(), "Weights from previous iteration (used as initialization and interpolation") + ("interpolation,p",po::value()->default_value(0.9), "Output weights are p*w + (1-p)*w_prev") + ("memory_buffers,m",po::value()->default_value(200), "Number of memory buffers (LBFGS)") + ("sigma_squared,s",po::value()->default_value(0.5), "Sigma squared for Gaussian prior") ("help,h", "Help"); po::options_description dcmdline_options; dcmdline_options.add(opts); po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - bool flag = conf->count("loss_function") == 0; - if (flag || conf->count("help")) { + if (conf->count("help")) { cerr << dcmdline_options << endl; exit(1); } @@ -32,50 +38,127 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { int main(int argc, char** argv) { po::variables_map conf; InitCommandLine(argc, argv, &conf); - const string loss_function = conf["loss_function"].as(); - ScoreType type = ScoreTypeFromString(loss_function); - LineOptimizer::ScoreType opt_type = LineOptimizer::MAXIMIZE_SCORE; - if (type == TER || type == AER) { - opt_type = LineOptimizer::MINIMIZE_SCORE; + string line; + vector > > training; + int lc = 0; + bool flag = false; + SparseVector old_weights; + const double psi = conf["interpolation"].as(); + if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; } + if (conf.count("weights")) { + Weights w; + w.InitFromFile(conf["weights"].as()); + w.InitSparseVector(&old_weights); } - string last_key; - vector esv; - while(cin) { - string line; - getline(cin, line); + while(getline(cin, line)) { + ++lc; + if (lc % 1000 == 0) { cerr << '.'; flag = true; } + if (lc % 40000 == 0) { cerr << " [" << lc << "]\n"; flag = false; } if (line.empty()) continue; - size_t ks = line.find("\t"); + const size_t ks = line.find("\t"); assert(string::npos != ks); - assert(ks > 2); - string key = line.substr(2, ks - 2); - string val = line.substr(ks + 1); - if (key != last_key) { - if (!last_key.empty()) { - float score; - double x = LineOptimizer::LineOptimize(esv, opt_type, &score); - cout << last_key << "|" << x << "|" << score << endl; + assert(ks == 1); + const bool y = line[0] == '1'; + SparseVector x; + size_t last_start = ks + 1; + size_t last_comma = string::npos; + size_t cur = last_start; + while(cur <= line.size()) { + if (line[cur] == ' ' || cur == line.size()) { + if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { + cerr << "[ERROR] " << line << endl << " position = " << cur << endl; + exit(1); + } + const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); + if (cur < line.size()) line[cur] = 0; + const double val = strtod(&line[last_comma + 1], NULL); + x.set_value(fid, val); + + last_comma = string::npos; + last_start = cur+1; + } else { + if (line[cur] == '=') + last_comma = cur; + } + ++cur; + } + training.push_back(make_pair(y, x)); + } + if (flag) cerr << endl; + + cerr << "Number of features: " << FD::NumFeats() << endl; + vector x(FD::NumFeats(), 0.0); // x[0] is bias + for (SparseVector::const_iterator it = old_weights.begin(); + it != old_weights.end(); ++it) + x[it->first] = it->second; + vector vg(FD::NumFeats(), 0.0); + SparseVector g; + bool converged = false; + LBFGSOptimizer opt(FD::NumFeats(), conf["memory_buffers"].as()); + while(!converged) { + double cll = 0; + double dbias = 0; + g.clear(); + for (int i = 0; i < training.size(); ++i) { + const double dotprod = training[i].second.dot(x) + x[0]; // x[0] is bias + double lp_false = dotprod; + double lp_true = -dotprod; + if (0 < lp_true) { + lp_true += log1p(exp(-lp_true)); + lp_false = log1p(exp(lp_false)); + } else { + lp_true = log1p(exp(lp_true)); + lp_false += log1p(exp(-lp_false)); + } + lp_true*=-1; + lp_false*=-1; + if (training[i].first) { // true label + cll -= lp_true; + g -= training[i].second * exp(lp_false); + dbias -= exp(lp_false); + } else { // false label + cll -= lp_false; + g += training[i].second * exp(lp_true); + dbias += exp(lp_true); } - last_key = key; - esv.clear(); } - if (val.size() % 4 != 0) { - cerr << "B64 encoding error 1! Skipping.\n"; - continue; + vg.clear(); + g.init_vector(&vg); + vg[0] = dbias; +#if 1 + const double sigsq = conf["sigma_squared"].as(); + double norm = 0; + for (int i = 1; i < x.size(); ++i) { + const double mean_i = 0.0; + const double param = (x[i] - mean_i); + norm += param * param; + vg[i] += param / sigsq; + } + const double reg = norm / (2.0 * sigsq); +#else + double reg = 0; +#endif + cll += reg; + cerr << cll << " (REG=" << reg << ")\t"; + bool failed = false; + try { + opt.Optimize(cll, vg, &x); + } catch (...) { + cerr << "Exception caught, assuming convergence is close enough...\n"; + failed = true; } - string encoded(val.size() / 4 * 3, '\0'); - if (!B64::b64decode(reinterpret_cast(&val[0]), val.size(), &encoded[0], encoded.size())) { - cerr << "B64 encoding error 2! Skipping.\n"; - continue; + if (fabs(x[0]) > MAX_BIAS) { + cerr << "Biased model learned. Are your training instances wrong?\n"; + cerr << " BIAS: " << x[0] << endl; } - esv.push_back(ErrorSurface()); - esv.back().Deserialize(type, encoded); + converged = failed || opt.HasConverged(); } - if (!esv.empty()) { - // cerr << "ESV=" << esv.size() << endl; - // for (int i = 0; i < esv.size(); ++i) { cerr << esv[i].size() << endl; } - float score; - double x = LineOptimizer::LineOptimize(esv, opt_type, &score); - cout << last_key << "|" << x << "|" << score << endl; + Weights w; + if (conf.count("weights")) { + for (int i = 1; i < x.size(); ++i) + x[i] = (x[i] * psi) + old_weights.get(i) * (1.0 - psi); } + w.InitFromVector(x); + w.WriteToFile("-"); return 0; } -- cgit v1.2.3 From 9ab32f74dd821f08cb5863faf88d40ca60301688 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 12 Jul 2011 21:39:44 -0400 Subject: nasty bug in operator- in sparse vector --- utils/fast_sparse_vector.h | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h index 4aae2039..9d72cb87 100644 --- a/utils/fast_sparse_vector.h +++ b/utils/fast_sparse_vector.h @@ -235,6 +235,13 @@ class FastSparseVector { } return *this; } + FastSparseVector erase_zeros(const T& EPSILON = 1e-4) const { + FastSparseVector o; + for (const_iterator it = begin(); it != end(); ++it) { + if (fabs(it->second) > EPSILON) o.set_value(it->first, it->second); + } + return o; + } const_iterator begin() const { return const_iterator(*this, false); } @@ -344,15 +351,9 @@ const FastSparseVector operator+(const FastSparseVector& x, const FastSpar template const FastSparseVector operator-(const FastSparseVector& x, const FastSparseVector& y) { - if (x.size() > y.size()) { - FastSparseVector res(x); - res -= y; - return res; - } else { - FastSparseVector res(y); - res -= x; - return res; - } + FastSparseVector res(x); + res -= y; + return res; } template -- cgit v1.2.3 From c87835f5f94b3aa954682133c40117b3f8e26585 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 12 Jul 2011 22:34:34 -0400 Subject: debugged pro trainer --- pro-train/dist-pro.pl | 9 +- pro-train/mr_pro_map.cc | 244 +++++++++++++++++++++++++++++++++++++-------- pro-train/mr_pro_reduce.cc | 57 ++++++----- utils/filelib.cc | 12 +++ utils/filelib.h | 1 + 5 files changed, 253 insertions(+), 70 deletions(-) diff --git a/pro-train/dist-pro.pl b/pro-train/dist-pro.pl index 55d7f1fa..c42e3876 100755 --- a/pro-train/dist-pro.pl +++ b/pro-train/dist-pro.pl @@ -66,6 +66,7 @@ my $bleu_weight=1; my $use_make; # use make to parallelize line search my $dirargs=''; my $usefork; +my $initial_weights; my $pass_suffix = ''; my $cpbin=1; # Process command-line options @@ -79,6 +80,7 @@ if (GetOptions( "dry-run" => \$dryrun, "epsilon=s" => \$epsilon, "help" => \$help, + "weights=s" => \$initial_weights, "interval" => \$interval, "iteration=i" => \$iteration, "local" => \$run_local, @@ -212,7 +214,7 @@ if ($dryrun){ close CMD; print STDERR $cline; chmod(0755,$cmdfile); - check_call("touch $dir/weights.0"); + check_call("cp $initial_weights $dir/weights.0"); die "Can't find weights.0" unless (-e "$dir/weights.0"); } write_config(*STDERR); @@ -239,7 +241,6 @@ my $random_seed = int(time / 1000); my $lastWeightsFile; my $lastPScore = 0; # main optimization loop -my @mapoutputs = (); # aggregate map outputs over all iters while (1){ print STDERR "\n\nITERATION $iteration\n==========\n"; @@ -262,6 +263,7 @@ while (1){ my $im1 = $iteration - 1; my $weightsFile="$dir/weights.$im1"; push @allweights, "-w $dir/weights.$im1"; + `rm -f $dir/hgs/*.gz`; my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; my $pcmd; if ($run_local) { @@ -333,6 +335,7 @@ while (1){ print $mkfile "all: $dir/splag.$im1/map.done\n\n"; } my @mkouts = (); # only used with makefiles + my @mapoutputs = (); for my $shard (@shards) { my $mapoutput = $shard; my $client_name = $shard; @@ -341,7 +344,7 @@ while (1){ $mapoutput =~ s/mapinput/mapoutput/; push @mapoutputs, "$dir/splag.$im1/$mapoutput"; $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard"; - my $script = "$MAPPER -s $srcFile -l $metric $refs_comma_sep @allweights < $dir/splag.$im1/$shard > $dir/splag.$im1/$mapoutput"; + my $script = "$MAPPER -s $srcFile -l $metric $refs_comma_sep -w $inweights -K $dir/kbest < $dir/splag.$im1/$shard > $dir/splag.$im1/$mapoutput"; if ($run_local) { print STDERR "COMMAND:\n$script\n"; check_bash_call($script); diff --git a/pro-train/mr_pro_map.cc b/pro-train/mr_pro_map.cc index 128d93ce..4324e8de 100644 --- a/pro-train/mr_pro_map.cc +++ b/pro-train/mr_pro_map.cc @@ -2,7 +2,9 @@ #include #include #include +#include +#include #include #include #include @@ -22,16 +24,63 @@ using namespace std; namespace po = boost::program_options; +struct ApproxVectorHasher { + static const size_t MASK = 0xFFFFFFFFull; + union UType { + double f; + size_t i; + }; + static inline double round(const double x) { + UType t; + t.f = x; + size_t r = t.i & MASK; + if ((r << 1) > MASK) + t.i += MASK - r + 1; + else + t.i &= (1ull - MASK); + return t.f; + } + size_t operator()(const SparseVector& x) const { + size_t h = 0x573915839; + for (SparseVector::const_iterator it = x.begin(); it != x.end(); ++it) { + UType t; + t.f = it->second; + if (t.f) { + size_t z = (t.i >> 32); + boost::hash_combine(h, it->first); + boost::hash_combine(h, z); + } + } + return h; + } +}; + +struct ApproxVectorEquals { + bool operator()(const SparseVector& a, const SparseVector& b) const { + SparseVector::const_iterator bit = b.begin(); + for (SparseVector::const_iterator ait = a.begin(); ait != a.end(); ++ait) { + if (bit == b.end() || + ait->first != bit->first || + ApproxVectorHasher::round(ait->second) != ApproxVectorHasher::round(bit->second)) + return false; + ++bit; + } + if (bit != b.end()) return false; + return true; + } +}; + boost::shared_ptr rng; void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") + ("weights,w",po::value(), "[REQD] Weights files from current iterations") + ("kbest_repository,K",po::value()->default_value("./kbest"),"K-best list repository (directory)") + ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") ("source,s",po::value()->default_value(""), "Source file (ignored, except for AER)") ("loss_function,l",po::value()->default_value("ibm_bleu"), "Loss function being optimized") - ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") - ("weights,w",po::value >(), "[REQD] Weights files from previous and current iterations") ("kbest_size,k",po::value()->default_value(1500u), "Top k-hypotheses to extract") ("candidate_pairs,G", po::value()->default_value(5000u), "Number of pairs to sample per hypothesis (Gamma)") ("best_pairs,X", po::value()->default_value(50u), "Number of pairs, ranked by magnitude of objective delta, to retain (Xi)") @@ -46,7 +95,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { flag = true; } if (!conf->count("weights")) { - cerr << "Please specify one or more weights using -w \n"; + cerr << "Please specify weights using -w \n"; flag = true; } if (flag || conf->count("help")) { @@ -56,6 +105,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } struct HypInfo { + HypInfo() : g_(-100.0) {} HypInfo(const vector& h, const SparseVector& feats) : hyp(h), g_(-100.0), x(feats) {} // lazy evaluation @@ -66,10 +116,92 @@ struct HypInfo { } vector hyp; mutable double g_; - public: SparseVector x; }; +struct HypInfoCompare { + bool operator()(const HypInfo& a, const HypInfo& b) const { + ApproxVectorEquals comp; + return (a.hyp == b.hyp && comp(a.x,b.x)); + } +}; + +struct HypInfoHasher { + size_t operator()(const HypInfo& x) const { + boost::hash > hhasher; + ApproxVectorHasher vhasher; + size_t ha = hhasher(x.hyp); + boost::hash_combine(ha, vhasher(x.x)); + return ha; + } +}; + +void WriteKBest(const string& file, const vector& kbest) { + WriteFile wf(file); + ostream& out = *wf.stream(); + out.precision(10); + for (int i = 0; i < kbest.size(); ++i) { + out << TD::GetString(kbest[i].hyp) << endl; + out << kbest[i].x << endl; + } +} + +void ParseSparseVector(string& line, size_t cur, SparseVector* out) { + SparseVector& x = *out; + size_t last_start = cur; + size_t last_comma = string::npos; + while(cur <= line.size()) { + if (line[cur] == ' ' || cur == line.size()) { + if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { + cerr << "[ERROR] " << line << endl << " position = " << cur << endl; + exit(1); + } + const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); + if (cur < line.size()) line[cur] = 0; + const double val = strtod(&line[last_comma + 1], NULL); + x.set_value(fid, val); + + last_comma = string::npos; + last_start = cur+1; + } else { + if (line[cur] == '=') + last_comma = cur; + } + ++cur; + } +} + +void ReadKBest(const string& file, vector* kbest) { + cerr << "Reading from " << file << endl; + ReadFile rf(file); + istream& in = *rf.stream(); + string cand; + string feats; + while(getline(in, cand)) { + getline(in, feats); + assert(in); + kbest->push_back(HypInfo()); + TD::ConvertSentence(cand, &kbest->back().hyp); + ParseSparseVector(feats, 0, &kbest->back().x); + } + cerr << " read " << kbest->size() << " hypotheses\n"; +} + +void Dedup(vector* h) { + cerr << "Dedup in=" << h->size(); + tr1::unordered_set u; + while(h->size() > 0) { + u.insert(h->back()); + h->pop_back(); + } + tr1::unordered_set::iterator it = u.begin(); + while (it != u.end()) { + h->push_back(*it); + it = u.erase(it); + } + cerr << " out=" << h->size() << endl; +} + struct ThresholdAlpha { explicit ThresholdAlpha(double t = 0.05) : threshold(t) {} double operator()(double mag) const { @@ -81,6 +213,7 @@ struct ThresholdAlpha { struct TrainingInstance { TrainingInstance(const SparseVector& feats, bool positive, double diff) : x(feats), y(positive), gdiff(diff) {} SparseVector x; +#undef DEBUGGING_PRO #ifdef DEBUGGING_PRO vector a; vector b; @@ -88,6 +221,11 @@ struct TrainingInstance { bool y; double gdiff; }; +#ifdef DEBUGGING_PRO +ostream& operator<<(ostream& os, const TrainingInstance& d) { + return os << d.gdiff << " y=" << d.y << "\tA:" << TD::GetString(d.a) << "\n\tB: " << TD::GetString(d.b) << "\n\tX: " << d.x; +} +#endif struct DiffOrder { bool operator()(const TrainingInstance& a, const TrainingInstance& b) const { @@ -95,36 +233,51 @@ struct DiffOrder { } }; -template -void Sample(const unsigned gamma, const unsigned xi, const vector& J_i, const SentenceScorer& scorer, const Alpha& alpha_i, bool invert_score, vector* pv) { - vector v; +void Sample(const unsigned gamma, const unsigned xi, const vector& J_i, const SentenceScorer& scorer, const bool invert_score, vector* pv) { + vector v1, v2; + double avg_diff = 0; for (unsigned i = 0; i < gamma; ++i) { - size_t a = rng->inclusive(0, J_i.size() - 1)(); - size_t b = rng->inclusive(0, J_i.size() - 1)(); + const size_t a = rng->inclusive(0, J_i.size() - 1)(); + const size_t b = rng->inclusive(0, J_i.size() - 1)(); if (a == b) continue; double ga = J_i[a].g(scorer); double gb = J_i[b].g(scorer); - bool positive = ga < gb; + bool positive = gb < ga; if (invert_score) positive = !positive; - double gdiff = fabs(ga - gb); + const double gdiff = fabs(ga - gb); if (!gdiff) continue; - if (rng->next() < alpha_i(gdiff)) { - v.push_back(TrainingInstance((J_i[a].x - J_i[b].x).erase_zeros(), positive, gdiff)); + avg_diff += gdiff; + SparseVector xdiff = (J_i[a].x - J_i[b].x).erase_zeros(); + if (xdiff.empty()) { + cerr << "Empty diff:\n " << TD::GetString(J_i[a].hyp) << endl << "x=" << J_i[a].x << endl; + cerr << " " << TD::GetString(J_i[b].hyp) << endl << "x=" << J_i[b].x << endl; + continue; + } + v1.push_back(TrainingInstance(xdiff, positive, gdiff)); #ifdef DEBUGGING_PRO - v.back().a = J_i[a].hyp; - v.back().b = J_i[b].hyp; + v1.back().a = J_i[a].hyp; + v1.back().b = J_i[b].hyp; + cerr << "N: " << v1.back() << endl; #endif - } } - vector::iterator mid = v.begin() + xi; - if (xi > v.size()) mid = v.end(); - partial_sort(v.begin(), mid, v.end(), DiffOrder()); - copy(v.begin(), mid, back_inserter(*pv)); + avg_diff /= v1.size(); + + for (unsigned i = 0; i < v1.size(); ++i) { + double p = 1.0 / (1.0 + exp(-avg_diff - v1[i].gdiff)); + // cerr << "avg_diff=" << avg_diff << " gdiff=" << v1[i].gdiff << " p=" << p << endl; + if (rng->next() < p) v2.push_back(v1[i]); + } + vector::iterator mid = v2.begin() + xi; + if (xi > v2.size()) mid = v2.end(); + partial_sort(v2.begin(), mid, v2.end(), DiffOrder()); + copy(v2.begin(), mid, back_inserter(*pv)); #ifdef DEBUGGING_PRO - if (v.size() >= 5) - for (int i =0; i < 5; ++i) { - cerr << v[i].gdiff << " y=" << v[i].y << "\tA:" << TD::GetString(v[i].a) << "\n\tB: " << TD::GetString(v[i].b) << endl; + if (v2.size() >= 5) { + for (int i =0; i < (mid - v2.begin()); ++i) { + cerr << v2[i] << endl; } + cerr << pv->back() << endl; + } #endif } @@ -136,6 +289,7 @@ int main(int argc, char** argv) { else rng.reset(new MT19937); const string loss_function = conf["loss_function"].as(); + ScoreType type = ScoreTypeFromString(loss_function); DocScorer ds(type, conf["reference"].as >(), conf["source"].as()); cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function << endl; @@ -146,13 +300,15 @@ int main(int argc, char** argv) { const unsigned kbest_size = conf["kbest_size"].as(); const unsigned gamma = conf["candidate_pairs"].as(); const unsigned xi = conf["best_pairs"].as(); - vector weights_files = conf["weights"].as >(); - vector > weights(weights_files.size()); - for (int i = 0; i < weights.size(); ++i) { + string weightsf = conf["weights"].as(); + vector weights; + { Weights w; - w.InitFromFile(weights_files[i]); - w.InitVector(&weights[i]); + w.InitFromFile(weightsf); + w.InitVector(&weights); } + string kbest_repo = conf["kbest_repository"].as(); + MkDirP(kbest_repo); while(in) { vector v; string line; @@ -164,24 +320,26 @@ int main(int argc, char** argv) { // path-to-file (JSON) sent_id is >> file >> sent_id; ReadFile rf(file); - HypergraphIO::ReadFromJSON(rf.stream(), &hg); + ostringstream os; vector J_i; - int start = weights.size(); - start -= 4; - if (start < 0) start = 0; - for (int i = start; i < weights.size(); ++i) { - hg.Reweight(weights[i]); - KBest::KBestDerivations, ESentenceTraversal> kbest(hg, kbest_size); - - for (int i = 0; i < kbest_size; ++i) { - const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(hg.nodes_.size() - 1, i); - if (!d) break; - J_i.push_back(HypInfo(d->yield, d->feature_values)); - } + os << kbest_repo << "/kbest." << sent_id << ".txt.gz"; + const string kbest_file = os.str(); + if (FileExists(kbest_file)) + ReadKBest(kbest_file, &J_i); + HypergraphIO::ReadFromJSON(rf.stream(), &hg); + hg.Reweight(weights); + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, kbest_size); + + for (int i = 0; i < kbest_size; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + J_i.push_back(HypInfo(d->yield, d->feature_values)); } + Dedup(&J_i); + WriteKBest(kbest_file, J_i); - Sample(gamma, xi, J_i, *ds[sent_id], ThresholdAlpha(0.05), (type == TER), &v); + Sample(gamma, xi, J_i, *ds[sent_id], (type == TER), &v); for (unsigned i = 0; i < v.size(); ++i) { const TrainingInstance& vi = v[i]; cout << vi.y << "\t" << vi.x << endl; diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 2b9c5ce7..e1a7db8a 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -24,7 +24,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("weights,w", po::value(), "Weights from previous iteration (used as initialization and interpolation") ("interpolation,p",po::value()->default_value(0.9), "Output weights are p*w + (1-p)*w_prev") ("memory_buffers,m",po::value()->default_value(200), "Number of memory buffers (LBFGS)") - ("sigma_squared,s",po::value()->default_value(0.5), "Sigma squared for Gaussian prior") + ("sigma_squared,s",po::value()->default_value(1.0), "Sigma squared for Gaussian prior") ("help,h", "Help"); po::options_description dcmdline_options; dcmdline_options.add(opts); @@ -35,6 +35,31 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } +void ParseSparseVector(string& line, size_t cur, SparseVector* out) { + SparseVector& x = *out; + size_t last_start = cur; + size_t last_comma = string::npos; + while(cur <= line.size()) { + if (line[cur] == ' ' || cur == line.size()) { + if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { + cerr << "[ERROR] " << line << endl << " position = " << cur << endl; + exit(1); + } + const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); + if (cur < line.size()) line[cur] = 0; + const double val = strtod(&line[last_comma + 1], NULL); + x.set_value(fid, val); + + last_comma = string::npos; + last_start = cur+1; + } else { + if (line[cur] == '=') + last_comma = cur; + } + ++cur; + } +} + int main(int argc, char** argv) { po::variables_map conf; InitCommandLine(argc, argv, &conf); @@ -60,28 +85,7 @@ int main(int argc, char** argv) { assert(ks == 1); const bool y = line[0] == '1'; SparseVector x; - size_t last_start = ks + 1; - size_t last_comma = string::npos; - size_t cur = last_start; - while(cur <= line.size()) { - if (line[cur] == ' ' || cur == line.size()) { - if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { - cerr << "[ERROR] " << line << endl << " position = " << cur << endl; - exit(1); - } - const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); - if (cur < line.size()) line[cur] = 0; - const double val = strtod(&line[last_comma + 1], NULL); - x.set_value(fid, val); - - last_comma = string::npos; - last_start = cur+1; - } else { - if (line[cur] == '=') - last_comma = cur; - } - ++cur; - } + ParseSparseVector(line, ks + 1, &x); training.push_back(make_pair(y, x)); } if (flag) cerr << endl; @@ -95,6 +99,7 @@ int main(int argc, char** argv) { SparseVector g; bool converged = false; LBFGSOptimizer opt(FD::NumFeats(), conf["memory_buffers"].as()); + double ppl = 0; while(!converged) { double cll = 0; double dbias = 0; @@ -114,14 +119,18 @@ int main(int argc, char** argv) { lp_false*=-1; if (training[i].first) { // true label cll -= lp_true; + ppl += lp_true / log(2); g -= training[i].second * exp(lp_false); dbias -= exp(lp_false); } else { // false label cll -= lp_false; + ppl += lp_false / log(2); g += training[i].second * exp(lp_true); dbias += exp(lp_true); } } + ppl /= training.size(); + ppl = pow(2.0, - ppl); vg.clear(); g.init_vector(&vg); vg[0] = dbias; @@ -139,7 +148,7 @@ int main(int argc, char** argv) { double reg = 0; #endif cll += reg; - cerr << cll << " (REG=" << reg << ")\t"; + cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t"; bool failed = false; try { opt.Optimize(cll, vg, &x); diff --git a/utils/filelib.cc b/utils/filelib.cc index 79ad2847..a0969b1a 100644 --- a/utils/filelib.cc +++ b/utils/filelib.cc @@ -20,3 +20,15 @@ bool DirectoryExists(const string& dir) { return false; } +void MkDirP(const string& dir) { + if (DirectoryExists(dir)) return; + if (mkdir(dir.c_str(), 0777)) { + perror(dir.c_str()); + abort(); + } + if (chmod(dir.c_str(), 07777)) { + perror(dir.c_str()); + abort(); + } +} + diff --git a/utils/filelib.h b/utils/filelib.h index dda98671..a8622246 100644 --- a/utils/filelib.h +++ b/utils/filelib.h @@ -12,6 +12,7 @@ bool FileExists(const std::string& file_name); bool DirectoryExists(const std::string& dir_name); +void MkDirP(const std::string& dir_name); // reads from standard in if filename is - // uncompresses if file ends with .gz -- cgit v1.2.3 From 540ed79cc13b694264471d8cd8f4735e841707ae Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 12 Jul 2011 23:32:11 -0400 Subject: minor optimization --- pro-train/mr_pro_reduce.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index e1a7db8a..5382e1a5 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -149,18 +149,20 @@ int main(int argc, char** argv) { #endif cll += reg; cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t"; - bool failed = false; try { - opt.Optimize(cll, vg, &x); + vector old_x = x; + do { + opt.Optimize(cll, vg, &x); + converged = opt.HasConverged(); + } while (!converged && x == old_x); } catch (...) { cerr << "Exception caught, assuming convergence is close enough...\n"; - failed = true; + converged = true; } if (fabs(x[0]) > MAX_BIAS) { cerr << "Biased model learned. Are your training instances wrong?\n"; cerr << " BIAS: " << x[0] << endl; } - converged = failed || opt.HasConverged(); } Weights w; if (conf.count("weights")) { -- cgit v1.2.3 From a037f52a87f7d5711b5521047e7fb3fcd756c647 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 13 Jul 2011 00:14:34 -0400 Subject: escape feature names --- decoder/ff_spans.cc | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc index e1da088d..bc23974d 100644 --- a/decoder/ff_spans.cc +++ b/decoder/ff_spans.cc @@ -13,6 +13,17 @@ using namespace std; +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + // log transform to make long spans cluster together // but preserve differences int SpanSizeTransform(unsigned span_size) { @@ -140,19 +151,19 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { word = MapIfNecessary(word); ostringstream sfid; sfid << "ES:" << TD::Convert(word); - end_span_ids_[i] = FD::Convert(sfid.str()); + end_span_ids_[i] = FD::Convert(Escape(sfid.str())); ostringstream esbiid; esbiid << "EBI:" << TD::Convert(bword) << "_" << TD::Convert(word); - end_bigram_ids_[i] = FD::Convert(esbiid.str()); + end_bigram_ids_[i] = FD::Convert(Escape(esbiid.str())); ostringstream bsbiid; bsbiid << "BBI:" << TD::Convert(bword) << "_" << TD::Convert(word); - beg_bigram_ids_[i] = FD::Convert(bsbiid.str()); + beg_bigram_ids_[i] = FD::Convert(Escape(bsbiid.str())); ostringstream bfid; bfid << "BS:" << TD::Convert(bword); - beg_span_ids_[i] = FD::Convert(bfid.str()); + beg_span_ids_[i] = FD::Convert(Escape(bfid.str())); if (use_collapsed_features_) { - end_span_vals_[i] = feat2val_[sfid.str()] + feat2val_[esbiid.str()]; - beg_span_vals_[i] = feat2val_[bfid.str()] + feat2val_[bsbiid.str()]; + end_span_vals_[i] = feat2val_[Escape(sfid.str())] + feat2val_[Escape(esbiid.str())]; + beg_span_vals_[i] = feat2val_[Escape(bfid.str())] + feat2val_[Escape(bsbiid.str())]; } } for (int i = 0; i <= lattice.size(); ++i) { @@ -167,16 +178,16 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { word = MapIfNecessary(word); ostringstream pf; pf << "S:" << TD::Convert(bword) << "_" << TD::Convert(word); - span_feats_(i,j).first = FD::Convert(pf.str()); - span_feats_(i,j).second = FD::Convert("S_" + pf.str()); + span_feats_(i,j).first = FD::Convert(Escape(pf.str())); + span_feats_(i,j).second = FD::Convert(Escape("S_" + pf.str())); ostringstream lf; const unsigned span_size = (i < j ? j - i : i - j); lf << "LS:" << SpanSizeTransform(span_size) << "_" << TD::Convert(bword) << "_" << TD::Convert(word); - len_span_feats_(i,j).first = FD::Convert(lf.str()); - len_span_feats_(i,j).second = FD::Convert("S_" + lf.str()); + len_span_feats_(i,j).first = FD::Convert(Escape(lf.str())); + len_span_feats_(i,j).second = FD::Convert(Escape("S_" + lf.str())); if (use_collapsed_features_) { - span_vals_(i,j).first = feat2val_[pf.str()] + feat2val_[lf.str()]; - span_vals_(i,j).second = feat2val_["S_" + pf.str()] + feat2val_["S_" + lf.str()]; + span_vals_(i,j).first = feat2val_[Escape(pf.str())] + feat2val_[Escape(lf.str())]; + span_vals_(i,j).second = feat2val_[Escape("S_" + pf.str())] + feat2val_[Escape("S_" + lf.str())]; } } } @@ -209,14 +220,14 @@ void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, const string& cur = TD::Convert(w); ostringstream os; os << "RB:" << prev << '_' << cur; - const int fid = FD::Convert(os.str()); + const int fid = FD::Convert(Escape(os.str())); if (fid <= 0) return; f.add_value(fid, 1.0); prev = cur; } ostringstream os; os << "RB:" << prev << '_' << ""; - f.set_value(FD::Convert(os.str()), 1.0); + f.set_value(FD::Convert(Escape(os.str())), 1.0); } (*features) += it->second; } -- cgit v1.2.3 From 34fdc73e613bbc30d59d7bd36c5db31a94a7ac68 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 13 Jul 2011 16:25:05 -0400 Subject: faster code, optional held-out test set --- pro-train/mr_pro_reduce.cc | 140 ++++++++++++++++++++++++++++----------------- 1 file changed, 89 insertions(+), 51 deletions(-) diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 5382e1a5..491ceb3a 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -7,6 +7,7 @@ #include #include +#include "filelib.h" #include "weights.h" #include "sparse_vector.h" #include "optimize.h" @@ -25,6 +26,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("interpolation,p",po::value()->default_value(0.9), "Output weights are p*w + (1-p)*w_prev") ("memory_buffers,m",po::value()->default_value(200), "Number of memory buffers (LBFGS)") ("sigma_squared,s",po::value()->default_value(1.0), "Sigma squared for Gaussian prior") + ("testset,t",po::value(), "Optional held-out test set to tune regularizer") ("help,h", "Help"); po::options_description dcmdline_options; dcmdline_options.add(opts); @@ -60,13 +62,79 @@ void ParseSparseVector(string& line, size_t cur, SparseVector* out) { } } +void ReadCorpus(istream* pin, vector > >* corpus) { + istream& in = *pin; + corpus->clear(); + bool flag = false; + int lc = 0; + string line; + SparseVector x; + while(getline(in, line)) { + ++lc; + if (lc % 1000 == 0) { cerr << '.'; flag = true; } + if (lc % 40000 == 0) { cerr << " [" << lc << "]\n"; flag = false; } + if (line.empty()) continue; + const size_t ks = line.find("\t"); + assert(string::npos != ks); + assert(ks == 1); + const bool y = line[0] == '1'; + x.clear(); + ParseSparseVector(line, ks + 1, &x); + corpus->push_back(make_pair(y, x)); + } + if (flag) cerr << endl; +} + +void GradAdd(const SparseVector& v, const double scale, vector* acc) { + for (SparseVector::const_iterator it = v.begin(); + it != v.end(); ++it) { + (*acc)[it->first] += it->second * scale; + } +} + +double TrainingInference(const vector& x, + const vector > >& corpus, + vector* g = NULL) { + if (g) fill(g->begin(), g->end(), 0.0); + + double cll = 0; + for (int i = 0; i < corpus.size(); ++i) { + const double dotprod = corpus[i].second.dot(x) + x[0]; // x[0] is bias + double lp_false = dotprod; + double lp_true = -dotprod; + if (0 < lp_true) { + lp_true += log1p(exp(-lp_true)); + lp_false = log1p(exp(lp_false)); + } else { + lp_true = log1p(exp(lp_true)); + lp_false += log1p(exp(-lp_false)); + } + lp_true*=-1; + lp_false*=-1; + if (corpus[i].first) { // true label + cll -= lp_true; + if (g) { + // g -= corpus[i].second * exp(lp_false); + GradAdd(corpus[i].second, -exp(lp_false), g); + (*g)[0] -= exp(lp_false); // bias + } + } else { // false label + cll -= lp_false; + if (g) { + // g += corpus[i].second * exp(lp_true); + GradAdd(corpus[i].second, exp(lp_true), g); + (*g)[0] += exp(lp_true); // bias + } + } + } + return cll; +} + int main(int argc, char** argv) { po::variables_map conf; InitCommandLine(argc, argv, &conf); string line; - vector > > training; - int lc = 0; - bool flag = false; + vector > > training, testing; SparseVector old_weights; const double psi = conf["interpolation"].as(); if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; } @@ -75,20 +143,11 @@ int main(int argc, char** argv) { w.InitFromFile(conf["weights"].as()); w.InitSparseVector(&old_weights); } - while(getline(cin, line)) { - ++lc; - if (lc % 1000 == 0) { cerr << '.'; flag = true; } - if (lc % 40000 == 0) { cerr << " [" << lc << "]\n"; flag = false; } - if (line.empty()) continue; - const size_t ks = line.find("\t"); - assert(string::npos != ks); - assert(ks == 1); - const bool y = line[0] == '1'; - SparseVector x; - ParseSparseVector(line, ks + 1, &x); - training.push_back(make_pair(y, x)); + ReadCorpus(&cin, &training); + if (conf.count("testset")) { + ReadFile rf(conf["testset"].as()); + ReadCorpus(rf.stream(), &testing); } - if (flag) cerr << endl; cerr << "Number of features: " << FD::NumFeats() << endl; vector x(FD::NumFeats(), 0.0); // x[0] is bias @@ -96,44 +155,23 @@ int main(int argc, char** argv) { it != old_weights.end(); ++it) x[it->first] = it->second; vector vg(FD::NumFeats(), 0.0); - SparseVector g; bool converged = false; LBFGSOptimizer opt(FD::NumFeats(), conf["memory_buffers"].as()); - double ppl = 0; while(!converged) { - double cll = 0; - double dbias = 0; - g.clear(); - for (int i = 0; i < training.size(); ++i) { - const double dotprod = training[i].second.dot(x) + x[0]; // x[0] is bias - double lp_false = dotprod; - double lp_true = -dotprod; - if (0 < lp_true) { - lp_true += log1p(exp(-lp_true)); - lp_false = log1p(exp(lp_false)); - } else { - lp_true = log1p(exp(lp_true)); - lp_false += log1p(exp(-lp_false)); - } - lp_true*=-1; - lp_false*=-1; - if (training[i].first) { // true label - cll -= lp_true; - ppl += lp_true / log(2); - g -= training[i].second * exp(lp_false); - dbias -= exp(lp_false); - } else { // false label - cll -= lp_false; - ppl += lp_false / log(2); - g += training[i].second * exp(lp_true); - dbias += exp(lp_true); - } - } + double cll = TrainingInference(x, training, &vg); + double ppl = cll / log(2); ppl /= training.size(); - ppl = pow(2.0, - ppl); - vg.clear(); - g.init_vector(&vg); - vg[0] = dbias; + ppl = pow(2.0, ppl); + double tppl = 0.0; + + // evaluate optional held-out test set + if (testing.size()) { + tppl = TrainingInference(x, testing) / log(2); + tppl /= testing.size(); + tppl = pow(2.0, tppl); + } + + // handle regularizer #if 1 const double sigsq = conf["sigma_squared"].as(); double norm = 0; @@ -148,7 +186,7 @@ int main(int argc, char** argv) { double reg = 0; #endif cll += reg; - cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t"; + cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t TEST_PPL=" << tppl << "\t"; try { vector old_x = x; do { -- cgit v1.2.3 From 816bee82abc909335d4f3a300cff99afa4dd1da5 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 13 Jul 2011 18:00:22 -0400 Subject: escape bad feature names --- decoder/ff_ngrams.cc | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc index d52667cd..04dd1906 100644 --- a/decoder/ff_ngrams.cc +++ b/decoder/ff_ngrams.cc @@ -46,6 +46,17 @@ struct State { }; } +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + class NgramDetectorImpl { // returns the number of unscored words at the left edge of a span @@ -114,11 +125,17 @@ class NgramDetectorImpl { int& fid = ft->fids[curword]; ++n; if (!fid) { - const char* code="_UBT456789"; + const char* code="_UBT456789"; // prefix code (unigram, bigram, etc.) ostringstream os; os << code[n] << ':'; - for (int i = n-1; i >= 0; --i) - os << (i != n-1 ? "_" : "") << TD::Convert(buf[i]); + for (int i = n-1; i >= 0; --i) { + os << (i != n-1 ? "_" : ""); + const string& tok = TD::Convert(buf[i]); + if (tok.find('=') == string::npos) + os << tok; + else + os << Escape(tok); + } fid = FD::Convert(os.str()); } feats->set_value(fid, 1); -- cgit v1.2.3 From c3828b0a2deb42de5c7378e93f93f5e69efb304c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 16 Jul 2011 19:13:21 -0400 Subject: tune regularizer --- mteval/scorer.cc | 12 +++- pro-train/dist-pro.pl | 139 ++++++++++++++++++++++++++------------------- pro-train/mr_pro_reduce.cc | 128 ++++++++++++++++++++++++++++++----------- 3 files changed, 185 insertions(+), 94 deletions(-) diff --git a/mteval/scorer.cc b/mteval/scorer.cc index 2daa0daa..a83b9e2f 100644 --- a/mteval/scorer.cc +++ b/mteval/scorer.cc @@ -430,6 +430,7 @@ float BLEUScore::ComputeScore(vector* precs, float* bp) const { float log_bleu = 0; if (precs) precs->clear(); int count = 0; + vector total_precs(N()); for (int i = 0; i < N(); ++i) { if (hyp_ngram_counts[i] > 0) { float cor_count = correct_ngram_hit_counts[i]; @@ -440,14 +441,21 @@ float BLEUScore::ComputeScore(vector* precs, float* bp) const { log_bleu += lprec; ++count; } + total_precs[i] = log_bleu; } - log_bleu /= static_cast(count); + vector bleus(N()); float lbp = 0.0; if (hyp_len < ref_len) lbp = (hyp_len - ref_len) / hyp_len; log_bleu += lbp; if (bp) *bp = exp(lbp); - return exp(log_bleu); + float wb = 0; + for (int i = 0; i < N(); ++i) { + bleus[i] = exp(total_precs[i] / (i+1) + lbp); + wb += bleus[i] / pow(2.0, 4.0 - i); + } + //return wb; + return bleus.back(); } diff --git a/pro-train/dist-pro.pl b/pro-train/dist-pro.pl index c42e3876..dbfa329a 100755 --- a/pro-train/dist-pro.pl +++ b/pro-train/dist-pro.pl @@ -37,42 +37,36 @@ die "Can't find decoder in $cdec" unless -x $cdec; die "Can't find $parallelize" unless -x $parallelize; die "Can't find $libcall" unless -e $libcall; my $decoder = $cdec; -my $lines_per_mapper = 100; +my $lines_per_mapper = 30; my $iteration = 1; my $run_local = 0; my $best_weights; -my $max_iterations = 15; -my $optimization_iters = 6; +my $max_iterations = 30; my $decode_nodes = 15; # number of decode nodes -my $pmem = "9g"; +my $pmem = "4g"; my $disable_clean = 0; my %seen_weights; -my $normalize; my $help = 0; my $epsilon = 0.0001; -my $interval = 5; my $dryrun = 0; my $last_score = -10000000; my $metric = "ibm_bleu"; my $dir; my $iniFile; my $weights; -my $decoderOpt; -my $noprimary; -my $maxsim=0; -my $oraclen=0; -my $oracleb=20; -my $bleu_weight=1; -my $use_make; # use make to parallelize line search -my $dirargs=''; +my $use_make; # use make to parallelize my $usefork; my $initial_weights; my $pass_suffix = ''; my $cpbin=1; + +# regularization strength +my $tune_regularizer = 0; +my $reg = 1e-2; + # Process command-line options Getopt::Long::Configure("no_auto_abbrev"); if (GetOptions( - "decoder=s" => \$decoderOpt, "decode-nodes=i" => \$decode_nodes, "dont-clean" => \$disable_clean, "pass-suffix=s" => \$pass_suffix, @@ -81,21 +75,13 @@ if (GetOptions( "epsilon=s" => \$epsilon, "help" => \$help, "weights=s" => \$initial_weights, - "interval" => \$interval, - "iteration=i" => \$iteration, + "tune-regularizer" => \$tune_regularizer, + "reg=f" => \$reg, "local" => \$run_local, "use-make=i" => \$use_make, "max-iterations=i" => \$max_iterations, - "normalize=s" => \$normalize, "pmem=s" => \$pmem, "cpbin!" => \$cpbin, - "bleu_weight=s" => \$bleu_weight, - "no-primary!" => \$noprimary, - "max-similarity=s" => \$maxsim, - "oracle-directions=i" => \$oraclen, - "n-oracle=i" => \$oraclen, - "oracle-batch=i" => \$oracleb, - "directions-args=s" => \$dirargs, "ref-files=s" => \$refFiles, "metric=s" => \$metric, "source-file=s" => \$srcFile, @@ -108,9 +94,7 @@ if (GetOptions( if ($usefork) { $usefork = "--use-fork"; } else { $usefork = ''; } if ($metric =~ /^(combi|ter)$/i) { - $lines_per_mapper = 40; -} elsif ($metric =~ /^meteor$/i) { - $lines_per_mapper = 2000; # start up time is really high + $lines_per_mapper = 5; } ($iniFile) = @ARGV; @@ -144,8 +128,6 @@ unless ($dir =~ /^\//){ # convert relative path to absolute path $dir = "$basedir/$dir"; } -if ($decoderOpt){ $decoder = $decoderOpt; } - # Initializations and helper functions srand; @@ -378,6 +360,22 @@ while (1){ else {$joblist = $joblist . "\|" . $jobid; } } } + my @dev_outs = (); + my @devtest_outs = (); + if ($tune_regularizer) { + for (my $i = 0; $i < scalar @mapoutputs; $i++) { + if ($i % 3 == 1) { + push @devtest_outs, $mapoutputs[$i]; + } else { + push @dev_outs, $mapoutputs[$i]; + } + } + if (scalar @devtest_outs == 0) { + die "Not enough training instances for regularization tuning! Rerun without --tune-regularizer\n"; + } + } else { + @dev_outs = @mapoutputs; + } if ($run_local) { print STDERR "\nCompleted extraction of training exemplars.\n"; } elsif ($use_make) { @@ -399,7 +397,13 @@ while (1){ } my $tol = 0; my $til = 0; - print STDERR "MO: @mapoutputs\n"; + my $dev_test_file = "$dir/splag.$im1/devtest.gz"; + if ($tune_regularizer) { + my $cmd = "cat @devtest_outs | gzip > $dev_test_file"; + check_bash_call($cmd); + die "Can't find file $dev_test_file" unless -f $dev_test_file; + } + #print STDERR "MO: @mapoutputs\n"; for my $mo (@mapoutputs) { #my $olines = get_lines($mo); #my $ilines = get_lines($o2i{$mo}); @@ -407,10 +411,24 @@ while (1){ } print STDERR "\nRUNNING CLASSIFIER (REDUCER)\n"; print STDERR unchecked_output("date"); - $cmd="cat @mapoutputs | $REDUCER -w $dir/weights.$im1 > $dir/weights.$iteration"; + $cmd="cat @dev_outs | $REDUCER -w $dir/weights.$im1 -s $reg"; + if ($tune_regularizer) { + $cmd .= " -T -t $dev_test_file"; + } + $cmd .= " > $dir/weights.$iteration"; print STDERR "COMMAND:\n$cmd\n"; check_bash_call($cmd); $lastWeightsFile = "$dir/weights.$iteration"; + if ($tune_regularizer) { + open W, "<$lastWeightsFile" or die "Can't read $lastWeightsFile: $!"; + my $line = ; + close W; + my ($sharp, $label, $nreg) = split /\s|=/, $line; + print STDERR "REGULARIZATION STRENGTH ($label) IS $nreg\n"; + $reg = $nreg; + # only tune regularizer on first iteration? + $tune_regularizer = 0; + } $lastPScore = $score; $iteration++; print STDERR "\n==========\n"; @@ -473,7 +491,6 @@ sub write_config { print $fh "SOURCE (DEV): $srcFile\n"; print $fh "REFS (DEV): $refFiles\n"; print $fh "EVAL METRIC: $metric\n"; - print $fh "START ITERATION: $iteration\n"; print $fh "MAX ITERATIONS: $max_iterations\n"; print $fh "DECODE NODES: $decode_nodes\n"; print $fh "HEAD NODE: $host\n"; @@ -535,31 +552,38 @@ Usage: $executable [options] based on certain conventions. For details, refer to descriptions of the options --decoder, --weights, and --workdir. -Options: +Required: + + --ref-files + Dev set ref files. This option takes only a single string argument. + To use multiple files (including file globbing), this argument should + be quoted. + + --source-file + Dev set source file. + + --weights + Initial weights file (use empty file to start from 0) + +General options: --local Run the decoder and optimizer locally with a single thread. - --use-make - Use make -j to run the optimizer commands (useful on large - shared-memory machines where qsub is unavailable). - --decode-nodes Number of decoder processes to run in parallel. [default=15] - --decoder - Decoder binary to use. - --help Print this message and exit. - --iteration - Starting iteration number. If not specified, defaults to 1. - --max-iterations Maximum number of iterations to run. If not specified, defaults to 10. + --metric + Metric to optimize. + Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi + --pass-suffix If the decoder is doing multi-pass decoding, the pass suffix "2", "3", etc., is used to control what iteration of weights is set. @@ -567,21 +591,9 @@ Options: --pmem Amount of physical memory requested for parallel decoding jobs. - --ref-files - Dev set ref files. This option takes only a single string argument. - To use multiple files (including file globbing), this argument should - be quoted. - - --metric - Metric to optimize. - Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi - - --normalize - After each iteration, rescale all feature weights such that feature- - name has a weight of 1.0. - - --source-file - Dev set source file. + --use-make + Use make -j to run the optimizer commands (useful on large + shared-memory machines where qsub is unavailable). --workdir Directory for intermediate and output files. If not specified, the @@ -591,6 +603,14 @@ Options: the filename. E.g. an ini file named decoder.foo.ini would have a default working directory name foo. +Regularization options: + + --tune-regularizer + Hold out one third of the tuning data and used this to tune the + regularization parameter. + + --reg + Help } @@ -606,7 +626,6 @@ sub convert { } - sub cmdline { return join ' ',($0,@ORIG_ARGV); } diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 491ceb3a..9b422f33 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -16,7 +16,7 @@ using namespace std; namespace po = boost::program_options; // since this is a ranking model, there should be equal numbers of -// positive and negative examples so the bias should be 0 +// positive and negative examples, so the bias should be 0 static const double MAX_BIAS = 1e-10; void InitCommandLine(int argc, char** argv, po::variables_map* conf) { @@ -25,8 +25,11 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("weights,w", po::value(), "Weights from previous iteration (used as initialization and interpolation") ("interpolation,p",po::value()->default_value(0.9), "Output weights are p*w + (1-p)*w_prev") ("memory_buffers,m",po::value()->default_value(200), "Number of memory buffers (LBFGS)") - ("sigma_squared,s",po::value()->default_value(1.0), "Sigma squared for Gaussian prior") - ("testset,t",po::value(), "Optional held-out test set to tune regularizer") + ("sigma_squared,s",po::value()->default_value(0.1), "Sigma squared for Gaussian prior") + ("min_reg,r",po::value()->default_value(1e-8), "When tuning (-T) regularization strength, minimum regularization strenght") + ("max_reg,R",po::value()->default_value(10.0), "When tuning (-T) regularization strength, maximum regularization strenght") + ("testset,t",po::value(), "Optional held-out test set") + ("tune_regularizer,T", "Use the held out test set (-t) to tune the regularization strength") ("help,h", "Help"); po::options_description dcmdline_options; dcmdline_options.add(opts); @@ -95,8 +98,6 @@ void GradAdd(const SparseVector& v, const double scale, vector* double TrainingInference(const vector& x, const vector > >& corpus, vector* g = NULL) { - if (g) fill(g->begin(), g->end(), 0.0); - double cll = 0; for (int i = 0; i < corpus.size(); ++i) { const double dotprod = corpus[i].second.dot(x) + x[0]; // x[0] is bias @@ -130,39 +131,23 @@ double TrainingInference(const vector& x, return cll; } -int main(int argc, char** argv) { - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - string line; - vector > > training, testing; - SparseVector old_weights; - const double psi = conf["interpolation"].as(); - if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; } - if (conf.count("weights")) { - Weights w; - w.InitFromFile(conf["weights"].as()); - w.InitSparseVector(&old_weights); - } - ReadCorpus(&cin, &training); - if (conf.count("testset")) { - ReadFile rf(conf["testset"].as()); - ReadCorpus(rf.stream(), &testing); - } - - cerr << "Number of features: " << FD::NumFeats() << endl; - vector x(FD::NumFeats(), 0.0); // x[0] is bias - for (SparseVector::const_iterator it = old_weights.begin(); - it != old_weights.end(); ++it) - x[it->first] = it->second; +// return held-out log likelihood +double LearnParameters(const vector > >& training, + const vector > >& testing, + const double sigsq, + const unsigned memory_buffers, + vector* px) { + vector& x = *px; vector vg(FD::NumFeats(), 0.0); bool converged = false; - LBFGSOptimizer opt(FD::NumFeats(), conf["memory_buffers"].as()); + LBFGSOptimizer opt(FD::NumFeats(), memory_buffers); + double tppl = 0.0; while(!converged) { + fill(vg.begin(), vg.end(), 0.0); double cll = TrainingInference(x, training, &vg); double ppl = cll / log(2); ppl /= training.size(); ppl = pow(2.0, ppl); - double tppl = 0.0; // evaluate optional held-out test set if (testing.size()) { @@ -173,7 +158,6 @@ int main(int argc, char** argv) { // handle regularizer #if 1 - const double sigsq = conf["sigma_squared"].as(); double norm = 0; for (int i = 1; i < x.size(); ++i) { const double mean_i = 0.0; @@ -202,11 +186,91 @@ int main(int argc, char** argv) { cerr << " BIAS: " << x[0] << endl; } } + return tppl; +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + string line; + vector > > training, testing; + SparseVector old_weights; + const bool tune_regularizer = conf.count("tune_regularizer"); + if (tune_regularizer && !conf.count("testset")) { + cerr << "--tune_regularizer requires --testset to be set\n"; + return 1; + } + const double min_reg = conf["min_reg"].as(); + const double max_reg = conf["max_reg"].as(); + double sigsq = conf["sigma_squared"].as(); + assert(sigsq > 0.0); + assert(min_reg > 0.0); + assert(max_reg > 0.0); + assert(max_reg > min_reg); + const double psi = conf["interpolation"].as(); + if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; } + if (conf.count("weights")) { + Weights w; + w.InitFromFile(conf["weights"].as()); + w.InitSparseVector(&old_weights); + } + ReadCorpus(&cin, &training); + if (conf.count("testset")) { + ReadFile rf(conf["testset"].as()); + ReadCorpus(rf.stream(), &testing); + } + cerr << "Number of features: " << FD::NumFeats() << endl; + vector x(FD::NumFeats(), 0.0); // x[0] is bias + for (SparseVector::const_iterator it = old_weights.begin(); + it != old_weights.end(); ++it) + x[it->first] = it->second; + double tppl = 0.0; + vector > sp; + vector smoothed; + if (tune_regularizer) { + sigsq = min_reg; + const double steps = 18; + double sweep_factor = exp((log(max_reg) - log(min_reg)) / steps); + cerr << "SWEEP FACTOR: " << sweep_factor << endl; + while(sigsq < max_reg) { + tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as(), &x); + sp.push_back(make_pair(sigsq, tppl)); + sigsq *= sweep_factor; + } + smoothed.resize(sp.size(), 0); + smoothed[0] = sp[0].second; + smoothed.back() = sp.back().second; + for (int i = 1; i < sp.size()-1; ++i) { + double prev = sp[i-1].second; + double next = sp[i+1].second; + double cur = sp[i].second; + smoothed[i] = (prev*0.2) + cur * 0.6 + (0.2*next); + } + double best_ppl = 9999999; + unsigned best_i = 0; + for (unsigned i = 0; i < sp.size(); ++i) { + if (smoothed[i] < best_ppl) { + best_ppl = smoothed[i]; + best_i = i; + } + } + sigsq = sp[best_i].first; + tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as(), &x); + } Weights w; if (conf.count("weights")) { for (int i = 1; i < x.size(); ++i) x[i] = (x[i] * psi) + old_weights.get(i) * (1.0 - psi); } + cout.precision(15); + cout << "# sigma^2=" << sigsq << "\theld out perplexity="; + if (tppl) { cout << tppl << endl; } else { cout << "N/A\n"; } + if (sp.size()) { + cout << "# Parameter sweep:\n"; + for (int i = 0; i < sp.size(); ++i) { + cout << "# " << sp[i].first << "\t" << sp[i].second << "\t" << smoothed[i] << endl; + } + } w.InitFromVector(x); w.WriteToFile("-"); return 0; -- cgit v1.2.3 From d73b5d25bd0af14a4a83490d67ba2553b6af9884 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 28 Jul 2011 17:08:59 +0100 Subject: stuff --- decoder/apply_models.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 62eff262..26cdb881 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -190,8 +190,12 @@ public: if (num_nodes > 100) every = 10; assert(in.nodes_[pregoal].out_edges_.size() == 1); if (!SILENT) cerr << " "; + int has = 0; for (int i = 0; i < in.nodes_.size(); ++i) { - if (!SILENT && i % every == 0) cerr << '.'; + if (!SILENT) { + int needs = (50 * i / in.nodes_.size()); + while (has < needs) { cerr << '.'; ++has; } + } if (strategy_==NORMAL_CP){ KBest(i, i == goal_id); } -- cgit v1.2.3 From 2c14cf2218031c29a9884bccf17e9273c71a33b2 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 18 Aug 2011 12:14:01 +0100 Subject: KenLM update: Bhiksha's trick, simple test for lms without unk, auto-detect binary files instead of requiring them to be specified at runtime. --- decoder/cdec_ff.cc | 5 +- decoder/ff_klm.cc | 38 +++++- decoder/ff_klm.h | 7 +- klm/compile.sh | 2 +- klm/lm/Makefile.am | 1 + klm/lm/bhiksha.cc | 93 +++++++++++++++ klm/lm/bhiksha.hh | 108 +++++++++++++++++ klm/lm/binary_format.cc | 13 ++- klm/lm/binary_format.hh | 9 +- klm/lm/build_binary.cc | 54 ++++++--- klm/lm/config.cc | 1 + klm/lm/config.hh | 5 +- klm/lm/model.cc | 67 ++++++----- klm/lm/model.hh | 12 +- klm/lm/model_test.cc | 73 ++++++++++-- klm/lm/ngram_query.cc | 9 ++ klm/lm/quantize.cc | 1 + klm/lm/quantize.hh | 4 +- klm/lm/read_arpa.cc | 6 +- klm/lm/search_hashed.cc | 2 +- klm/lm/search_hashed.hh | 3 +- klm/lm/search_trie.cc | 45 +++---- klm/lm/search_trie.hh | 20 ++-- klm/lm/test_nounk.arpa | 120 +++++++++++++++++++ klm/lm/trie.cc | 57 ++++----- klm/lm/trie.hh | 24 ++-- klm/lm/vocab.cc | 6 +- klm/lm/vocab.hh | 4 + klm/util/bit_packing.hh | 13 ++- klm/util/murmur_hash.cc | 258 ++++++++++++++++++++--------------------- klm/util/probing_hash_table.hh | 2 +- klm/util/sorted_uniform.hh | 23 +++- 32 files changed, 792 insertions(+), 293 deletions(-) create mode 100644 klm/lm/bhiksha.cc create mode 100644 klm/lm/bhiksha.hh create mode 100644 klm/lm/test_nounk.arpa diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 3451c9fb..1ef76a05 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -55,10 +55,7 @@ void register_feature_functions() { ff_registry.Register("NgramFeatures", new FFFactory()); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); - ff_registry.Register("KLanguageModel", new FFFactory >()); - ff_registry.Register("KLanguageModel_Trie", new FFFactory >()); - ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory >()); - ff_registry.Register("KLanguageModel_Probing", new FFFactory >()); + ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); ff_registry.Register("NonLatinCount", new FFFactory); ff_registry.Register("RuleShape", new FFFactory); ff_registry.Register("RelativeSentencePosition", new FFFactory); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 9b7fe2d3..24dcb9c3 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -9,6 +9,7 @@ #include "stringlib.h" #include "hg.h" #include "tdict.h" +#include "lm/model.hh" #include "lm/enumerate_vocab.hh" using namespace std; @@ -434,8 +435,37 @@ void KLanguageModel::FinalTraversalFeatures(const void* ant_state, features->set_value(oov_fid_, oovs); } -// instantiate templates -template class KLanguageModel; -template class KLanguageModel; -template class KLanguageModel; +template boost::shared_ptr CreateModel(const std::string ¶m) { + KLanguageModel *ret = new KLanguageModel(param); + ret->Init(); + return boost::shared_ptr(ret); +} +boost::shared_ptr KLanguageModelFactory::Create(std::string param) const { + using namespace lm::ngram; + std::string filename, ignored_map; + bool ignored_markers; + std::string ignored_featname; + ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); + ModelType m; + if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + + switch (m) { + case HASH_PROBING: + return CreateModel(param); + case TRIE_SORTED: + return CreateModel(param); + case ARRAY_TRIE_SORTED: + return CreateModel(param); + case QUANT_TRIE_SORTED: + return CreateModel(param); + case QUANT_ARRAY_TRIE_SORTED: + return CreateModel(param); + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); + } +} + +std::string KLanguageModelFactory::usage(bool params,bool verbose) const { + return KLanguageModel::usage(params, verbose); +} diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index 5eafe8be..6efe50f6 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -4,8 +4,8 @@ #include #include +#include "ff_factory.h" #include "ff.h" -#include "lm/model.hh" template struct KLanguageModelImpl; @@ -34,4 +34,9 @@ class KLanguageModel : public FeatureFunction { KLanguageModelImpl* pimpl_; }; +struct KLanguageModelFactory : public FactoryBase { + FP Create(std::string param) const; + std::string usage(bool params,bool verbose) const; +}; + #endif diff --git a/klm/compile.sh b/klm/compile.sh index 6ca85e1f..abe3473a 100755 --- a/klm/compile.sh +++ b/klm/compile.sh @@ -5,7 +5,7 @@ set -e -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do +for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{bhiksha,binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o done g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 395494bc..fae6b41a 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -12,6 +12,7 @@ build_binary_LDADD = libklm.a ../util/libklm_util.a -lz noinst_LIBRARIES = libklm.a libklm_a_SOURCES = \ + bhiksha.cc \ binary_format.cc \ config.cc \ lm_exception.cc \ diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc new file mode 100644 index 00000000..bf86fd4b --- /dev/null +++ b/klm/lm/bhiksha.cc @@ -0,0 +1,93 @@ +#include "lm/bhiksha.hh" +#include "lm/config.hh" + +#include + +namespace lm { +namespace ngram { +namespace trie { + +DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) : + next_(util::BitsMask::ByMax(max_next)) {} + +const uint8_t kArrayBhikshaVersion = 0; + +void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { + uint8_t version; + uint8_t configured_bits; + if (read(fd, &version, 1) != 1 || read(fd, &configured_bits, 1) != 1) { + UTIL_THROW(util::ErrnoException, "Could not read from binary file"); + } + if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); + config.pointer_bhiksha_bits = configured_bits; +} + +namespace { + +// Find argmin_{chopped \in [0, RequiredBits(max_next)]} ChoppedDelta(max_offset) +uint8_t ChopBits(uint64_t max_offset, uint64_t max_next, const Config &config) { + uint8_t required = util::RequiredBits(max_next); + uint8_t best_chop = 0; + int64_t lowest_change = std::numeric_limits::max(); + // There are probably faster ways but I don't care because this is only done once per order at construction time. + for (uint8_t chop = 0; chop <= std::min(required, config.pointer_bhiksha_bits); ++chop) { + int64_t change = (max_next >> (required - chop)) * 64 /* table cost in bits */ + - max_offset * static_cast(chop); /* savings in bits*/ + if (change < lowest_change) { + lowest_change = change; + best_chop = chop; + } + } + return best_chop; +} + +std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &config) { + uint8_t required = util::RequiredBits(max_next); + uint8_t chopping = ChopBits(max_offset, max_next, config); + return (max_next >> (required - chopping)) + 1 /* we store 0 too */; +} +} // namespace + +std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { + return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */; +} + +uint8_t ArrayBhiksha::InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config) { + return util::RequiredBits(max_next) - ChopBits(max_offset, max_next, config); +} + +namespace { + +void *AlignTo8(void *from) { + uint8_t *val = reinterpret_cast(from); + std::size_t remainder = reinterpret_cast(val) & 7; + if (!remainder) return val; + return val + 8 - remainder; +} + +} // namespace + +ArrayBhiksha::ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_next, const Config &config) + : next_inline_(util::BitsMask::ByBits(InlineBits(max_offset, max_next, config))), + offset_begin_(reinterpret_cast(AlignTo8(base)) + 1 /* 8-byte header */), + offset_end_(offset_begin_ + ArrayCount(max_offset, max_next, config)), + write_to_(reinterpret_cast(AlignTo8(base)) + 1 /* 8-byte header */ + 1 /* first entry is 0 */), + original_base_(base) {} + +void ArrayBhiksha::FinishedLoading(const Config &config) { + // *offset_begin_ = 0 but without a const_cast. + *(write_to_ - (write_to_ - offset_begin_)) = 0; + + if (write_to_ != offset_end_) UTIL_THROW(util::Exception, "Did not get all the array entries that were expected."); + + uint8_t *head_write = reinterpret_cast(original_base_); + *(head_write++) = kArrayBhikshaVersion; + *(head_write++) = config.pointer_bhiksha_bits; +} + +void ArrayBhiksha::LoadedBinary() { +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh new file mode 100644 index 00000000..cfb2b053 --- /dev/null +++ b/klm/lm/bhiksha.hh @@ -0,0 +1,108 @@ +/* Simple implementation of + * @inproceedings{bhikshacompression, + * author={Bhiksha Raj and Ed Whittaker}, + * year={2003}, + * title={Lossless Compression of Language Model Structure and Word Identifiers}, + * booktitle={Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing}, + * pages={388--391}, + * } + * + * Currently only used for next pointers. + */ + +#include + +#include "lm/binary_format.hh" +#include "lm/trie.hh" +#include "util/bit_packing.hh" +#include "util/sorted_uniform.hh" + +namespace lm { +namespace ngram { +class Config; + +namespace trie { + +class DontBhiksha { + public: + static const ModelType kModelTypeAdd = static_cast(0); + + static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} + + static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } + + static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) { + return util::RequiredBits(max_next); + } + + DontBhiksha(const void *base, uint64_t max_offset, uint64_t max_next, const Config &config); + + void ReadNext(const void *base, uint64_t bit_offset, uint64_t /*index*/, uint8_t total_bits, NodeRange &out) const { + out.begin = util::ReadInt57(base, bit_offset, next_.bits, next_.mask); + out.end = util::ReadInt57(base, bit_offset + total_bits, next_.bits, next_.mask); + //assert(out.end >= out.begin); + } + + void WriteNext(void *base, uint64_t bit_offset, uint64_t /*index*/, uint64_t value) { + util::WriteInt57(base, bit_offset, next_.bits, value); + } + + void FinishedLoading(const Config &/*config*/) {} + + void LoadedBinary() {} + + uint8_t InlineBits() const { return next_.bits; } + + private: + util::BitsMask next_; +}; + +class ArrayBhiksha { + public: + static const ModelType kModelTypeAdd = kArrayAdd; + + static void UpdateConfigFromBinary(int fd, Config &config); + + static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); + + static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config); + + ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config); + + void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const { + const uint64_t *begin_it = util::BinaryBelow(util::IdentityAccessor(), offset_begin_, offset_end_, index); + const uint64_t *end_it; + for (end_it = begin_it; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {} + --end_it; + out.begin = ((begin_it - offset_begin_) << next_inline_.bits) | + util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask); + out.end = ((end_it - offset_begin_) << next_inline_.bits) | + util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask); + } + + void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) { + uint64_t encode = value >> next_inline_.bits; + for (; write_to_ <= offset_begin_ + encode; ++write_to_) *write_to_ = index; + util::WriteInt57(base, bit_offset, next_inline_.bits, value & next_inline_.mask); + } + + void FinishedLoading(const Config &config); + + void LoadedBinary(); + + uint8_t InlineBits() const { return next_inline_.bits; } + + private: + const util::BitsMask next_inline_; + + const uint64_t *const offset_begin_; + const uint64_t *const offset_end_; + + uint64_t *write_to_; + + void *original_base_; +}; + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 92b1008b..e02e621a 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -40,7 +40,7 @@ struct Sanity { } }; -const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"}; +const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; std::size_t Align8(std::size_t in) { std::size_t off = in % 8; @@ -100,16 +100,17 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ } } -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) { +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { + std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; if (config.write_mmap) { // Grow the file to accomodate the search, using zeros. - if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size)) - UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed"); + if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size)) + UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed"); // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. off_t page_size = sysconf(_SC_PAGE_SIZE); - off_t alignment_cruft = backing.vocab.size() % page_size; - backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); + off_t alignment_cruft = adjusted_vocab % page_size; + backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); return reinterpret_cast(backing.search.get()) + alignment_cruft; } else { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 2b32b450..d28cb6c5 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -16,7 +16,12 @@ namespace lm { namespace ngram { -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType; +/* Not the best numbering system, but it grew this way for historical reasons + * and I want to preserve existing binary files. */ +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; + +const static ModelType kQuantAdd = static_cast(QUANT_TRIE_SORTED - TRIE_SORTED); +const static ModelType kArrayAdd = static_cast(ARRAY_TRIE_SORTED - TRIE_SORTED); /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in @@ -55,7 +60,7 @@ void AdvanceOrThrow(int fd, off_t off); // Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing); +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing); // Write header to binary file. This is done last to prevent incomplete files // from loading. diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 4552c419..b7aee4de 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,12 +15,12 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n" -"-u sets the default log10 probability for if the ARPA file does not have\n" -"one.\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n" +"-u sets the log10 probability for if the ARPA file does not have one.\n" +" Default is -100. The ARPA file will always take precedence.\n" "-s allows models to be built even if they do not have and .\n" -"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" -"type is either probing or trie:\n\n" +"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n" +"type is either probing or trie. Default is probing.\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" "trie is a straightforward trie with bit-level packing. It uses the least\n" @@ -29,10 +29,11 @@ void Usage(const char *name) { "-t is the temporary directory prefix. Default is the output file name.\n" "-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n" "-q turns quantization on and sets the number of bits (e.g. -q 8).\n" -"-b sets backoff quantization bits. Requires -q and defaults to that value.\n\n" -"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" -"Passing only an input file will print memory usage of each data structure.\n" -"If the ARPA file does not have , -u sets 's probability; default 0.0.\n"; +"-b sets backoff quantization bits. Requires -q and defaults to that value.\n" +"-a compresses pointers using an array of offsets. The parameter is the\n" +" maximum number of bits encoded by the array. Memory is minimized subject\n" +" to the maximum, so pick 255 to minimize memory.\n\n" +"Get a memory estimate by passing an ARPA file without an output file name.\n"; exit(1); } @@ -63,12 +64,14 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t sizes[3]; + std::size_t sizes[5]; sizes[0] = ProbingModel::Size(counts, config); sizes[1] = TrieModel::Size(counts, config); sizes[2] = QuantTrieModel::Size(counts, config); - std::size_t max_length = *std::max_element(sizes, sizes + 3); - std::size_t min_length = *std::max_element(sizes, sizes + 3); + sizes[3] = ArrayTrieModel::Size(counts, config); + sizes[4] = QuantArrayTrieModel::Size(counts, config); + std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); + std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); std::size_t divide; char prefix; if (min_length < (1 << 10) * 10) { @@ -91,7 +94,9 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::cout << prefix << "B\n" "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" "trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n" - "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"; + "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" + "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" + "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; } void ProbingQuantizationUnsupported() { @@ -106,11 +111,11 @@ void ProbingQuantizationUnsupported() { int main(int argc, char *argv[]) { using namespace lm::ngram; - bool quantize = false, set_backoff_bits = false; try { + bool quantize = false, set_backoff_bits = false, bhiksha = false; lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) { + while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) { switch(opt) { case 'q': config.prob_bits = ParseBitCount(optarg); @@ -121,6 +126,9 @@ int main(int argc, char *argv[]) { config.backoff_bits = ParseBitCount(optarg); set_backoff_bits = true; break; + case 'a': + config.pointer_bhiksha_bits = ParseBitCount(optarg); + bhiksha = true; case 'u': config.unknown_missing_logprob = ParseFloat(optarg); break; @@ -162,9 +170,17 @@ int main(int argc, char *argv[]) { ProbingModel(from_file, config); } else if (!strcmp(model_type, "trie")) { if (quantize) { - QuantTrieModel(from_file, config); + if (bhiksha) { + QuantArrayTrieModel(from_file, config); + } else { + QuantTrieModel(from_file, config); + } } else { - TrieModel(from_file, config); + if (bhiksha) { + ArrayTrieModel(from_file, config); + } else { + TrieModel(from_file, config); + } } } else { Usage(argv[0]); @@ -173,9 +189,9 @@ int main(int argc, char *argv[]) { Usage(argv[0]); } } - catch (std::exception &e) { + catch (const std::exception &e) { std::cerr << e.what() << std::endl; - abort(); + return 1; } return 0; } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index 08e1af5c..297589a4 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -20,6 +20,7 @@ Config::Config() : include_vocab(true), prob_bits(8), backoff_bits(8), + pointer_bhiksha_bits(22), load_method(util::POPULATE_OR_READ) {} } // namespace ngram diff --git a/klm/lm/config.hh b/klm/lm/config.hh index dcc7cf35..227b8512 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -73,9 +73,12 @@ struct Config { // Quantization options. Only effective for QuantTrieModel. One value is // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used - // to quantize. + // to quantize (and one of the remaining backoffs will be 0). uint8_t prob_bits, backoff_bits; + // Bhiksha compression (simple form). Only works with trie. + uint8_t pointer_bhiksha_bits; + // ONLY EFFECTIVE WHEN READING BINARY diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a1d10b3d..27e24b1c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -21,6 +21,8 @@ size_t hash_value(const State &state) { namespace detail { +template const ModelType GenericModel::kModelType = Search::kModelType; + template size_t GenericModel::Size(const std::vector &counts, const Config &config) { return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); } @@ -56,35 +58,40 @@ template void GenericModel void GenericModel::InitializeFromARPA(const char *file, const Config &config) { // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. util::FilePiece f(backing_.file.release(), file, config.messages); - std::vector counts; - // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. - ReadARPACounts(f, counts); - - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); - if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); - if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); - - std::size_t vocab_size = VocabularyT::Size(counts[0], config); - // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. - vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); - - if (config.write_mmap) { - WriteWordsWrapper wrap(config.enumerate_vocab); - vocab_.ConfigureEnumerate(&wrap, counts[0]); - search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get()); - } else { - vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); - search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - } + try { + std::vector counts; + // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. + ReadARPACounts(f, counts); + + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); + if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); + + std::size_t vocab_size = VocabularyT::Size(counts[0], config); + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); + + if (config.write_mmap) { + WriteWordsWrapper wrap(config.enumerate_vocab); + vocab_.ConfigureEnumerate(&wrap, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + wrap.Write(backing_.file.get()); + } else { + vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + } - if (!vocab_.SawUnk()) { - assert(config.unknown_missing != THROW_UP); - // Default probabilities for unknown. - search_.unigram.Unknown().backoff = 0.0; - search_.unigram.Unknown().prob = config.unknown_missing_logprob; + if (!vocab_.SawUnk()) { + assert(config.unknown_missing != THROW_UP); + // Default probabilities for unknown. + search_.unigram.Unknown().backoff = 0.0; + search_.unigram.Unknown().prob = config.unknown_missing_logprob; + } + FinishFile(config, kModelType, counts, backing_); + } catch (util::Exception &e) { + e << " Byte: " << f.Offset(); + throw; } - FinishFile(config, kModelType, counts, backing_); } template FullScoreReturn GenericModel::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { @@ -225,8 +232,10 @@ template FullScoreReturn GenericModel; // HASH_PROBING -template class GenericModel, SortedVocabulary>; // TRIE_SORTED -template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel, SortedVocabulary>; // TRIE_SORTED +template class GenericModel, SortedVocabulary>; +template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel, SortedVocabulary>; } // namespace detail } // namespace ngram diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 1f49a382..21595321 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -1,6 +1,7 @@ #ifndef LM_MODEL__ #define LM_MODEL__ +#include "lm/bhiksha.hh" #include "lm/binary_format.hh" #include "lm/config.hh" #include "lm/facade.hh" @@ -71,6 +72,9 @@ template class GenericModel : public base::Mod private: typedef base::ModelFacade, State, VocabularyT> P; public: + // This is the model type returned by RecognizeBinary. + static const ModelType kModelType; + /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class * itself. @@ -131,8 +135,6 @@ template class GenericModel : public base::Mod Backing &MutableBacking() { return backing_; } - static const ModelType kModelType = Search::kModelType; - Backing backing_; VocabularyT vocab_; @@ -152,9 +154,11 @@ typedef ProbingModel Model; // Smaller implementation. typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> ArrayTrieModel; -typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> QuantArrayTrieModel; } // namespace ngram } // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 8bf040ff..57c7291c 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -193,6 +193,14 @@ template void Stateless(const M &model) { BOOST_CHECK_EQUAL(static_cast(0), state.history_[0]); } +template void NoUnkCheck(const M &model) { + WordIndex unk_index = 0; + State state; + + FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state); + BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001); +} + template void Everything(const M &m) { Starters(m); Continuation(m); @@ -231,25 +239,38 @@ template void LoadingTest() { Config config; config.arpa_complain = Config::NONE; config.messages = NULL; - ExpectEnumerateVocab enumerate; - config.enumerate_vocab = &enumerate; config.probing_multiplier = 2.0; - ModelT m("test.arpa", config); - enumerate.Check(m.GetVocabulary()); - Everything(m); + { + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test.arpa", config); + enumerate.Check(m.GetVocabulary()); + Everything(m); + } + { + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test_nounk.arpa", config); + enumerate.Check(m.GetVocabulary()); + NoUnkCheck(m); + } } BOOST_AUTO_TEST_CASE(probing) { LoadingTest(); } - BOOST_AUTO_TEST_CASE(trie) { LoadingTest(); } - -BOOST_AUTO_TEST_CASE(quant) { +BOOST_AUTO_TEST_CASE(quant_trie) { LoadingTest(); } +BOOST_AUTO_TEST_CASE(bhiksha_trie) { + LoadingTest(); +} +BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) { + LoadingTest(); +} template void BinaryTest() { Config config; @@ -267,10 +288,34 @@ template void BinaryTest() { config.write_mmap = NULL; - ModelT binary("test.binary", config); - enumerate.Check(binary.GetVocabulary()); - Everything(binary); + ModelType type; + BOOST_REQUIRE(RecognizeBinary("test.binary", type)); + BOOST_CHECK_EQUAL(ModelT::kModelType, type); + + { + ModelT binary("test.binary", config); + enumerate.Check(binary.GetVocabulary()); + Everything(binary); + } unlink("test.binary"); + + // Now test without . + config.write_mmap = "test_nounk.binary"; + config.messages = NULL; + enumerate.Clear(); + { + ModelT copy_model("test_nounk.arpa", config); + enumerate.Check(copy_model.GetVocabulary()); + enumerate.Clear(); + NoUnkCheck(copy_model); + } + config.write_mmap = NULL; + { + ModelT binary("test_nounk.binary", config); + enumerate.Check(binary.GetVocabulary()); + NoUnkCheck(binary); + } + unlink("test_nounk.binary"); } BOOST_AUTO_TEST_CASE(write_and_read_probing) { @@ -282,6 +327,12 @@ BOOST_AUTO_TEST_CASE(write_and_read_trie) { BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) { BinaryTest(); } +BOOST_AUTO_TEST_CASE(write_and_read_array_trie) { + BinaryTest(); +} +BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) { + BinaryTest(); +} } // namespace } // namespace ngram diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 9454a6d1..d9db4aa2 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -99,6 +99,15 @@ int main(int argc, char *argv[]) { case lm::ngram::TRIE_SORTED: Query(argv[1], sentence_context); break; + case lm::ngram::QUANT_TRIE_SORTED: + Query(argv[1], sentence_context); + break; + case lm::ngram::ARRAY_TRIE_SORTED: + Query(argv[1], sentence_context); + break; + case lm::ngram::QUANT_ARRAY_TRIE_SORTED: + Query(argv[1], sentence_context); + break; case lm::ngram::HASH_SORTED: default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index 4bb6b1b8..fd371cc8 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -43,6 +43,7 @@ void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector(0); static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } @@ -108,7 +108,7 @@ class SeparatelyQuantize { }; public: - static const ModelType kModelType = QUANT_TRIE_SORTED; + static const ModelType kModelTypeAdd = kQuantAdd; static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config); diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 060a97ea..455bc4ba 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -31,15 +31,15 @@ const char kBinaryMagic[] = "mmap lm http://kheafield.com/code"; void ReadARPACounts(util::FilePiece &in, std::vector &number) { number.clear(); StringPiece line; - if (!IsEntirelyWhiteSpace(line = in.ReadLine())) { + while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} + if (line != "\\data\\") { if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast(line.data()[1]) == 0x8b)) { UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); } if (static_cast(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic) UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?"); - UTIL_THROW(FormatLoadException, "First line was \"" << line.data() << "\" not blank"); + UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\."); } - if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\."); while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); // So strtol doesn't go off the end of line. diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index c56ba7b8..82c53ec8 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -98,7 +98,7 @@ template uint8_t *TemplateHashedSearch template void TemplateHashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. - SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config); + SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config); PositiveProbWarn warn(config.positive_log_probability); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index f3acdefc..c62985e4 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -52,12 +52,11 @@ struct HashedSearch { Unigram unigram; - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { const ProbBackoff &entry = unigram.Lookup(word); prob = entry.prob; backoff = entry.backoff; next = static_cast(word); - return true; } }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 91f87f1c..05059ffb 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -1,6 +1,7 @@ /* This is where the trie is built. It's on-disk. */ #include "lm/search_trie.hh" +#include "lm/bhiksha.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" @@ -543,8 +544,8 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector appears. - size_t extra_count = counts[0] + 1; - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff)); + size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); CheckSpecials(config, vocab); if (!vocab.SawUnk()) ++counts[0]; @@ -610,9 +611,9 @@ class JustCount { }; // Phase to actually write n-grams to the trie. -template class WriteEntries { +template class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -649,7 +650,7 @@ template class WriteEntries { private: ContextReader *contexts_; UnigramValue *const unigrams_; - BitPackedMiddle *const middle_; + BitPackedMiddle *const middle_; BitPackedLongest &longest_; BitPacked &bigram_pack_; }; @@ -821,7 +822,7 @@ template void TrainProbQuantizer(uint8_t order, uint64_t count, So } // namespace -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing) { +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { std::vector inputs(counts.size() - 1); std::vector contexts(counts.size() - 1); @@ -846,7 +847,7 @@ template void BuildTrie(const std::string &file_prefix, std::vecto SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); if (Quant::kTrain) { util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); @@ -863,7 +864,7 @@ template void BuildTrie(const std::string &file_prefix, std::vecto UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -901,14 +902,14 @@ template void BuildTrie(const std::string &file_prefix, std::vecto /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. if (out.middle_begin_ != out.middle_end_) { - for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { - i->FinishedLoading((i+1)->InsertIndex()); + for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { + i->FinishedLoading((i+1)->InsertIndex(), config); } - (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex()); + (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config); } } -template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { +template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { quant_.SetupMemory(start, config); start += Quant::Size(counts.size(), config); unigram.Init(start); @@ -919,22 +920,24 @@ template uint8_t *TrieSearch::SetupMemory(uint8_t *start, c std::vector middle_starts(counts.size() - 2); for (unsigned char i = 2; i < counts.size(); ++i) { middle_starts[i-2] = start; - start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config); } - // Crazy backwards thing so we initialize in the correct order. + // Crazy backwards thing so we initialize using pointers to ones that have already been initialized for (unsigned char i = counts.size() - 1; i >= 2; --i) { new (middle_begin_ + i - 2) Middle( middle_starts[i-2], quant_.Mid(i), + counts[i-1], counts[0], counts[i], - (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1])); + (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1]), + config); } longest.Init(start, quant_.Long(counts.size()), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -template void TrieSearch::LoadedBinary() { +template void TrieSearch::LoadedBinary() { unigram.LoadedBinary(); for (Middle *i = middle_begin_; i != middle_end_; ++i) { i->LoadedBinary(); @@ -942,7 +945,7 @@ template void TrieSearch::LoadedBinary() { longest.LoadedBinary(); } -template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; @@ -966,14 +969,16 @@ template void TrieSearch::InitializeFromARPA(const char *fi // At least 1MB sorting memory. ARPAToSortedFiles(config, f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory, counts, config, *this, quant_, backing); + BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } } -template class TrieSearch; -template class TrieSearch; +template class TrieSearch; +template class TrieSearch; +template class TrieSearch; +template class TrieSearch; } // namespace trie } // namespace ngram diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0a52acb5..2f39c09f 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,31 +13,33 @@ struct Backing; class SortedVocabulary; namespace trie { -template class TrieSearch; -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); +template class TrieSearch; +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); -template class TrieSearch { +template class TrieSearch { public: typedef NodeRange Node; typedef ::lm::ngram::trie::Unigram Unigram; Unigram unigram; - typedef trie::BitPackedMiddle Middle; + typedef trie::BitPackedMiddle Middle; typedef trie::BitPackedLongest Longest; Longest longest; - static const ModelType kModelType = Quant::kModelType; + static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); + AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); + Bhiksha::UpdateConfigFromBinary(fd, config); } static std::size_t Size(const std::vector &counts, const Config &config) { std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); for (unsigned char i = 1; i < counts.size() - 1; ++i) { - ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); + ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config); } return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } @@ -55,8 +57,8 @@ template class TrieSearch { void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - return unigram.Find(word, prob, backoff, node); + void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + unigram.Find(word, prob, backoff, node); } bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { @@ -83,7 +85,7 @@ template class TrieSearch { } private: - friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); + friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); // Middles are managed manually so we can delay construction and they don't have to be copyable. void FreeMiddles() { diff --git a/klm/lm/test_nounk.arpa b/klm/lm/test_nounk.arpa new file mode 100644 index 00000000..060733d9 --- /dev/null +++ b/klm/lm/test_nounk.arpa @@ -0,0 +1,120 @@ + +\data\ +ngram 1=36 +ngram 2=45 +ngram 3=10 +ngram 4=6 +ngram 5=4 + +\1-grams: +-1.383514 , -0.30103 +-1.139057 . -0.845098 +-1.029493 +-99 -0.4149733 +-1.285941 a -0.69897 +-1.687872 also -0.30103 +-1.687872 beyond -0.30103 +-1.687872 biarritz -0.30103 +-1.687872 call -0.30103 +-1.687872 concerns -0.30103 +-1.687872 consider -0.30103 +-1.687872 considering -0.30103 +-1.687872 for -0.30103 +-1.509559 higher -0.30103 +-1.687872 however -0.30103 +-1.687872 i -0.30103 +-1.687872 immediate -0.30103 +-1.687872 in -0.30103 +-1.687872 is -0.30103 +-1.285941 little -0.69897 +-1.383514 loin -0.30103 +-1.687872 look -0.30103 +-1.285941 looking -0.4771212 +-1.206319 more -0.544068 +-1.509559 on -0.4771212 +-1.509559 screening -0.4771212 +-1.687872 small -0.30103 +-1.687872 the -0.30103 +-1.687872 to -0.30103 +-1.687872 watch -0.30103 +-1.687872 watching -0.30103 +-1.687872 what -0.30103 +-1.687872 would -0.30103 +-3.141592 foo +-2.718281 bar 3.0 +-6.535897 baz -0.0 + +\2-grams: +-0.6925742 , . +-0.7522095 , however +-0.7522095 , is +-0.0602359 . +-0.4846522 looking -0.4771214 +-1.051485 screening +-1.07153 the +-1.07153 watching +-1.07153 what +-0.09132547 a little -0.69897 +-0.2922095 also call +-0.2922095 beyond immediate +-0.2705918 biarritz . +-0.2922095 call for +-0.2922095 concerns in +-0.2922095 consider watch +-0.2922095 considering consider +-0.2834328 for , +-0.5511513 higher more +-0.5845945 higher small +-0.2834328 however , +-0.2922095 i would +-0.2922095 immediate concerns +-0.2922095 in biarritz +-0.2922095 is to +-0.09021038 little more -0.1998621 +-0.7273645 loin , +-0.6925742 loin . +-0.6708385 loin +-0.2922095 look beyond +-0.4638903 looking higher +-0.4638903 looking on -0.4771212 +-0.5136299 more . -0.4771212 +-0.3561665 more loin +-0.1649931 on a -0.4771213 +-0.1649931 screening a -0.4771213 +-0.2705918 small . +-0.287799 the screening +-0.2922095 to look +-0.2622373 watch +-0.2922095 watching considering +-0.2922095 what i +-0.2922095 would also +-2 also would -6 +-6 foo bar + +\3-grams: +-0.01916512 more . +-0.0283603 on a little -0.4771212 +-0.0283603 screening a little -0.4771212 +-0.01660496 a little more -0.09409451 +-0.3488368 looking higher +-0.3488368 looking on -0.4771212 +-0.1892331 little more loin +-0.04835128 looking on a -0.4771212 +-3 also would consider -7 +-7 to look good + +\4-grams: +-0.009249173 looking on a little -0.4771212 +-0.005464747 on a little more -0.4771212 +-0.005464747 screening a little more +-0.1453306 a little more loin +-0.01552657 looking on a -0.4771212 +-4 also would consider higher -8 + +\5-grams: +-0.003061223 looking on a little +-0.001813953 looking on a little more +-0.0432557 on a little more loin +-5 also would consider higher looking + +\end\ diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 63c2a612..8c536e66 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,5 +1,6 @@ #include "lm/trie.hh" +#include "lm/bhiksha.hh" #include "lm/quantize.hh" #include "util/bit_packing.hh" #include "util/exception.hh" @@ -57,16 +58,21 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) max_vocab_ = max_vocab; } -template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { - return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr)); +template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { + return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config)); } -template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) { - if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); - BaseInit(base, max_vocab, quant.TotalBits() + next_bits_); +template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : + BitPacked(), + quant_(quant), + // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary. + bhiksha_(base, entries + 1, max_next, config), + next_source_(&next_source) { + if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); + BaseInit(reinterpret_cast(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits()); } -template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { +template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { assert(word <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; @@ -75,47 +81,42 @@ template void BitPackedMiddle::Insert(WordIndex word, float quant_.Write(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); uint64_t next = next_source_->InsertIndex(); - assert(next <= next_mask_); - util::WriteInt57(base_, at_pointer, next_bits_, next); + bhiksha_.WriteNext(base_, at_pointer, insert_index_, next); ++insert_index_; } -template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } + uint64_t index = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; quant_.Read(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); - range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); - // Read the next entry's pointer. - at_pointer += total_bits_; - range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); + bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); + return true; } -template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { - uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; - at_pointer *= total_bits_; +template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { + uint64_t index; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false; + uint64_t at_pointer = index * total_bits_; at_pointer += word_bits_; quant_.ReadBackoff(base_, at_pointer, backoff); at_pointer += quant_.TotalBits(); - range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); - // Read the next entry's pointer. - at_pointer += total_bits_; - range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); + bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); return true; } -template void BitPackedMiddle::FinishedLoading(uint64_t next_end) { - assert(next_end <= next_mask_); - uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; - util::WriteInt57(base_, last_next_write, next_bits_, next_end); +template void BitPackedMiddle::FinishedLoading(uint64_t next_end, const Config &config) { + uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits(); + bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end); + bhiksha_.FinishedLoading(config); } template void BitPackedLongest::Insert(WordIndex index, float prob) { @@ -135,8 +136,10 @@ template bool BitPackedLongest::Find(WordIndex word, float return true; } -template class BitPackedMiddle; -template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; template class BitPackedLongest; template class BitPackedLongest; diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 8fa21aaf..53612064 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -10,6 +10,7 @@ namespace lm { namespace ngram { +class Config; namespace trie { struct NodeRange { @@ -46,13 +47,12 @@ class Unigram { void LoadedBinary() {} - bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { + void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { UnigramValue *val = unigram_ + word; prob = val->weights.prob; backoff = val->weights.backoff; next.begin = val->next; next.end = (val+1)->next; - return true; } private: @@ -67,8 +67,6 @@ class BitPacked { return insert_index_; } - void LoadedBinary() {} - protected: static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); @@ -83,30 +81,30 @@ class BitPacked { uint64_t insert_index_, max_vocab_; }; -template class BitPackedMiddle : public BitPacked { +template class BitPackedMiddle : public BitPacked { public: - static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); // next_source need not be initialized. - BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); + BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); void Insert(WordIndex word, float prob, float backoff); + void FinishedLoading(uint64_t next_end, const Config &config); + + void LoadedBinary() { bhiksha_.LoadedBinary(); } + bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; - void FinishedLoading(uint64_t next_end); - private: Quant quant_; - uint8_t next_bits_; - uint64_t next_mask_; + Bhiksha bhiksha_; const BitPacked *next_source_; }; - template class BitPackedLongest : public BitPacked { public: static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { @@ -120,6 +118,8 @@ template class BitPackedLongest : public BitPacked { BaseInit(base, max_vocab, quant_.TotalBits()); } + void LoadedBinary() {} + void Insert(WordIndex word, float prob); bool Find(WordIndex word, float &prob, const NodeRange &node) const; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 7defd5c1..04979d51 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -37,14 +37,14 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { WordIndex index = 0; while (true) { ssize_t got = read(fd, &buf[0], kInitialRead); - if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); + UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words"); if (got == 0) return index; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; ssize_t ret = read(fd, &next_char, 1); - if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); - if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word."); + UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words"); + UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word."); buf.push_back(next_char); } // Ok now we have null terminated strings. diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index c92518e4..9d218fff 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -61,6 +61,7 @@ class SortedVocabulary : public base::Vocabulary { } } + // Size for purposes of file writing static size_t Size(std::size_t entries, const Config &config); // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. @@ -77,6 +78,9 @@ class SortedVocabulary : public base::Vocabulary { // Reorders reorder_vocab so that the IDs are sorted. void FinishedLoading(ProbBackoff *reorder_vocab); + // Trie stores the correct counts including in the header. If this was previously sized based on a count exluding , padding with 8 bytes will make it the correct size based on a count including . + std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); } + bool SawUnk() const { return saw_unk_; } void LoadedBinary(int fd, EnumerateVocab *to); diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index b35d80c8..9f47d559 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -107,9 +107,20 @@ void BitPackingSanity(); uint8_t RequiredBits(uint64_t max_value); struct BitsMask { + static BitsMask ByMax(uint64_t max_value) { + BitsMask ret; + ret.FromMax(max_value); + return ret; + } + static BitsMask ByBits(uint8_t bits) { + BitsMask ret; + ret.bits = bits; + ret.mask = (1ULL << bits) - 1; + return ret; + } void FromMax(uint64_t max_value) { bits = RequiredBits(max_value); - mask = (1 << bits) - 1; + mask = (1ULL << bits) - 1; } uint8_t bits; uint64_t mask; diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index d58a0727..fec47fd9 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -1,129 +1,129 @@ -/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All - * code is released to the public domain. For business purposes, Murmurhash is - * under the MIT license." - * This is modified from the original: - * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. - * length changed to unsigned int. - * placed in namespace util - * add MurmurHashNative - * default option = 0 for seed - */ - -#include "util/murmur_hash.hh" - -namespace util { - -//----------------------------------------------------------------------------- -// MurmurHash2, 64-bit versions, by Austin Appleby - -// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment -// and endian-ness issues if used across multiple platforms. - -// 64-bit hash for 64-bit platforms - -uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) -{ - const uint64_t m = 0xc6a4a7935bd1e995ULL; - const int r = 47; - - uint64_t h = seed ^ (len * m); - - const uint64_t * data = (const uint64_t *)key; - const uint64_t * end = data + (len/8); - - while(data != end) - { - uint64_t k = *data++; - - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - - const unsigned char * data2 = (const unsigned char*)data; - - switch(len & 7) - { - case 7: h ^= uint64_t(data2[6]) << 48; - case 6: h ^= uint64_t(data2[5]) << 40; - case 5: h ^= uint64_t(data2[4]) << 32; - case 4: h ^= uint64_t(data2[3]) << 24; - case 3: h ^= uint64_t(data2[2]) << 16; - case 2: h ^= uint64_t(data2[1]) << 8; - case 1: h ^= uint64_t(data2[0]); - h *= m; - }; - - h ^= h >> r; - h *= m; - h ^= h >> r; - - return h; -} - - -// 64-bit hash for 32-bit platforms - -uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) -{ - const unsigned int m = 0x5bd1e995; - const int r = 24; - - unsigned int h1 = seed ^ len; - unsigned int h2 = 0; - - const unsigned int * data = (const unsigned int *)key; - - while(len >= 8) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - - unsigned int k2 = *data++; - k2 *= m; k2 ^= k2 >> r; k2 *= m; - h2 *= m; h2 ^= k2; - len -= 4; - } - - if(len >= 4) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - } - - switch(len) - { - case 3: h2 ^= ((unsigned char*)data)[2] << 16; - case 2: h2 ^= ((unsigned char*)data)[1] << 8; - case 1: h2 ^= ((unsigned char*)data)[0]; - h2 *= m; - }; - - h1 ^= h2 >> 18; h1 *= m; - h2 ^= h1 >> 22; h2 *= m; - h1 ^= h2 >> 17; h1 *= m; - h2 ^= h1 >> 19; h2 *= m; - - uint64_t h = h1; - - h = (h << 32) | h2; - - return h; -} - -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { - if (sizeof(int) == 4) { - return MurmurHash64B(key, len, seed); - } else { - return MurmurHash64A(key, len, seed); - } -} - -} // namespace util +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All + * code is released to the public domain. For business purposes, Murmurhash is + * under the MIT license." + * This is modified from the original: + * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. + * length changed to unsigned int. + * placed in namespace util + * add MurmurHashNative + * default option = 0 for seed + */ + +#include "util/murmur_hash.hh" + +namespace util { + +//----------------------------------------------------------------------------- +// MurmurHash2, 64-bit versions, by Austin Appleby + +// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment +// and endian-ness issues if used across multiple platforms. + +// 64-bit hash for 64-bit platforms + +uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +{ + const uint64_t m = 0xc6a4a7935bd1e995ULL; + const int r = 47; + + uint64_t h = seed ^ (len * m); + + const uint64_t * data = (const uint64_t *)key; + const uint64_t * end = data + (len/8); + + while(data != end) + { + uint64_t k = *data++; + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + const unsigned char * data2 = (const unsigned char*)data; + + switch(len & 7) + { + case 7: h ^= uint64_t(data2[6]) << 48; + case 6: h ^= uint64_t(data2[5]) << 40; + case 5: h ^= uint64_t(data2[4]) << 32; + case 4: h ^= uint64_t(data2[3]) << 24; + case 3: h ^= uint64_t(data2[2]) << 16; + case 2: h ^= uint64_t(data2[1]) << 8; + case 1: h ^= uint64_t(data2[0]); + h *= m; + }; + + h ^= h >> r; + h *= m; + h ^= h >> r; + + return h; +} + + +// 64-bit hash for 32-bit platforms + +uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +{ + const unsigned int m = 0x5bd1e995; + const int r = 24; + + unsigned int h1 = seed ^ len; + unsigned int h2 = 0; + + const unsigned int * data = (const unsigned int *)key; + + while(len >= 8) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + + unsigned int k2 = *data++; + k2 *= m; k2 ^= k2 >> r; k2 *= m; + h2 *= m; h2 ^= k2; + len -= 4; + } + + if(len >= 4) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + } + + switch(len) + { + case 3: h2 ^= ((unsigned char*)data)[2] << 16; + case 2: h2 ^= ((unsigned char*)data)[1] << 8; + case 1: h2 ^= ((unsigned char*)data)[0]; + h2 *= m; + }; + + h1 ^= h2 >> 18; h1 *= m; + h2 ^= h1 >> 22; h2 *= m; + h1 ^= h2 >> 17; h1 *= m; + h2 ^= h1 >> 19; h2 *= m; + + uint64_t h = h1; + + h = (h << 32) | h2; + + return h; +} + +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { + if (sizeof(int) == 4) { + return MurmurHash64B(key, len, seed); + } else { + return MurmurHash64A(key, len, seed); + } +} + +} // namespace util diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 00be0ed7..2ec342a6 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -57,7 +57,7 @@ template class IdentityAccessor { public: typedef T Key; - T operator()(const uint64_t *in) const { return *in; } + T operator()(const T *in) const { return *in; } }; struct Pivot64 { @@ -101,6 +101,27 @@ template bool SortedUniformFind(co return BoundedSortedUniformFind(accessor, begin, below, end, above, key, out); } +// May return begin - 1. +template Iterator BinaryBelow( + const Accessor &accessor, + Iterator begin, + Iterator end, + const typename Accessor::Key key) { + while (end > begin) { + Iterator pivot(begin + (end - begin) / 2); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + begin = pivot + 1; + } else if (mid > key) { + end = pivot; + } else { + for (++pivot; (pivot < end) && accessor(pivot) == mid; ++pivot) {} + return pivot - 1; + } + } + return begin - 1; +} + // To use this template, you need to define a Pivot function to match Key. template class SortedUniformMap { public: -- cgit v1.2.3 From 5f0c8a675a8341c3b835c7597c4c92a838fa02ea Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 3 Sep 2011 17:14:18 +0100 Subject: fix sparse vector to work with boost serialization --- utils/fast_sparse_vector.h | 46 ++++++++++++++++++++++++++++++++++++++++++++++ utils/sparse_vector.h | 38 -------------------------------------- 2 files changed, 46 insertions(+), 38 deletions(-) diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h index 9d72cb87..b3f9588d 100644 --- a/utils/fast_sparse_vector.h +++ b/utils/fast_sparse_vector.h @@ -7,6 +7,8 @@ // important: indexes are integers // important: iterators may return elements in any order +#include "config.h" + #include #include #include @@ -16,6 +18,13 @@ #include +#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP +#include +#include +#endif + +#include "fdict.h" + // this is architecture dependent, it should be // detected in some way but it's probably easiest (for me) // to just set it @@ -334,8 +343,45 @@ class FastSparseVector { } data_; unsigned char local_size_; bool is_remote_; + +#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP + private: + friend class boost::serialization::access; + template + void save(Archive & ar, const unsigned int version) const { + (void) version; + int eff_size = size(); + const_iterator it = this->begin(); + if (eff_size > 0) { + // 0 index is reserved as empty + if (it->first == 0) { ++it; --eff_size; } + } + ar & eff_size; + while (it != this->end()) { + const std::pair wire_pair(FD::Convert(it->first), it->second); + ar & wire_pair; + ++it; + } + } + template + void load(Archive & ar, const unsigned int version) { + (void) version; + this->clear(); + int sz; ar & sz; + for (int i = 0; i < sz; ++i) { + std::pair wire_pair; + ar & wire_pair; + this->set_value(FD::Convert(wire_pair.first), wire_pair.second); + } + } + BOOST_SERIALIZATION_SPLIT_MEMBER() +#endif }; +#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP +BOOST_CLASS_TRACKING(FastSparseVector,track_never) +#endif + template const FastSparseVector operator+(const FastSparseVector& x, const FastSparseVector& y) { if (x.size() > y.size()) { diff --git a/utils/sparse_vector.h b/utils/sparse_vector.h index a55436fb..049151f7 100644 --- a/utils/sparse_vector.h +++ b/utils/sparse_vector.h @@ -1,44 +1,6 @@ #ifndef _SPARSE_VECTOR_H_ #define _SPARSE_VECTOR_H_ -#if 0 - -#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP - friend class boost::serialization::access; - template - void save(Archive & ar, const unsigned int version) const { - (void) version; - int eff_size = values_.size(); - const_iterator it = this->begin(); - if (values_.find(0) != values_.end()) { ++it; --eff_size; } - ar & eff_size; - while (it != this->end()) { - const std::pair wire_pair(FD::Convert(it->first), it->second); - ar & wire_pair; - ++it; - } - } - template - void load(Archive & ar, const unsigned int version) { - (void) version; - this->clear(); - int sz; ar & sz; - for (int i = 0; i < sz; ++i) { - std::pair wire_pair; - ar & wire_pair; - this->set_value(FD::Convert(wire_pair.first), wire_pair.second); - } - } - BOOST_SERIALIZATION_SPLIT_MEMBER() -#endif -}; - -#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP -BOOST_CLASS_TRACKING(SparseVector,track_never) -#endif - -#endif /// FIX - #include "fast_sparse_vector.h" #define SparseVector FastSparseVector -- cgit v1.2.3 From e73e7925a1b2fce06b1cdbe13e53fe6f10d56261 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 3 Sep 2011 17:27:12 +0100 Subject: fix header problem when serializing sparse vector with boost --- utils/fast_sparse_vector.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h index b3f9588d..1301581a 100644 --- a/utils/fast_sparse_vector.h +++ b/utils/fast_sparse_vector.h @@ -19,8 +19,7 @@ #include #if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP -#include -#include +#include #endif #include "fdict.h" -- cgit v1.2.3 From 51e8f4a5b9ffc96f3486ede77fe4511918156cf4 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 8 Sep 2011 13:24:11 +0200 Subject: fix viterbi to work with non prob_t types --- decoder/viterbi.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/decoder/viterbi.h b/decoder/viterbi.h index ac0b9a11..daee3d7a 100644 --- a/decoder/viterbi.h +++ b/decoder/viterbi.h @@ -25,7 +25,7 @@ typename WeightFunction::Weight Viterbi(const Hypergraph& hg, typedef typename WeightFunction::Weight WeightType; const int num_nodes = hg.nodes_.size(); std::vector vit_result(num_nodes); - std::vector vit_weight(num_nodes, WeightType::Zero()); + std::vector vit_weight(num_nodes, WeightType()); for (int i = 0; i < num_nodes; ++i) { const Hypergraph::Node& cur_node = hg.nodes_[i]; -- cgit v1.2.3 From 9f7a0765905e2906c43fbb5359d00ccdac38ca7f Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 9 Sep 2011 10:15:56 +0200 Subject: rule feature refactoring --- decoder/Makefile.am | 1 + decoder/cdec_ff.cc | 2 + decoder/ff_rules.cc | 107 ++++++++++++++++++++++++++++++++++++++++++++++++++++ decoder/ff_rules.h | 40 ++++++++++++++++++++ decoder/ff_spans.cc | 39 ------------------- decoder/ff_spans.h | 15 -------- 6 files changed, 150 insertions(+), 54 deletions(-) create mode 100644 decoder/ff_rules.cc create mode 100644 decoder/ff_rules.h diff --git a/decoder/Makefile.am b/decoder/Makefile.am index d884c431..e5f7505f 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -61,6 +61,7 @@ libcdec_a_SOURCES = \ phrasetable_fst.cc \ trule.cc \ ff.cc \ + ff_rules.cc \ ff_wordset.cc \ ff_charset.cc \ ff_lm.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 1ef76a05..588842f1 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -9,6 +9,7 @@ #include "ff_wordalign.h" #include "ff_tagger.h" #include "ff_factory.h" +#include "ff_rules.h" #include "ff_ruleshape.h" #include "ff_bleu.h" #include "ff_lm_fsa.h" @@ -53,6 +54,7 @@ void register_feature_functions() { #endif ff_registry.Register("SpanFeatures", new FFFactory()); ff_registry.Register("NgramFeatures", new FFFactory()); + ff_registry.Register("RuleIdentityFeatures", new FFFactory()); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); diff --git a/decoder/ff_rules.cc b/decoder/ff_rules.cc new file mode 100644 index 00000000..bd4c4cc0 --- /dev/null +++ b/decoder/ff_rules.cc @@ -0,0 +1,107 @@ +#include "ff_rules.h" + +#include +#include +#include + +#include "filelib.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "lattice.h" +#include "fdict.h" +#include "verbose.h" + +using namespace std; + +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + +RuleIdentityFeatures::RuleIdentityFeatures(const std::string& param) { +} + +void RuleIdentityFeatures::PrepareForInput(const SentenceMetadata& smeta) { +// std::map > + rule2_fid_.clear(); +} + +void RuleIdentityFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + map::iterator it = rule2_fid_.find(edge.rule_.get()); + if (it == rule2_fid_.end()) { + const TRule& rule = *edge.rule_; + ostringstream os; + os << "R:"; + if (rule.lhs_ < 0) os << TD::Convert(-rule.lhs_) << ':'; + for (unsigned i = 0; i < rule.f_.size(); ++i) { + if (i > 0) os << '_'; + WordID w = rule.f_[i]; + if (w < 0) { os << 'N'; w = -w; } + assert(w > 0); + os << TD::Convert(w); + } + os << ':'; + for (unsigned i = 0; i < rule.e_.size(); ++i) { + if (i > 0) os << '_'; + WordID w = rule.e_[i]; + if (w <= 0) { + os << 'N' << (1-w); + } else { + os << TD::Convert(w); + } + } + it = rule2_fid_.insert(make_pair(&rule, FD::Convert(Escape(os.str())))).first; + } + features->add_value(it->second, 1); +} + +RuleNgramFeatures::RuleNgramFeatures(const std::string& param) { +} + +void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) { +// std::map > + rule2_feats_.clear(); +} + +void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + map >::iterator it = rule2_feats_.find(edge.rule_.get()); + if (it == rule2_feats_.end()) { + const TRule& rule = *edge.rule_; + it = rule2_feats_.insert(make_pair(&rule, SparseVector())).first; + SparseVector& f = it->second; + string prev = ""; + for (int i = 0; i < rule.f_.size(); ++i) { + WordID w = rule.f_[i]; + if (w < 0) w = -w; + assert(w > 0); + const string& cur = TD::Convert(w); + ostringstream os; + os << "RB:" << prev << '_' << cur; + const int fid = FD::Convert(Escape(os.str())); + if (fid <= 0) return; + f.add_value(fid, 1.0); + prev = cur; + } + ostringstream os; + os << "RB:" << prev << '_' << ""; + f.set_value(FD::Convert(Escape(os.str())), 1.0); + } + (*features) += it->second; +} + diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h new file mode 100644 index 00000000..48d8bd05 --- /dev/null +++ b/decoder/ff_rules.h @@ -0,0 +1,40 @@ +#ifndef _FF_RULES_H_ +#define _FF_RULES_H_ + +#include +#include +#include "ff.h" +#include "array2d.h" +#include "wordid.h" + +class RuleIdentityFeatures : public FeatureFunction { + public: + RuleIdentityFeatures(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + mutable std::map rule2_fid_; +}; + +class RuleNgramFeatures : public FeatureFunction { + public: + RuleNgramFeatures(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + mutable std::map > rule2_feats_; +}; + +#endif diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc index bc23974d..0483517b 100644 --- a/decoder/ff_spans.cc +++ b/decoder/ff_spans.cc @@ -193,45 +193,6 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { } } -RuleNgramFeatures::RuleNgramFeatures(const std::string& param) { -} - -void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) { -// std::map > - rule2_feats_.clear(); -} - -void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const vector& ant_contexts, - SparseVector* features, - SparseVector* estimated_features, - void* context) const { - map >::iterator it = rule2_feats_.find(edge.rule_.get()); - if (it == rule2_feats_.end()) { - const TRule& rule = *edge.rule_; - it = rule2_feats_.insert(make_pair(&rule, SparseVector())).first; - SparseVector& f = it->second; - string prev = ""; - for (int i = 0; i < rule.f_.size(); ++i) { - WordID w = rule.f_[i]; - if (w < 0) w = -w; - assert(w > 0); - const string& cur = TD::Convert(w); - ostringstream os; - os << "RB:" << prev << '_' << cur; - const int fid = FD::Convert(Escape(os.str())); - if (fid <= 0) return; - f.add_value(fid, 1.0); - prev = cur; - } - ostringstream os; - os << "RB:" << prev << '_' << ""; - f.set_value(FD::Convert(Escape(os.str())), 1.0); - } - (*features) += it->second; -} - inline bool IsArity2RuleReordered(const TRule& rule) { const vector& e = rule.e_; for (int i = 0; i < e.size(); ++i) { diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h index b22c4d03..24e0dede 100644 --- a/decoder/ff_spans.h +++ b/decoder/ff_spans.h @@ -44,21 +44,6 @@ class SpanFeatures : public FeatureFunction { WordID oov_; }; -class RuleNgramFeatures : public FeatureFunction { - public: - RuleNgramFeatures(const std::string& param); - protected: - virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector& ant_contexts, - SparseVector* features, - SparseVector* estimated_features, - void* context) const; - virtual void PrepareForInput(const SentenceMetadata& smeta); - private: - mutable std::map > rule2_feats_; -}; - class CMR2008ReorderingFeatures : public FeatureFunction { public: CMR2008ReorderingFeatures(const std::string& param); -- cgit v1.2.3 From 700b2abf48bf0a455064d6cf08754cbfd4e3a383 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 12 Sep 2011 19:22:59 +0100 Subject: source syntax features ~ blunsom emnlp 2008 --- decoder/Makefile.am | 1 + decoder/cdec_ff.cc | 2 + decoder/ff_source_syntax.cc | 157 ++++++++++++++++++++++++++++++++++++++++++++ decoder/ff_source_syntax.h | 24 +++++++ utils/stringlib.cc | 7 +- 5 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 decoder/ff_source_syntax.cc create mode 100644 decoder/ff_source_syntax.h diff --git a/decoder/Makefile.am b/decoder/Makefile.am index e5f7505f..ede1cff0 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -72,6 +72,7 @@ libcdec_a_SOURCES = \ ff_wordalign.cc \ ff_csplit.cc \ ff_tagger.cc \ + ff_source_syntax.cc \ ff_bleu.cc \ ff_factory.cc \ freqdict.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 588842f1..d562bc3a 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -14,6 +14,7 @@ #include "ff_bleu.h" #include "ff_lm_fsa.h" #include "ff_sample_fsa.h" +#include "ff_source_syntax.h" #include "ff_register.h" #include "ff_charset.h" #include "ff_wordset.h" @@ -55,6 +56,7 @@ void register_feature_functions() { ff_registry.Register("SpanFeatures", new FFFactory()); ff_registry.Register("NgramFeatures", new FFFactory()); ff_registry.Register("RuleIdentityFeatures", new FFFactory()); + ff_registry.Register("SourceSyntaxFeatures", new FFFactory); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc new file mode 100644 index 00000000..99acbd87 --- /dev/null +++ b/decoder/ff_source_syntax.cc @@ -0,0 +1,157 @@ +#include "ff_source_syntax.h" + +#include +#include + +#include "sentence_metadata.h" +#include "array2d.h" +#include "filelib.h" + +using namespace std; + +// implements the source side syntax features described in Blunsom et al. (EMNLP 2008) +// source trees must be represented in Penn Treebank format, e.g. +// (S (NP John) (VP (V left))) + +struct SourceSyntaxFeaturesImpl { + SourceSyntaxFeaturesImpl() {} + + void InitializeGrids(const string& tree, unsigned src_len) { + assert(tree.size() > 0); + fids_cat.clear(); + fids_fonly.clear(); + fids_ef.clear(); + src_tree.clear(); + fids_cat.resize(src_len, src_len + 1); + fids_fonly.resize(src_len, src_len + 1); + fids_ef.resize(src_len, src_len + 1); + src_tree.resize(src_len, src_len + 1, TD::Convert("XX")); + ParseTreeString(tree, src_len); + } + + void ParseTreeString(const string& tree, unsigned src_len) { + stack > stk; // first = i, second = category + pair cur_cat; cur_cat.first = -1; + unsigned i = 0; + unsigned p = 0; + while(p < tree.size()) { + const char cur = tree[p]; + if (cur == '(') { + stk.push(cur_cat); + ++p; + unsigned k = p + 1; + while (k < tree.size() && tree[k] != ' ') { ++k; } + cur_cat.first = i; + cur_cat.second = TD::Convert(tree.substr(p, k - p)); + // cerr << "NT: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n"; + p = k + 1; + } else if (cur == ')') { + unsigned k = p; + while (k < tree.size() && tree[k] == ')') { ++k; } + const unsigned num_closes = k - p; + for (unsigned ci = 0; ci < num_closes; ++ci) { + // cur_cat.second spans from cur_cat.first to i + // cerr << TD::Convert(cur_cat.second) << " from " << cur_cat.first << " to " << i << endl; + // NOTE: unary rule chains end up being labeled with the top-most category + src_tree(cur_cat.first, i) = cur_cat.second; + cur_cat = stk.top(); + stk.pop(); + } + p = k; + while (p < tree.size() && (tree[p] == ' ' || tree[p] == '\t')) { ++p; } + } else if (cur == ' ' || cur == '\t') { + cerr << "Unexpected whitespace in: " << tree << endl; + abort(); + } else { // terminal symbol + unsigned k = p + 1; + do { + while (k < tree.size() && tree[k] != ')' && tree[k] != ' ') { ++k; } + // cerr << "TERM: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n"; + ++i; + assert(i <= src_len); + while (k < tree.size() && tree[k] == ' ') { ++k; } + p = k; + } while (p < tree.size() && tree[p] != ')'); + } + } + // cerr << "i=" << i << " src_len=" << src_len << endl; + assert(i == src_len); // make sure tree specified in src_tree is + // the same length as the source sentence + } + + WordID FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector* feats) { + //cerr << "fire features: " << rule.AsString() << " for " << i << "," << j << endl; + const WordID lhs = src_tree(i,j); + int& fid_cat = fids_cat(i,j); + int& fid_fonly = fids_fonly(i,j)[&rule]; + int& fid_ef = fids_ef(i,j)[&rule]; + if (fid_ef <= 0) { + ostringstream os; + os << "SYN:" << TD::Convert(lhs); + fid_cat = FD::Convert(os.str()); + os << ':'; + unsigned ntc = 0; + for (unsigned k = 0; k < rule.f_.size(); ++k) { + if (k > 0) os << '_'; + int fj = rule.f_[k]; + if (fj <= 0) { + os << '[' << TD::Convert(ants[ntc++]) << ']'; + } else { + os << TD::Convert(fj); + } + } + fid_fonly = FD::Convert(os.str()); + os << ':'; + for (unsigned k = 0; k < rule.e_.size(); ++k) { + const int ei = rule.e_[k]; + if (k > 0) os << '_'; + if (ei <= 0) + os << '[' << (1-ei) << ']'; + else + os << TD::Convert(ei); + } + fid_ef = FD::Convert(os.str()); + } + if (fid_cat > 0) + feats->set_value(fid_cat, 1.0); + if (fid_fonly > 0) + feats->set_value(fid_fonly, 1.0); + if (fid_ef > 0) + feats->set_value(fid_ef, 1.0); + return lhs; + } + + Array2D src_tree; // src_tree(i,j) NT = type + mutable Array2D fids_cat; // fires for an LHS match + mutable Array2D > fids_fonly; // fires for an f-string + mutable Array2D > fids_ef; // fires for fully lexicalized +}; + +SourceSyntaxFeatures::SourceSyntaxFeatures(const string& param) : + FeatureFunction(sizeof(WordID)) { + impl = new SourceSyntaxFeaturesImpl; +} + +SourceSyntaxFeatures::~SourceSyntaxFeatures() { + delete impl; + impl = NULL; +} + +void SourceSyntaxFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + WordID ants[8]; + for (unsigned i = 0; i < ant_contexts.size(); ++i) + ants[i] = *static_cast(ant_contexts[i]); + + *static_cast(context) = + impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features); +} + +void SourceSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) { + impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength()); +} + diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h new file mode 100644 index 00000000..1e890736 --- /dev/null +++ b/decoder/ff_source_syntax.h @@ -0,0 +1,24 @@ +#ifndef _FF_SOURCE_TOOLS_H_ +#define _FF_SOURCE_TOOLS_H_ + +#include "ff.h" + +struct SourceSyntaxFeaturesImpl; + +class SourceSyntaxFeatures : public FeatureFunction { + public: + SourceSyntaxFeatures(const std::string& param); + ~SourceSyntaxFeatures(); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + SourceSyntaxFeaturesImpl* impl; +}; + +#endif diff --git a/utils/stringlib.cc b/utils/stringlib.cc index 7aaee9f0..ade02ca9 100644 --- a/utils/stringlib.cc +++ b/utils/stringlib.cc @@ -32,7 +32,12 @@ void ParseTranslatorInput(const string& line, string* input, string* ref) { void ProcessAndStripSGML(string* pline, map* out) { map& meta = *out; string& line = *pline; - string lline = LowercaseString(line); + string lline = *pline; + if (lline.find(" must be lowercase!\n"; + cerr << " " << *pline << endl; + abort(); + } if (lline.find(""); if (close == string::npos) return; // error -- cgit v1.2.3 From af28b860c3f5d5b7c58feb16620853512c8454ad Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 12 Sep 2011 23:11:17 +0100 Subject: add configuration option for perfect hashing library --- configure.ac | 53 +++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/configure.ac b/configure.ac index 4e708073..6fa7b914 100644 --- a/configure.ac +++ b/configure.ac @@ -11,6 +11,45 @@ AC_PROG_CXX AC_LANG_CPLUSPLUS BOOST_REQUIRE BOOST_PROGRAM_OPTIONS +AC_ARG_ENABLE(mpi, + [ --enable-mpi Build MPI binaries, assumes mpi.h is present ], + [ mpi=yes + ]) +AM_CONDITIONAL([MPI], [test "x$mpi" = xyes]) + +if test "x$mpi" = xyes +then + BOOST_SERIALIZATION + AC_DEFINE([HAVE_MPI], [1], [flag for MPI]) + # TODO BOOST_MPI needs to be implemented + LIBS="$LIBS -lboost_mpi $BOOST_SERIALIZATION_LIBS" +fi + +AM_CONDITIONAL([CMPH], false) +AC_ARG_WITH(cmph, + [AC_HELP_STRING([--with-cmph=PATH], [(optional) path to cmph perfect hashing library])], + [with_cmph=$withval], + [with_cmph=no] + ) + +if test "x$with_cmph" != 'xno' +then + SAVE_CPPFLAGS="$CPPFLAGS" + CPPFLAGS="$CPPFLAGS -I${with_cmph}/include" + + AC_CHECK_HEADER(cmph.h, + [AC_DEFINE([HAVE_CMPH], [], [flag for cmph perfect hashing library])], + [AC_MSG_ERROR([Cannot find cmph library!])]) + + LDFLAGS="$LDFLAGS -L${with_cmph}/lib" + AC_CHECK_LIB(cmph, cmph_search) + + #LIB_CMPH="-lcmph" + #LIBS="$LIBS $LIB_CMPH" + #FMTLIBS="$FMTLIBS libcmph.a" + AM_CONDITIONAL([CMPH], true) +fi + #BOOST_THREADS CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS" @@ -27,20 +66,6 @@ AC_CHECK_HEADER(google/dense_hash_map, AC_PROG_INSTALL GTEST_LIB_CHECK -AC_ARG_ENABLE(mpi, - [ --enable-mpi Build MPI binaries, assumes mpi.h is present ], - [ mpi=yes - ]) -AM_CONDITIONAL([MPI], [test "x$mpi" = xyes]) - -if test "x$mpi" = xyes -then - BOOST_SERIALIZATION - AC_DEFINE([HAVE_MPI], [1], [flag for MPI]) - # TODO BOOST_MPI needs to be implemented - LIBS="$LIBS -lboost_mpi $BOOST_SERIALIZATION_LIBS -lmpi++ -lmpi" -fi - AM_CONDITIONAL([RAND_LM], false) AC_ARG_WITH(randlm, [AC_HELP_STRING([--with-randlm=PATH], [(optional) path to RandLM toolkit])], -- cgit v1.2.3 From b09ca8a5e6f5e8c1840e51a93c9f8e6b8c4bcc33 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 09:45:01 +0100 Subject: add one more source syntax feature --- decoder/ff_source_syntax.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index 99acbd87..5b7c16f6 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -13,6 +13,13 @@ using namespace std; // source trees must be represented in Penn Treebank format, e.g. // (S (NP John) (VP (V left))) +// log transform to make long spans cluster together +// but preserve differences +inline int SpanSizeTransform(unsigned span_size) { + if (!span_size) return 0; + return static_cast(log(span_size+1) / log(1.39)) - 1; +} + struct SourceSyntaxFeaturesImpl { SourceSyntaxFeaturesImpl() {} @@ -87,8 +94,10 @@ struct SourceSyntaxFeaturesImpl { int& fid_ef = fids_ef(i,j)[&rule]; if (fid_ef <= 0) { ostringstream os; + ostringstream os2; os << "SYN:" << TD::Convert(lhs); - fid_cat = FD::Convert(os.str()); + os2 << "SYN:" << TD::Convert(lhs) << '_' << SpanSizeTransform(j - i); + fid_cat = FD::Convert(os2.str()); os << ':'; unsigned ntc = 0; for (unsigned k = 0; k < rule.f_.size(); ++k) { -- cgit v1.2.3 From 38a5bee71f6b49515cd105a9467ff602ff9dee64 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 13:25:46 +0100 Subject: optional support for doing perfect hashing of feature strings to save lots of memory --- decoder/decoder.cc | 22 ++++++++- utils/Makefile.am | 9 +++- utils/fdict.cc | 4 ++ utils/fdict.h | 36 ++++++++++++++ utils/perfect_hash.cc | 37 ++++++++++++++ utils/perfect_hash.h | 24 +++++++++ utils/phmt.cc | 44 +++++++++++++++++ utils/weights.cc | 132 ++++++++++++++++++++++++++++++++++---------------- utils/weights.h | 14 +++--- 9 files changed, 269 insertions(+), 53 deletions(-) create mode 100644 utils/perfect_hash.cc create mode 100644 utils/perfect_hash.h create mode 100644 utils/phmt.cc diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 76f31352..25eb2de4 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -328,6 +328,7 @@ struct DecoderImpl { bool write_gradient; // TODO Observer bool feature_expectations; // TODO Observer bool output_training_vector; // TODO Observer + bool remove_intersected_rule_annotations; static void ConvertSV(const SparseVector& src, SparseVector* trg) { for (SparseVector::const_iterator it = src.begin(); it != src.end(); ++it) @@ -361,6 +362,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("grammar,g",po::value >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") ("per_sentence_grammar_file", po::value(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset") ("list_feature_functions,L","List available feature functions") +#ifdef HAVE_CMPH + ("cmph_perfect_feature_hash,h", po::value(), "Load perfect hash function for features") +#endif ("weights,w",po::value(),"Feature weights file (initial forest / pass 1)") ("feature_function,F",po::value >()->composing(), "Pass 1 additional feature function(s) (-L for list)") @@ -433,7 +437,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") ("vector_format",po::value()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") - ("forest_output,O",po::value(),"Directory to write forests to"); + ("forest_output,O",po::value(),"Directory to write forests to") + ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)"); // ob.AddOptions(&opts); #ifdef FSA_RESCORING @@ -443,7 +448,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream po::options_description clo("Command line options"); clo.add_options() ("config,c", po::value >(&cfg_files), "Configuration file(s) - latest has priority") - ("help,h", "Print this help message and exit") + ("help,?", "Print this help message and exit") ("usage,u", po::value(), "Describe a feature function type") ("compgen", "Print just option names suitable for bash command line completion builtin 'compgen'") ; @@ -645,6 +650,12 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream FD::Freeze(); // this means we can't see the feature names of not-weighted features } + if (conf.count("cmph_perfect_feature_hash")) { + cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as() << " ...\n"; + FD::EnableHash(conf["cmph_perfect_feature_hash"].as()); + cerr << " " << FD::NumFeats() << " features in map\n"; + } + // set up translation back end if (formalism == "scfg") translator.reset(new SCFGTranslator(conf)); @@ -695,6 +706,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream unique_kbest = conf.count("unique_k_best"); get_oracle_forest = conf.count("get_oracle_forest"); oracle.show_derivation=conf.count("show_derivations"); + remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations"); #ifdef FSA_RESCORING cfg_options.Validate(); @@ -1010,6 +1022,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { // if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n"; // for (int i = 0; i < forest.edges_.size(); ++i) // forest.edges_[i].edge_prob_=prob_t::One(); } + if (remove_intersected_rule_annotations) { + for (unsigned i = 0; i < forest.edges_.size(); ++i) + if (forest.edges_[i].rule_ && + forest.edges_[i].rule_->parent_rule_) + forest.edges_[i].rule_ = forest.edges_[i].rule_->parent_rule_; + } forest.Reweight(last_weights); if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,oracle.show_derivation); if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; diff --git a/utils/Makefile.am b/utils/Makefile.am index 94f9be30..c50747bf 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -1,5 +1,5 @@ -noinst_PROGRAMS = ts -TESTS = ts +noinst_PROGRAMS = ts phmt +TESTS = ts phmt if HAVE_GTEST noinst_PROGRAMS += \ @@ -27,6 +27,11 @@ libutils_a_SOURCES = \ verbose.cc \ weights.cc +if HAVE_CMPH + libutils_a_SOURCES += perfect_hash.cc +endif + +phmt_SOURCES = phmt.cc ts_SOURCES = ts.cc dict_test_SOURCES = dict_test.cc dict_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) diff --git a/utils/fdict.cc b/utils/fdict.cc index baa0b552..676c951c 100644 --- a/utils/fdict.cc +++ b/utils/fdict.cc @@ -9,6 +9,10 @@ using namespace std; Dict FD::dict_; bool FD::frozen_ = false; +#ifdef HAVE_CMPH +PerfectHashFunction* FD::hash_ = NULL; +#endif + std::string FD::Convert(std::vector const& v) { return Convert(&*v.begin(),&*v.end()); } diff --git a/utils/fdict.h b/utils/fdict.h index f9673023..771e8b91 100644 --- a/utils/fdict.h +++ b/utils/fdict.h @@ -1,23 +1,56 @@ #ifndef _FDICT_H_ #define _FDICT_H_ +#include "config.h" + +#include #include #include #include "dict.h" +#ifdef HAVE_CMPH +#include "perfect_hash.h" +#include "string_to.h" +#endif + struct FD { // once the FD is frozen, new features not already in the // dictionary will return 0 static void Freeze() { frozen_ = true; } + static bool UsingPerfectHashFunction() { +#ifdef HAVE_CMPH + return hash_; +#else + return false; +#endif + } + static void EnableHash(const std::string& cmph_file) { +#ifdef HAVE_CMPH + hash_ = new PerfectHashFunction(cmph_file); +#endif + } static inline int NumFeats() { +#ifdef HAVE_CMPH + if (hash_) return hash_->number_of_keys(); +#endif return dict_.max() + 1; } static inline WordID Convert(const std::string& s) { +#ifdef HAVE_CMPH + if (hash_) return (*hash_)(s); +#endif return dict_.Convert(s, frozen_); } static inline const std::string& Convert(const WordID& w) { +#ifdef HAVE_CMPH + if (hash_) { + static std::string tls; + tls = to_string(w); + return tls; + } +#endif return dict_.Convert(w); } static std::string Convert(WordID const *i,WordID const* e); @@ -29,6 +62,9 @@ struct FD { static Dict dict_; private: static bool frozen_; +#ifdef HAVE_CMPH + static PerfectHashFunction* hash_; +#endif }; #endif diff --git a/utils/perfect_hash.cc b/utils/perfect_hash.cc new file mode 100644 index 00000000..706e2741 --- /dev/null +++ b/utils/perfect_hash.cc @@ -0,0 +1,37 @@ +#include "config.h" + +#ifdef HAVE_CMPH + +#include "perfect_hash.h" + +#include +#include + +using namespace std; + +PerfectHashFunction::~PerfectHashFunction() { + cmph_destroy(mphf_); +} + +PerfectHashFunction::PerfectHashFunction(const string& fname) { + FILE* f = fopen(fname.c_str(), "r"); + if (!f) { + cerr << "Failed to open file " << fname << " for reading: cannot load hash function.\n"; + abort(); + } + mphf_ = cmph_load(f); + if (!mphf_) { + cerr << "cmph_load failed on " << fname << "!\n"; + abort(); + } +} + +size_t PerfectHashFunction::operator()(const string& key) const { + return cmph_search(mphf_, &key[0], key.size()); +} + +size_t PerfectHashFunction::number_of_keys() const { + return cmph_size(mphf_); +} + +#endif diff --git a/utils/perfect_hash.h b/utils/perfect_hash.h new file mode 100644 index 00000000..8ac11f18 --- /dev/null +++ b/utils/perfect_hash.h @@ -0,0 +1,24 @@ +#ifndef _PERFECT_HASH_MAP_H_ +#define _PERFECT_HASH_MAP_H_ + +#include "config.h" + +#ifndef HAVE_CMPH +#error libcmph is required to use PerfectHashFunction +#endif + +#include +#include +#include "cmph.h" + +class PerfectHashFunction : boost::noncopyable { + public: + explicit PerfectHashFunction(const std::string& fname); + ~PerfectHashFunction(); + size_t operator()(const std::string& key) const; + size_t number_of_keys() const; + private: + cmph_t *mphf_; +}; + +#endif diff --git a/utils/phmt.cc b/utils/phmt.cc new file mode 100644 index 00000000..1f59afaf --- /dev/null +++ b/utils/phmt.cc @@ -0,0 +1,44 @@ +#include "config.h" + +#ifndef HAVE_CMPH +int main() { + return 0; +} +#else + +#include +#include "weights.h" +#include "fdict.h" + +using namespace std; + +int main(int argc, char** argv) { + if (argc != 2) { cerr << "Usage: " << argv[0] << " file.mphf\n"; return 1; } + FD::EnableHash(argv[1]); + cerr << "Number of keys: " << FD::NumFeats() << endl; + cerr << "LexFE = " << FD::Convert("LexFE") << endl; + cerr << "LexEF = " << FD::Convert("LexEF") << endl; + { + Weights w; + vector v(FD::NumFeats()); + v[FD::Convert("LexFE")] = 1.0; + v[FD::Convert("LexEF")] = 0.5; + w.InitFromVector(v); + cerr << "Writing...\n"; + w.WriteToFile("weights.bin"); + cerr << "Done.\n"; + } + { + Weights w; + vector v(FD::NumFeats()); + cerr << "Reading...\n"; + w.InitFromFile("weights.bin"); + cerr << "Done.\n"; + w.InitVector(&v); + assert(v[FD::Convert("LexFE")] == 1.0); + assert(v[FD::Convert("LexEF")] == 0.5); + } +} + +#endif + diff --git a/utils/weights.cc b/utils/weights.cc index b994a2fe..0916b72a 100644 --- a/utils/weights.cc +++ b/utils/weights.cc @@ -13,40 +13,75 @@ void Weights::InitFromFile(const std::string& filename, vector* feature_ ReadFile in_file(filename); istream& in = *in_file.stream(); assert(in); - int weight_count = 0; - bool fl = false; - string buf; - double val = 0; - while (in) { - getline(in, buf); - if (buf.size() == 0) continue; - if (buf[0] == '#') continue; - for (int i = 0; i < buf.size(); ++i) - if (buf[i] == '=') buf[i] = ' '; - int start = 0; - while(start < buf.size() && buf[start] == ' ') ++start; - int end = 0; - while(end < buf.size() && buf[end] != ' ') ++end; - const int fid = FD::Convert(buf.substr(start, end - start)); - while(end < buf.size() && buf[end] == ' ') ++end; - val = strtod(&buf.c_str()[end], NULL); - if (isnan(val)) { - cerr << FD::Convert(fid) << " has weight NaN!\n"; - abort(); + + bool read_text = true; + if (1) { + ReadFile hdrrf(filename); + istream& hi = *hdrrf.stream(); + assert(hi); + char buf[10]; + hi.get(buf, 6); + assert(hi.good()); + if (strncmp(buf, "_PHWf", 5) == 0) { + read_text = false; + } + } + + if (read_text) { + int weight_count = 0; + bool fl = false; + string buf; + weight_t val = 0; + while (in) { + getline(in, buf); + if (buf.size() == 0) continue; + if (buf[0] == '#') continue; + if (buf[0] == ' ') { + cerr << "Weights file lines may not start with whitespace.\n" << buf << endl; + abort(); + } + for (int i = buf.size() - 1; i > 0; --i) + if (buf[i] == '=' || buf[i] == '\t') { buf[i] = ' '; break; } + int start = 0; + while(start < buf.size() && buf[start] == ' ') ++start; + int end = 0; + while(end < buf.size() && buf[end] != ' ') ++end; + const int fid = FD::Convert(buf.substr(start, end - start)); + while(end < buf.size() && buf[end] == ' ') ++end; + val = strtod(&buf.c_str()[end], NULL); + if (isnan(val)) { + cerr << FD::Convert(fid) << " has weight NaN!\n"; + abort(); + } + if (wv_.size() <= fid) + wv_.resize(fid + 1); + wv_[fid] = val; + if (feature_list) { feature_list->push_back(FD::Convert(fid)); } + ++weight_count; + if (!SILENT) { + if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; } + if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; } + } } - if (wv_.size() <= fid) - wv_.resize(fid + 1); - wv_[fid] = val; - if (feature_list) { feature_list->push_back(FD::Convert(fid)); } - ++weight_count; if (!SILENT) { - if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; } - if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; } + if (fl) { cerr << endl; } + cerr << "Loaded " << weight_count << " feature weights\n"; + } + } else { // !read_text + char buf[6]; + in.get(buf, 6); + size_t num_keys[2]; + in.get(reinterpret_cast(&num_keys[0]), sizeof(size_t) + 1); + if (num_keys[0] != FD::NumFeats()) { + cerr << "Hash function reports " << FD::NumFeats() << " keys but weights file contains " << num_keys[0] << endl; + abort(); + } + wv_.resize(num_keys[0]); + in.get(reinterpret_cast(&wv_[0]), num_keys[0] * sizeof(weight_t)); + if (!in.good()) { + cerr << "Error loading weights!\n"; + abort(); } - } - if (!SILENT) { - if (fl) { cerr << endl; } - cerr << "Loaded " << weight_count << " feature weights\n"; } } @@ -54,37 +89,48 @@ void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_feature WriteFile out(fname); ostream& o = *out.stream(); assert(o); - if (extra) { o << "# " << *extra << endl; } - o.precision(17); - const int num_feats = FD::NumFeats(); - for (int i = 1; i < num_feats; ++i) { - const double val = (i < wv_.size() ? wv_[i] : 0.0); - if (hide_zero_value_features && val == 0.0) continue; - o << FD::Convert(i) << ' ' << val << endl; + bool write_text = !FD::UsingPerfectHashFunction(); + + if (write_text) { + if (extra) { o << "# " << *extra << endl; } + o.precision(17); + const int num_feats = FD::NumFeats(); + for (int i = 1; i < num_feats; ++i) { + const weight_t val = (i < wv_.size() ? wv_[i] : 0.0); + if (hide_zero_value_features && val == 0.0) continue; + o << FD::Convert(i) << ' ' << val << endl; + } + } else { + o.write("_PHWf", 5); + const size_t keys = FD::NumFeats(); + assert(keys <= wv_.size()); + o.write(reinterpret_cast(&keys), sizeof(keys)); + o.write(reinterpret_cast(&wv_[0]), keys * sizeof(weight_t)); } } -void Weights::InitVector(std::vector* w) const { +void Weights::InitVector(std::vector* w) const { *w = wv_; } -void Weights::InitSparseVector(SparseVector* w) const { +void Weights::InitSparseVector(SparseVector* w) const { for (int i = 1; i < wv_.size(); ++i) { - const double& weight = wv_[i]; + const weight_t& weight = wv_[i]; if (weight) w->set_value(i, weight); } } -void Weights::InitFromVector(const std::vector& w) { +void Weights::InitFromVector(const std::vector& w) { wv_ = w; if (wv_.size() > FD::NumFeats()) cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n"; wv_.resize(FD::NumFeats(), 0); } -void Weights::InitFromVector(const SparseVector& w) { +void Weights::InitFromVector(const SparseVector& w) { wv_.clear(); wv_.resize(FD::NumFeats(), 0.0); for (int i = 1; i < FD::NumFeats(); ++i) wv_[i] = w.value(i); } + diff --git a/utils/weights.h b/utils/weights.h index cc20283c..7664810b 100644 --- a/utils/weights.h +++ b/utils/weights.h @@ -2,21 +2,23 @@ #define _WEIGHTS_H_ #include -#include #include #include "sparse_vector.h" +// warning: in the future this will become float +typedef double weight_t; + class Weights { public: Weights() {} void InitFromFile(const std::string& fname, std::vector* feature_list = NULL); void WriteToFile(const std::string& fname, bool hide_zero_value_features = true, const std::string* extra = NULL) const; - void InitVector(std::vector* w) const; - void InitSparseVector(SparseVector* w) const; - void InitFromVector(const std::vector& w); - void InitFromVector(const SparseVector& w); + void InitVector(std::vector* w) const; + void InitSparseVector(SparseVector* w) const; + void InitFromVector(const std::vector& w); + void InitFromVector(const SparseVector& w); private: - std::vector wv_; + std::vector wv_; }; #endif -- cgit v1.2.3 From 75bff8e374f3cdcf3dc141f8b7b37858d0611234 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 15:02:36 +0100 Subject: cmph configuration option fix --- configure.ac | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/configure.ac b/configure.ac index 6fa7b914..8e06bffd 100644 --- a/configure.ac +++ b/configure.ac @@ -25,7 +25,7 @@ then LIBS="$LIBS -lboost_mpi $BOOST_SERIALIZATION_LIBS" fi -AM_CONDITIONAL([CMPH], false) +AM_CONDITIONAL([HAVE_CMPH], false) AC_ARG_WITH(cmph, [AC_HELP_STRING([--with-cmph=PATH], [(optional) path to cmph perfect hashing library])], [with_cmph=$withval], @@ -43,11 +43,7 @@ then LDFLAGS="$LDFLAGS -L${with_cmph}/lib" AC_CHECK_LIB(cmph, cmph_search) - - #LIB_CMPH="-lcmph" - #LIBS="$LIBS $LIB_CMPH" - #FMTLIBS="$FMTLIBS libcmph.a" - AM_CONDITIONAL([CMPH], true) + AM_CONDITIONAL([HAVE_CMPH], true) fi #BOOST_THREADS -- cgit v1.2.3 From 251da4347ea356f799e6c227ac8cf541c0cef2f2 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 17:36:23 +0100 Subject: get rid of bad Weights class so it no longer keeps a copy of a vector inside it --- decoder/decoder.cc | 64 ++++++++--------- decoder/decoder.h | 9 ++- mira/kbest_mira.cc | 62 ++++------------- pro-train/mr_pro_map.cc | 8 +-- pro-train/mr_pro_reduce.cc | 16 ++--- training/Makefile.am | 8 --- training/augment_grammar.cc | 4 +- training/collapse_weights.cc | 6 +- training/compute_cllh.cc | 23 +++--- training/grammar_convert.cc | 8 +-- training/mpi_batch_optimize.cc | 127 ++++++++-------------------------- training/mpi_online_optimize.cc | 69 +++++++----------- training/mr_optimize_reduce.cc | 19 ++--- utils/fdict.h | 2 + utils/phmt.cc | 8 +-- utils/weights.cc | 75 ++++++++++++-------- utils/weights.h | 22 +++--- vest/mr_vest_generate_mapper_input.cc | 6 +- 18 files changed, 201 insertions(+), 335 deletions(-) diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 25eb2de4..4d4b6245 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -159,8 +159,7 @@ struct RescoringPass { shared_ptr models; shared_ptr inter_conf; vector ffs; - shared_ptr w; // null == use previous weights - vector weight_vector; + shared_ptr > weight_vector; int fid_summary; // 0 == no summary feature double density_prune; // 0 == don't density prune double beam_prune; // 0 == don't beam prune @@ -169,7 +168,7 @@ struct RescoringPass { ostream& operator<<(ostream& os, const RescoringPass& rp) { os << "[num_fn=" << rp.ffs.size(); if (rp.inter_conf) { os << " int_alg=" << *rp.inter_conf; } - if (rp.w) os << " new_weights"; + //if (rp.weight_vector.size() > 0) os << " new_weights"; if (rp.fid_summary) os << " summary_feature=" << FD::Convert(rp.fid_summary); if (rp.density_prune) os << " density_prune=" << rp.density_prune; if (rp.beam_prune) os << " beam_prune=" << rp.beam_prune; @@ -181,13 +180,8 @@ struct DecoderImpl { DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg); ~DecoderImpl(); bool Decode(const string& input, DecoderObserver*); - void SetWeights(const vector& weights) { - init_weights = weights; - for (int i = 0; i < rescoring_passes.size(); ++i) { - if (rescoring_passes[i].models) - rescoring_passes[i].models->SetWeights(weights); - rescoring_passes[i].weight_vector = weights; - } + vector& CurrentWeightVector() { + return *rescoring_passes.back().weight_vector; } void SetId(int next_sent_id) { sent_id = next_sent_id - 1; } @@ -300,8 +294,7 @@ struct DecoderImpl { OracleBleu oracle; string formalism; shared_ptr translator; - Weights w_init_weights; // used with initial parse - vector init_weights; // weights used with initial parse + shared_ptr > init_weights; // weights used with initial parse vector > pffs; #ifdef FSA_RESCORING CFGOptions cfg_options; @@ -557,13 +550,18 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream exit(1); } - // load initial feature weights (and possibly freeze feature set) - if (conf.count("weights")) { - w_init_weights.InitFromFile(str("weights",conf)); - w_init_weights.InitVector(&init_weights); - init_weights.resize(FD::NumFeats()); + // load perfect hash function for features + if (conf.count("cmph_perfect_feature_hash")) { + cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as() << " ...\n"; + FD::EnableHash(conf["cmph_perfect_feature_hash"].as()); + cerr << " " << FD::NumFeats() << " features in map\n"; } + // load initial feature weights (and possibly freeze feature set) + init_weights.reset(new vector); + if (conf.count("weights")) + Weights::InitFromFile(str("weights",conf), init_weights.get()); + // cube pruning pop-limit: we may want to configure this on a per-pass basis pop_limit = conf["cubepruning_pop_limit"].as(); @@ -582,9 +580,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream RescoringPass& rp = rescoring_passes.back(); // only configure new weights if pass > 0, otherwise we reuse the initial chart weights if (nth_pass_condition && conf.count(ws)) { - rp.w.reset(new Weights); - rp.w->InitFromFile(str(ws.c_str(), conf)); - rp.w->InitVector(&rp.weight_vector); + rp.weight_vector.reset(new vector()); + Weights::InitFromFile(str(ws.c_str(), conf), rp.weight_vector.get()); } bool has_stateful = false; if (conf.count(ff)) { @@ -624,11 +621,15 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } // set up weight vectors since later phases may reuse weights from earlier phases - const vector* prev = &init_weights; + shared_ptr > prev_weights = init_weights; for (int pass = 0; pass < rescoring_passes.size(); ++pass) { RescoringPass& rp = rescoring_passes[pass]; - if (!rp.w) { rp.weight_vector = *prev; } else { prev = &rp.weight_vector; } - rp.models.reset(new ModelSet(rp.weight_vector, rp.ffs)); + if (!rp.weight_vector) { + rp.weight_vector = prev_weights; + } else { + prev_weights = rp.weight_vector; + } + rp.models.reset(new ModelSet(*rp.weight_vector, rp.ffs)); string ps = "Pass1 "; ps[4] += pass; if (!SILENT) show_models(conf,*rp.models,ps.c_str()); } @@ -650,12 +651,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream FD::Freeze(); // this means we can't see the feature names of not-weighted features } - if (conf.count("cmph_perfect_feature_hash")) { - cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as() << " ...\n"; - FD::EnableHash(conf["cmph_perfect_feature_hash"].as()); - cerr << " " << FD::NumFeats() << " features in map\n"; - } - // set up translation back end if (formalism == "scfg") translator.reset(new SCFGTranslator(conf)); @@ -685,7 +680,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } if (!fsa_ffs.empty()) { cerr<<"FSA: "; - show_all_features(fsa_ffs,init_weights,cerr,cerr,true,true); + show_all_features(fsa_ffs,*init_weights,cerr,cerr,true,true); } #endif @@ -733,7 +728,8 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) { if (del) delete o; return res; } -void Decoder::SetWeights(const vector& weights) { pimpl_->SetWeights(weights); } +vector& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); } +const vector& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); } void Decoder::SetSupplementalGrammar(const std::string& grammar_string) { assert(pimpl_->translator->GetDecoderType() == "SCFG"); static_cast(*pimpl_->translator).SetSupplementalGrammar(grammar_string); @@ -774,7 +770,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { translator->ProcessMarkupHints(smeta.sgml_); Timer t("Translation"); const bool translation_successful = - translator->Translate(to_translate, &smeta, init_weights, &forest); + translator->Translate(to_translate, &smeta, *init_weights, &forest); translator->SentenceComplete(); if (!translation_successful) { @@ -812,7 +808,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { for (int pass = 0; pass < rescoring_passes.size(); ++pass) { const RescoringPass& rp = rescoring_passes[pass]; - const vector& cur_weights = rp.weight_vector; + const vector& cur_weights = *rp.weight_vector; if (!SILENT) cerr << endl << " RESCORING PASS #" << (pass+1) << " " << rp << endl; #ifdef FSA_RESCORING cfg_options.maybe_output_source(forest); @@ -933,7 +929,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { #endif } - const vector& last_weights = (rescoring_passes.empty() ? init_weights : rescoring_passes.back().weight_vector); + const vector& last_weights = (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector); // Oracle Rescoring if(get_oracle_forest) { diff --git a/decoder/decoder.h b/decoder/decoder.h index 5491369f..9d009ffa 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -7,6 +7,8 @@ #include #include +#include "weights.h" // weight_t + #undef CP_TIME //#define CP_TIME #ifdef CP_TIME @@ -39,7 +41,12 @@ struct Decoder { Decoder(int argc, char** argv); Decoder(std::istream* config_file); bool Decode(const std::string& input, DecoderObserver* observer = NULL); - void SetWeights(const std::vector& weights); + + // access this to either *read* or *write* to the decoder's last + // weight vector (i.e., the weights of the finest past) + std::vector& CurrentWeightVector(); + const std::vector& CurrentWeightVector() const; + void SetId(int id); ~Decoder(); const boost::program_options::variables_map& GetConf() const { return conf; } diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc index 6918a9a1..459a5e6f 100644 --- a/mira/kbest_mira.cc +++ b/mira/kbest_mira.cc @@ -32,21 +32,6 @@ namespace po = boost::program_options; bool invert_score; boost::shared_ptr rng; -void SanityCheck(const vector& w) { - for (int i = 0; i < w.size(); ++i) { - assert(!isnan(w[i])); - assert(!isinf(w[i])); - } -} - -struct FComp { - const vector& w_; - FComp(const vector& w) : w_(w) {} - bool operator()(int a, int b) const { - return fabs(w_[a]) > fabs(w_[b]); - } -}; - void RandomPermutation(int len, vector* p_ids) { vector& ids = *p_ids; ids.resize(len); @@ -58,21 +43,6 @@ void RandomPermutation(int len, vector* p_ids) { } } -void ShowLargestFeatures(const vector& w) { - vector fnums(w.size()); - for (int i = 0; i < w.size(); ++i) - fnums[i] = i; - vector::iterator mid = fnums.begin(); - mid += (w.size() > 10 ? 10 : w.size()); - partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); - cerr << "TOP FEATURES:"; - --mid; - for (vector::iterator i = fnums.begin(); i != mid; ++i) { - cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; - } - cerr << endl; -} - bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() @@ -209,14 +179,16 @@ int main(int argc, char** argv) { cerr << "Mismatched number of references (" << ds.size() << ") and sources (" << corpus.size() << ")\n"; return 1; } - // load initial weights - Weights weights; - weights.InitFromFile(conf["input_weights"].as()); - SparseVector lambdas; - weights.InitSparseVector(&lambdas); ReadFile ini_rf(conf["decoder_config"].as()); Decoder decoder(ini_rf.stream()); + + // load initial weights + vector& dense_weights = decoder.CurrentWeightVector(); + SparseVector lambdas; + Weights::InitFromFile(conf["input_weights"].as(), &dense_weights); + Weights::InitSparseVector(dense_weights, &lambdas); + const double max_step_size = conf["max_step_size"].as(); const double mt_metric_scale = conf["mt_metric_scale"].as(); @@ -230,7 +202,6 @@ int main(int argc, char** argv) { double tot_loss = 0; int dots = 0; int cur_pass = 0; - vector dense_weights; SparseVector tot; tot += lambdas; // initial weights normalizer++; // count for initial weights @@ -240,27 +211,22 @@ int main(int argc, char** argv) { vector order; RandomPermutation(corpus.size(), &order); while (lcount <= max_iteration) { - dense_weights.clear(); - weights.InitFromVector(lambdas); - weights.InitVector(&dense_weights); - decoder.SetWeights(dense_weights); + lambdas.init_vector(&dense_weights); if ((cur_sent * 40 / corpus.size()) > dots) { ++dots; cerr << '.'; } if (corpus.size() == cur_sent) { cerr << " [AVG METRIC LAST PASS=" << (tot_loss / corpus.size()) << "]\n"; - ShowLargestFeatures(dense_weights); + Weights::ShowLargestFeatures(dense_weights); cur_sent = 0; tot_loss = 0; dots = 0; ostringstream os; os << "weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << ".gz"; - weights.WriteToFile(os.str(), true, &msg); SparseVector x = tot; x /= normalizer; ostringstream sa; sa << "weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "-avg.gz"; - Weights ww; - ww.InitFromVector(x); - ww.WriteToFile(sa.str(), true, &msga); + x.init_vector(&dense_weights); + Weights::WriteToFile(os.str(), dense_weights, true, &msg); ++cur_pass; RandomPermutation(corpus.size(), &order); } @@ -294,11 +260,11 @@ int main(int argc, char** argv) { ++cur_sent; } cerr << endl; - weights.WriteToFile("weights.mira-final.gz", true, &msg); + Weights::WriteToFile("weights.mira-final.gz", dense_weights, true, &msg); tot /= normalizer; - weights.InitFromVector(tot); + tot.init_vector(dense_weights); msg = "# MIRA tuned weights (averaged vector)"; - weights.WriteToFile("weights.mira-final-avg.gz", true, &msg); + Weights::WriteToFile("weights.mira-final-avg.gz", dense_weights, true, &msg); cerr << "Optimization complete.\nAVERAGED WEIGHTS: weights.mira-final-avg.gz\n"; return 0; } diff --git a/pro-train/mr_pro_map.cc b/pro-train/mr_pro_map.cc index 4324e8de..bc59285b 100644 --- a/pro-train/mr_pro_map.cc +++ b/pro-train/mr_pro_map.cc @@ -301,12 +301,8 @@ int main(int argc, char** argv) { const unsigned gamma = conf["candidate_pairs"].as(); const unsigned xi = conf["best_pairs"].as(); string weightsf = conf["weights"].as(); - vector weights; - { - Weights w; - w.InitFromFile(weightsf); - w.InitVector(&weights); - } + vector weights; + Weights::InitFromFile(weightsf, &weights); string kbest_repo = conf["kbest_repository"].as(); MkDirP(kbest_repo); while(in) { diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 9b422f33..9caaa1d1 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -194,7 +194,7 @@ int main(int argc, char** argv) { InitCommandLine(argc, argv, &conf); string line; vector > > training, testing; - SparseVector old_weights; + SparseVector old_weights; const bool tune_regularizer = conf.count("tune_regularizer"); if (tune_regularizer && !conf.count("testset")) { cerr << "--tune_regularizer requires --testset to be set\n"; @@ -210,9 +210,9 @@ int main(int argc, char** argv) { const double psi = conf["interpolation"].as(); if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; } if (conf.count("weights")) { - Weights w; - w.InitFromFile(conf["weights"].as()); - w.InitSparseVector(&old_weights); + vector dt; + Weights::InitFromFile(conf["weights"].as(), &dt); + Weights::InitSparseVector(dt, &old_weights); } ReadCorpus(&cin, &training); if (conf.count("testset")) { @@ -220,8 +220,8 @@ int main(int argc, char** argv) { ReadCorpus(rf.stream(), &testing); } cerr << "Number of features: " << FD::NumFeats() << endl; - vector x(FD::NumFeats(), 0.0); // x[0] is bias - for (SparseVector::const_iterator it = old_weights.begin(); + vector x(FD::NumFeats(), 0.0); // x[0] is bias + for (SparseVector::const_iterator it = old_weights.begin(); it != old_weights.end(); ++it) x[it->first] = it->second; double tppl = 0.0; @@ -257,7 +257,6 @@ int main(int argc, char** argv) { sigsq = sp[best_i].first; tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as(), &x); } - Weights w; if (conf.count("weights")) { for (int i = 1; i < x.size(); ++i) x[i] = (x[i] * psi) + old_weights.get(i) * (1.0 - psi); @@ -271,7 +270,6 @@ int main(int argc, char** argv) { cout << "# " << sp[i].first << "\t" << sp[i].second << "\t" << smoothed[i] << endl; } } - w.InitFromVector(x); - w.WriteToFile("-"); + Weights::WriteToFile("-", x); return 0; } diff --git a/training/Makefile.am b/training/Makefile.am index e075e417..6e2c06f5 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -12,9 +12,7 @@ bin_PROGRAMS = \ cllh_filter_grammar \ mpi_online_optimize \ mpi_batch_optimize \ - mpi_em_optimize \ compute_cllh \ - feature_expectations \ augment_grammar noinst_PROGRAMS = \ @@ -29,12 +27,6 @@ mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -feature_expectations_SOURCES = feature_expectations.cc -feature_expectations_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -mpi_em_optimize_SOURCES = mpi_em_optimize.cc optimize.cc -mpi_em_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - compute_cllh_SOURCES = compute_cllh.cc compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz diff --git a/training/augment_grammar.cc b/training/augment_grammar.cc index df8d4ee8..e89a92d5 100644 --- a/training/augment_grammar.cc +++ b/training/augment_grammar.cc @@ -134,9 +134,7 @@ int main(int argc, char** argv) { } else { ngram = NULL; } extra_feature = conf.count("extra_lex_feature") > 0; if (conf.count("collapse_weights")) { - Weights w; - w.InitFromFile(conf["collapse_weights"].as()); - w.InitVector(&col_weights); + Weights::InitFromFile(conf["collapse_weights"].as(), &col_weights); } clear_features = conf.count("clear_features_after_collapse") > 0; gather_rules = false; diff --git a/training/collapse_weights.cc b/training/collapse_weights.cc index 4fb742fb..dc480f6c 100644 --- a/training/collapse_weights.cc +++ b/training/collapse_weights.cc @@ -59,10 +59,8 @@ int main(int argc, char** argv) { InitCommandLine(argc, argv, &conf); const string wfile = conf["weights"].as(); const string gfile = conf["grammar"].as(); - Weights wm; - wm.InitFromFile(wfile); - vector w; - wm.InitVector(&w); + vector w; + Weights::InitFromFile(wfile, &w); MarginalMap e_tots; MarginalMap f_tots; prob_t tot; diff --git a/training/compute_cllh.cc b/training/compute_cllh.cc index 332f6d0c..b496d196 100644 --- a/training/compute_cllh.cc +++ b/training/compute_cllh.cc @@ -148,15 +148,6 @@ int main(int argc, char** argv) { if (!InitCommandLine(argc, argv, &conf)) return false; - // load initial weights - Weights weights; - if (conf.count("weights")) - weights.InitFromFile(conf["weights"].as()); - - // freeze feature set - //const bool freeze_feature_set = conf.count("freeze_feature_set"); - //if (freeze_feature_set) FD::Freeze(); - // load cdec.ini and set up decoder ReadFile ini_rf(conf["decoder_config"].as()); Decoder decoder(ini_rf.stream()); @@ -165,17 +156,22 @@ int main(int argc, char** argv) { abort(); } + // load weights + vector& weights = decoder.CurrentWeightVector(); + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &weights); + + // freeze feature set + //const bool freeze_feature_set = conf.count("freeze_feature_set"); + //if (freeze_feature_set) FD::Freeze(); + vector corpus; vector ids; ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); assert(corpus.size() > 0); assert(corpus.size() == ids.size()); - vector wv; - weights.InitVector(&wv); - decoder.SetWeights(wv); TrainingObserver observer; double objective = 0; - bool converged = false; observer.Reset(); if (rank == 0) @@ -197,3 +193,4 @@ int main(int argc, char** argv) { return 0; } + diff --git a/training/grammar_convert.cc b/training/grammar_convert.cc index 8d292f8a..bf8abb26 100644 --- a/training/grammar_convert.cc +++ b/training/grammar_convert.cc @@ -251,12 +251,10 @@ int main(int argc, char **argv) { const bool is_split_input = (conf["format"].as() == "split"); const bool is_json_input = is_split_input || (conf["format"].as() == "json"); const bool collapse_weights = conf.count("collapse_weights"); - Weights wts; vector w; - if (conf.count("weights")) { - wts.InitFromFile(conf["weights"].as()); - wts.InitVector(&w); - } + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &w); + if (collapse_weights && !w.size()) { cerr << "--collapse_weights requires a weights file to be specified!\n"; exit(1); diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index 39a8af7d..cc5953f6 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -31,42 +31,12 @@ using namespace std; using boost::shared_ptr; namespace po = boost::program_options; -void SanityCheck(const vector& w) { - for (int i = 0; i < w.size(); ++i) { - assert(!isnan(w[i])); - assert(!isinf(w[i])); - } -} - -struct FComp { - const vector& w_; - FComp(const vector& w) : w_(w) {} - bool operator()(int a, int b) const { - return fabs(w_[a]) > fabs(w_[b]); - } -}; - -void ShowLargestFeatures(const vector& w) { - vector fnums(w.size()); - for (int i = 0; i < w.size(); ++i) - fnums[i] = i; - vector::iterator mid = fnums.begin(); - mid += (w.size() > 10 ? 10 : w.size()); - partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); - cerr << "TOP FEATURES:"; - for (vector::iterator i = fnums.begin(); i != mid; ++i) { - cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; - } - cerr << endl; -} - bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("input_weights,w",po::value(),"Input feature weights file") ("training_data,t",po::value(),"Training data") ("decoder_config,d",po::value(),"Decoder configuration file") - ("sharded_input,s",po::value(), "Corpus and grammar files are 'sharded' so each processor loads its own input and grammar file. Argument is the directory containing the shards.") ("output_weights,o",po::value()->default_value("-"),"Output feature weights file") ("optimization_method,m", po::value()->default_value("lbfgs"), "Optimization method (sgd, lbfgs, rprop)") ("correction_buffers,M", po::value()->default_value(10), "Number of gradients for LBFGS to maintain in memory") @@ -88,14 +58,10 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { } po::notify(*conf); - if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data") | conf->count("sharded_input")) || !conf->count("decoder_config")) { + if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data")) || !conf->count("decoder_config")) { cerr << dcmdline_options << endl; return false; } - if (conf->count("training_data") && conf->count("sharded_input")) { - cerr << "Cannot specify both --training_data and --sharded_input\n"; - return false; - } return true; } @@ -236,42 +202,9 @@ int main(int argc, char** argv) { po::variables_map conf; if (!InitCommandLine(argc, argv, &conf)) return 1; - string shard_dir; - if (conf.count("sharded_input")) { - shard_dir = conf["sharded_input"].as(); - if (!DirectoryExists(shard_dir)) { - if (rank == 0) cerr << "Can't find shard directory: " << shard_dir << endl; - return 1; - } - if (rank == 0) - cerr << "Shard directory: " << shard_dir << endl; - } - - // load initial weights - Weights weights; - if (rank == 0) { cerr << "Loading weights...\n"; } - weights.InitFromFile(conf["input_weights"].as()); - if (rank == 0) { cerr << "Done loading weights.\n"; } - - // freeze feature set (should be optional?) - const bool freeze_feature_set = true; - if (freeze_feature_set) FD::Freeze(); - // load cdec.ini and set up decoder vector cdec_ini; ReadConfig(conf["decoder_config"].as(), &cdec_ini); - if (shard_dir.size()) { - if (rank == 0) { - for (int i = 0; i < cdec_ini.size(); ++i) { - if (cdec_ini[i].find("grammar=") == 0) { - cerr << "!!! using sharded input and " << conf["decoder_config"].as() << " contains a grammar specification:\n" << cdec_ini[i] << "\n VERIFY THAT THIS IS CORRECT!\n"; - } - } - } - ostringstream g; - g << "grammar=" << shard_dir << "/grammar." << rank << "_of_" << size << ".gz"; - cdec_ini.push_back(g.str()); - } istringstream ini; StoreConfig(cdec_ini, &ini); if (rank == 0) cerr << "Loading grammar...\n"; @@ -282,22 +215,28 @@ int main(int argc, char** argv) { } if (rank == 0) cerr << "Done loading grammar!\n"; + // load initial weights + if (rank == 0) { cerr << "Loading weights...\n"; } + vector& lambdas = decoder->CurrentWeightVector(); + Weights::InitFromFile(conf["input_weights"].as(), &lambdas); + if (rank == 0) { cerr << "Done loading weights.\n"; } + + // freeze feature set (should be optional?) + const bool freeze_feature_set = true; + if (freeze_feature_set) FD::Freeze(); + const int num_feats = FD::NumFeats(); if (rank == 0) cerr << "Number of features: " << num_feats << endl; + lambdas.resize(num_feats); + const bool gaussian_prior = conf.count("gaussian_prior"); - vector means(num_feats, 0); + vector means(num_feats, 0); if (conf.count("means")) { if (!gaussian_prior) { cerr << "Don't use --means without --gaussian_prior!\n"; exit(1); } - Weights wm; - wm.InitFromFile(conf["means"].as()); - if (num_feats != FD::NumFeats()) { - cerr << "[ERROR] Means file had unexpected features!\n"; - exit(1); - } - wm.InitVector(&means); + Weights::InitFromFile(conf["means"].as(), &means); } shared_ptr o; if (rank == 0) { @@ -309,26 +248,13 @@ int main(int argc, char** argv) { cerr << "Optimizer: " << o->Name() << endl; } double objective = 0; - vector lambdas(num_feats, 0.0); - weights.InitVector(&lambdas); - if (lambdas.size() != num_feats) { - cerr << "Initial weights file did not have all features specified!\n feats=" - << num_feats << "\n weights file=" << lambdas.size() << endl; - lambdas.resize(num_feats, 0.0); - } vector gradient(num_feats, 0.0); - vector rcv_grad(num_feats, 0.0); + vector rcv_grad; + rcv_grad.clear(); bool converged = false; vector corpus; - if (shard_dir.size()) { - ostringstream os; os << shard_dir << "/corpus." << rank << "_of_" << size; - ReadTrainingCorpus(os.str(), 0, 1, &corpus); - cerr << os.str() << " has " << corpus.size() << " training examples. " << endl; - if (corpus.size() > 500) { corpus.resize(500); cerr << " TRUNCATING\n"; } - } else { - ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); - } + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); assert(corpus.size() > 0); TrainingObserver observer; @@ -341,19 +267,20 @@ int main(int argc, char** argv) { if (rank == 0) { cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n"; } - decoder->SetWeights(lambdas); for (int i = 0; i < corpus.size(); ++i) decoder->Decode(corpus[i], &observer); cerr << " process " << rank << '/' << size << " done\n"; fill(gradient.begin(), gradient.end(), 0); - fill(rcv_grad.begin(), rcv_grad.end(), 0); observer.SetLocalGradientAndObjective(&gradient, &objective); double to = 0; #ifdef HAVE_MPI + rcv_grad.resize(num_feats, 0.0); mpi::reduce(world, &gradient[0], gradient.size(), &rcv_grad[0], plus(), 0); - mpi::reduce(world, objective, to, plus(), 0); swap(gradient, rcv_grad); + rcv_grad.clear(); + + mpi::reduce(world, objective, to, plus(), 0); objective = to; #endif @@ -378,7 +305,7 @@ int main(int argc, char** argv) { for (int i = 0; i < gradient.size(); ++i) gnorm += gradient[i] * gradient[i]; cerr << " GNORM=" << sqrt(gnorm) << endl; - vector old = lambdas; + vector old = lambdas; int c = 0; while (old == lambdas) { ++c; @@ -387,9 +314,8 @@ int main(int argc, char** argv) { assert(c < 5); } old.clear(); - SanityCheck(lambdas); - ShowLargestFeatures(lambdas); - weights.InitFromVector(lambdas); + Weights::SanityCheck(lambdas); + Weights::ShowLargestFeatures(lambdas); converged = o->HasConverged(); if (converged) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } @@ -399,7 +325,7 @@ int main(int argc, char** argv) { ostringstream vv; vv << "Objective = " << objective << " (eval count=" << o->EvaluationCount() << ")"; const string svv = vv.str(); - weights.WriteToFile(fname, true, &svv); + Weights::WriteToFile(fname, lambdas, true, &svv); } // rank == 0 int cint = converged; #ifdef HAVE_MPI @@ -411,3 +337,4 @@ int main(int argc, char** argv) { } return 0; } + diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 32033c19..2ef4a2e7 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -31,35 +31,6 @@ namespace mpi = boost::mpi; using namespace std; namespace po = boost::program_options; -void SanityCheck(const vector& w) { - for (int i = 0; i < w.size(); ++i) { - assert(!isnan(w[i])); - assert(!isinf(w[i])); - } -} - -struct FComp { - const vector& w_; - FComp(const vector& w) : w_(w) {} - bool operator()(int a, int b) const { - return fabs(w_[a]) > fabs(w_[b]); - } -}; - -void ShowLargestFeatures(const vector& w) { - vector fnums(w.size()); - for (int i = 0; i < w.size(); ++i) - fnums[i] = i; - vector::iterator mid = fnums.begin(); - mid += (w.size() > 10 ? 10 : w.size()); - partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); - cerr << "TOP FEATURES:"; - for (vector::iterator i = fnums.begin(); i != mid; ++i) { - cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; - } - cerr << endl; -} - bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() @@ -250,10 +221,25 @@ int main(int argc, char** argv) { if (!InitCommandLine(argc, argv, &conf)) return 1; + vector > agenda; + if (!LoadAgenda(conf["training_agenda"].as(), &agenda)) + return 1; + if (rank == 0) + cerr << "Loaded agenda defining " << agenda.size() << " training epochs\n"; + + assert(agenda.size() > 0); + + if (1) { // hack to load the feature hash functions -- TODO this should not be in cdec.ini + const string& cur_config = agenda[0].first; + const unsigned max_iteration = agenda[0].second; + ReadFile ini_rf(cur_config); + Decoder decoder(ini_rf.stream()); + } + // load initial weights - Weights weights; + vector init_weights; if (conf.count("input_weights")) - weights.InitFromFile(conf["input_weights"].as()); + Weights::InitFromFile(conf["input_weights"].as(), &init_weights); vector frozen_fids; if (conf.count("frozen_features")) { @@ -310,19 +296,12 @@ int main(int argc, char** argv) { rng.reset(new MT19937); SparseVector x; - weights.InitSparseVector(&x); + Weights::InitSparseVector(init_weights, &x); TrainingObserver observer; int write_weights_every_ith = 100; // TODO configure int titer = -1; - vector > agenda; - if (!LoadAgenda(conf["training_agenda"].as(), &agenda)) - return 1; - if (rank == 0) - cerr << "Loaded agenda defining " << agenda.size() << " training epochs\n"; - - vector lambdas; for (int ai = 0; ai < agenda.size(); ++ai) { const string& cur_config = agenda[ai].first; const unsigned max_iteration = agenda[ai].second; @@ -331,6 +310,8 @@ int main(int argc, char** argv) { // load cdec.ini and set up decoder ReadFile ini_rf(cur_config); Decoder decoder(ini_rf.stream()); + vector& lambdas = decoder.CurrentWeightVector(); + if (ai == 0) { lambdas.swap(init_weights); init_weights.clear(); } if (rank == 0) o->ResetEpoch(); // resets the learning rate-- TODO is this good? @@ -341,15 +322,13 @@ int main(int argc, char** argv) { #ifdef HAVE_MPI mpi::timer timer; #endif - weights.InitFromVector(x); - weights.InitVector(&lambdas); + x.init_vector(&lambdas); ++iter; ++titer; observer.Reset(); - decoder.SetWeights(lambdas); if (rank == 0) { converged = (iter == max_iteration); - SanityCheck(lambdas); - ShowLargestFeatures(lambdas); + Weights::SanityCheck(lambdas); + Weights::ShowLargestFeatures(lambdas); string fname = "weights.cur.gz"; if (iter % write_weights_every_ith == 0) { ostringstream o; o << "weights.epoch_" << (ai+1) << '.' << iter << ".gz"; @@ -360,7 +339,7 @@ int main(int argc, char** argv) { vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast(corpus.size())) << " eta=" << lr->eta(titer); const string svv = vv.str(); cerr << svv << endl; - weights.WriteToFile(fname, true, &svv); + Weights::WriteToFile(fname, lambdas, true, &svv); } for (int i = 0; i < size_per_proc; ++i) { diff --git a/training/mr_optimize_reduce.cc b/training/mr_optimize_reduce.cc index b931991d..15e28fa1 100644 --- a/training/mr_optimize_reduce.cc +++ b/training/mr_optimize_reduce.cc @@ -88,25 +88,19 @@ int main(int argc, char** argv) { const bool use_b64 = conf["input_format"].as() == "b64"; - Weights weights; - weights.InitFromFile(conf["input_weights"].as()); + vector lambdas; + Weights::InitFromFile(conf["input_weights"].as(), &lambdas); const string s_obj = "**OBJ**"; int num_feats = FD::NumFeats(); cerr << "Number of features: " << num_feats << endl; const bool gaussian_prior = conf.count("gaussian_prior"); - vector means(num_feats, 0); + vector means(num_feats, 0); if (conf.count("means")) { if (!gaussian_prior) { cerr << "Don't use --means without --gaussian_prior!\n"; exit(1); } - Weights wm; - wm.InitFromFile(conf["means"].as()); - if (num_feats != FD::NumFeats()) { - cerr << "[ERROR] Means file had unexpected features!\n"; - exit(1); - } - wm.InitVector(&means); + Weights::InitFromFile(conf["means"].as(), &means); } shared_ptr o; const string omethod = conf["optimization_method"].as(); @@ -124,8 +118,6 @@ int main(int argc, char** argv) { cerr << "No state file found, assuming ITERATION 1\n"; } - vector lambdas(num_feats, 0); - weights.InitVector(&lambdas); double objective = 0; vector gradient(num_feats, 0); // 0**OBJ**=12.2;Feat1=2.3;Feat2=-0.2; @@ -223,8 +215,7 @@ int main(int argc, char** argv) { old.clear(); SanityCheck(lambdas); ShowLargestFeatures(lambdas); - weights.InitFromVector(lambdas); - weights.WriteToFile(conf["output_weights"].as(), false); + Weights::WriteToFile(conf["output_weights"].as(), lambdas, false); const bool conv = o->HasConverged(); if (conv) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } diff --git a/utils/fdict.h b/utils/fdict.h index 771e8b91..f0871b9a 100644 --- a/utils/fdict.h +++ b/utils/fdict.h @@ -28,6 +28,8 @@ struct FD { } static void EnableHash(const std::string& cmph_file) { #ifdef HAVE_CMPH + assert(dict_.max() == 0); // dictionary must not have + // been added to hash_ = new PerfectHashFunction(cmph_file); #endif } diff --git a/utils/phmt.cc b/utils/phmt.cc index 1f59afaf..48d9f093 100644 --- a/utils/phmt.cc +++ b/utils/phmt.cc @@ -19,22 +19,18 @@ int main(int argc, char** argv) { cerr << "LexFE = " << FD::Convert("LexFE") << endl; cerr << "LexEF = " << FD::Convert("LexEF") << endl; { - Weights w; vector v(FD::NumFeats()); v[FD::Convert("LexFE")] = 1.0; v[FD::Convert("LexEF")] = 0.5; - w.InitFromVector(v); cerr << "Writing...\n"; - w.WriteToFile("weights.bin"); + Weights::WriteToFile("weights.bin", v); cerr << "Done.\n"; } { - Weights w; vector v(FD::NumFeats()); cerr << "Reading...\n"; - w.InitFromFile("weights.bin"); + Weights::InitFromFile("weights.bin", &v); cerr << "Done.\n"; - w.InitVector(&v); assert(v[FD::Convert("LexFE")] == 1.0); assert(v[FD::Convert("LexEF")] == 0.5); } diff --git a/utils/weights.cc b/utils/weights.cc index 0916b72a..c49000be 100644 --- a/utils/weights.cc +++ b/utils/weights.cc @@ -8,7 +8,10 @@ using namespace std; -void Weights::InitFromFile(const std::string& filename, vector* feature_list) { +void Weights::InitFromFile(const string& filename, + vector* pweights, + vector* feature_list) { + vector& weights = *pweights; if (!SILENT) cerr << "Reading weights from " << filename << endl; ReadFile in_file(filename); istream& in = *in_file.stream(); @@ -47,16 +50,16 @@ void Weights::InitFromFile(const std::string& filename, vector* feature_ int end = 0; while(end < buf.size() && buf[end] != ' ') ++end; const int fid = FD::Convert(buf.substr(start, end - start)); + if (feature_list) { feature_list->push_back(buf.substr(start, end - start)); } while(end < buf.size() && buf[end] == ' ') ++end; val = strtod(&buf.c_str()[end], NULL); if (isnan(val)) { cerr << FD::Convert(fid) << " has weight NaN!\n"; abort(); } - if (wv_.size() <= fid) - wv_.resize(fid + 1); - wv_[fid] = val; - if (feature_list) { feature_list->push_back(FD::Convert(fid)); } + if (weights.size() <= fid) + weights.resize(fid + 1); + weights[fid] = val; ++weight_count; if (!SILENT) { if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; } @@ -76,8 +79,8 @@ void Weights::InitFromFile(const std::string& filename, vector* feature_ cerr << "Hash function reports " << FD::NumFeats() << " keys but weights file contains " << num_keys[0] << endl; abort(); } - wv_.resize(num_keys[0]); - in.get(reinterpret_cast(&wv_[0]), num_keys[0] * sizeof(weight_t)); + weights.resize(num_keys[0]); + in.get(reinterpret_cast(&weights[0]), num_keys[0] * sizeof(weight_t)); if (!in.good()) { cerr << "Error loading weights!\n"; abort(); @@ -85,7 +88,10 @@ void Weights::InitFromFile(const std::string& filename, vector* feature_ } } -void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_features, const string* extra) const { +void Weights::WriteToFile(const string& fname, + const vector& weights, + bool hide_zero_value_features, + const string* extra) { WriteFile out(fname); ostream& o = *out.stream(); assert(o); @@ -96,41 +102,54 @@ void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_feature o.precision(17); const int num_feats = FD::NumFeats(); for (int i = 1; i < num_feats; ++i) { - const weight_t val = (i < wv_.size() ? wv_[i] : 0.0); + const weight_t val = (i < weights.size() ? weights[i] : 0.0); if (hide_zero_value_features && val == 0.0) continue; o << FD::Convert(i) << ' ' << val << endl; } } else { o.write("_PHWf", 5); const size_t keys = FD::NumFeats(); - assert(keys <= wv_.size()); + assert(keys <= weights.size()); o.write(reinterpret_cast(&keys), sizeof(keys)); - o.write(reinterpret_cast(&wv_[0]), keys * sizeof(weight_t)); + o.write(reinterpret_cast(&weights[0]), keys * sizeof(weight_t)); } } -void Weights::InitVector(std::vector* w) const { - *w = wv_; +void Weights::InitSparseVector(const vector& dv, + SparseVector* sv) { + sv->clear(); + for (unsigned i = 1; i < dv.size(); ++i) { + if (dv[i]) sv->set_value(i, dv[i]); + } } -void Weights::InitSparseVector(SparseVector* w) const { - for (int i = 1; i < wv_.size(); ++i) { - const weight_t& weight = wv_[i]; - if (weight) w->set_value(i, weight); +void Weights::SanityCheck(const vector& w) { + for (int i = 0; i < w.size(); ++i) { + assert(!isnan(w[i])); + assert(!isinf(w[i])); } } -void Weights::InitFromVector(const std::vector& w) { - wv_ = w; - if (wv_.size() > FD::NumFeats()) - cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n"; - wv_.resize(FD::NumFeats(), 0); -} +struct FComp { + const vector& w_; + FComp(const vector& w) : w_(w) {} + bool operator()(int a, int b) const { + return fabs(w_[a]) > fabs(w_[b]); + } +}; -void Weights::InitFromVector(const SparseVector& w) { - wv_.clear(); - wv_.resize(FD::NumFeats(), 0.0); - for (int i = 1; i < FD::NumFeats(); ++i) - wv_[i] = w.value(i); +void Weights::ShowLargestFeatures(const vector& w) { + vector fnums(w.size()); + for (int i = 0; i < w.size(); ++i) + fnums[i] = i; + vector::iterator mid = fnums.begin(); + mid += (w.size() > 10 ? 10 : w.size()); + partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); + cerr << "TOP FEATURES:"; + for (vector::iterator i = fnums.begin(); i != mid; ++i) { + cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; + } + cerr << endl; } + diff --git a/utils/weights.h b/utils/weights.h index 7664810b..30f71db0 100644 --- a/utils/weights.h +++ b/utils/weights.h @@ -10,15 +10,21 @@ typedef double weight_t; class Weights { public: - Weights() {} - void InitFromFile(const std::string& fname, std::vector* feature_list = NULL); - void WriteToFile(const std::string& fname, bool hide_zero_value_features = true, const std::string* extra = NULL) const; - void InitVector(std::vector* w) const; - void InitSparseVector(SparseVector* w) const; - void InitFromVector(const std::vector& w); - void InitFromVector(const SparseVector& w); + static void InitFromFile(const std::string& fname, + std::vector* weights, + std::vector* feature_list = NULL); + static void WriteToFile(const std::string& fname, + const std::vector& weights, + bool hide_zero_value_features = true, + const std::string* extra = NULL); + static void InitSparseVector(const std::vector& dv, + SparseVector* sv); + // check for infinities, NaNs, etc + static void SanityCheck(const std::vector& w); + // write weights with largest magnitude to cerr + static void ShowLargestFeatures(const std::vector& w); private: - std::vector wv_; + Weights(); }; #endif diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc index b84c44bc..0c094fd5 100644 --- a/vest/mr_vest_generate_mapper_input.cc +++ b/vest/mr_vest_generate_mapper_input.cc @@ -223,16 +223,16 @@ struct oracle_directions { cerr << "Forest repo: " << forest_repository << endl; assert(DirectoryExists(forest_repository)); vector features; - weights.InitFromFile(weights_file, &features); + vector dorigin; + Weights::InitFromFile(weights_file, &dorigin, &features); if (optimize_features.size()) features=optimize_features; - weights.InitSparseVector(&origin); + Weights::InitSparseVector(dorigin, &origin); fids.clear(); AddFeatureIds(features); oracles.resize(dev_set_size); } - Weights weights; void AddFeatureIds(vector const& features) { int i = fids.size(); fids.resize(fids.size()+features.size()); -- cgit v1.2.3 From bff9f7f6e3ed777c9379c0373657eeaf43a6a213 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 17:57:32 +0100 Subject: fix for crash with no rescoring --- decoder/decoder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 4d4b6245..45404c47 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -181,7 +181,7 @@ struct DecoderImpl { ~DecoderImpl(); bool Decode(const string& input, DecoderObserver*); vector& CurrentWeightVector() { - return *rescoring_passes.back().weight_vector; + return (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector); } void SetId(int next_sent_id) { sent_id = next_sent_id - 1; } -- cgit v1.2.3 From dffebff1a33e581a4a36ba060faf5a2ba8e87faa Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 18:33:08 +0100 Subject: fix weight serialization bug --- utils/weights.cc | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/utils/weights.cc b/utils/weights.cc index c49000be..ac407dfb 100644 --- a/utils/weights.cc +++ b/utils/weights.cc @@ -23,7 +23,7 @@ void Weights::InitFromFile(const string& filename, istream& hi = *hdrrf.stream(); assert(hi); char buf[10]; - hi.get(buf, 6); + hi.read(buf, 5); assert(hi.good()); if (strncmp(buf, "_PHWf", 5) == 0) { read_text = false; @@ -72,18 +72,20 @@ void Weights::InitFromFile(const string& filename, } } else { // !read_text char buf[6]; - in.get(buf, 6); - size_t num_keys[2]; - in.get(reinterpret_cast(&num_keys[0]), sizeof(size_t) + 1); - if (num_keys[0] != FD::NumFeats()) { - cerr << "Hash function reports " << FD::NumFeats() << " keys but weights file contains " << num_keys[0] << endl; + in.read(buf, 5); + size_t num_keys; + in.read(reinterpret_cast(&num_keys), sizeof(size_t)); + if (num_keys != FD::NumFeats()) { + cerr << "Hash function reports " << FD::NumFeats() << " keys but weights file contains " << num_keys << endl; abort(); } - weights.resize(num_keys[0]); - in.get(reinterpret_cast(&weights[0]), num_keys[0] * sizeof(weight_t)); + weights.resize(num_keys); + in.read(reinterpret_cast(&weights.front()), num_keys * sizeof(weight_t)); if (!in.good()) { cerr << "Error loading weights!\n"; abort(); + } else { + cerr << " Successfully loaded " << (num_keys * sizeof(weight_t)) << " bytes\n"; } } } -- cgit v1.2.3 From ddc38ce211d4b38f66e56dfa072856a4e9de2c17 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 18:46:33 +0100 Subject: remove features that are overfitting --- decoder/ff_source_syntax.cc | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index 5b7c16f6..ffe07f03 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -25,12 +25,10 @@ struct SourceSyntaxFeaturesImpl { void InitializeGrids(const string& tree, unsigned src_len) { assert(tree.size() > 0); - fids_cat.clear(); - fids_fonly.clear(); + //fids_cat.clear(); fids_ef.clear(); src_tree.clear(); - fids_cat.resize(src_len, src_len + 1); - fids_fonly.resize(src_len, src_len + 1); + //fids_cat.resize(src_len, src_len + 1); fids_ef.resize(src_len, src_len + 1); src_tree.resize(src_len, src_len + 1, TD::Convert("XX")); ParseTreeString(tree, src_len); @@ -89,15 +87,14 @@ struct SourceSyntaxFeaturesImpl { WordID FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector* feats) { //cerr << "fire features: " << rule.AsString() << " for " << i << "," << j << endl; const WordID lhs = src_tree(i,j); - int& fid_cat = fids_cat(i,j); - int& fid_fonly = fids_fonly(i,j)[&rule]; + //int& fid_cat = fids_cat(i,j); int& fid_ef = fids_ef(i,j)[&rule]; if (fid_ef <= 0) { ostringstream os; - ostringstream os2; + //ostringstream os2; os << "SYN:" << TD::Convert(lhs); - os2 << "SYN:" << TD::Convert(lhs) << '_' << SpanSizeTransform(j - i); - fid_cat = FD::Convert(os2.str()); + //os2 << "SYN:" << TD::Convert(lhs) << '_' << SpanSizeTransform(j - i); + //fid_cat = FD::Convert(os2.str()); os << ':'; unsigned ntc = 0; for (unsigned k = 0; k < rule.f_.size(); ++k) { @@ -109,7 +106,6 @@ struct SourceSyntaxFeaturesImpl { os << TD::Convert(fj); } } - fid_fonly = FD::Convert(os.str()); os << ':'; for (unsigned k = 0; k < rule.e_.size(); ++k) { const int ei = rule.e_[k]; @@ -121,18 +117,15 @@ struct SourceSyntaxFeaturesImpl { } fid_ef = FD::Convert(os.str()); } - if (fid_cat > 0) - feats->set_value(fid_cat, 1.0); - if (fid_fonly > 0) - feats->set_value(fid_fonly, 1.0); + //if (fid_cat > 0) + // feats->set_value(fid_cat, 1.0); if (fid_ef > 0) feats->set_value(fid_ef, 1.0); return lhs; } Array2D src_tree; // src_tree(i,j) NT = type - mutable Array2D fids_cat; // fires for an LHS match - mutable Array2D > fids_fonly; // fires for an f-string + // mutable Array2D fids_cat; // this tends to overfit baddly mutable Array2D > fids_ef; // fires for fully lexicalized }; -- cgit v1.2.3 From 4d87d0edc375a9a7bedddb22d075b6484daf0bf6 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 20:16:17 +0100 Subject: tool to reconstruct text weights from a hash function, key file, and (binary) weights file --- utils/Makefile.am | 5 ++++ utils/reconstruct_weights.cc | 68 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 utils/reconstruct_weights.cc diff --git a/utils/Makefile.am b/utils/Makefile.am index c50747bf..df667655 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -1,3 +1,6 @@ + +bin_PROGRAMS = reconstruct_weights + noinst_PROGRAMS = ts phmt TESTS = ts phmt @@ -11,6 +14,8 @@ noinst_PROGRAMS += \ TESTS += small_vector_test logval_test weights_test dict_test endif +reconstruct_weights_SOURCES = reconstruct_weights.cc + noinst_LIBRARIES = libutils.a libutils_a_SOURCES = \ diff --git a/utils/reconstruct_weights.cc b/utils/reconstruct_weights.cc new file mode 100644 index 00000000..d32e4f67 --- /dev/null +++ b/utils/reconstruct_weights.cc @@ -0,0 +1,68 @@ +#include +#include +#include + +#include +#include + +#include "filelib.h" +#include "fdict.h" +#include "weights.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("weights,w",po::value(),"Input feature weights file") + ("keys,k",po::value(),"Keys file (list of features with dummy value at start)") + ("cmph_perfect_hash_file,h",po::value(),"cmph perfect hash function file"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,?", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("cmph_perfect_hash_file") || !conf->count("weights") || !conf->count("keys")) { + cerr << "Generate a text format weights file. Options -w -k and -h are required.\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +int main(int argc, char** argv) { + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return false; + + FD::EnableHash(conf["cmph_perfect_hash_file"].as()); + + // load weights + vector weights; + Weights::InitFromFile(conf["weights"].as(), &weights); + + ReadFile rf(conf["keys"].as()); + istream& in = *rf.stream(); + string key; + size_t lc = 0; + while(getline(in, key)) { + ++lc; + if (lc == 1) continue; + assert(lc <= weights.size()); + cout << key << " " << weights[lc - 1] << endl; + } + + return 0; +} + -- cgit v1.2.3 From b9d54044619b964467857b20921c19ab9135326c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 21:55:02 +0100 Subject: binary to extract features encountered --- training/Makefile.am | 4 ++ training/mpi_extract_features.cc | 151 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 training/mpi_extract_features.cc diff --git a/training/Makefile.am b/training/Makefile.am index 6e2c06f5..7ceeda34 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -11,6 +11,7 @@ bin_PROGRAMS = \ collapse_weights \ cllh_filter_grammar \ mpi_online_optimize \ + mpi_extract_features \ mpi_batch_optimize \ compute_cllh \ augment_grammar @@ -24,6 +25,9 @@ TESTS = lbfgs_test optimize_test mpi_online_optimize_SOURCES = mpi_online_optimize.cc online_optimizer.cc mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_features_SOURCES = mpi_extract_features.cc +mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz diff --git a/training/mpi_extract_features.cc b/training/mpi_extract_features.cc new file mode 100644 index 00000000..6750aa15 --- /dev/null +++ b/training/mpi_extract_features.cc @@ -0,0 +1,151 @@ +#include +#include +#include +#include + +#include "config.h" +#ifdef HAVE_MPI +#include +#endif +#include +#include + +#include "ff_register.h" +#include "verbose.h" +#include "filelib.h" +#include "fdict.h" +#include "decoder.h" +#include "weights.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("training_data,t",po::value(),"Training data corpus") + ("decoder_config,c",po::value(),"Decoder configuration file") + ("weights,w", po::value(), "(Optional) weights file; weights may affect what features are encountered in pruning configurations") + ("output_prefix,o",po::value()->default_value("features"),"Output path prefix"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { + cerr << "Decode an input set (optionally in parallel using MPI) and write\nout the feature strings encountered.\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct TrainingObserver : public DecoderObserver { + + virtual void NotifyDecodingStart(const SentenceMetadata&) { + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + } +}; + +#ifdef HAVE_MPI +namespace mpi = boost::mpi; +#endif + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return false; + + // load cdec.ini and set up decoder + ReadFile ini_rf(conf["decoder_config"].as()); + Decoder decoder(ini_rf.stream()); + if (decoder.GetConf()["input"].as() != "-") { + cerr << "cdec.ini must not set an input file\n"; + abort(); + } + + if (FD::UsingPerfectHashFunction()) { + cerr << "Your configuration file has enabled a cmph hash function. Please disable.\n"; + return 1; + } + + // load optional weights + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &decoder.CurrentWeightVector()); + + vector corpus; + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); + assert(corpus.size() > 0); + + TrainingObserver observer; + + if (rank == 0) + cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; + + for (int i = 0; i < corpus.size(); ++i) + decoder.Decode(corpus[i], &observer); + + { + ostringstream os; + os << conf["output_prefix"].as() << '.' << rank << "_of_" << size; + WriteFile wf(os.str()); + ostream& out = *wf.stream(); + const unsigned num_feats = FD::NumFeats(); + for (unsigned i = 1; i < num_feats; ++i) { + out << FD::Convert(i) << endl; + } + cerr << "Wrote " << os.str() << endl; + } + +#ifdef HAVE_MPI + world.barrier(); +#else +#endif + + return 0; +} + -- cgit v1.2.3 From f67fee820ba941cfb7f11ee0ee5df6b356ff959c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 14 Sep 2011 12:17:04 +0100 Subject: weight_t refactoring --- pro-train/mr_pro_map.cc | 42 +++++++++++++++++++++--------------------- pro-train/mr_pro_reduce.cc | 34 +++++++++++++++++----------------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/pro-train/mr_pro_map.cc b/pro-train/mr_pro_map.cc index bc59285b..0a9b75d7 100644 --- a/pro-train/mr_pro_map.cc +++ b/pro-train/mr_pro_map.cc @@ -27,7 +27,7 @@ namespace po = boost::program_options; struct ApproxVectorHasher { static const size_t MASK = 0xFFFFFFFFull; union UType { - double f; + double f; // leave as double size_t i; }; static inline double round(const double x) { @@ -40,9 +40,9 @@ struct ApproxVectorHasher { t.i &= (1ull - MASK); return t.f; } - size_t operator()(const SparseVector& x) const { + size_t operator()(const SparseVector& x) const { size_t h = 0x573915839; - for (SparseVector::const_iterator it = x.begin(); it != x.end(); ++it) { + for (SparseVector::const_iterator it = x.begin(); it != x.end(); ++it) { UType t; t.f = it->second; if (t.f) { @@ -56,9 +56,9 @@ struct ApproxVectorHasher { }; struct ApproxVectorEquals { - bool operator()(const SparseVector& a, const SparseVector& b) const { - SparseVector::const_iterator bit = b.begin(); - for (SparseVector::const_iterator ait = a.begin(); ait != a.end(); ++ait) { + bool operator()(const SparseVector& a, const SparseVector& b) const { + SparseVector::const_iterator bit = b.begin(); + for (SparseVector::const_iterator ait = a.begin(); ait != a.end(); ++ait) { if (bit == b.end() || ait->first != bit->first || ApproxVectorHasher::round(ait->second) != ApproxVectorHasher::round(bit->second)) @@ -105,18 +105,18 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } struct HypInfo { - HypInfo() : g_(-100.0) {} - HypInfo(const vector& h, const SparseVector& feats) : hyp(h), g_(-100.0), x(feats) {} + HypInfo() : g_(-100.0f) {} + HypInfo(const vector& h, const SparseVector& feats) : hyp(h), g_(-100.0f), x(feats) {} // lazy evaluation double g(const SentenceScorer& scorer) const { - if (g_ == -100.0) + if (g_ == -100.0f) g_ = scorer.ScoreCandidate(hyp)->ComputeScore(); return g_; } vector hyp; - mutable double g_; - SparseVector x; + mutable float g_; + SparseVector x; }; struct HypInfoCompare { @@ -146,8 +146,8 @@ void WriteKBest(const string& file, const vector& kbest) { } } -void ParseSparseVector(string& line, size_t cur, SparseVector* out) { - SparseVector& x = *out; +void ParseSparseVector(string& line, size_t cur, SparseVector* out) { + SparseVector& x = *out; size_t last_start = cur; size_t last_comma = string::npos; while(cur <= line.size()) { @@ -211,15 +211,15 @@ struct ThresholdAlpha { }; struct TrainingInstance { - TrainingInstance(const SparseVector& feats, bool positive, double diff) : x(feats), y(positive), gdiff(diff) {} - SparseVector x; + TrainingInstance(const SparseVector& feats, bool positive, float diff) : x(feats), y(positive), gdiff(diff) {} + SparseVector x; #undef DEBUGGING_PRO #ifdef DEBUGGING_PRO vector a; vector b; #endif bool y; - double gdiff; + float gdiff; }; #ifdef DEBUGGING_PRO ostream& operator<<(ostream& os, const TrainingInstance& d) { @@ -235,19 +235,19 @@ struct DiffOrder { void Sample(const unsigned gamma, const unsigned xi, const vector& J_i, const SentenceScorer& scorer, const bool invert_score, vector* pv) { vector v1, v2; - double avg_diff = 0; + float avg_diff = 0; for (unsigned i = 0; i < gamma; ++i) { const size_t a = rng->inclusive(0, J_i.size() - 1)(); const size_t b = rng->inclusive(0, J_i.size() - 1)(); if (a == b) continue; - double ga = J_i[a].g(scorer); - double gb = J_i[b].g(scorer); + float ga = J_i[a].g(scorer); + float gb = J_i[b].g(scorer); bool positive = gb < ga; if (invert_score) positive = !positive; - const double gdiff = fabs(ga - gb); + const float gdiff = fabs(ga - gb); if (!gdiff) continue; avg_diff += gdiff; - SparseVector xdiff = (J_i[a].x - J_i[b].x).erase_zeros(); + SparseVector xdiff = (J_i[a].x - J_i[b].x).erase_zeros(); if (xdiff.empty()) { cerr << "Empty diff:\n " << TD::GetString(J_i[a].hyp) << endl << "x=" << J_i[a].x << endl; cerr << " " << TD::GetString(J_i[b].hyp) << endl << "x=" << J_i[b].x << endl; diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 9caaa1d1..239649c1 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -40,8 +40,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } -void ParseSparseVector(string& line, size_t cur, SparseVector* out) { - SparseVector& x = *out; +void ParseSparseVector(string& line, size_t cur, SparseVector* out) { + SparseVector& x = *out; size_t last_start = cur; size_t last_comma = string::npos; while(cur <= line.size()) { @@ -52,7 +52,7 @@ void ParseSparseVector(string& line, size_t cur, SparseVector* out) { } const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); if (cur < line.size()) line[cur] = 0; - const double val = strtod(&line[last_comma + 1], NULL); + const weight_t val = strtod(&line[last_comma + 1], NULL); x.set_value(fid, val); last_comma = string::npos; @@ -65,13 +65,13 @@ void ParseSparseVector(string& line, size_t cur, SparseVector* out) { } } -void ReadCorpus(istream* pin, vector > >* corpus) { +void ReadCorpus(istream* pin, vector > >* corpus) { istream& in = *pin; corpus->clear(); bool flag = false; int lc = 0; string line; - SparseVector x; + SparseVector x; while(getline(in, line)) { ++lc; if (lc % 1000 == 0) { cerr << '.'; flag = true; } @@ -88,16 +88,16 @@ void ReadCorpus(istream* pin, vector > >* corpus if (flag) cerr << endl; } -void GradAdd(const SparseVector& v, const double scale, vector* acc) { - for (SparseVector::const_iterator it = v.begin(); +void GradAdd(const SparseVector& v, const double scale, vector* acc) { + for (SparseVector::const_iterator it = v.begin(); it != v.end(); ++it) { (*acc)[it->first] += it->second * scale; } } -double TrainingInference(const vector& x, - const vector > >& corpus, - vector* g = NULL) { +double TrainingInference(const vector& x, + const vector > >& corpus, + vector* g = NULL) { double cll = 0; for (int i = 0; i < corpus.size(); ++i) { const double dotprod = corpus[i].second.dot(x) + x[0]; // x[0] is bias @@ -132,13 +132,13 @@ double TrainingInference(const vector& x, } // return held-out log likelihood -double LearnParameters(const vector > >& training, - const vector > >& testing, +double LearnParameters(const vector > >& training, + const vector > >& testing, const double sigsq, const unsigned memory_buffers, - vector* px) { - vector& x = *px; - vector vg(FD::NumFeats(), 0.0); + vector* px) { + vector& x = *px; + vector vg(FD::NumFeats(), 0.0); bool converged = false; LBFGSOptimizer opt(FD::NumFeats(), memory_buffers); double tppl = 0.0; @@ -172,7 +172,7 @@ double LearnParameters(const vector > >& trainin cll += reg; cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t TEST_PPL=" << tppl << "\t"; try { - vector old_x = x; + vector old_x = x; do { opt.Optimize(cll, vg, &x); converged = opt.HasConverged(); @@ -193,7 +193,7 @@ int main(int argc, char** argv) { po::variables_map conf; InitCommandLine(argc, argv, &conf); string line; - vector > > training, testing; + vector > > training, testing; SparseVector old_weights; const bool tune_regularizer = conf.count("tune_regularizer"); if (tune_regularizer && !conf.count("testset")) { -- cgit v1.2.3 From 678b11b7e5a537170c81eb577113843ee147d88f Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 14 Sep 2011 12:40:01 +0100 Subject: oxford env --- environment/LocalConfig.pm | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/environment/LocalConfig.pm b/environment/LocalConfig.pm index 7b3d950c..4e5e0d5f 100644 --- a/environment/LocalConfig.pm +++ b/environment/LocalConfig.pm @@ -44,6 +44,10 @@ my $CCONFIG = { 'HOST_REGEXP' => qr/^(barrow|chicago).lti.cs.cmu.edu$/, 'QSubMemFlag' => '-l pmem=', }, + 'OxfordDeathSnakes' => { + 'HOST_REGEXP' => qr/^(taipan|tiger).cs.ox.ac.uk$/, + 'QSubMemFlag' => '-l pmem=', + }, 'LOCAL' => { 'HOST_REGEXP' => qr/local\./, 'QSubMemFlag' => ' ', -- cgit v1.2.3 From 4c7d79b5aa38514444a45d54697ad93a6bbae539 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 14 Sep 2011 13:12:01 +0100 Subject: fix pro train bug causing it not to optimize when there is no held-out test set --- pro-train/mr_pro_reduce.cc | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 239649c1..e71347ba 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -194,7 +194,6 @@ int main(int argc, char** argv) { InitCommandLine(argc, argv, &conf); string line; vector > > training, testing; - SparseVector old_weights; const bool tune_regularizer = conf.count("tune_regularizer"); if (tune_regularizer && !conf.count("testset")) { cerr << "--tune_regularizer requires --testset to be set\n"; @@ -202,28 +201,28 @@ int main(int argc, char** argv) { } const double min_reg = conf["min_reg"].as(); const double max_reg = conf["max_reg"].as(); - double sigsq = conf["sigma_squared"].as(); + double sigsq = conf["sigma_squared"].as(); // will be overridden if parameter is tuned assert(sigsq > 0.0); assert(min_reg > 0.0); assert(max_reg > 0.0); assert(max_reg > min_reg); const double psi = conf["interpolation"].as(); if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; } - if (conf.count("weights")) { - vector dt; - Weights::InitFromFile(conf["weights"].as(), &dt); - Weights::InitSparseVector(dt, &old_weights); - } ReadCorpus(&cin, &training); if (conf.count("testset")) { ReadFile rf(conf["testset"].as()); ReadCorpus(rf.stream(), &testing); } cerr << "Number of features: " << FD::NumFeats() << endl; - vector x(FD::NumFeats(), 0.0); // x[0] is bias - for (SparseVector::const_iterator it = old_weights.begin(); - it != old_weights.end(); ++it) - x[it->first] = it->second; + + vector x, prev_x; // x[0] is bias + if (conf.count("weights")) { + Weights::InitFromFile(conf["weights"].as(), &x); + prev_x = x; + } + cerr << " Number of features: " << x.size() << endl; + cerr << "Number of training examples: " << training.size() << endl; + cerr << "Number of testing examples: " << testing.size() << endl; double tppl = 0.0; vector > sp; vector smoothed; @@ -255,11 +254,12 @@ int main(int argc, char** argv) { } } sigsq = sp[best_i].first; - tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as(), &x); - } + } // tune regularizer + tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as(), &x); if (conf.count("weights")) { - for (int i = 1; i < x.size(); ++i) - x[i] = (x[i] * psi) + old_weights.get(i) * (1.0 - psi); + for (int i = 1; i < x.size(); ++i) { + x[i] = (x[i] * psi) + prev_x[i] * (1.0 - psi); + } } cout.precision(15); cout << "# sigma^2=" << sigsq << "\theld out perplexity="; -- cgit v1.2.3 From 1ef494a6ece4852d51a437c4f008500535343021 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 14 Sep 2011 14:43:03 +0100 Subject: fix for potential segv with no weights --- pro-train/mr_pro_reduce.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index e71347ba..6b491918 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -100,7 +100,7 @@ double TrainingInference(const vector& x, vector* g = NULL) { double cll = 0; for (int i = 0; i < corpus.size(); ++i) { - const double dotprod = corpus[i].second.dot(x) + x[0]; // x[0] is bias + const double dotprod = corpus[i].second.dot(x) + (x.size() ? x[0] : weight_t()); // x[0] is bias double lp_false = dotprod; double lp_true = -dotprod; if (0 < lp_true) { -- cgit v1.2.3 From b7bff725993ba7ffc960a46db9b75fc570671ab5 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 14 Sep 2011 14:47:06 +0100 Subject: fix for more problems caused by hash refactoring --- pro-train/mr_pro_reduce.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 6b491918..aff410a0 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -218,6 +218,10 @@ int main(int argc, char** argv) { vector x, prev_x; // x[0] is bias if (conf.count("weights")) { Weights::InitFromFile(conf["weights"].as(), &x); + x.resize(FD::NumFeats()); + prev_x = x; + } else { + x.resize(FD::NumFeats()); prev_x = x; } cerr << " Number of features: " << x.size() << endl; -- cgit v1.2.3 From 08f1814923005f702300d661c4d67f4635fc901c Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Thu, 15 Sep 2011 12:52:59 +0100 Subject: script to filter reachable sentences, weight cleanup --- decoder/apply_models.cc | 3 +- decoder/hg.h | 8 +- training/Makefile.am | 10 +- training/cllh_filter_grammar.cc | 197 -------------------------------------- training/mpi_extract_reachable.cc | 163 +++++++++++++++++++++++++++++++ utils/feature_vector.h | 4 +- 6 files changed, 174 insertions(+), 211 deletions(-) delete mode 100644 training/cllh_filter_grammar.cc create mode 100644 training/mpi_extract_reachable.cc diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 26cdb881..40fd27e4 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -276,8 +276,7 @@ public: make_heap(cand.begin(), cand.end(), HeapCandCompare()); State2Node state2node; // "buf" in Figure 2 int pops = 0; - int pop_limit_eff=max(1,int(v.promise*pop_limit_)); - while(!cand.empty() && pops < pop_limit_eff) { + while(!cand.empty() && pops < pop_limit_) { pop_heap(cand.begin(), cand.end(), HeapCandCompare()); Candidate* item = cand.back(); cand.pop_back(); diff --git a/decoder/hg.h b/decoder/hg.h index e5ef05f8..f0ddbb76 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -49,16 +49,14 @@ public: // TODO get rid of cat_? // TODO keep cat_ and add span and/or state? :) struct Node { - Node() : id_(), cat_(), promise(1) {} + Node() : id_(), cat_() {} int id_; // equal to this object's position in the nodes_ vector WordID cat_; // non-terminal category if <0, 0 if not set WordID NT() const { return -cat_; } EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_ EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_ - double promise; // set in global pruning; in [0,infty) so that mean is 1. use: e.g. scale cube poplimit. //TODO: appears to be useless, compile without this? on the other hand, pretty cheap. void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting cat_=o.cat_; - promise=o.promise; } void copy_reindex(Node const& o,indices_after const& n2,indices_after const& e2) { copy_fixed(o); @@ -81,7 +79,7 @@ public: int head_node_; // refers to a position in nodes_ TailNodeVector tail_nodes_; // contents refer to positions in nodes_ TRulePtr rule_; - FeatureVector feature_values_; + SparseVector feature_values_; prob_t edge_prob_; // dot product of weights and feat_values int id_; // equal to this object's position in the edges_ vector @@ -468,7 +466,7 @@ public: /// drop edge i if edge_margin[i] < prune_below, unless preserve_mask[i] void MarginPrune(EdgeProbs const& edge_margin,prob_t prune_below,EdgeMask const* preserve_mask=0,bool safe_inside=false,bool verbose=false); - //TODO: in my opinion, looking at the ratio of logprobs (features \dot weights) rather than the absolute difference generalizes more nicely across sentence lengths and weight vectors that are constant multiples of one another. at least make that an option. i worked around this a little in cdec by making "beam alpha per source word" but that's not helping with different tuning runs. this would also make me more comfortable about allocating Node.promise + //TODO: in my opinion, looking at the ratio of logprobs (features \dot weights) rather than the absolute difference generalizes more nicely across sentence lengths and weight vectors that are constant multiples of one another. at least make that an option. i worked around this a little in cdec by making "beam alpha per source word" but that's not helping with different tuning runs. // beam_alpha=0 means don't beam prune, otherwise drop things that are e^beam_alpha times worse than best - // prunes any edge whose prob_t on the best path taking that edge is more than e^alpha times //density=0 means don't density prune: // for density>=1.0, keep this many times the edges needed for the 1best derivation diff --git a/training/Makefile.am b/training/Makefile.am index 7ceeda34..5752859e 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -9,9 +9,9 @@ bin_PROGRAMS = \ atools \ plftools \ collapse_weights \ - cllh_filter_grammar \ - mpi_online_optimize \ + mpi_extract_reachable \ mpi_extract_features \ + mpi_online_optimize \ mpi_batch_optimize \ compute_cllh \ augment_grammar @@ -25,6 +25,9 @@ TESTS = lbfgs_test optimize_test mpi_online_optimize_SOURCES = mpi_online_optimize.cc online_optimizer.cc mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc +mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + mpi_extract_features_SOURCES = mpi_extract_features.cc mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz @@ -34,9 +37,6 @@ mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/ compute_cllh_SOURCES = compute_cllh.cc compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -cllh_filter_grammar_SOURCES = cllh_filter_grammar.cc -cllh_filter_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - augment_grammar_SOURCES = augment_grammar.cc augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz diff --git a/training/cllh_filter_grammar.cc b/training/cllh_filter_grammar.cc deleted file mode 100644 index 6998ec2b..00000000 --- a/training/cllh_filter_grammar.cc +++ /dev/null @@ -1,197 +0,0 @@ -#include -#include -#include -#include // fork -#include // waitpid - -#include -#include - -#include "tdict.h" -#include "ff_register.h" -#include "verbose.h" -#include "hg.h" -#include "decoder.h" -#include "filelib.h" - -using namespace std; -namespace po = boost::program_options; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("training_data,t",po::value(),"Training data corpus") - ("decoder_config,c",po::value(),"Decoder configuration file") - ("shards,s",po::value()->default_value(1),"Number of shards") - ("starting_shard,S",po::value()->default_value(0), "In this invocation only process shards >= S") - ("work_limit,l",po::value()->default_value(9999), "Process maximially this many shards") - ("ncpus,C",po::value()->default_value(1),"Number of CPUs to use"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* ids) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - int lc = 0; - assert(size > 0); - assert(rank < size); - while(in) { - getline(in, line); - if (!in) break; - if (lc % size == rank) { - c->push_back(line); - ids->push_back(lc); - } - ++lc; - } -} - -struct TrainingObserver : public DecoderObserver { - TrainingObserver() : s_lhs(-TD::Convert("S")), goal_lhs(-TD::Convert("Goal")) {} - - void Reset() { - total_complete = 0; - } - - virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { - state = 1; - used.clear(); - failed = true; - } - - virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 1); - for (int i = 0; i < hg->edges_.size(); ++i) { - const TRule* rule = hg->edges_[i].rule_.get(); - if (rule->lhs_ == s_lhs || rule->lhs_ == goal_lhs) // fragile hack to filter out glue rules - continue; - used.insert(rule); - } - state = 2; - } - - virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 2); - state = 3; - } - - virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { - if (state == 3) { - failed = false; - } else { - failed = true; - } - } - - set used; - - const int s_lhs; - const int goal_lhs; - bool failed; - int total_complete; - int state; -}; - -void work(const string& fname, int rank, int size, Decoder* decoder) { - cerr << "Worker " << rank << '/' << size << " starting.\n"; - vector corpus; - vector ids; - ReadTrainingCorpus(fname, rank, size, &corpus, &ids); - assert(corpus.size() > 0); - assert(corpus.size() == ids.size()); - cerr << " " << rank << '/' << size << ": has " << corpus.size() << " sentences to process\n"; - ostringstream oc; oc << "corpus." << rank << "_of_" << size; - WriteFile foc(oc.str()); - ostringstream og; og << "grammar." << rank << "_of_" << size << ".gz"; - WriteFile fog(og.str()); - - set all_used; - TrainingObserver observer; - for (int i = 0; i < corpus.size(); ++i) { - const int sent_id = ids[i]; - const string& input = corpus[i]; - decoder->SetId(sent_id); - decoder->Decode(input, &observer); - if (observer.failed) { - // do nothing - } else { - (*foc.stream()) << input << endl; - for (set::iterator it = observer.used.begin(); it != observer.used.end(); ++it) { - if (all_used.insert(*it).second) - (*fog.stream()) << **it << endl; - } - } - } -} - -int main(int argc, char** argv) { - register_feature_functions(); - - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - const string fname = conf["training_data"].as(); - const unsigned ncpus = conf["ncpus"].as(); - const unsigned shards = conf["shards"].as(); - const unsigned start = conf["starting_shard"].as(); - const unsigned work_limit = conf["work_limit"].as(); - const unsigned eff_shards = min(start + work_limit, shards); - cerr << "Processing shards " << start << "/" << shards << " to " << eff_shards << "/" << shards << endl; - assert(ncpus > 0); - ReadFile ini_rf(conf["decoder_config"].as()); - Decoder decoder(ini_rf.stream()); - if (decoder.GetConf()["input"].as() != "-") { - cerr << "cdec.ini must not set an input file\n"; - abort(); - } - SetSilent(true); // turn off verbose decoder output - cerr << "Forking " << ncpus << " time(s)\n"; - vector children; - for (int i = 0; i < ncpus; ++i) { - pid_t pid = fork(); - if (pid < 0) { - cerr << "Fork failed!\n"; - exit(1); - } - if (pid > 0) { - children.push_back(pid); - } else { - for (int j = start; j < eff_shards; ++j) { - if (j % ncpus == i) { - cerr << " CPU " << i << " processing shard " << j << endl; - work(fname, j, shards, &decoder); - cerr << " Shard " << j << "/" << shards << " finished.\n"; - } - } - _exit(0); - } - } - for (int i = 0; i < children.size(); ++i) { - int status; - int w = waitpid(children[i], &status, 0); - if (w < 0) { cerr << "Error while waiting for children!"; return 1; } - if (WIFSIGNALED(status)) { - cerr << "Child " << i << " received signal " << WTERMSIG(status) << endl; - if (WTERMSIG(status) == 11) { cerr << " this is a SEGV- you may be trying to print temporarily created rules\n"; } - } - } - return 0; -} diff --git a/training/mpi_extract_reachable.cc b/training/mpi_extract_reachable.cc new file mode 100644 index 00000000..2a7c2b9d --- /dev/null +++ b/training/mpi_extract_reachable.cc @@ -0,0 +1,163 @@ +#include +#include +#include +#include + +#include "config.h" +#ifdef HAVE_MPI +#include +#endif +#include +#include + +#include "ff_register.h" +#include "verbose.h" +#include "filelib.h" +#include "fdict.h" +#include "decoder.h" +#include "weights.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("training_data,t",po::value(),"Training data corpus") + ("decoder_config,c",po::value(),"Decoder configuration file") + ("weights,w", po::value(), "(Optional) weights file; weights may affect what features are encountered in pruning configurations") + ("output_prefix,o",po::value()->default_value("reachable"),"Output path prefix"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { + cerr << "Decode an input set (optionally in parallel using MPI) and write\nout the inputs that produce reachable parallel parses.\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct ReachabilityObserver : public DecoderObserver { + + virtual void NotifyDecodingStart(const SentenceMetadata&) { + reachable = false; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + reachable = true; + } + + bool reachable; +}; + +#ifdef HAVE_MPI +namespace mpi = boost::mpi; +#endif + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return false; + + // load cdec.ini and set up decoder + ReadFile ini_rf(conf["decoder_config"].as()); + Decoder decoder(ini_rf.stream()); + if (decoder.GetConf()["input"].as() != "-") { + cerr << "cdec.ini must not set an input file\n"; + abort(); + } + + if (FD::UsingPerfectHashFunction()) { + cerr << "Your configuration file has enabled a cmph hash function. Please disable.\n"; + return 1; + } + + // load optional weights + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &decoder.CurrentWeightVector()); + + vector corpus; + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); + assert(corpus.size() > 0); + + + if (rank == 0) + cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; + + size_t num_reached = 0; + { + ostringstream os; + os << conf["output_prefix"].as() << '.' << rank << "_of_" << size; + WriteFile wf(os.str()); + ostream& out = *wf.stream(); + ReachabilityObserver observer; + for (int i = 0; i < corpus.size(); ++i) { + decoder.Decode(corpus[i], &observer); + if (observer.reachable) { + out << corpus[i] << endl; + ++num_reached; + } + corpus[i].clear(); + } + cerr << "Shard " << rank << '/' << size << " finished, wrote " + << num_reached << " instances to " << os.str() << endl; + } + + size_t total = 0; +#ifdef HAVE_MPI + reduce(world, num_reached, total, std::plus(), 0); +#else + total = num_reached; +#endif + if (rank == 0) { + cerr << "-----------------------------------------\n"; + cerr << "TOTAL = " << total << " instances\n"; + } + return 0; +} + diff --git a/utils/feature_vector.h b/utils/feature_vector.h index 733aa99e..a7b61a66 100755 --- a/utils/feature_vector.h +++ b/utils/feature_vector.h @@ -3,9 +3,9 @@ #include #include "sparse_vector.h" -#include "fdict.h" +#include "weights.h" -typedef double Featval; +typedef weight_t Featval; typedef SparseVector FeatureVector; typedef SparseVector WeightVector; typedef std::vector DenseWeightVector; -- cgit v1.2.3 From b70a0be1c34bd177e8ac7c53cb466f226008cc52 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Sep 2011 00:20:52 +0100 Subject: add md5 hash function --- utils/stringlib.cc | 363 +++++++++++++++++++++++++++++++++++++++++++++++++++++ utils/stringlib.h | 3 + utils/ts.cc | 6 + 3 files changed, 372 insertions(+) diff --git a/utils/stringlib.cc b/utils/stringlib.cc index ade02ca9..3a56965c 100644 --- a/utils/stringlib.cc +++ b/utils/stringlib.cc @@ -90,3 +90,366 @@ void ProcessAndStripSGML(string* pline, map* out) { } } +string SGMLOpenSegTag(const map& attr) { + ostringstream os; + os << "::const_iterator it = attr.begin(); it != attr.end(); ++it) + os << ' ' << it->first << '=' << '"' << it->second << '"'; + os << '>'; + return os.str(); +} + +class MD5 { +public: + typedef unsigned int size_type; // must be 32bit + + MD5(); + MD5(const std::string& text); + void update(const unsigned char *buf, size_type length); + void update(const char *buf, size_type length); + MD5& finalize(); + std::string hexdigest() const; + +private: + void init(); + typedef unsigned char uint1; // 8bit + typedef unsigned int uint4; // 32bit + enum {blocksize = 64}; // VC6 won't eat a const static int here + + void transform(const uint1 block[blocksize]); + static void decode(uint4 output[], const uint1 input[], size_type len); + static void encode(uint1 output[], const uint4 input[], size_type len); + + bool finalized; + uint1 buffer[blocksize]; // bytes that didn't fit in last 64 byte chunk + uint4 count[2]; // 64bit counter for number of bits (lo, hi) + uint4 state[4]; // digest so far + uint1 digest[16]; // the result + + // low level logic operations + static inline uint4 F(uint4 x, uint4 y, uint4 z); + static inline uint4 G(uint4 x, uint4 y, uint4 z); + static inline uint4 H(uint4 x, uint4 y, uint4 z); + static inline uint4 I(uint4 x, uint4 y, uint4 z); + static inline uint4 rotate_left(uint4 x, int n); + static inline void FF(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac); + static inline void GG(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac); + static inline void HH(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac); + static inline void II(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac); +}; + +// Constants for MD5Transform routine. +#define S11 7 +#define S12 12 +#define S13 17 +#define S14 22 +#define S21 5 +#define S22 9 +#define S23 14 +#define S24 20 +#define S31 4 +#define S32 11 +#define S33 16 +#define S34 23 +#define S41 6 +#define S42 10 +#define S43 15 +#define S44 21 + +/////////////////////////////////////////////// + +// F, G, H and I are basic MD5 functions. +inline MD5::uint4 MD5::F(uint4 x, uint4 y, uint4 z) { + return (x&y) | (~x&z); +} + +inline MD5::uint4 MD5::G(uint4 x, uint4 y, uint4 z) { + return (x&z) | (y&~z); +} + +inline MD5::uint4 MD5::H(uint4 x, uint4 y, uint4 z) { + return x^y^z; +} + +inline MD5::uint4 MD5::I(uint4 x, uint4 y, uint4 z) { + return y ^ (x | ~z); +} + +// rotate_left rotates x left n bits. +inline MD5::uint4 MD5::rotate_left(uint4 x, int n) { + return (x << n) | (x >> (32-n)); +} + +// FF, GG, HH, and II transformations for rounds 1, 2, 3, and 4. +// Rotation is separate from addition to prevent recomputation. +inline void MD5::FF(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac) { + a = rotate_left(a+ F(b,c,d) + x + ac, s) + b; +} + +inline void MD5::GG(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac) { + a = rotate_left(a + G(b,c,d) + x + ac, s) + b; +} + +inline void MD5::HH(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac) { + a = rotate_left(a + H(b,c,d) + x + ac, s) + b; +} + +inline void MD5::II(uint4 &a, uint4 b, uint4 c, uint4 d, uint4 x, uint4 s, uint4 ac) { + a = rotate_left(a + I(b,c,d) + x + ac, s) + b; +} + +////////////////////////////////////////////// + +// default ctor, just initailize +MD5::MD5() +{ + init(); +} + +////////////////////////////////////////////// + +// nifty shortcut ctor, compute MD5 for string and finalize it right away +MD5::MD5(const std::string &text) +{ + init(); + update(text.c_str(), text.length()); + finalize(); +} + +////////////////////////////// + +void MD5::init() +{ + finalized=false; + + count[0] = 0; + count[1] = 0; + + // load magic initialization constants. + state[0] = 0x67452301; + state[1] = 0xefcdab89; + state[2] = 0x98badcfe; + state[3] = 0x10325476; +} + +////////////////////////////// + +// decodes input (unsigned char) into output (uint4). Assumes len is a multiple of 4. +void MD5::decode(uint4 output[], const uint1 input[], size_type len) +{ + for (unsigned int i = 0, j = 0; j < len; i++, j += 4) + output[i] = ((uint4)input[j]) | (((uint4)input[j+1]) << 8) | + (((uint4)input[j+2]) << 16) | (((uint4)input[j+3]) << 24); +} + +////////////////////////////// + +// encodes input (uint4) into output (unsigned char). Assumes len is +// a multiple of 4. +void MD5::encode(uint1 output[], const uint4 input[], size_type len) +{ + for (size_type i = 0, j = 0; j < len; i++, j += 4) { + output[j] = input[i] & 0xff; + output[j+1] = (input[i] >> 8) & 0xff; + output[j+2] = (input[i] >> 16) & 0xff; + output[j+3] = (input[i] >> 24) & 0xff; + } +} + +////////////////////////////// + +// apply MD5 algo on a block +void MD5::transform(const uint1 block[blocksize]) +{ + uint4 a = state[0], b = state[1], c = state[2], d = state[3], x[16]; + decode (x, block, blocksize); + + /* Round 1 */ + FF (a, b, c, d, x[ 0], S11, 0xd76aa478); /* 1 */ + FF (d, a, b, c, x[ 1], S12, 0xe8c7b756); /* 2 */ + FF (c, d, a, b, x[ 2], S13, 0x242070db); /* 3 */ + FF (b, c, d, a, x[ 3], S14, 0xc1bdceee); /* 4 */ + FF (a, b, c, d, x[ 4], S11, 0xf57c0faf); /* 5 */ + FF (d, a, b, c, x[ 5], S12, 0x4787c62a); /* 6 */ + FF (c, d, a, b, x[ 6], S13, 0xa8304613); /* 7 */ + FF (b, c, d, a, x[ 7], S14, 0xfd469501); /* 8 */ + FF (a, b, c, d, x[ 8], S11, 0x698098d8); /* 9 */ + FF (d, a, b, c, x[ 9], S12, 0x8b44f7af); /* 10 */ + FF (c, d, a, b, x[10], S13, 0xffff5bb1); /* 11 */ + FF (b, c, d, a, x[11], S14, 0x895cd7be); /* 12 */ + FF (a, b, c, d, x[12], S11, 0x6b901122); /* 13 */ + FF (d, a, b, c, x[13], S12, 0xfd987193); /* 14 */ + FF (c, d, a, b, x[14], S13, 0xa679438e); /* 15 */ + FF (b, c, d, a, x[15], S14, 0x49b40821); /* 16 */ + + /* Round 2 */ + GG (a, b, c, d, x[ 1], S21, 0xf61e2562); /* 17 */ + GG (d, a, b, c, x[ 6], S22, 0xc040b340); /* 18 */ + GG (c, d, a, b, x[11], S23, 0x265e5a51); /* 19 */ + GG (b, c, d, a, x[ 0], S24, 0xe9b6c7aa); /* 20 */ + GG (a, b, c, d, x[ 5], S21, 0xd62f105d); /* 21 */ + GG (d, a, b, c, x[10], S22, 0x2441453); /* 22 */ + GG (c, d, a, b, x[15], S23, 0xd8a1e681); /* 23 */ + GG (b, c, d, a, x[ 4], S24, 0xe7d3fbc8); /* 24 */ + GG (a, b, c, d, x[ 9], S21, 0x21e1cde6); /* 25 */ + GG (d, a, b, c, x[14], S22, 0xc33707d6); /* 26 */ + GG (c, d, a, b, x[ 3], S23, 0xf4d50d87); /* 27 */ + GG (b, c, d, a, x[ 8], S24, 0x455a14ed); /* 28 */ + GG (a, b, c, d, x[13], S21, 0xa9e3e905); /* 29 */ + GG (d, a, b, c, x[ 2], S22, 0xfcefa3f8); /* 30 */ + GG (c, d, a, b, x[ 7], S23, 0x676f02d9); /* 31 */ + GG (b, c, d, a, x[12], S24, 0x8d2a4c8a); /* 32 */ + + /* Round 3 */ + HH (a, b, c, d, x[ 5], S31, 0xfffa3942); /* 33 */ + HH (d, a, b, c, x[ 8], S32, 0x8771f681); /* 34 */ + HH (c, d, a, b, x[11], S33, 0x6d9d6122); /* 35 */ + HH (b, c, d, a, x[14], S34, 0xfde5380c); /* 36 */ + HH (a, b, c, d, x[ 1], S31, 0xa4beea44); /* 37 */ + HH (d, a, b, c, x[ 4], S32, 0x4bdecfa9); /* 38 */ + HH (c, d, a, b, x[ 7], S33, 0xf6bb4b60); /* 39 */ + HH (b, c, d, a, x[10], S34, 0xbebfbc70); /* 40 */ + HH (a, b, c, d, x[13], S31, 0x289b7ec6); /* 41 */ + HH (d, a, b, c, x[ 0], S32, 0xeaa127fa); /* 42 */ + HH (c, d, a, b, x[ 3], S33, 0xd4ef3085); /* 43 */ + HH (b, c, d, a, x[ 6], S34, 0x4881d05); /* 44 */ + HH (a, b, c, d, x[ 9], S31, 0xd9d4d039); /* 45 */ + HH (d, a, b, c, x[12], S32, 0xe6db99e5); /* 46 */ + HH (c, d, a, b, x[15], S33, 0x1fa27cf8); /* 47 */ + HH (b, c, d, a, x[ 2], S34, 0xc4ac5665); /* 48 */ + + /* Round 4 */ + II (a, b, c, d, x[ 0], S41, 0xf4292244); /* 49 */ + II (d, a, b, c, x[ 7], S42, 0x432aff97); /* 50 */ + II (c, d, a, b, x[14], S43, 0xab9423a7); /* 51 */ + II (b, c, d, a, x[ 5], S44, 0xfc93a039); /* 52 */ + II (a, b, c, d, x[12], S41, 0x655b59c3); /* 53 */ + II (d, a, b, c, x[ 3], S42, 0x8f0ccc92); /* 54 */ + II (c, d, a, b, x[10], S43, 0xffeff47d); /* 55 */ + II (b, c, d, a, x[ 1], S44, 0x85845dd1); /* 56 */ + II (a, b, c, d, x[ 8], S41, 0x6fa87e4f); /* 57 */ + II (d, a, b, c, x[15], S42, 0xfe2ce6e0); /* 58 */ + II (c, d, a, b, x[ 6], S43, 0xa3014314); /* 59 */ + II (b, c, d, a, x[13], S44, 0x4e0811a1); /* 60 */ + II (a, b, c, d, x[ 4], S41, 0xf7537e82); /* 61 */ + II (d, a, b, c, x[11], S42, 0xbd3af235); /* 62 */ + II (c, d, a, b, x[ 2], S43, 0x2ad7d2bb); /* 63 */ + II (b, c, d, a, x[ 9], S44, 0xeb86d391); /* 64 */ + + state[0] += a; + state[1] += b; + state[2] += c; + state[3] += d; + + // Zeroize sensitive information. + memset(x, 0, sizeof x); +} + +////////////////////////////// + +// MD5 block update operation. Continues an MD5 message-digest +// operation, processing another message block +void MD5::update(const unsigned char input[], size_type length) +{ + // compute number of bytes mod 64 + size_type index = count[0] / 8 % blocksize; + + // Update number of bits + if ((count[0] += (length << 3)) < (length << 3)) + count[1]++; + count[1] += (length >> 29); + + // number of bytes we need to fill in buffer + size_type firstpart = 64 - index; + + size_type i; + + // transform as many times as possible. + if (length >= firstpart) + { + // fill buffer first, transform + memcpy(&buffer[index], input, firstpart); + transform(buffer); + + // transform chunks of blocksize (64 bytes) + for (i = firstpart; i + blocksize <= length; i += blocksize) + transform(&input[i]); + + index = 0; + } + else + i = 0; + + // buffer remaining input + memcpy(&buffer[index], &input[i], length-i); +} + +////////////////////////////// + +// for convenience provide a verson with signed char +void MD5::update(const char input[], size_type length) +{ + update((const unsigned char*)input, length); +} + +////////////////////////////// + +// MD5 finalization. Ends an MD5 message-digest operation, writing the +// the message digest and zeroizing the context. +MD5& MD5::finalize() +{ + static unsigned char padding[64] = { + 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }; + + if (!finalized) { + // Save number of bits + unsigned char bits[8]; + encode(bits, count, 8); + + // pad out to 56 mod 64. + size_type index = count[0] / 8 % 64; + size_type padLen = (index < 56) ? (56 - index) : (120 - index); + update(padding, padLen); + + // Append length (before padding) + update(bits, 8); + + // Store state in digest + encode(digest, state, 16); + + // Zeroize sensitive information. + memset(buffer, 0, sizeof buffer); + memset(count, 0, sizeof count); + + finalized=true; + } + + return *this; +} + +////////////////////////////// + +// return hex representation of digest as string +std::string MD5::hexdigest() const +{ + if (!finalized) + return ""; + + char buf[33]; + for (int i=0; i<16; i++) + sprintf(buf+i*2, "%02x", digest[i]); + buf[32]=0; + + return std::string(buf); +} + +////////////////////////////// + +std::string md5(const std::string& str) { + MD5 md5 = MD5(str); + return md5.hexdigest(); +} + diff --git a/utils/stringlib.h b/utils/stringlib.h index 8022bb88..cafbdac3 100644 --- a/utils/stringlib.h +++ b/utils/stringlib.h @@ -249,6 +249,7 @@ inline void SplitCommandAndParam(const std::string& in, std::string* cmd, std::s } void ProcessAndStripSGML(std::string* line, std::map* out); +std::string SGMLOpenSegTag(const std::map& attr); // given the first character of a UTF8 block, find out how wide it is // see http://en.wikipedia.org/wiki/UTF-8 for more info @@ -260,4 +261,6 @@ inline unsigned int UTF8Len(unsigned char x) { else return 0; } +std::string md5(const std::string& in); + #endif diff --git a/utils/ts.cc b/utils/ts.cc index 3694e076..bf4f8f69 100644 --- a/utils/ts.cc +++ b/utils/ts.cc @@ -7,6 +7,7 @@ #include "prob.h" #include "sparse_vector.h" #include "fast_sparse_vector.h" +#include "stringlib.h" using namespace std; @@ -79,6 +80,11 @@ int main() { y -= y; } cerr << "Counted " << c << " times\n"; + + cerr << md5("this is a test") << endl; + cerr << md5("some other ||| string is") << endl; + map x; x["id"] = "12"; x["grammar"] = "/path/to/grammar.gz"; + cerr << SGMLOpenSegTag(x) << endl; return 0; } -- cgit v1.2.3 From a28c48d07df4e426a875f5381c80ebf4fbbd1de2 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Sep 2011 01:08:45 +0100 Subject: enable ramdisk scratch for per-sentence-grammars --- training/mpi_batch_optimize.cc | 35 +++++++++++++++++++++++++++++++++++ utils/filelib.cc | 19 +++++++++++++++++++ utils/filelib.h | 5 +---- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index cc5953f6..0ba8c530 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -22,6 +22,7 @@ namespace mpi = boost::mpi; #include "ff_register.h" #include "decoder.h" #include "filelib.h" +#include "stringlib.h" #include "optimize.h" #include "fdict.h" #include "weights.h" @@ -42,6 +43,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("correction_buffers,M", po::value()->default_value(10), "Number of gradients for LBFGS to maintain in memory") ("gaussian_prior,p","Use a Gaussian prior on the weights") ("means,u", po::value(), "File containing the means for Gaussian prior") + ("per_sentence_grammar_scratch,P", po::value(), "(Optional) location of scratch space to copy per-sentence grammars for fast access, useful if a RAM disk is available") ("sigma_squared", po::value()->default_value(1.0), "Sigma squared term for spherical Gaussian prior"); po::options_description clo("Command line options"); clo.add_options() @@ -186,6 +188,36 @@ struct VectorPlus : public binary_function, vector, vector > { } }; +void MovePerSentenceGrammars(const string& root, int size, int rank, vector* c) { + if (!DirectoryExists(root)) { + cerr << "Can't find scratch space at " << root << endl; + abort(); + } + ostringstream os; + os << root << "/psg." << size << "_of_" << rank; + const string path = os.str(); + MkDirP(path); + string sent; + map attr; + for (unsigned i = 0; i < c->size(); ++i) { + sent = (*c)[i]; + attr.clear(); + ProcessAndStripSGML(&sent, &attr); + map::iterator it = attr.find("grammar"); + if (it != attr.end()) { + string src_file = it->second; + bool is_gzipped = (src_file.size() > 3) && (src_file.rfind(".gz") == (src_file.size() - 3)); + string new_name = path + "/" + md5(sent); + if (is_gzipped) new_name += ".gz"; + CopyFile(src_file, new_name); + it->second = new_name; + } + ostringstream ns; + ns << SGMLOpenSegTag(attr) << ' ' << sent << " "; + (*c)[i] = ns.str(); + } +} + int main(int argc, char** argv) { #ifdef HAVE_MPI mpi::environment env(argc, argv); @@ -257,6 +289,9 @@ int main(int argc, char** argv) { ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); assert(corpus.size() > 0); + if (conf.count("per_sentence_grammar_scratch")) + MovePerSentenceGrammars(conf["per_sentence_grammar_scratch"].as(), rank, size, &corpus); + TrainingObserver observer; while (!converged) { observer.Reset(); diff --git a/utils/filelib.cc b/utils/filelib.cc index a0969b1a..d206fc19 100644 --- a/utils/filelib.cc +++ b/utils/filelib.cc @@ -2,6 +2,12 @@ #include #include +#include +#include +#include +#include +#include +#include using namespace std; @@ -32,3 +38,16 @@ void MkDirP(const string& dir) { } } +#if 0 +void CopyFile(const string& inf, const string& outf) { + WriteFile w(outf); + CopyFile(inf,*w); +} +#else +void CopyFile(const string& inf, const string& outf) { + ofstream of(outf.c_str(), fstream::trunc|fstream::binary); + ifstream in(inf.c_str(), fstream::binary); + of << in.rdbuf(); +} +#endif + diff --git a/utils/filelib.h b/utils/filelib.h index a8622246..bb6e7415 100644 --- a/utils/filelib.h +++ b/utils/filelib.h @@ -113,9 +113,6 @@ inline void CopyFile(std::string const& inf,std::ostream &out) { CopyFile(*r,out); } -inline void CopyFile(std::string const& inf,std::string const& outf) { - WriteFile w(outf); - CopyFile(inf,*w); -} +void CopyFile(std::string const& inf,std::string const& outf); #endif -- cgit v1.2.3 From 1afbff874473c79619ce74cdf90f3c312185e4e1 Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Sat, 17 Sep 2011 01:39:07 +0100 Subject: add dep --- training/cluster-em.pl | 114 ---------------- training/cluster-ptrain.pl | 206 ----------------------------- training/compute_cllh.cc | 196 --------------------------- training/make-lexcrf-grammar.pl | 285 ---------------------------------------- training/mpi_compute_cllh.cc | 196 +++++++++++++++++++++++++++ utils/stringlib.cc | 14 +- 6 files changed, 203 insertions(+), 808 deletions(-) delete mode 100755 training/cluster-em.pl delete mode 100755 training/cluster-ptrain.pl delete mode 100644 training/compute_cllh.cc delete mode 100755 training/make-lexcrf-grammar.pl create mode 100644 training/mpi_compute_cllh.cc diff --git a/training/cluster-em.pl b/training/cluster-em.pl deleted file mode 100755 index 267ab642..00000000 --- a/training/cluster-em.pl +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/perl -w - -use strict; -my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR; } -use Getopt::Long; -my $parallel = 0; - -my $CWD=`pwd`; chomp $CWD; -my $BIN_DIR = "$CWD/.."; -my $REDUCER = "$BIN_DIR/training/mr_em_adapted_reduce"; -my $REDUCE2WEIGHTS = "$BIN_DIR/training/mr_reduce_to_weights"; -my $ADAPTER = "$BIN_DIR/training/mr_em_map_adapter"; -my $DECODER = "$BIN_DIR/decoder/cdec"; -my $COMBINER_CACHE_SIZE = 10000000; -my $PARALLEL = "/chomes/redpony/svn-trunk/sa-utils/parallelize.pl"; -die "Can't find $REDUCER" unless -f $REDUCER; -die "Can't execute $REDUCER" unless -x $REDUCER; -die "Can't find $REDUCE2WEIGHTS" unless -f $REDUCE2WEIGHTS; -die "Can't execute $REDUCE2WEIGHTS" unless -x $REDUCE2WEIGHTS; -die "Can't find $ADAPTER" unless -f $ADAPTER; -die "Can't execute $ADAPTER" unless -x $ADAPTER; -die "Can't find $DECODER" unless -f $DECODER; -die "Can't execute $DECODER" unless -x $DECODER; -my $restart = ''; -if ($ARGV[0] && $ARGV[0] eq '--restart') { shift @ARGV; $restart = 1; } - -die "Usage: $0 [--restart] training.corpus cdec.ini\n" unless (scalar @ARGV == 2); - -my $training_corpus = shift @ARGV; -my $config = shift @ARGV; -my $pmem="2500mb"; -my $nodes = 40; -my $max_iteration = 1000; -my $CFLAG = "-C 1"; -if ($parallel) { - die "Can't find $PARALLEL" unless -f $PARALLEL; - die "Can't execute $PARALLEL" unless -x $PARALLEL; -} else { $CFLAG = "-C 500"; } - -my $initial_weights = ''; - -print STDERR < \$DECODER, - "distributed" => \$DISTRIBUTED, - "sigma_squared=f" => \$sigsq, - "lbfgs_memory_buffers=i" => \$mem_buffers, - "max_iteration=i" => \$max_iteration, - "means=s" => \$means_file, - "optimizer=s" => \$OALG, - "gaussian_prior" => \$PRIOR, - "restart_if_necessary" => \$RESTART_IF_NECESSARY, - "jobs=i" => \$nodes, - "pmem=s" => \$pmem - ) or usage(); -usage() unless scalar @ARGV==3; -my $config_file = shift @ARGV; -my $training_corpus = shift @ARGV; -my $initial_weights = shift @ARGV; -unless ($DISTRIBUTED) { $LOCAL = 1; } -die "Can't find $config_file" unless -f $config_file; -die "Can't find $DECODER" unless -f $DECODER; -die "Can't execute $DECODER" unless -x $DECODER; -if ($LOCAL) { print STDERR "Will run LOCALLY.\n"; $parallel = 0; } -if ($PRIOR) { - $PRIOR_FLAG="-p --sigma_squared $sigsq"; - if ($means_file) { $PRIOR_FLAG .= " -u $means_file"; } -} - -if ($parallel) { - die "Can't find $PARALLEL" unless -f $PARALLEL; - die "Can't execute $PARALLEL" unless -x $PARALLEL; -} -unless ($parallel) { $CFLAG = "-C 500"; } -unless ($config_file =~ /^\//) { $config_file = $CWD . '/' . $config_file; } -my $clines = num_lines($training_corpus); -my $dir = "$CWD/ptrain"; - -if ($RESTART_IF_NECESSARY && -d $dir) { - $restart = 1; -} - -print STDERR <$dir/training.in"; - my $lc = 0; - while() { - chomp; - s/^\s+//; - s/\s+$//; - die "Expected A ||| B in input file" unless / \|\|\| /; - print TO "$_\n"; - $lc++; - } - close T; - close TO; -} -$training_corpus = "$dir/training.in"; - -my $iter_attempts = 1; -while ($iter < $max_iteration) { - my $cur_time = `date`; chomp $cur_time; - print STDERR "\nStarting iteration $iter...\n"; - print STDERR " time: $cur_time\n"; - my $start = time; - my $next_iter = $iter + 1; - my $dec_cmd="$DECODER -G $CFLAG -c $config_file -w $dir/weights.$iter.gz < $training_corpus 2> $dir/deco.log.$iter"; - my $opt_cmd = "$OPTIMIZER $PRIOR_FLAG -M $mem_buffers $OALG -s $dir/opt.state -i $dir/weights.$iter.gz -o $dir/weights.$next_iter.gz"; - my $pcmd = "$PARALLEL -e $dir/err -p $pmem --nodelist \"$nodelist\" -- "; - my $cmd = ""; - if ($parallel) { $cmd = $pcmd; } - $cmd .= "$dec_cmd | $opt_cmd"; - - print STDERR "EXECUTING: $cmd\n"; - my $result = `$cmd`; - my $exit_code = $? >> 8; - if ($exit_code == 99) { - $iter_attempts++; - if ($iter_attempts > $MAX_ITER_ATTEMPTS) { - die "Received restart request $iter_attempts times from optimizer, giving up\n"; - } - print STDERR "Function evaluation failed, retrying (attempt $iter_attempts)\n"; - next; - } - if ($? != 0) { - die "Error running iteration $iter: $!"; - } - chomp $result; - my $end = time; - my $diff = ($end - $start); - print STDERR " ITERATION $iter TOOK $diff SECONDS\n"; - $iter = $next_iter; - if ($result =~ /1$/) { - print STDERR "Training converged.\n"; - last; - } - $iter_attempts = 1; -} - -print "FINAL WEIGHTS: $dir/weights.$iter\n"; -`mv $dir/weights.$iter.gz $dir/weights.final.gz`; - -sub usage { - die <) { $lines++; } - close $fh; - return $lines; -} diff --git a/training/compute_cllh.cc b/training/compute_cllh.cc deleted file mode 100644 index b496d196..00000000 --- a/training/compute_cllh.cc +++ /dev/null @@ -1,196 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "config.h" -#ifdef HAVE_MPI -#include -#endif -#include -#include - -#include "verbose.h" -#include "hg.h" -#include "prob.h" -#include "inside_outside.h" -#include "ff_register.h" -#include "decoder.h" -#include "filelib.h" -#include "weights.h" - -using namespace std; -namespace po = boost::program_options; - -bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("weights,w",po::value(),"Input feature weights file") - ("training_data,t",po::value(),"Training data corpus") - ("decoder_config,c",po::value(),"Decoder configuration file"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { - cerr << dcmdline_options << endl; - return false; - } - return true; -} - -void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* ids) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - int lc = 0; - while(in) { - getline(in, line); - if (!in) break; - if (lc % size == rank) { - c->push_back(line); - ids->push_back(lc); - } - ++lc; - } -} - -static const double kMINUS_EPSILON = -1e-6; - -struct TrainingObserver : public DecoderObserver { - void Reset() { - acc_obj = 0; - } - - virtual void NotifyDecodingStart(const SentenceMetadata&) { - cur_obj = 0; - state = 1; - } - - // compute model expectations, denominator of objective - virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { - assert(state == 1); - state = 2; - SparseVector cur_model_exp; - const prob_t z = InsideOutside, - EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); - cur_obj = log(z); - } - - // compute "empirical" expectations, numerator of objective - virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 2); - state = 3; - SparseVector ref_exp; - const prob_t ref_z = InsideOutside, - EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); - - double log_ref_z; -#if 0 - if (crf_uniform_empirical) { - log_ref_z = ref_exp.dot(feature_weights); - } else { - log_ref_z = log(ref_z); - } -#else - log_ref_z = log(ref_z); -#endif - - // rounding errors means that <0 is too strict - if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { - cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; - exit(1); - } - assert(!isnan(log_ref_z)); - acc_obj += (cur_obj - log_ref_z); - } - - double acc_obj; - double cur_obj; - int state; -}; - -#ifdef HAVE_MPI -namespace mpi = boost::mpi; -#endif - -int main(int argc, char** argv) { -#ifdef HAVE_MPI - mpi::environment env(argc, argv); - mpi::communicator world; - const int size = world.size(); - const int rank = world.rank(); -#else - const int size = 1; - const int rank = 0; -#endif - if (size > 1) SetSilent(true); // turn off verbose decoder output - register_feature_functions(); - - po::variables_map conf; - if (!InitCommandLine(argc, argv, &conf)) - return false; - - // load cdec.ini and set up decoder - ReadFile ini_rf(conf["decoder_config"].as()); - Decoder decoder(ini_rf.stream()); - if (decoder.GetConf()["input"].as() != "-") { - cerr << "cdec.ini must not set an input file\n"; - abort(); - } - - // load weights - vector& weights = decoder.CurrentWeightVector(); - if (conf.count("weights")) - Weights::InitFromFile(conf["weights"].as(), &weights); - - // freeze feature set - //const bool freeze_feature_set = conf.count("freeze_feature_set"); - //if (freeze_feature_set) FD::Freeze(); - - vector corpus; vector ids; - ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); - assert(corpus.size() > 0); - assert(corpus.size() == ids.size()); - - TrainingObserver observer; - double objective = 0; - - observer.Reset(); - if (rank == 0) - cerr << "Each processor is decoding " << corpus.size() << " training examples...\n"; - - for (int i = 0; i < corpus.size(); ++i) { - decoder.SetId(ids[i]); - decoder.Decode(corpus[i], &observer); - } - -#ifdef HAVE_MPI - reduce(world, observer.acc_obj, objective, std::plus(), 0); -#else - objective = observer.acc_obj; -#endif - - if (rank == 0) - cout << "OBJECTIVE: " << objective << endl; - - return 0; -} - diff --git a/training/make-lexcrf-grammar.pl b/training/make-lexcrf-grammar.pl deleted file mode 100755 index 8cdf7718..00000000 --- a/training/make-lexcrf-grammar.pl +++ /dev/null @@ -1,285 +0,0 @@ -#!/usr/bin/perl -w -use utf8; -use strict; -my ($effile, $model1) = @ARGV; -die "Usage: $0 corpus.fr-en corpus.model1\n" unless $effile && -f $effile && $model1 && -f $model1; - -open EF, "<$effile" or die; -open M1, "<$model1" or die; -binmode(EF,":utf8"); -binmode(M1,":utf8"); -binmode(STDOUT,":utf8"); -my %model1; -while() { - chomp; - my ($f, $e, $lp) = split /\s+/; - $model1{$f}->{$e} = $lp; -} - -my $ADD_MODEL1 = 0; # found that model1 hurts performance -my $IS_FRENCH_F = 1; # indicates that the f language is french -my $IS_ARABIC_F = 0; # indicates that the f language is arabic -my $IS_URDU_F = 0; # indicates that the f language is arabic -my $ADD_PREFIX_ID = 0; -my $ADD_LEN = 1; -my $ADD_SIM = 1; -my $ADD_DICE = 1; -my $ADD_111 = 1; -my $ADD_ID = 1; -my $ADD_PUNC = 1; -my $ADD_NUM_MM = 1; -my $ADD_NULL = 1; -my $ADD_STEM_ID = 1; -my $BEAM_RATIO = 50; - -my %fdict; -my %fcounts; -my %ecounts; - -my %sdict; - -while() { - chomp; - my ($f, $e) = split /\s*\|\|\|\s*/; - my @es = split /\s+/, $e; - my @fs = split /\s+/, $f; - for my $ew (@es){ $ecounts{$ew}++; } - push @fs, '' if $ADD_NULL; - for my $fw (@fs){ $fcounts{$fw}++; } - for my $fw (@fs){ - for my $ew (@es){ - $fdict{$fw}->{$ew}++; - } - } -} - -print STDERR "Dice 0\n" if $ADD_DICE; -print STDERR "OneOneOne 0\nId_OneOneOne 0\n" if $ADD_111; -print STDERR "Identical 0\n" if $ADD_ID; -print STDERR "PuncMiss 0\n" if $ADD_PUNC; -print STDERR "IsNull 0\n" if $ADD_NULL; -print STDERR "Model1 0\n" if $ADD_MODEL1; -print STDERR "DLen 0\n" if $ADD_LEN; -print STDERR "NumMM 0\nNumMatch 0\n" if $ADD_NUM_MM; -print STDERR "OrthoSim 0\n" if $ADD_SIM; -print STDERR "PfxIdentical 0\n" if ($ADD_PREFIX_ID); -my $fc = 1000000; -my $sids = 1000000; -for my $f (sort keys %fdict) { - my $re = $fdict{$f}; - my $max; - for my $e (sort {$re->{$b} <=> $re->{$a}} keys %$re) { - my $efcount = $re->{$e}; - unless (defined $max) { $max = $efcount; } - my $m1 = $model1{$f}->{$e}; - unless (defined $m1) { next; } - $fc++; - my $dice = 2 * $efcount / ($ecounts{$e} + $fcounts{$f}); - my $feats = "F$fc=1"; - my $oe = $e; - my $of = $f; # normalized form - if ($IS_FRENCH_F) { - # see http://en.wikipedia.org/wiki/Use_of_the_circumflex_in_French - $of =~ s/â/as/g; - $of =~ s/ê/es/g; - $of =~ s/î/is/g; - $of =~ s/ô/os/g; - $of =~ s/û/us/g; - } elsif ($IS_ARABIC_F) { - if (length($of) > 1 && !($of =~ /\d/)) { - $of =~ s/\$/sh/g; - } - } elsif ($IS_URDU_F) { - if (length($of) > 1 && !($of =~ /\d/)) { - $of =~ s/\$/sh/g; - } - $oe =~ s/^-e-//; - $oe =~ s/^al-/al/; - $of =~ s/([a-z])\~/$1$1/g; - $of =~ s/E/'/g; - $of =~ s/^Aw/o/g; - $of =~ s/\|/a/g; - $of =~ s/@/h/g; - $of =~ s/c/ch/g; - $of =~ s/x/kh/g; - $of =~ s/\*/dh/g; - $of =~ s/w/o/g; - $of =~ s/Z/dh/g; - $of =~ s/y/i/g; - $of =~ s/Y/a/g; - $of = lc $of; - } - my $len_e = length($oe); - my $len_f = length($of); - $feats .= " Model1=$m1" if ($ADD_MODEL1); - $feats .= " Dice=$dice" if $ADD_DICE; - my $is_null = undef; - if ($ADD_NULL && $f eq '') { - $feats .= " IsNull=1"; - $is_null = 1; - } - if ($ADD_LEN) { - if (!$is_null) { - my $dlen = abs($len_e - $len_f); - $feats .= " DLen=$dlen"; - } - } - my $f_num = ($of =~ /^-?\d[0-9\.\,]+%?$/ && (length($of) > 3)); - my $e_num = ($oe =~ /^-?\d[0-9\.\,]+%?$/ && (length($oe) > 3)); - my $both_non_numeric = (!$e_num && !$f_num); - if ($ADD_NUM_MM && (($f_num && !$e_num) || ($e_num && !$f_num))) { - $feats .= " NumMM=1"; - } - if ($ADD_NUM_MM && ($f_num && $e_num) && ($oe eq $of)) { - $feats .= " NumMatch=1"; - } - if ($ADD_STEM_ID) { - my $el = 4; - my $fl = 4; - if ($oe =~ /^al|re|co/) { $el++; } - if ($of =~ /^al|re|co/) { $fl++; } - if ($oe =~ /^trans|inter/) { $el+=2; } - if ($of =~ /^trans|inter/) { $fl+=2; } - if ($fl > length($of)) { $fl = length($of); } - if ($el > length($oe)) { $el = length($oe); } - my $sf = substr $of, 0, $fl; - my $se = substr $oe, 0, $el; - my $id = $sdict{$sf}->{$se}; - if (!$id) { - $sids++; - $sdict{$sf}->{$se} = $sids; - $id = $sids; - print STDERR "S$sids 0\n" - } - $feats .= " S$id=1"; - } - if ($ADD_PREFIX_ID) { - if ($len_e > 3 && $len_f > 3 && $both_non_numeric) { - my $pe = substr $oe, 0, 3; - my $pf = substr $of, 0, 3; - if ($pe eq $pf) { $feats .= " PfxIdentical=1"; } - } - } - if ($ADD_SIM) { - my $ld = 0; - my $eff = $len_e; - if ($eff < $len_f) { $eff = $len_f; } - if (!$is_null) { - $ld = ($eff - levenshtein($oe, $of)) / sqrt($eff); - } - $feats .= " OrthoSim=$ld"; - } - my $ident = ($e eq $f); - if ($ident && $ADD_ID) { $feats .= " Identical=1"; } - if ($ADD_111 && ($efcount == 1 && $ecounts{$e} == 1 && $fcounts{$f} == 1)) { - if ($ident && $ADD_ID) { - $feats .= " Id_OneOneOne=1"; - } - $feats .= " OneOneOne=1"; - } - if ($ADD_PUNC) { - if (($f =~ /^[0-9!\$%,\-\/"':;=+?.()«»]+$/ && $e =~ /[a-z]+/) || - ($e =~ /^[0-9!\$%,\-\/"':;=+?.()«»]+$/ && $f =~ /[a-z]+/)) { - $feats .= " PuncMiss=1"; - } - } - my $r = (0.5 - rand)/5; - print STDERR "F$fc $r\n"; - print "$f ||| $e ||| $feats\n"; - } -} - -sub levenshtein -{ - # $s1 and $s2 are the two strings - # $len1 and $len2 are their respective lengths - # - my ($s1, $s2) = @_; - my ($len1, $len2) = (length $s1, length $s2); - - # If one of the strings is empty, the distance is the length - # of the other string - # - return $len2 if ($len1 == 0); - return $len1 if ($len2 == 0); - - my %mat; - - # Init the distance matrix - # - # The first row to 0..$len1 - # The first column to 0..$len2 - # The rest to 0 - # - # The first row and column are initialized so to denote distance - # from the empty string - # - for (my $i = 0; $i <= $len1; ++$i) - { - for (my $j = 0; $j <= $len2; ++$j) - { - $mat{$i}{$j} = 0; - $mat{0}{$j} = $j; - } - - $mat{$i}{0} = $i; - } - - # Some char-by-char processing is ahead, so prepare - # array of chars from the strings - # - my @ar1 = split(//, $s1); - my @ar2 = split(//, $s2); - - for (my $i = 1; $i <= $len1; ++$i) - { - for (my $j = 1; $j <= $len2; ++$j) - { - # Set the cost to 1 iff the ith char of $s1 - # equals the jth of $s2 - # - # Denotes a substitution cost. When the char are equal - # there is no need to substitute, so the cost is 0 - # - my $cost = ($ar1[$i-1] eq $ar2[$j-1]) ? 0 : 1; - - # Cell $mat{$i}{$j} equals the minimum of: - # - # - The cell immediately above plus 1 - # - The cell immediately to the left plus 1 - # - The cell diagonally above and to the left plus the cost - # - # We can either insert a new char, delete a char or - # substitute an existing char (with an associated cost) - # - $mat{$i}{$j} = min([$mat{$i-1}{$j} + 1, - $mat{$i}{$j-1} + 1, - $mat{$i-1}{$j-1} + $cost]); - } - } - - # Finally, the Levenshtein distance equals the rightmost bottom cell - # of the matrix - # - # Note that $mat{$x}{$y} denotes the distance between the substrings - # 1..$x and 1..$y - # - return $mat{$len1}{$len2}; -} - - -# minimal element of a list -# -sub min -{ - my @list = @{$_[0]}; - my $min = $list[0]; - - foreach my $i (@list) - { - $min = $i if ($i < $min); - } - - return $min; -} - diff --git a/training/mpi_compute_cllh.cc b/training/mpi_compute_cllh.cc new file mode 100644 index 00000000..b496d196 --- /dev/null +++ b/training/mpi_compute_cllh.cc @@ -0,0 +1,196 @@ +#include +#include +#include +#include +#include +#include + +#include "config.h" +#ifdef HAVE_MPI +#include +#endif +#include +#include + +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "weights.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("weights,w",po::value(),"Input feature weights file") + ("training_data,t",po::value(),"Training data corpus") + ("decoder_config,c",po::value(),"Decoder configuration file"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* ids) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) { + c->push_back(line); + ids->push_back(lc); + } + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct TrainingObserver : public DecoderObserver { + void Reset() { + acc_obj = 0; + } + + virtual void NotifyDecodingStart(const SentenceMetadata&) { + cur_obj = 0; + state = 1; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + assert(state == 1); + state = 2; + SparseVector cur_model_exp; + const prob_t z = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); + cur_obj = log(z); + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 2); + state = 3; + SparseVector ref_exp; + const prob_t ref_z = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); + + double log_ref_z; +#if 0 + if (crf_uniform_empirical) { + log_ref_z = ref_exp.dot(feature_weights); + } else { + log_ref_z = log(ref_z); + } +#else + log_ref_z = log(ref_z); +#endif + + // rounding errors means that <0 is too strict + if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; + exit(1); + } + assert(!isnan(log_ref_z)); + acc_obj += (cur_obj - log_ref_z); + } + + double acc_obj; + double cur_obj; + int state; +}; + +#ifdef HAVE_MPI +namespace mpi = boost::mpi; +#endif + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return false; + + // load cdec.ini and set up decoder + ReadFile ini_rf(conf["decoder_config"].as()); + Decoder decoder(ini_rf.stream()); + if (decoder.GetConf()["input"].as() != "-") { + cerr << "cdec.ini must not set an input file\n"; + abort(); + } + + // load weights + vector& weights = decoder.CurrentWeightVector(); + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &weights); + + // freeze feature set + //const bool freeze_feature_set = conf.count("freeze_feature_set"); + //if (freeze_feature_set) FD::Freeze(); + + vector corpus; vector ids; + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); + assert(corpus.size() > 0); + assert(corpus.size() == ids.size()); + + TrainingObserver observer; + double objective = 0; + + observer.Reset(); + if (rank == 0) + cerr << "Each processor is decoding " << corpus.size() << " training examples...\n"; + + for (int i = 0; i < corpus.size(); ++i) { + decoder.SetId(ids[i]); + decoder.Decode(corpus[i], &observer); + } + +#ifdef HAVE_MPI + reduce(world, observer.acc_obj, objective, std::plus(), 0); +#else + objective = observer.acc_obj; +#endif + + if (rank == 0) + cout << "OBJECTIVE: " << objective << endl; + + return 0; +} + diff --git a/utils/stringlib.cc b/utils/stringlib.cc index 3a56965c..1a152985 100644 --- a/utils/stringlib.cc +++ b/utils/stringlib.cc @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -104,11 +105,11 @@ public: typedef unsigned int size_type; // must be 32bit MD5(); - MD5(const std::string& text); + MD5(const string& text); void update(const unsigned char *buf, size_type length); void update(const char *buf, size_type length); MD5& finalize(); - std::string hexdigest() const; + string hexdigest() const; private: void init(); @@ -209,7 +210,7 @@ MD5::MD5() ////////////////////////////////////////////// // nifty shortcut ctor, compute MD5 for string and finalize it right away -MD5::MD5(const std::string &text) +MD5::MD5(const string &text) { init(); update(text.c_str(), text.length()); @@ -433,8 +434,7 @@ MD5& MD5::finalize() ////////////////////////////// // return hex representation of digest as string -std::string MD5::hexdigest() const -{ +string MD5::hexdigest() const { if (!finalized) return ""; @@ -443,12 +443,12 @@ std::string MD5::hexdigest() const sprintf(buf+i*2, "%02x", digest[i]); buf[32]=0; - return std::string(buf); + return string(buf); } ////////////////////////////// -std::string md5(const std::string& str) { +string md5(const string& str) { MD5 md5 = MD5(str); return md5.hexdigest(); } -- cgit v1.2.3 From ce830ec51477f345c811987e11a9ed4322edcac0 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Sep 2011 16:19:11 +0100 Subject: make fix --- training/Makefile.am | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/training/Makefile.am b/training/Makefile.am index 5752859e..0b598fd5 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -13,7 +13,7 @@ bin_PROGRAMS = \ mpi_extract_features \ mpi_online_optimize \ mpi_batch_optimize \ - compute_cllh \ + mpi_compute_cllh \ augment_grammar noinst_PROGRAMS = \ @@ -34,8 +34,8 @@ mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteva mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -compute_cllh_SOURCES = compute_cllh.cc -compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc +mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz augment_grammar_SOURCES = augment_grammar.cc augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -- cgit v1.2.3 From 10cfa1082059db646148af1884117082335a48e7 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Sep 2011 17:06:40 +0100 Subject: source span size features --- decoder/ff_source_syntax.cc | 62 +++++++++++++++++++++++++++++++++++++++++++++ decoder/ff_source_syntax.h | 17 +++++++++++++ 2 files changed, 79 insertions(+) diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index ffe07f03..2df31c3a 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -157,3 +157,65 @@ void SourceSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) { impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength()); } +struct SourceSpanSizeFeaturesImpl { + SourceSpanSizeFeaturesImpl() {} + + void InitializeGrids(unsigned src_len) { + fids.clear(); + fids.resize(src_len, src_len + 1); + } + + int FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector* feats) { + int& fid = fids(i,j)[&rule]; + if (fid <= 0) { + ostringstream os; + os << "SSS:"; + unsigned ntc = 0; + for (unsigned k = 0; k < rule.f_.size(); ++k) { + if (k > 0) os << '_'; + int fj = rule.f_[k]; + if (fj <= 0) { + os << '[' << TD::Convert(-fj) << ants[ntc++] << ']'; + } else { + os << TD::Convert(fj); + } + } + fid = FD::Convert(os.str()); + } + if (fid > 0) + feats->set_value(fid, 1.0); + return SpanSizeTransform(j - i); + } + + mutable Array2D > fids; +}; + +SourceSpanSizeFeatures::SourceSpanSizeFeatures(const string& param) : + FeatureFunction(sizeof(char)) { + impl = new SourceSpanSizeFeaturesImpl; +} + +SourceSpanSizeFeatures::~SourceSpanSizeFeatures() { + delete impl; + impl = NULL; +} + +void SourceSpanSizeFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + int ants[8]; + for (unsigned i = 0; i < ant_contexts.size(); ++i) + ants[i] = *static_cast(ant_contexts[i]); + + *static_cast(context) = + impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features); +} + +void SourceSpanSizeFeatures::PrepareForInput(const SentenceMetadata& smeta) { + impl->InitializeGrids(smeta.GetSourceLength()); +} + + diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h index 1e890736..279563e1 100644 --- a/decoder/ff_source_syntax.h +++ b/decoder/ff_source_syntax.h @@ -21,4 +21,21 @@ class SourceSyntaxFeatures : public FeatureFunction { SourceSyntaxFeaturesImpl* impl; }; +struct SourceSpanSizeFeaturesImpl; +class SourceSpanSizeFeatures : public FeatureFunction { + public: + SourceSpanSizeFeatures(const std::string& param); + ~SourceSpanSizeFeatures(); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + SourceSpanSizeFeaturesImpl* impl; +}; + #endif -- cgit v1.2.3 From e7d2352ed630d16a790113223cd8a80155f61615 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Sep 2011 17:11:55 +0100 Subject: enable sss features --- decoder/cdec_ff.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index d562bc3a..69f40c93 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -57,6 +57,7 @@ void register_feature_functions() { ff_registry.Register("NgramFeatures", new FFFactory()); ff_registry.Register("RuleIdentityFeatures", new FFFactory()); ff_registry.Register("SourceSyntaxFeatures", new FFFactory); + ff_registry.Register("SourceSpanSizeFeatures", new FFFactory); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); -- cgit v1.2.3 From 5d7ac6050aab3eac5121a2168fe9bd81453d118a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Sep 2011 22:38:42 +0100 Subject: arity > 0 rules only for sss features --- decoder/ff_source_syntax.cc | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index 2df31c3a..fc341bb0 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -166,24 +166,26 @@ struct SourceSpanSizeFeaturesImpl { } int FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector* feats) { - int& fid = fids(i,j)[&rule]; - if (fid <= 0) { - ostringstream os; - os << "SSS:"; - unsigned ntc = 0; - for (unsigned k = 0; k < rule.f_.size(); ++k) { - if (k > 0) os << '_'; - int fj = rule.f_[k]; - if (fj <= 0) { - os << '[' << TD::Convert(-fj) << ants[ntc++] << ']'; - } else { - os << TD::Convert(fj); + if (rule.Arity() > 0) { + int& fid = fids(i,j)[&rule]; + if (fid <= 0) { + ostringstream os; + os << "SSS:"; + unsigned ntc = 0; + for (unsigned k = 0; k < rule.f_.size(); ++k) { + if (k > 0) os << '_'; + int fj = rule.f_[k]; + if (fj <= 0) { + os << '[' << TD::Convert(-fj) << ants[ntc++] << ']'; + } else { + os << TD::Convert(fj); + } } + fid = FD::Convert(os.str()); } - fid = FD::Convert(os.str()); + if (fid > 0) + feats->set_value(fid, 1.0); } - if (fid > 0) - feats->set_value(fid, 1.0); return SpanSizeTransform(j - i); } -- cgit v1.2.3 From 388081290e99fdd6eacc9d761ebfdea69647fa72 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Sep 2011 22:42:19 +0100 Subject: add target side for sss features --- decoder/ff_source_syntax.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index fc341bb0..035132b4 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -181,6 +181,15 @@ struct SourceSpanSizeFeaturesImpl { os << TD::Convert(fj); } } + os << ':'; + for (unsigned k = 0; k < rule.e_.size(); ++k) { + const int ei = rule.e_[k]; + if (k > 0) os << '_'; + if (ei <= 0) + os << '[' << (1-ei) << ']'; + else + os << TD::Convert(ei); + } fid = FD::Convert(os.str()); } if (fid > 0) -- cgit v1.2.3 From f111672dd611f78656fceb3df3729a290453ef56 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 21 Sep 2011 18:23:50 -0400 Subject: Updated kenlm. Includes left state support but not the cdec-side use of it. Updated binary format. --- compound-split/de/charlm.rev.5gm.de.klm | Bin 17376695 -> 17376711 bytes klm/compile.sh | 10 +- klm/lm/bhiksha.hh | 2 +- klm/lm/binary_format.cc | 16 +- klm/lm/binary_format.hh | 20 +- klm/lm/blank.hh | 14 - klm/lm/left.hh | 181 ++++++ klm/lm/left_test.cc | 360 +++++++++++ klm/lm/model.cc | 132 ++-- klm/lm/model.hh | 47 +- klm/lm/model_test.cc | 184 ++++-- klm/lm/model_type.hh | 16 + klm/lm/quantize.cc | 4 +- klm/lm/quantize.hh | 13 +- klm/lm/return.hh | 39 ++ klm/lm/search_hashed.cc | 79 ++- klm/lm/search_hashed.hh | 43 +- klm/lm/search_trie.cc | 1052 ++++++++++--------------------- klm/lm/search_trie.hh | 37 +- klm/lm/trie.cc | 5 +- klm/lm/trie.hh | 10 +- klm/lm/trie_sort.cc | 261 ++++++++ klm/lm/trie_sort.hh | 94 +++ klm/lm/virtual_interface.hh | 26 +- klm/lm/vocab.cc | 44 +- klm/lm/vocab.hh | 10 +- klm/test.sh | 2 +- klm/util/bit_packing.hh | 14 + klm/util/exception.cc | 5 + klm/util/exception.hh | 6 + klm/util/file.cc | 74 +++ klm/util/file.hh | 74 +++ klm/util/file_piece.cc | 18 +- klm/util/file_piece.hh | 14 +- klm/util/mmap.cc | 18 +- klm/util/mmap.hh | 4 +- klm/util/murmur_hash.cc | 258 ++++---- klm/util/scoped.cc | 24 - klm/util/scoped.hh | 58 +- klm/util/sized_iterator.hh | 107 ++++ klm/util/tokenize_piece.hh | 69 ++ 41 files changed, 2261 insertions(+), 1183 deletions(-) create mode 100644 klm/lm/left.hh create mode 100644 klm/lm/left_test.cc create mode 100644 klm/lm/model_type.hh create mode 100644 klm/lm/return.hh create mode 100644 klm/lm/trie_sort.cc create mode 100644 klm/lm/trie_sort.hh create mode 100644 klm/util/file.cc create mode 100644 klm/util/file.hh delete mode 100644 klm/util/scoped.cc create mode 100644 klm/util/sized_iterator.hh create mode 100644 klm/util/tokenize_piece.hh diff --git a/compound-split/de/charlm.rev.5gm.de.klm b/compound-split/de/charlm.rev.5gm.de.klm index e8d114bd..28d09b54 100644 Binary files a/compound-split/de/charlm.rev.5gm.de.klm and b/compound-split/de/charlm.rev.5gm.de.klm differ diff --git a/klm/compile.sh b/klm/compile.sh index abe3473a..56f2e9b2 100755 --- a/klm/compile.sh +++ b/klm/compile.sh @@ -3,10 +3,12 @@ #If your code uses ICU, edit util/string_piece.hh and uncomment #define USE_ICU #I use zlib by default. If you don't want to depend on zlib, remove #define USE_ZLIB from util/file_piece.hh +#don't need to use if compiling with moses Makefiles already + set -e -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{bhiksha,binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do - g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o +for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,file,mmap} lm/{bhiksha,binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,trie_sort,virtual_interface,vocab}; do + g++ -I. -O3 -DNDEBUG $CXXFLAGS -c $i.cc -o $i.o done -g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary -g++ -I. -O3 $CXXFLAGS lm/ngram_query.cc {lm,util}/*.o -lz -o query +g++ -I. -O3 -DNDEBUG $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary +g++ -I. -O3 -DNDEBUG $CXXFLAGS lm/ngram_query.cc {lm,util}/*.o -lz -o query diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index cfb2b053..ff7fe452 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -12,7 +12,7 @@ #include -#include "lm/binary_format.hh" +#include "lm/model_type.hh" #include "lm/trie.hh" #include "util/bit_packing.hh" #include "util/sorted_uniform.hh" diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index e02e621a..27cada13 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -19,10 +19,10 @@ namespace lm { namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; -const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0"; +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; // This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; -const long int kMagicVersion = 4; +const long int kMagicVersion = 5; // Test values. struct Sanity { @@ -42,12 +42,6 @@ struct Sanity { const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; -std::size_t Align8(std::size_t in) { - std::size_t off = in % 8; - if (!off) return in; - return in + 8 - off; -} - std::size_t TotalHeaderSize(unsigned char order) { return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); } @@ -119,7 +113,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t } } -void FinishFile(const Config &config, ModelType model_type, const std::vector &counts, Backing &backing) { +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, Backing &backing) { if (config.write_mmap) { if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC)) UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap); @@ -130,6 +124,7 @@ void FinishFile(const Config &config, ModelType model_type, const std::vector(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *))) UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast(params.fixed.model_type) << " but this is not implemented for in this inference code."); UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]); } + UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version); } void SeekPastHeader(int fd, const Parameters ¶ms) { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index d28cb6c5..e9df0892 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -2,6 +2,7 @@ #define LM_BINARY_FORMAT__ #include "lm/config.hh" +#include "lm/model_type.hh" #include "lm/read_arpa.hh" #include "util/file_piece.hh" @@ -16,13 +17,6 @@ namespace lm { namespace ngram { -/* Not the best numbering system, but it grew this way for historical reasons - * and I want to preserve existing binary files. */ -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; - -const static ModelType kQuantAdd = static_cast(QUANT_TRIE_SORTED - TRIE_SORTED); -const static ModelType kArrayAdd = static_cast(ARRAY_TRIE_SORTED - TRIE_SORTED); - /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in * this header designed for use by decoder authors. @@ -36,8 +30,14 @@ struct FixedWidthParameters { ModelType model_type; // Does the end of the file have the actual strings in the vocabulary? bool has_vocabulary; + unsigned int search_version; }; +inline std::size_t Align8(std::size_t in) { + std::size_t off = in % 8; + return off ? (in + 8 - off) : in; +} + // Parameters stored in the header of a binary file. struct Parameters { FixedWidthParameters fixed; @@ -64,7 +64,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t // Write header to binary file. This is done last to prevent incomplete files // from loading. -void FinishFile(const Config &config, ModelType model_type, const std::vector &counts, Backing &backing); +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, Backing &backing); namespace detail { @@ -72,7 +72,7 @@ bool IsBinaryFormat(int fd); void ReadHeader(int fd, Parameters ¶ms); -void MatchCheck(ModelType model_type, const Parameters ¶ms); +void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms); void SeekPastHeader(int fd, const Parameters ¶ms); @@ -90,7 +90,7 @@ template void LoadLM(const char *file, const Config &config, To &to) if (detail::IsBinaryFormat(backing.file.get())) { Parameters params; detail::ReadHeader(backing.file.get(), params); - detail::MatchCheck(To::kModelType, params); + detail::MatchCheck(To::kModelType, To::kVersion, params); // Replace the run-time configured probing_multiplier with the one in the file. Config new_config(config); new_config.probing_multiplier = params.fixed.probing_multiplier; diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 162411a9..2fb64cd0 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -38,20 +38,6 @@ inline bool HasExtension(const float &backoff) { return compare.i != interpret.i; } -/* Suppose "foo bar baz quux" appears in the ARPA but not "bar baz quux" or - * "baz quux" (because they were pruned). 1.2% of n-grams generated by SRI - * with default settings on the benchmark data set are like this. Since search - * proceeds by finding "quux", "baz quux", "bar baz quux", and finally - * "foo bar baz quux" and the trie needs pointer nodes anyway, blanks are - * inserted. The blanks have probability kBlankProb and backoff kBlankBackoff. - * A blank is recognized by kBlankProb in the probability field; kBlankBackoff - * must be 0 so that inference asseses zero backoff from these blanks. - */ -const float kBlankProb = -std::numeric_limits::infinity(); -const float kBlankBackoff = kNoExtensionBackoff; -const uint32_t kBlankProbQuant = 0; -const uint32_t kBlankBackoffQuant = 0; - } // namespace ngram } // namespace lm #endif // LM_BLANK__ diff --git a/klm/lm/left.hh b/klm/lm/left.hh new file mode 100644 index 00000000..df69e97a --- /dev/null +++ b/klm/lm/left.hh @@ -0,0 +1,181 @@ +#ifndef LM_LEFT__ +#define LM_LEFT__ + +#include "lm/max_order.hh" +#include "lm/model.hh" +#include "lm/return.hh" + +#include + +namespace lm { +namespace ngram { + +struct Left { + bool operator==(const Left &other) const { + return + (length == other.length) && + pointers[length - 1] == other.pointers[length - 1]; + } + + int Compare(const Left &other) const { + if (length != other.length) { + return (int)length - (int)other.length; + } + if (pointers[length - 1] > other.pointers[length - 1]) return 1; + if (pointers[length - 1] < other.pointers[length - 1]) return -1; + return 0; + } + + uint64_t pointers[kMaxOrder - 1]; + unsigned char length; +}; + +struct ChartState { + bool operator==(const ChartState &other) { + return (left == other.left) && (right == other.right) && (full == other.full); + } + + int Compare(const ChartState &other) const { + int lres = left.Compare(other.left); + if (lres) return lres; + int rres = right.Compare(other.right); + if (rres) return rres; + return (int)full - (int)other.full; + } + + Left left; + State right; + bool full; +}; + +template class RuleScore { + public: + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { + out.left.length = 0; + out.right.length = 0; + } + + void BeginSentence() { + out_.right = model_.BeginSentenceState(); + // out_.left is empty. + left_done_ = true; + } + + void Terminal(WordIndex word) { + State copy(out_.right); + FullScoreReturn ret = model_.FullScore(copy, word, out_.right); + ProcessRet(ret); + if (out_.right.length != copy.length + 1) left_done_ = true; + } + + // Faster version of NonTerminal for the case where the rule begins with a non-terminal. + void BeginNonTerminal(const ChartState &in, float prob) { + prob_ = prob; + out_ = in; + left_write_ = out_.left.pointers + out_.left.length; + left_done_ = in.full; + } + + void NonTerminal(const ChartState &in, float prob) { + prob_ += prob; + + if (!in.left.length) { + if (in.full) { + for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; + left_done_ = true; + out_.right = in.right; + } + return; + } + + if (!out_.right.length) { + out_.right = in.right; + if (left_done_) return; + if (left_write_ != out_.left.pointers) { + left_done_ = true; + } else { + out_.left = in.left; + left_write_ = out_.left.pointers + in.left.length; + left_done_ = in.full; + } + return; + } + + float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1]; + float *back = backoffs, *back2 = backoffs2; + unsigned char next_use; + FullScoreReturn ret; + ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use)); + if (!next_use) { + left_done_ = true; + out_.right = in.right; + return; + } + unsigned char extend_length = 2; + for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) { + ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use)); + if (!next_use) { + left_done_ = true; + out_.right = in.right; + return; + } + std::swap(back, back2); + } + + if (in.full) { + for (const float *i = back; i != back + next_use; ++i) prob_ += *i; + left_done_ = true; + out_.right = in.right; + return; + } + + // Right state was minimized, so it's already independent of the new words to the left. + if (in.right.length < in.left.length) { + out_.right = in.right; + return; + } + + // Shift exisiting words down. + for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { + *(i + in.right.length) = *i; + } + // Add words from in.right. + std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); + // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. + std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); + std::copy(back, back + next_use, out_.right.backoff + in.right.length); + out_.right.length = in.right.length + next_use; + } + + float Finish() { + out_.left.length = left_write_ - out_.left.pointers; + out_.full = left_done_; + return prob_; + } + + private: + void ProcessRet(const FullScoreReturn &ret) { + prob_ += ret.prob; + if (left_done_) return; + if (ret.independent_left) { + left_done_ = true; + return; + } + *(left_write_++) = ret.extend_left; + } + + const M &model_; + + ChartState &out_; + + bool left_done_; + + uint64_t *left_write_; + + float prob_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_LEFT__ diff --git a/klm/lm/left_test.cc b/klm/lm/left_test.cc new file mode 100644 index 00000000..8bb91cb3 --- /dev/null +++ b/klm/lm/left_test.cc @@ -0,0 +1,360 @@ +#include "lm/left.hh" +#include "lm/model.hh" + +#include "util/tokenize_piece.hh" + +#include + +#define BOOST_TEST_MODULE LeftTest +#include +#include + +namespace lm { +namespace ngram { +namespace { + +#define Term(word) score.Terminal(m.GetVocabulary().Index(word)); +#define VCheck(word, value) BOOST_CHECK_EQUAL(m.GetVocabulary().Index(word), value); + +template void Short(const M &m) { + ChartState base; + { + RuleScore score(m, base); + Term("more"); + Term("loin"); + BOOST_CHECK_CLOSE(-1.206319 - 0.3561665, score.Finish(), 0.001); + } + BOOST_CHECK(base.full); + BOOST_CHECK_EQUAL(2, base.left.length); + BOOST_CHECK_EQUAL(1, base.right.length); + VCheck("loin", base.right.words[0]); + + ChartState more_left; + { + RuleScore score(m, more_left); + Term("little"); + score.NonTerminal(base, -1.206319 - 0.3561665); + // p(little more loin | null context) + BOOST_CHECK_CLOSE(-1.56538, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(3, more_left.left.length); + BOOST_CHECK_EQUAL(1, more_left.right.length); + VCheck("loin", more_left.right.words[0]); + BOOST_CHECK(more_left.full); + + ChartState shorter; + { + RuleScore score(m, shorter); + Term("to"); + score.NonTerminal(base, -1.206319 - 0.3561665); + BOOST_CHECK_CLOSE(-0.30103 - 1.687872 - 1.206319 - 0.3561665, score.Finish(), 0.01); + } + BOOST_CHECK_EQUAL(1, shorter.left.length); + BOOST_CHECK_EQUAL(1, shorter.right.length); + VCheck("loin", shorter.right.words[0]); + BOOST_CHECK(shorter.full); +} + +template void Charge(const M &m) { + ChartState base; + { + RuleScore score(m, base); + Term("on"); + Term("more"); + BOOST_CHECK_CLOSE(-1.509559 -0.4771212 -1.206319, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(1, base.left.length); + BOOST_CHECK_EQUAL(1, base.right.length); + VCheck("more", base.right.words[0]); + BOOST_CHECK(base.full); + + ChartState extend; + { + RuleScore score(m, extend); + Term("looking"); + score.NonTerminal(base, -1.509559 -0.4771212 -1.206319); + BOOST_CHECK_CLOSE(-3.91039, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, extend.left.length); + BOOST_CHECK_EQUAL(1, extend.right.length); + VCheck("more", extend.right.words[0]); + BOOST_CHECK(extend.full); + + ChartState tobos; + { + RuleScore score(m, tobos); + score.BeginSentence(); + score.NonTerminal(extend, -3.91039); + BOOST_CHECK_CLOSE(-3.471169, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(0, tobos.left.length); + BOOST_CHECK_EQUAL(1, tobos.right.length); +} + +template float LeftToRight(const M &m, const std::vector &words) { + float ret = 0.0; + State right = m.NullContextState(); + for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { + State copy(right); + ret += m.Score(copy, *i, right); + } + return ret; +} + +template float RightToLeft(const M &m, const std::vector &words) { + float ret = 0.0; + ChartState state; + state.left.length = 0; + state.right.length = 0; + state.full = false; + for (std::vector::const_reverse_iterator i = words.rbegin(); i != words.rend(); ++i) { + ChartState copy(state); + RuleScore score(m, state); + score.Terminal(*i); + score.NonTerminal(copy, ret); + ret = score.Finish(); + } + return ret; +} + +template float TreeMiddle(const M &m, const std::vector &words) { + std::vector > states(words.size()); + for (unsigned int i = 0; i < words.size(); ++i) { + RuleScore score(m, states[i].first); + score.Terminal(words[i]); + states[i].second = score.Finish(); + } + while (states.size() > 1) { + std::vector > upper((states.size() + 1) / 2); + for (unsigned int i = 0; i < states.size() / 2; ++i) { + RuleScore score(m, upper[i].first); + score.NonTerminal(states[i*2].first, states[i*2].second); + score.NonTerminal(states[i*2+1].first, states[i*2+1].second); + upper[i].second = score.Finish(); + } + if (states.size() % 2) { + upper.back() = states.back(); + } + std::swap(states, upper); + } + return states.empty() ? 0 : states.back().second; +} + +template void LookupVocab(const M &m, const StringPiece &str, std::vector &out) { + out.clear(); + for (util::PieceIterator<' '> i(str); i; ++i) { + out.push_back(m.GetVocabulary().Index(*i)); + } +} + +#define TEXT_TEST(str) \ +{ \ + std::vector words; \ + LookupVocab(m, str, words); \ + float expect = LeftToRight(m, words); \ + BOOST_CHECK_CLOSE(expect, RightToLeft(m, words), 0.001); \ + BOOST_CHECK_CLOSE(expect, TreeMiddle(m, words), 0.001); \ +} + +// Build sentences, or parts thereof, from right to left. +template void GrowBig(const M &m) { + TEXT_TEST("in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown "); + TEXT_TEST("on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown "); + TEXT_TEST("on a little more loin also would consider higher to look good"); + TEXT_TEST("more loin also would consider higher to look good"); + TEXT_TEST("more loin also would consider higher to look"); + TEXT_TEST("also would consider higher to look"); + TEXT_TEST("also would consider higher"); + TEXT_TEST("would consider higher to look"); + TEXT_TEST("consider higher to look"); + TEXT_TEST("consider higher to"); + TEXT_TEST("consider higher"); +} + +template void AlsoWouldConsiderHigher(const M &m) { + ChartState also; + { + RuleScore score(m, also); + score.Terminal(m.GetVocabulary().Index("also")); + BOOST_CHECK_CLOSE(-1.687872, score.Finish(), 0.001); + } + ChartState would; + { + RuleScore score(m, would); + score.Terminal(m.GetVocabulary().Index("would")); + BOOST_CHECK_CLOSE(-1.687872, score.Finish(), 0.001); + } + ChartState combine_also_would; + { + RuleScore score(m, combine_also_would); + score.NonTerminal(also, -1.687872); + score.NonTerminal(would, -1.687872); + BOOST_CHECK_CLOSE(-1.687872 - 2.0, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, combine_also_would.right.length); + + ChartState also_would; + { + RuleScore score(m, also_would); + score.Terminal(m.GetVocabulary().Index("also")); + score.Terminal(m.GetVocabulary().Index("would")); + BOOST_CHECK_CLOSE(-1.687872 - 2.0, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, also_would.right.length); + + ChartState consider; + { + RuleScore score(m, consider); + score.Terminal(m.GetVocabulary().Index("consider")); + BOOST_CHECK_CLOSE(-1.687872, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(1, consider.left.length); + BOOST_CHECK_EQUAL(1, consider.right.length); + BOOST_CHECK(!consider.full); + + ChartState higher; + float higher_score; + { + RuleScore score(m, higher); + score.Terminal(m.GetVocabulary().Index("higher")); + higher_score = score.Finish(); + } + BOOST_CHECK_CLOSE(-1.509559, higher_score, 0.001); + BOOST_CHECK_EQUAL(1, higher.left.length); + BOOST_CHECK_EQUAL(1, higher.right.length); + BOOST_CHECK(!higher.full); + VCheck("higher", higher.right.words[0]); + BOOST_CHECK_CLOSE(-0.30103, higher.right.backoff[0], 0.001); + + ChartState consider_higher; + { + RuleScore score(m, consider_higher); + score.NonTerminal(consider, -1.687872); + score.NonTerminal(higher, higher_score); + BOOST_CHECK_CLOSE(-1.509559 - 1.687872 - 0.30103, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, consider_higher.left.length); + BOOST_CHECK(!consider_higher.full); + + ChartState full; + { + RuleScore score(m, full); + score.NonTerminal(combine_also_would, -1.687872 - 2.0); + score.NonTerminal(consider_higher, -1.509559 - 1.687872 - 0.30103); + BOOST_CHECK_CLOSE(-10.6879, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(4, full.right.length); +} + +template void GrowSmall(const M &m) { + TEXT_TEST("in biarritz watching considering looking . "); + TEXT_TEST("in biarritz watching considering looking ."); + TEXT_TEST("in biarritz"); +} + +#define CHECK_SCORE(str, val) \ +{ \ + float got = val; \ + std::vector indices; \ + LookupVocab(m, str, indices); \ + BOOST_CHECK_CLOSE(LeftToRight(m, indices), got, 0.001); \ +} + +template void FullGrow(const M &m) { + std::vector words; + LookupVocab(m, "in biarritz watching considering looking . ", words); + + ChartState lexical[7]; + float lexical_scores[7]; + for (unsigned int i = 0; i < 7; ++i) { + RuleScore score(m, lexical[i]); + score.Terminal(words[i]); + lexical_scores[i] = score.Finish(); + } + CHECK_SCORE("in", lexical_scores[0]); + CHECK_SCORE("biarritz", lexical_scores[1]); + CHECK_SCORE("watching", lexical_scores[2]); + CHECK_SCORE("", lexical_scores[6]); + + ChartState l1[4]; + float l1_scores[4]; + { + RuleScore score(m, l1[0]); + score.NonTerminal(lexical[0], lexical_scores[0]); + score.NonTerminal(lexical[1], lexical_scores[1]); + CHECK_SCORE("in biarritz", l1_scores[0] = score.Finish()); + } + { + RuleScore score(m, l1[1]); + score.NonTerminal(lexical[2], lexical_scores[2]); + score.NonTerminal(lexical[3], lexical_scores[3]); + CHECK_SCORE("watching considering", l1_scores[1] = score.Finish()); + } + { + RuleScore score(m, l1[2]); + score.NonTerminal(lexical[4], lexical_scores[4]); + score.NonTerminal(lexical[5], lexical_scores[5]); + CHECK_SCORE("looking .", l1_scores[2] = score.Finish()); + } + BOOST_CHECK_EQUAL(l1[2].left.length, 1); + l1[3] = lexical[6]; + l1_scores[3] = lexical_scores[6]; + + ChartState l2[2]; + float l2_scores[2]; + { + RuleScore score(m, l2[0]); + score.NonTerminal(l1[0], l1_scores[0]); + score.NonTerminal(l1[1], l1_scores[1]); + CHECK_SCORE("in biarritz watching considering", l2_scores[0] = score.Finish()); + } + { + RuleScore score(m, l2[1]); + score.NonTerminal(l1[2], l1_scores[2]); + score.NonTerminal(l1[3], l1_scores[3]); + CHECK_SCORE("looking . ", l2_scores[1] = score.Finish()); + } + BOOST_CHECK_EQUAL(l2[1].left.length, 1); + BOOST_CHECK(l2[1].full); + + ChartState top; + { + RuleScore score(m, top); + score.NonTerminal(l2[0], l2_scores[0]); + score.NonTerminal(l2[1], l2_scores[1]); + CHECK_SCORE("in biarritz watching considering looking . ", score.Finish()); + } +} + +template void Everything() { + Config config; + config.messages = NULL; + M m("test.arpa", config); + + Short(m); + Charge(m); + GrowBig(m); + AlsoWouldConsiderHigher(m); + GrowSmall(m); + FullGrow(m); +} + +BOOST_AUTO_TEST_CASE(ProbingAll) { + Everything(); +} +BOOST_AUTO_TEST_CASE(TrieAll) { + Everything(); +} +BOOST_AUTO_TEST_CASE(QuantTrieAll) { + Everything(); +} +BOOST_AUTO_TEST_CASE(ArrayQuantTrieAll) { + Everything(); +} +BOOST_AUTO_TEST_CASE(ArrayTrieAll) { + Everything(); +} + +} // namespace +} // namespace ngram +} // namespace lm diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 27e24b1c..ca581d8a 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -16,7 +16,7 @@ namespace lm { namespace ngram { size_t hash_value(const State &state) { - return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_); + return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); } namespace detail { @@ -41,11 +41,11 @@ template GenericModel::Ge // g++ prints warnings unless these are fully initialized. State begin_sentence = State(); - begin_sentence.valid_length_ = 1; - begin_sentence.history_[0] = vocab_.BeginSentence(); - begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; + begin_sentence.length = 1; + begin_sentence.words[0] = vocab_.BeginSentence(); + begin_sentence.backoff[0] = search_.unigram.Lookup(begin_sentence.words[0]).backoff; State null_context = State(); - null_context.valid_length_ = 0; + null_context.length = 0; P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); } @@ -87,7 +87,7 @@ template void GenericModel void GenericModel FullScoreReturn GenericModel::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { - FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, out_state); - if (ret.ngram_length - 1 < in_state.valid_length_) { - ret.prob = std::accumulate(in_state.backoff_ + ret.ngram_length - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); + FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); + if (ret.ngram_length - 1 < in_state.length) { + ret.prob = std::accumulate(in_state.backoff + ret.ngram_length - 1, in_state.backoff + in_state.length, ret.prob); } return ret; } @@ -131,32 +131,80 @@ template void GenericModel FullScoreReturn GenericModel::ExtendLeft( + const WordIndex *add_rbegin, const WordIndex *add_rend, + const float *backoff_in, + uint64_t extend_pointer, + unsigned char extend_length, + float *backoff_out, + unsigned char &next_use) const { + FullScoreReturn ret; + float subtract_me; + typename Search::Node node(search_.Unpack(extend_pointer, extend_length, subtract_me)); + ret.prob = subtract_me; + ret.ngram_length = extend_length; + next_use = 0; + // If this function is called, then it does depend on left words. + ret.independent_left = false; + ret.extend_left = extend_pointer; + const typename Search::Middle *mid_iter = search_.MiddleBegin() + extend_length - 1; + const WordIndex *i = add_rbegin; + for (; ; ++i, ++backoff_out, ++mid_iter) { + if (i == add_rend) { + // Ran out of words. + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; + ret.prob -= subtract_me; + return ret; + } + if (mid_iter == search_.MiddleEnd()) break; + if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *i, *backoff_out, node, ret)) { + // Didn't match a word. + ret.independent_left = true; + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; + ret.prob -= subtract_me; + return ret; + } + ret.ngram_length = mid_iter - search_.MiddleBegin() + 2; + if (HasExtension(*backoff_out)) next_use = i - add_rbegin + 1; + } + + if (ret.independent_left || !search_.LookupLongest(*i, ret.prob, node)) { + // The last backoff weight, for Order() - 1. + ret.prob += backoff_in[i - add_rbegin]; + } else { + ret.ngram_length = P::Order(); + } + ret.independent_left = true; + ret.prob -= subtract_me; + return ret; } namespace { // Do a paraonoid copy of history, assuming new_word has already been copied -// (hence the -1). out_state.valid_length_ could be zero so I avoided using +// (hence the -1). out_state.length could be zero so I avoided using // std::copy. void CopyRemainingHistory(const WordIndex *from, State &out_state) { - WordIndex *out = out_state.history_ + 1; - const WordIndex *in_end = from + static_cast(out_state.valid_length_) - 1; + WordIndex *out = out_state.words + 1; + const WordIndex *in_end = from + static_cast(out_state.length) - 1; for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in; } } // namespace @@ -175,17 +223,17 @@ template FullScoreReturn GenericModel FullScoreReturn GenericModel class GenericModel : public base::Mod // This is the model type returned by RecognizeBinary. static const ModelType kModelType; + static const unsigned int kVersion = Search::kVersion; + /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class * itself. @@ -114,6 +116,25 @@ template class GenericModel : public base::Mod */ void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; + /* More efficient version of FullScore where a partial n-gram has already + * been scored. + * NOTE: THE RETURNED .prob IS RELATIVE, NOT ABSOLUTE. So for example, if + * the n-gram does not end up extending further left, then 0 is returned. + */ + FullScoreReturn ExtendLeft( + // Additional context in reverse order. This will update add_rend to + const WordIndex *add_rbegin, const WordIndex *add_rend, + // Backoff weights to use. + const float *backoff_in, + // extend_left returned by a previous query. + uint64_t extend_pointer, + // Length of n-gram that the pointer corresponds to. + unsigned char extend_length, + // Where to write additional backoffs for [extend_length + 1, min(Order() - 1, return.ngram_length)] + float *backoff_out, + // Amount of additional content that should be considered by the next call. + unsigned char &next_use) const; + private: friend void LoadLM<>(const char *file, const Config &config, GenericModel &to); diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 57c7291c..2654071f 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -10,8 +10,8 @@ namespace lm { namespace ngram { std::ostream &operator<<(std::ostream &o, const State &state) { - o << "State length " << static_cast(state.valid_length_) << ':'; - for (const WordIndex *i = state.history_; i < state.history_ + state.valid_length_; ++i) { + o << "State length " << static_cast(state.length) << ':'; + for (const WordIndex *i = state.words; i < state.words + state.length; ++i) { o << ' ' << *i; } return o; @@ -19,25 +19,26 @@ std::ostream &operator<<(std::ostream &o, const State &state) { namespace { -#define StartTest(word, ngram, score) \ +#define StartTest(word, ngram, score, indep_left) \ ret = model.FullScore( \ state, \ model.GetVocabulary().Index(word), \ out);\ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ BOOST_CHECK_EQUAL(static_cast(ngram), ret.ngram_length); \ - BOOST_CHECK_GE(std::min(ngram, 5 - 1), out.valid_length_); \ + BOOST_CHECK_GE(std::min(ngram, 5 - 1), out.length); \ + BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \ {\ - WordIndex context[state.valid_length_ + 1]; \ + WordIndex context[state.length + 1]; \ context[0] = model.GetVocabulary().Index(word); \ - std::copy(state.history_, state.history_ + state.valid_length_, context + 1); \ + std::copy(state.words, state.words + state.length, context + 1); \ State get_state; \ - model.GetState(context, context + state.valid_length_ + 1, get_state); \ + model.GetState(context, context + state.length + 1, get_state); \ BOOST_CHECK_EQUAL(out, get_state); \ } -#define AppendTest(word, ngram, score) \ - StartTest(word, ngram, score) \ +#define AppendTest(word, ngram, score, indep_left) \ + StartTest(word, ngram, score, indep_left) \ state = out; template void Starters(const M &model) { @@ -45,12 +46,12 @@ template void Starters(const M &model) { Model::State state(model.BeginSentenceState()); Model::State out; - StartTest("looking", 2, -0.4846522); + StartTest("looking", 2, -0.4846522, true); // , probability plus backoff - StartTest(",", 1, -1.383514 + -0.4149733); + StartTest(",", 1, -1.383514 + -0.4149733, true); // probability plus backoff - StartTest("this_is_not_found", 1, -1.995635 + -0.4149733); + StartTest("this_is_not_found", 1, -1.995635 + -0.4149733, true); } template void Continuation(const M &model) { @@ -58,46 +59,64 @@ template void Continuation(const M &model) { Model::State state(model.BeginSentenceState()); Model::State out; - AppendTest("looking", 2, -0.484652); - AppendTest("on", 3, -0.348837); - AppendTest("a", 4, -0.0155266); - AppendTest("little", 5, -0.00306122); + AppendTest("looking", 2, -0.484652, true); + AppendTest("on", 3, -0.348837, true); + AppendTest("a", 4, -0.0155266, true); + AppendTest("little", 5, -0.00306122, true); State preserve = state; - AppendTest("the", 1, -4.04005); - AppendTest("biarritz", 1, -1.9889); - AppendTest("not_found", 1, -2.29666); - AppendTest("more", 1, -1.20632 - 20.0); - AppendTest(".", 2, -0.51363); - AppendTest("", 3, -0.0191651); - BOOST_CHECK_EQUAL(0, state.valid_length_); + AppendTest("the", 1, -4.04005, true); + AppendTest("biarritz", 1, -1.9889, true); + AppendTest("not_found", 1, -2.29666, true); + AppendTest("more", 1, -1.20632 - 20.0, true); + AppendTest(".", 2, -0.51363, true); + AppendTest("", 3, -0.0191651, true); + BOOST_CHECK_EQUAL(0, state.length); state = preserve; - AppendTest("more", 5, -0.00181395); - BOOST_CHECK_EQUAL(4, state.valid_length_); - AppendTest("loin", 5, -0.0432557); - BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("more", 5, -0.00181395, true); + BOOST_CHECK_EQUAL(4, state.length); + AppendTest("loin", 5, -0.0432557, true); + BOOST_CHECK_EQUAL(1, state.length); } template void Blanks(const M &model) { FullScoreReturn ret; State state(model.NullContextState()); State out; - AppendTest("also", 1, -1.687872); - AppendTest("would", 2, -2); - AppendTest("consider", 3, -3); + AppendTest("also", 1, -1.687872, false); + AppendTest("would", 2, -2, true); + AppendTest("consider", 3, -3, true); State preserve = state; - AppendTest("higher", 4, -4); - AppendTest("looking", 5, -5); - BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("higher", 4, -4, true); + AppendTest("looking", 5, -5, true); + BOOST_CHECK_EQUAL(1, state.length); state = preserve; - AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103); + // also would consider not_found + AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103, true); state = model.NullContextState(); // higher looking is a blank. - AppendTest("higher", 1, -1.509559); - AppendTest("looking", 1, -1.285941 - 0.30103); - AppendTest("not_found", 1, -1.995635 - 0.4771212); + AppendTest("higher", 1, -1.509559, false); + AppendTest("looking", 2, -1.285941 - 0.30103, false); + + State higher_looking = state; + + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("not_found", 1, -1.995635 - 0.4771212, true); + + state = higher_looking; + // higher looking consider + AppendTest("consider", 1, -1.687872 - 0.4771212, true); + + state = model.NullContextState(); + AppendTest("would", 1, -1.687872, false); + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("consider", 2, -1.687872 -0.30103, false); + BOOST_CHECK_EQUAL(2, state.length); + AppendTest("higher", 3, -1.509559 - 0.30103, false); + BOOST_CHECK_EQUAL(3, state.length); + AppendTest("looking", 4, -1.285941 - 0.30103, false); } template void Unknowns(const M &model) { @@ -105,14 +124,14 @@ template void Unknowns(const M &model) { State state(model.NullContextState()); State out; - AppendTest("not_found", 1, -1.995635); + AppendTest("not_found", 1, -1.995635, false); State preserve = state; - AppendTest("not_found2", 2, -15.0); - AppendTest("not_found3", 2, -15.0 - 2.0); + AppendTest("not_found2", 2, -15.0, true); + AppendTest("not_found3", 2, -15.0 - 2.0, true); state = preserve; - AppendTest("however", 2, -4); - AppendTest("not_found3", 3, -6); + AppendTest("however", 2, -4, true); + AppendTest("not_found3", 3, -6, true); } template void MinimalState(const M &model) { @@ -120,22 +139,66 @@ template void MinimalState(const M &model) { State state(model.NullContextState()); State out; - AppendTest("baz", 1, -6.535897); - BOOST_CHECK_EQUAL(0, state.valid_length_); + AppendTest("baz", 1, -6.535897, true); + BOOST_CHECK_EQUAL(0, state.length); state = model.NullContextState(); - AppendTest("foo", 1, -3.141592); - BOOST_CHECK_EQUAL(1, state.valid_length_); - AppendTest("bar", 2, -6.0); + AppendTest("foo", 1, -3.141592, true); + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("bar", 2, -6.0, true); // Has to include the backoff weight. - BOOST_CHECK_EQUAL(1, state.valid_length_); - AppendTest("bar", 1, -2.718281 + 3.0); - BOOST_CHECK_EQUAL(1, state.valid_length_); + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("bar", 1, -2.718281 + 3.0, true); + BOOST_CHECK_EQUAL(1, state.length); state = model.NullContextState(); - AppendTest("to", 1, -1.687872); - AppendTest("look", 2, -0.2922095); - BOOST_CHECK_EQUAL(2, state.valid_length_); - AppendTest("good", 3, -7); + AppendTest("to", 1, -1.687872, false); + AppendTest("look", 2, -0.2922095, true); + BOOST_CHECK_EQUAL(2, state.length); + AppendTest("good", 3, -7, true); +} + +template void ExtendLeftTest(const M &model) { + State right; + FullScoreReturn little(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("little"), right)); + const float kLittleProb = -1.285941; + BOOST_CHECK_CLOSE(kLittleProb, little.prob, 0.001); + unsigned char next_use; + float backoff_out[4]; + + FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use)); + BOOST_CHECK_EQUAL(0, next_use); + BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left); + BOOST_CHECK_CLOSE(0.0, extend_none.prob, 0.001); + BOOST_CHECK_EQUAL(1, extend_none.ngram_length); + + const WordIndex a = model.GetVocabulary().Index("a"); + float backoff_in = 3.14; + // a little + FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use)); + BOOST_CHECK_EQUAL(1, next_use); + BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001); + BOOST_CHECK_CLOSE(-0.09132547 - kLittleProb, extend_a.prob, 0.001); + BOOST_CHECK_EQUAL(2, extend_a.ngram_length); + BOOST_CHECK(!extend_a.independent_left); + + const WordIndex on = model.GetVocabulary().Index("on"); + FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use)); + BOOST_CHECK_EQUAL(1, next_use); + BOOST_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001); + BOOST_CHECK_CLOSE(-0.0283603 - -0.09132547, extend_on.prob, 0.001); + BOOST_CHECK_EQUAL(3, extend_on.ngram_length); + BOOST_CHECK(!extend_on.independent_left); + + const WordIndex both[2] = {a, on}; + float backoff_in_arr[4]; + FullScoreReturn extend_both(model.ExtendLeft(both, both + 2, backoff_in_arr, little.extend_left, 1, backoff_out, next_use)); + BOOST_CHECK_EQUAL(2, next_use); + BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001); + BOOST_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001); + BOOST_CHECK_CLOSE(-0.0283603 - kLittleProb, extend_both.prob, 0.001); + BOOST_CHECK_EQUAL(3, extend_both.ngram_length); + BOOST_CHECK(!extend_both.independent_left); + BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left); } #define StatelessTest(word, provide, ngram, score) \ @@ -166,17 +229,17 @@ template void Stateless(const M &model) { // looking StatelessTest(1, 2, 2, -0.484652); // on - AppendTest("on", 3, -0.348837); + AppendTest("on", 3, -0.348837, true); StatelessTest(2, 3, 3, -0.348837); StatelessTest(2, 2, 3, -0.348837); StatelessTest(2, 1, 2, -0.4638903); // a StatelessTest(3, 4, 4, -0.0155266); // little - AppendTest("little", 5, -0.00306122); + AppendTest("little", 5, -0.00306122, true); StatelessTest(4, 5, 5, -0.00306122); // the - AppendTest("the", 1, -4.04005); + AppendTest("the", 1, -4.04005, true); StatelessTest(5, 5, 1, -4.04005); // No context of the. StatelessTest(5, 0, 1, -1.687872); @@ -189,8 +252,8 @@ template void Stateless(const M &model) { WordIndex unk[1]; unk[0] = 0; model.GetState(unk, unk + 1, state); - BOOST_CHECK_EQUAL(1, state.valid_length_); - BOOST_CHECK_EQUAL(static_cast(0), state.history_[0]); + BOOST_CHECK_EQUAL(1, state.length); + BOOST_CHECK_EQUAL(static_cast(0), state.words[0]); } template void NoUnkCheck(const M &model) { @@ -207,6 +270,7 @@ template void Everything(const M &m) { Blanks(m); Unknowns(m); MinimalState(m); + ExtendLeftTest(m); Stateless(m); } @@ -245,6 +309,7 @@ template void LoadingTest() { config.enumerate_vocab = &enumerate; ModelT m("test.arpa", config); enumerate.Check(m.GetVocabulary()); + BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); Everything(m); } { @@ -252,6 +317,7 @@ template void LoadingTest() { config.enumerate_vocab = &enumerate; ModelT m("test_nounk.arpa", config); enumerate.Check(m.GetVocabulary()); + BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); NoUnkCheck(m); } } diff --git a/klm/lm/model_type.hh b/klm/lm/model_type.hh new file mode 100644 index 00000000..5057ed25 --- /dev/null +++ b/klm/lm/model_type.hh @@ -0,0 +1,16 @@ +#ifndef LM_MODEL_TYPE__ +#define LM_MODEL_TYPE__ + +namespace lm { +namespace ngram { + +/* Not the best numbering system, but it grew this way for historical reasons + * and I want to preserve existing binary files. */ +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; + +const static ModelType kQuantAdd = static_cast(QUANT_TRIE_SORTED - TRIE_SORTED); +const static ModelType kArrayAdd = static_cast(ARRAY_TRIE_SORTED - TRIE_SORTED); + +} // namespace ngram +} // namespace lm +#endif // LM_MODEL_TYPE__ diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index fd371cc8..98a5d048 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -1,5 +1,6 @@ #include "lm/quantize.hh" +#include "lm/binary_format.hh" #include "lm/lm_exception.hh" #include @@ -70,8 +71,7 @@ void SeparatelyQuantize::Train(uint8_t order, std::vector &prob, std::vec void SeparatelyQuantize::TrainProb(uint8_t order, std::vector &prob) { float *centers = start_ + TableStart(order); - *(centers++) = kBlankProb; - MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_) - 1); + MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_)); } void SeparatelyQuantize::FinishedLoading(const Config &config) { diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 0b71d14a..4cf4236e 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -1,9 +1,9 @@ #ifndef LM_QUANTIZE_H__ #define LM_QUANTIZE_H__ -#include "lm/binary_format.hh" // for ModelType #include "lm/blank.hh" #include "lm/config.hh" +#include "lm/model_type.hh" #include "util/bit_packing.hh" #include @@ -36,6 +36,9 @@ class DontQuantize { prob = util::ReadNonPositiveFloat31(base, bit_offset); backoff = util::ReadFloat32(base, bit_offset + 31); } + void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + } void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { backoff = util::ReadFloat32(base, bit_offset + 31); } @@ -77,7 +80,7 @@ class SeparatelyQuantize { Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} uint64_t EncodeProb(float value) const { - return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1)); + return Encode(value, 0); } uint64_t EncodeBackoff(float value) const { @@ -132,6 +135,10 @@ class SeparatelyQuantize { (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); } + void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { + prob = prob_.Decode(util::ReadInt25(base, bit_offset + backoff_.Bits(), prob_.Bits(), prob_.Mask())); + } + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); prob = prob_.Decode(both >> backoff_.Bits()); @@ -179,7 +186,7 @@ class SeparatelyQuantize { void SetupMemory(void *start, const Config &config); static const bool kTrain = true; - // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff. + // Assumes 0.0 is removed from backoff. void Train(uint8_t order, std::vector &prob, std::vector &backoff); // Train just probabilities (for longest order). void TrainProb(uint8_t order, std::vector &prob); diff --git a/klm/lm/return.hh b/klm/lm/return.hh new file mode 100644 index 00000000..15571960 --- /dev/null +++ b/klm/lm/return.hh @@ -0,0 +1,39 @@ +#ifndef LM_RETURN__ +#define LM_RETURN__ + +#include + +namespace lm { +/* Structure returned by scoring routines. */ +struct FullScoreReturn { + // log10 probability + float prob; + + /* The length of n-gram matched. Do not use this for recombination. + * Consider a model containing only the following n-grams: + * -1 foo + * -3.14 bar + * -2.718 baz -5 + * -6 foo bar + * + * If you score ``bar'' then ngram_length is 1 and recombination state is the + * empty string because bar has zero backoff and does not extend to the + * right. + * If you score ``foo'' then ngram_length is 1 and recombination state is + * ``foo''. + * + * Ideally, keep output states around and compare them. Failing that, + * get out_state.ValidLength() and use that length for recombination. + */ + unsigned char ngram_length; + + /* Left extension information. If independent_left is set, then prob is + * independent of words to the left (up to additional backoff). Otherwise, + * extend_left indicates how to efficiently extend further to the left. + */ + bool independent_left; + uint64_t extend_left; // Defined only if independent_left +}; + +} // namespace lm +#endif // LM_RETURN__ diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 82c53ec8..334adf12 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -1,10 +1,12 @@ #include "lm/search_hashed.hh" +#include "lm/binary_format.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/read_arpa.hh" #include "lm/vocab.hh" +#include "util/bit_packing.hh" #include "util/file_piece.hh" #include @@ -48,30 +50,77 @@ class ActivateUnigram { ProbBackoff *modify_; }; -template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector &middle, Activate activate, Store &store, PositiveProbWarn &warn) { - - ReadNGramHeader(f, n); +template void FixSRI(int lower, float negative_lower_prob, unsigned int n, const uint64_t *keys, const WordIndex *vocab_ids, ProbBackoff *unigrams, std::vector &middle) { ProbBackoff blank; - blank.prob = kBlankProb; - blank.backoff = kBlankBackoff; + blank.backoff = kNoExtensionBackoff; + // Fix SRI's stupidity. + // Note that negative_lower_prob is the negative of the probability (so it's currently >= 0). We still want the sign bit off to indicate left extension, so I just do -= on the backoffs. + blank.prob = negative_lower_prob; + // An entry was found at lower (order lower + 2). + // We need to insert blanks starting at lower + 1 (order lower + 3). + unsigned int fix = static_cast(lower + 1); + uint64_t backoff_hash = detail::CombineWordHash(static_cast(vocab_ids[1]), vocab_ids[2]); + if (fix == 0) { + // Insert a missing bigram. + blank.prob -= unigrams[vocab_ids[1]].backoff; + SetExtension(unigrams[vocab_ids[1]].backoff); + // Bigram including a unigram's backoff + middle[0].Insert(Middle::Packing::Make(keys[0], blank)); + fix = 1; + } else { + for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]); + } + // fix >= 1. Insert trigrams and above. + for (; fix <= n - 3; ++fix) { + typename Middle::MutableIterator gotit; + if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) { + float &backoff = gotit->MutableValue().backoff; + SetExtension(backoff); + blank.prob -= backoff; + } + middle[fix].Insert(Middle::Packing::Make(keys[fix], blank)); + backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]); + } +} + +template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector &middle, Activate activate, Store &store, PositiveProbWarn &warn) { + ReadNGramHeader(f, n); // vocab ids of words in reverse order WordIndex vocab_ids[n]; uint64_t keys[n - 1]; typename Store::Packing::Value value; - typename Middle::ConstIterator found; + typename Middle::MutableIterator found; for (size_t i = 0; i < count; ++i) { ReadNGram(f, n, vocab, vocab_ids, value, warn); + keys[0] = detail::CombineWordHash(static_cast(*vocab_ids), vocab_ids[1]); for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); } + // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. + util::SetSign(value.prob); store.Insert(Store::Packing::Make(keys[n-2], value)); - // Go back and insert blanks. - for (int lower = n - 3; lower >= 0; --lower) { - if (middle[lower].Find(keys[lower], found)) break; - middle[lower].Insert(Middle::Packing::Make(keys[lower], blank)); + // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. + int lower; + util::FloatEnc fix_prob; + for (lower = n - 3; ; --lower) { + if (lower == -1) { + fix_prob.f = unigrams[vocab_ids[0]].prob; + fix_prob.i &= ~util::kSignBit; + unigrams[vocab_ids[0]].prob = fix_prob.f; + break; + } + if (middle[lower].UnsafeMutableFind(keys[lower], found)) { + // Turn off sign bit to indicate that it extends left. + fix_prob.f = found->MutableValue().prob; + fix_prob.i &= ~util::kSignBit; + found->MutableValue().prob = fix_prob.f; + // We don't need to recurse further down because this entry already set the bits for lower entries. + break; + } } + if (lower != static_cast(n) - 3) FixSRI(lower, fix_prob.f, n, keys, vocab_ids, unigrams, middle); activate(vocab_ids, n); } @@ -107,15 +156,15 @@ template template void TemplateHashe try { if (counts.size() > 2) { - ReadNGrams(f, 2, counts[1], vocab, middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); + ReadNGrams(f, 2, counts[1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); } for (unsigned int n = 3; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle_, ActivateLowerMiddle(middle_[n-3]), middle_[n-2], warn); + ReadNGrams(f, n, counts[n-1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle(middle_[n-3]), middle_[n-2], warn); } if (counts.size() > 2) { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateLowerMiddle(middle_.back()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle(middle_.back()), longest, warn); } else { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateUnigram(unigram.Raw()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), longest, warn); } } catch (util::ProbingSizeException &e) { UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); @@ -133,7 +182,7 @@ template void TemplateHashedSearch; -template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); +template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); } // namespace detail } // namespace ngram diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index c62985e4..e289fd11 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -1,15 +1,18 @@ #ifndef LM_SEARCH_HASHED__ #define LM_SEARCH_HASHED__ -#include "lm/binary_format.hh" +#include "lm/model_type.hh" #include "lm/config.hh" #include "lm/read_arpa.hh" +#include "lm/return.hh" #include "lm/weights.hh" +#include "util/bit_packing.hh" #include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" #include +#include #include namespace util { class FilePiece; } @@ -52,9 +55,14 @@ struct HashedSearch { Unigram unigram; - void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const { const ProbBackoff &entry = unigram.Lookup(word); - prob = entry.prob; + util::FloatEnc val; + val.f = entry.prob; + ret.independent_left = (val.i & util::kSignBit); + ret.extend_left = static_cast(word); + val.i |= util::kSignBit; + ret.prob = val.f; backoff = entry.backoff; next = static_cast(word); } @@ -67,6 +75,8 @@ template class TemplateHashedSearch : public Has typedef LongestT Longest; Longest longest; + static const unsigned int kVersion = 0; + // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} @@ -85,11 +95,33 @@ template class TemplateHashedSearch : public Has const Middle *MiddleBegin() const { return &*middle_.begin(); } const Middle *MiddleEnd() const { return &*middle_.end(); } - bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { + Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { + util::FloatEnc val; + if (extend_length == 1) { + val.f = unigram.Lookup(static_cast(extend_pointer)).prob; + } else { + typename Middle::ConstIterator found; + if (!middle_[extend_length - 2].Find(extend_pointer, found)) { + std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl; + abort(); + } + val.f = found->GetValue().prob; + } + val.i |= util::kSignBit; + prob = val.f; + return extend_pointer; + } + + bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { node = CombineWordHash(node, word); typename Middle::ConstIterator found; if (!middle.Find(node, found)) return false; - prob = found->GetValue().prob; + util::FloatEnc enc; + enc.f = found->GetValue().prob; + ret.independent_left = (enc.i & util::kSignBit); + ret.extend_left = node; + enc.i |= util::kSignBit; + ret.prob = enc.f; backoff = found->GetValue().backoff; return true; } @@ -105,6 +137,7 @@ template class TemplateHashedSearch : public Has } bool LookupLongest(WordIndex word, float &prob, Node &node) const { + // Sign bit is always on because longest n-grams do not extend left. node = CombineWordHash(node, word); typename Longest::ConstIterator found; if (!longest.Find(node, found)) return false; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 05059ffb..6479813b 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -2,26 +2,25 @@ #include "lm/search_trie.hh" #include "lm/bhiksha.hh" +#include "lm/binary_format.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" #include "lm/quantize.hh" -#include "lm/read_arpa.hh" #include "lm/trie.hh" +#include "lm/trie_sort.hh" #include "lm/vocab.hh" #include "lm/weights.hh" #include "lm/word_index.hh" #include "util/ersatz_progress.hh" -#include "util/file_piece.hh" -#include "util/have.hh" #include "util/proxy_iterator.hh" #include "util/scoped.hh" +#include "util/sized_iterator.hh" #include -#include #include #include -#include +#include #include #include #include @@ -29,575 +28,221 @@ #include #include #include -#include -#include -#include namespace lm { namespace ngram { namespace trie { namespace { -/* An entry is a n-gram with probability. It consists of: - * WordIndex[order] - * float probability - * backoff probability (omitted for highest order n-gram) - * These are stored consecutively in memory. We want to sort them. - * - * The problem is the length depends on order (but all n-grams being compared - * have the same order). Allocating each entry on the heap (i.e. std::vector - * or std::string) then sorting pointers is the normal solution. But that's - * too memory inefficient. A lot of this code is just here to force std::sort - * to work with records where length is specified at runtime (and avoid using - * Boost for LM code). I could have used qsort, but the point is to also - * support __gnu_cxx:parallel_sort which doesn't have a qsort version. - */ - -class EntryIterator { - public: - EntryIterator() {} - - EntryIterator(void *ptr, std::size_t size) : ptr_(static_cast(ptr)), size_(size) {} - - bool operator==(const EntryIterator &other) const { - return ptr_ == other.ptr_; - } - bool operator<(const EntryIterator &other) const { - return ptr_ < other.ptr_; - } - EntryIterator &operator+=(std::ptrdiff_t amount) { - ptr_ += amount * size_; - return *this; - } - std::ptrdiff_t operator-(const EntryIterator &other) const { - return (ptr_ - other.ptr_) / size_; - } - - const void *Data() const { return ptr_; } - void *Data() { return ptr_; } - std::size_t EntrySize() const { return size_; } - - private: - uint8_t *ptr_; - std::size_t size_; -}; - -class EntryProxy { - public: - EntryProxy() {} - - EntryProxy(void *ptr, std::size_t size) : inner_(ptr, size) {} - - operator std::string() const { - return std::string(reinterpret_cast(inner_.Data()), inner_.EntrySize()); - } - - EntryProxy &operator=(const EntryProxy &from) { - memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize()); - return *this; - } - - EntryProxy &operator=(const std::string &from) { - memcpy(inner_.Data(), from.data(), inner_.EntrySize()); - return *this; - } - - const WordIndex *Indices() const { - return reinterpret_cast(inner_.Data()); - } - - private: - friend class util::ProxyIterator; - - typedef std::string value_type; - - typedef EntryIterator InnerIterator; - InnerIterator &Inner() { return inner_; } - const InnerIterator &Inner() const { return inner_; } - InnerIterator inner_; -}; - -typedef util::ProxyIterator NGramIter; - -// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams. -class PartialViewProxy { - public: - PartialViewProxy() : attention_size_(0), inner_() {} - - PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {} - - operator std::string() const { - return std::string(reinterpret_cast(inner_.Data()), attention_size_); - } - - PartialViewProxy &operator=(const PartialViewProxy &from) { - memcpy(inner_.Data(), from.inner_.Data(), attention_size_); - return *this; - } - - PartialViewProxy &operator=(const std::string &from) { - memcpy(inner_.Data(), from.data(), attention_size_); - return *this; - } - - const WordIndex *Indices() const { - return reinterpret_cast(inner_.Data()); - } - - private: - friend class util::ProxyIterator; - - typedef std::string value_type; - - const std::size_t attention_size_; - - typedef EntryIterator InnerIterator; - InnerIterator &Inner() { return inner_; } - const InnerIterator &Inner() const { return inner_; } - InnerIterator inner_; -}; - -typedef util::ProxyIterator PartialIter; - -template class CompareRecords : public std::binary_function { - public: - explicit CompareRecords(unsigned char order) : order_(order) {} - - bool operator()(const Proxy &first, const Proxy &second) const { - return Compare(first.Indices(), second.Indices()); - } - bool operator()(const Proxy &first, const std::string &second) const { - return Compare(first.Indices(), reinterpret_cast(second.data())); - } - bool operator()(const std::string &first, const Proxy &second) const { - return Compare(reinterpret_cast(first.data()), second.Indices()); - } - bool operator()(const std::string &first, const std::string &second) const { - return Compare(reinterpret_cast(first.data()), reinterpret_cast(second.data())); - } - - private: - bool Compare(const WordIndex *first, const WordIndex *second) const { - const WordIndex *end = first + order_; - for (; first != end; ++first, ++second) { - if (*first < *second) return true; - if (*first > *second) return false; - } - return false; - } - - unsigned char order_; -}; - -FILE *OpenOrThrow(const char *name, const char *mode) { - FILE *ret = fopen(name, mode); - if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode); - return ret; -} - -void WriteOrThrow(FILE *to, const void *data, size_t size) { - assert(size); - if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); -} - void ReadOrThrow(FILE *from, void *data, size_t size) { - if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size); -} - -const std::size_t kCopyBufSize = 512; -void CopyOrThrow(FILE *from, FILE *to, size_t size) { - char buf[std::min(size, kCopyBufSize)]; - for (size_t i = 0; i < size; i += kCopyBufSize) { - std::size_t amount = std::min(size - i, kCopyBufSize); - ReadOrThrow(from, buf, amount); - WriteOrThrow(to, buf, amount); - } + UTIL_THROW_IF(1 != std::fread(data, size, 1, from), util::ErrnoException, "Short read"); } -void CopyRestOrThrow(FILE *from, FILE *to) { - char buf[kCopyBufSize]; - size_t amount; - while ((amount = fread(buf, 1, kCopyBufSize, from))) { - WriteOrThrow(to, buf, amount); +int Compare(unsigned char order, const void *first_void, const void *second_void) { + const WordIndex *first = reinterpret_cast(first_void), *second = reinterpret_cast(second_void); + const WordIndex *end = first + order; + for (; first != end; ++first, ++second) { + if (*first < *second) return -1; + if (*first > *second) return 1; } - if (!feof(from)) UTIL_THROW(util::ErrnoException, "Short read"); -} - -void RemoveOrThrow(const char *name) { - if (std::remove(name)) UTIL_THROW(util::ErrnoException, "Could not remove " << name); + return 0; } -std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) { - const std::size_t entry_size = sizeof(WordIndex) * order + weights_size; - const std::size_t prefix_size = sizeof(WordIndex) * (order - 1); - std::stringstream assembled; - assembled << file_prefix << static_cast(order) << '_' << batch; - std::string ret(assembled.str()); - util::scoped_FILE out(OpenOrThrow(ret.c_str(), "w")); - // Compress entries that being with the same (order-1) words. - for (const uint8_t *group_begin = static_cast(mem_begin); group_begin != static_cast(mem_end);) { - const uint8_t *group_end; - for (group_end = group_begin + entry_size; - (group_end != static_cast(mem_end)) && !memcmp(group_begin, group_end, prefix_size); - group_end += entry_size) {} - WriteOrThrow(out.get(), group_begin, prefix_size); - WordIndex group_size = (group_end - group_begin) / entry_size; - WriteOrThrow(out.get(), &group_size, sizeof(group_size)); - for (const uint8_t *i = group_begin; i != group_end; i += entry_size) { - WriteOrThrow(out.get(), i + prefix_size, sizeof(WordIndex)); - WriteOrThrow(out.get(), i + sizeof(WordIndex) * order, weights_size); - } - group_begin = group_end; - } - return ret; -} +struct ProbPointer { + unsigned char array; + uint64_t index; +}; -class SortedFileReader { +// Array of n-grams and float indices. +class BackoffMessages { public: - SortedFileReader() : ended_(false) {} - - void Init(const std::string &name, unsigned char order) { - file_.reset(OpenOrThrow(name.c_str(), "r")); - header_.resize(order - 1); - NextHeader(); + void Init(std::size_t entry_size) { + current_ = NULL; + allocated_ = NULL; + entry_size_ = entry_size; } - // Preceding words. - const WordIndex *Header() const { - return &*header_.begin(); - } - const std::vector &HeaderVector() const { return header_;} - - std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); } - - void NextHeader() { - if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get())) { - if (feof(file_.get())) { - ended_ = true; - } else { - UTIL_THROW(util::ErrnoException, "Short read of counts"); + void Add(const WordIndex *to, ProbPointer index) { + while (current_ + entry_size_ > allocated_) { + std::size_t allocated_size = allocated_ - (uint8_t*)backing_.get(); + Resize(std::max(allocated_size * 2, entry_size_)); + } + memcpy(current_, to, entry_size_ - sizeof(ProbPointer)); + *reinterpret_cast(current_ + entry_size_ - sizeof(ProbPointer)) = index; + current_ += entry_size_; + } + + void Apply(float *const *const base, FILE *unigrams) { + FinishedAdding(); + if (current_ == allocated_) return; + rewind(unigrams); + ProbBackoff weights; + WordIndex unigram = 0; + ReadOrThrow(unigrams, &weights, sizeof(weights)); + for (; current_ != allocated_; current_ += entry_size_) { + const WordIndex &cur_word = *reinterpret_cast(current_); + for (; unigram < cur_word; ++unigram) { + ReadOrThrow(unigrams, &weights, sizeof(weights)); } + if (!HasExtension(weights.backoff)) { + weights.backoff = kExtensionBackoff; + UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed."); + WriteOrThrow(unigrams, &weights, sizeof(weights)); + } + const ProbPointer &write_to = *reinterpret_cast(current_ + sizeof(WordIndex)); + base[write_to.array][write_to.index] += weights.backoff; } + backing_.reset(); + } + + void Apply(float *const *const base, RecordReader &reader) { + FinishedAdding(); + if (current_ == allocated_) return; + // We'll also use the same buffer to record messages to blanks that they extend. + WordIndex *extend_out = reinterpret_cast(current_); + const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex); + for (reader.Rewind(); reader && (current_ != allocated_); ) { + switch (Compare(order, reader.Data(), current_)) { + case -1: + ++reader; + break; + case 1: + // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. + for (const WordIndex *w = reinterpret_cast(current_); w != reinterpret_cast(current_) + order; ++w, ++extend_out) *extend_out = *w; + current_ += entry_size_; + break; + case 0: + float &backoff = reinterpret_cast((uint8_t*)reader.Data() + order * sizeof(WordIndex))->backoff; + if (!HasExtension(backoff)) { + backoff = kExtensionBackoff; + reader.Overwrite(&backoff, sizeof(float)); + } else { + const ProbPointer &write_to = *reinterpret_cast(current_ + entry_size_ - sizeof(ProbPointer)); + base[write_to.array][write_to.index] += backoff; + } + current_ += entry_size_; + break; + } + } + // Now this is a list of blanks that extend right. + entry_size_ = sizeof(WordIndex) * order; + Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get())); + current_ = (uint8_t*)backing_.get(); } - WordIndex ReadCount() { - WordIndex ret; - ReadOrThrow(file_.get(), &ret, sizeof(WordIndex)); - return ret; - } - - WordIndex ReadWord() { - WordIndex ret; - ReadOrThrow(file_.get(), &ret, sizeof(WordIndex)); - return ret; - } - - template void ReadWeights(Weights &weights) { - ReadOrThrow(file_.get(), &weights, sizeof(Weights)); + // Call after Apply + bool Extends(unsigned char order, const WordIndex *words) { + if (current_ == allocated_) return false; + assert(order * sizeof(WordIndex) == entry_size_); + while (true) { + switch(Compare(order, words, current_)) { + case 1: + current_ += entry_size_; + if (current_ == allocated_) return false; + break; + case -1: + return false; + case 0: + return true; + } + } } - bool Ended() const { - return ended_; + private: + void FinishedAdding() { + Resize(current_ - (uint8_t*)backing_.get()); + current_ = (uint8_t*)backing_.get(); } - void Rewind() { - rewind(file_.get()); - ended_ = false; - NextHeader(); + void Resize(std::size_t to) { + std::size_t current = current_ - (uint8_t*)backing_.get(); + backing_.call_realloc(to); + current_ = (uint8_t*)backing_.get() + current; + allocated_ = (uint8_t*)backing_.get() + to; } - FILE *File() { return file_.get(); } - - private: - util::scoped_FILE file_; + util::scoped_malloc backing_; - std::vector header_; + uint8_t *current_, *allocated_; - bool ended_; + std::size_t entry_size_; }; -void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) { - WriteOrThrow(to, from.Header(), from.HeaderBytes()); - WordIndex count = from.ReadCount(); - WriteOrThrow(to, &count, sizeof(WordIndex)); - - CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count); -} - -void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order) { - SortedFileReader first, second; - first.Init(first_name.c_str(), order); - RemoveOrThrow(first_name.c_str()); - second.Init(second_name.c_str(), order); - RemoveOrThrow(second_name.c_str()); - util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); - while (!first.Ended() && !second.Ended()) { - if (first.HeaderVector() < second.HeaderVector()) { - CopyFullRecord(first, out_file.get(), weights_size); - first.NextHeader(); - continue; - } - if (first.HeaderVector() > second.HeaderVector()) { - CopyFullRecord(second, out_file.get(), weights_size); - second.NextHeader(); - continue; - } - // Merge at the entry level. - WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes()); - WordIndex first_count = first.ReadCount(), second_count = second.ReadCount(); - WordIndex total_count = first_count + second_count; - WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex)); - - WordIndex first_word = first.ReadWord(), second_word = second.ReadWord(); - WordIndex first_index = 0, second_index = 0; - while (true) { - if (first_word < second_word) { - WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); - CopyOrThrow(first.File(), out_file.get(), weights_size); - if (++first_index == first_count) break; - first_word = first.ReadWord(); - } else { - WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); - CopyOrThrow(second.File(), out_file.get(), weights_size); - if (++second_index == second_count) break; - second_word = second.ReadWord(); - } - } - if (first_index == first_count) { - WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); - CopyOrThrow(second.File(), out_file.get(), (second_count - second_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); - } else { - WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); - CopyOrThrow(first.File(), out_file.get(), (first_count - first_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); - } - first.NextHeader(); - second.NextHeader(); - } - - for (SortedFileReader &remaining = first.Ended() ? second : first; !remaining.Ended(); remaining.NextHeader()) { - CopyFullRecord(remaining, out_file.get(), weights_size); - } -} - -const char *kContextSuffix = "_contexts"; - -void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { - const size_t context_size = sizeof(WordIndex) * (order - 1); - // Sort just the contexts using the same memory. - PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); - PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); - - std::sort(context_begin, context_end, CompareRecords(order - 1)); - - std::string name(ngram_file_name + kContextSuffix); - util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); - - // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. - if (context_begin == context_end) return; - PartialIter i(context_begin); - WriteOrThrow(out.get(), i->Indices(), context_size); - const WordIndex *previous = i->Indices(); - ++i; - for (; i != context_end; ++i) { - if (memcmp(previous, i->Indices(), context_size)) { - WriteOrThrow(out.get(), i->Indices(), context_size); - previous = i->Indices(); - } - } -} +const float kBadProb = std::numeric_limits::infinity(); -class ContextReader { +class SRISucks { public: - ContextReader() : valid_(false) {} - - ContextReader(const char *name, unsigned char order) { - Reset(name, order); - } - - void Reset(const char *name, unsigned char order) { - file_.reset(OpenOrThrow(name, "r")); - length_ = sizeof(WordIndex) * static_cast(order); - words_.resize(order); - valid_ = true; - ++*this; - } - - ContextReader &operator++() { - if (1 != fread(&*words_.begin(), length_, 1, file_.get())) { - if (!feof(file_.get())) - UTIL_THROW(util::ErrnoException, "Short read"); - valid_ = false; + SRISucks() { + for (BackoffMessages *i = messages_; i != messages_ + kMaxOrder - 1; ++i) + i->Init(sizeof(ProbPointer) + sizeof(WordIndex) * (i - messages_ + 1)); + } + + void Send(unsigned char begin, unsigned char order, const WordIndex *to, float prob_basis) { + assert(prob_basis != kBadProb); + ProbPointer pointer; + pointer.array = order - 1; + pointer.index = values_[order - 1].size(); + for (unsigned char i = begin; i < order; ++i) { + messages_[i - 1].Add(to, pointer); } - return *this; + values_[order - 1].push_back(prob_basis); } - const WordIndex *operator*() const { return &*words_.begin(); } - - operator bool() const { return valid_; } - - FILE *GetFile() { return file_.get(); } - - private: - util::scoped_FILE file_; - - size_t length_; - - std::vector words_; - - bool valid_; -}; - -void MergeContextFiles(const std::string &first_base, const std::string &second_base, const std::string &out_base, unsigned char order) { - const size_t context_size = sizeof(WordIndex) * (order - 1); - std::string first_name(first_base + kContextSuffix); - std::string second_name(second_base + kContextSuffix); - ContextReader first(first_name.c_str(), order - 1), second(second_name.c_str(), order - 1); - RemoveOrThrow(first_name.c_str()); - RemoveOrThrow(second_name.c_str()); - std::string out_name(out_base + kContextSuffix); - util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w")); - while (first && second) { - for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) { - if (f == *first + order - 1) { - // Equal. - WriteOrThrow(out.get(), *first, context_size); - ++first; - ++second; - break; + void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) { + for (unsigned char i = 0; i < kMaxOrder - 1; ++i) { + it_[i] = &*values_[i].begin(); } - if (*f < *s) { - // First lower - WriteOrThrow(out.get(), *first, context_size); - ++first; - break; - } else if (*f > *s) { - WriteOrThrow(out.get(), *second, context_size); - ++second; - break; + messages_[0].Apply(it_, unigram_file); + BackoffMessages *messages = messages_ + 1; + const RecordReader *end = reader + total_order - 2 /* exclude unigrams and longest order */; + for (; reader != end; ++messages, ++reader) { + messages->Apply(it_, *reader); } } - } - ContextReader &remaining = first ? first : second; - if (!remaining) return; - WriteOrThrow(out.get(), *remaining, context_size); - CopyRestOrThrow(remaining.GetFile(), out.get()); -} -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { - ReadNGramHeader(f, order); - const size_t count = counts[order - 1]; - // Size of weights. Does it include backoff? - const size_t words_size = sizeof(WordIndex) * order; - const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); - const size_t entry_size = words_size + weights_size; - const size_t batch_size = std::min(count, mem.size() / entry_size); - uint8_t *const begin = reinterpret_cast(mem.get()); - std::deque files; - for (std::size_t batch = 0, done = 0; done < count; ++batch) { - uint8_t *out = begin; - uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; - if (order == counts.size()) { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); - } - } else { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); - } + ProbBackoff GetBlank(unsigned char total_order, unsigned char order, const WordIndex *indices) { + assert(order > 1); + ProbBackoff ret; + ret.prob = *(it_[order - 1]++); + ret.backoff = ((order != total_order - 1) && messages_[order - 1].Extends(order, indices)) ? kExtensionBackoff : kNoExtensionBackoff; + return ret; } - // Sort full records by full n-gram. - EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); - // parallel_sort uses too much RAM - std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order)); - files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size)); - WriteContextFile(begin, out_end, files.back(), entry_size, order); - - done += (out_end - begin) / entry_size; - } - // All individual files created. Merge them. - - std::size_t merge_count = 0; - while (files.size() > 1) { - std::stringstream assembled; - assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); - files.push_back(assembled.str()); - MergeSortedFiles(files[0], files[1], files.back(), weights_size, order); - MergeContextFiles(files[0], files[1], files.back(), order); - files.pop_front(); - files.pop_front(); - } - if (!files.empty()) { - std::stringstream assembled; - assembled << file_prefix << static_cast(order) << "_merged"; - std::string merged_name(assembled.str()); - if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); - std::string context_name = files[0] + kContextSuffix; - merged_name += kContextSuffix; - if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); - } -} - -void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { - PositiveProbWarn warn(config.positive_log_probability); - { - std::string unigram_name = file_prefix + "unigrams"; - util::scoped_fd unigram_file; - // In case appears. - size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); - Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); - CheckSpecials(config, vocab); - if (!vocab.SawUnk()) ++counts[0]; - } + const std::vector &Values(unsigned char order) const { + return values_[order - 1]; + } - // Only use as much buffer as we need. - size_t buffer_use = 0; - for (unsigned int order = 2; order < counts.size(); ++order) { - buffer_use = std::max(buffer_use, static_cast((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); - } - buffer_use = std::max(buffer_use, static_cast((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); - buffer = std::min(buffer, buffer_use); + private: + // This used to be one array. Then I needed to separate it by order for quantization to work. + std::vector values_[kMaxOrder - 1]; + BackoffMessages messages_[kMaxOrder - 1]; - util::scoped_memory mem; - mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); - if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); + float *it_[kMaxOrder - 1]; +}; - for (unsigned char order = 2; order <= counts.size(); ++order) { - ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); - } - ReadEnd(f); -} +class FindBlanks { + public: + FindBlanks(uint64_t *counts, unsigned char order, const ProbBackoff *unigrams, SRISucks &messages) + : counts_(counts), longest_counts_(counts + order - 1), unigrams_(unigrams), sri_(messages) {} -bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) { - for (; words != words_end; ++words, ++header) { - if (*words != *header) { - //assert(*words <= *header); - return false; + float UnigramProb(WordIndex index) const { + return unigrams_[index].prob; } - } - return true; -} -// Phase to count n-grams, including blanks inserted because they were pruned but have extensions -class JustCount { - public: - template JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order) - : counts_(counts), longest_counts_(counts + order - 1) {} - - void Unigrams(WordIndex begin, WordIndex end) { - counts_[0] += end - begin; + void Unigram(WordIndex /*index*/) { + ++counts_[0]; } - void MiddleBlank(const unsigned char mid_idx, WordIndex /* idx */) { - ++counts_[mid_idx + 1]; + void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_basis) { + sri_.Send(lower, order, indices + 1, prob_basis); + ++counts_[order - 1]; } - void Middle(const unsigned char mid_idx, const WordIndex * /*before*/, WordIndex /*key*/, const ProbBackoff &/*weights*/) { - ++counts_[mid_idx + 1]; + void Middle(const unsigned char order, const void * /*data*/) { + ++counts_[order - 1]; } - void Longest(WordIndex /*key*/, Prob /*prob*/) { + void Longest(const void * /*data*/) { ++*longest_counts_; } @@ -608,167 +253,156 @@ class JustCount { private: uint64_t *const counts_, *const longest_counts_; + + const ProbBackoff *unigrams_; + + SRISucks &sri_; }; // Phase to actually write n-grams to the trie. template class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : contexts_(contexts), unigrams_(unigrams), middle_(middle), longest_(longest), - bigram_pack_((order == 2) ? static_cast(longest_) : static_cast(*middle_)) {} + bigram_pack_((order == 2) ? static_cast(longest_) : static_cast(*middle_)), + order_(order), + sri_(sri) {} - void Unigrams(WordIndex begin, WordIndex end) { - uint64_t next = bigram_pack_.InsertIndex(); - for (UnigramValue *i = unigrams_ + begin; i < unigrams_ + end; ++i) { - i->next = next; - } + float UnigramProb(WordIndex index) const { return unigrams_[index].weights.prob; } + + void Unigram(WordIndex word) { + unigrams_[word].next = bigram_pack_.InsertIndex(); } - void MiddleBlank(const unsigned char mid_idx, WordIndex key) { - middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff); + void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char /*lower*/, float /*prob_base*/) { + ProbBackoff weights = sri_.GetBlank(order_, order, indices); + middle_[order - 2].Insert(indices[order - 1], weights.prob, weights.backoff); } - void Middle(const unsigned char mid_idx, const WordIndex *before, WordIndex key, ProbBackoff weights) { - // Order (mid_idx+2). - ContextReader &context = contexts_[mid_idx + 1]; - if (context && !memcmp(before, *context, sizeof(WordIndex) * (mid_idx + 1)) && (*context)[mid_idx + 1] == key) { + void Middle(const unsigned char order, const void *data) { + RecordReader &context = contexts_[order - 1]; + const WordIndex *words = reinterpret_cast(data); + ProbBackoff weights = *reinterpret_cast(words + order); + if (context && !memcmp(data, context.Data(), sizeof(WordIndex) * order)) { SetExtension(weights.backoff); ++context; } - middle_[mid_idx].Insert(key, weights.prob, weights.backoff); + middle_[order - 2].Insert(words[order - 1], weights.prob, weights.backoff); } - void Longest(WordIndex key, Prob prob) { - longest_.Insert(key, prob.prob); + void Longest(const void *data) { + const WordIndex *words = reinterpret_cast(data); + longest_.Insert(words[order_ - 1], reinterpret_cast(words + order_)->prob); } void Cleanup() {} private: - ContextReader *contexts_; + RecordReader *contexts_; UnigramValue *const unigrams_; BitPackedMiddle *const middle_; BitPackedLongest &longest_; BitPacked &bigram_pack_; + const unsigned char order_; + SRISucks &sri_; }; -template class RecursiveInsert { - public: - template RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) : - doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { - } +struct Gram { + Gram(const WordIndex *in_begin, unsigned char order) : begin(in_begin), end(in_begin + order) {} - // Outer unigram loop. - void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) { - util::ErsatzProgress progress(progress_out, message, unigram_count + 1); - for (words_[0] = 0; ; ++words_[0]) { - progress.Set(words_[0]); - WordIndex min_continue = unigram_count; - for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) { - if (other->Ended()) continue; - min_continue = std::min(min_continue, other->Header()[0]); - } - // This will write at unigram_count. This is by design so that the next pointers will make sense. - doing_.Unigrams(words_[0], min_continue + 1); - if (min_continue == unigram_count) break; - words_[0] = min_continue; - Middle(0); - } - doing_.Cleanup(); - } + const WordIndex *begin, *end; - private: - void Middle(const unsigned char mid_idx) { - // (mid_idx + 2)-gram. - if (mid_idx == order_minus_2_) { - Longest(); - return; - } - // Orders [2, order) + // For queue, this is the direction we want. + bool operator<(const Gram &other) const { + return std::lexicographical_compare(other.begin, other.end, begin, end); + } +}; - SortedFileReader &reader = inputs_[mid_idx]; +template class BlankManager { + public: + BlankManager(unsigned char total_order, Doing &doing) : total_order_(total_order), been_length_(0), doing_(doing) { + for (float *i = basis_; i != basis_ + kMaxOrder - 1; ++i) *i = kBadProb; + } - if (reader.Ended() || !HeadMatch(words_, words_ + mid_idx + 1, reader.Header())) { - // This order doesn't have a header match, but longer ones might. - MiddleAllBlank(mid_idx); - return; + void Visit(const WordIndex *to, unsigned char length, float prob) { + basis_[length - 1] = prob; + unsigned char overlap = std::min(length - 1, been_length_); + const WordIndex *cur; + WordIndex *pre; + for (cur = to, pre = been_; cur != to + overlap; ++cur, ++pre) { + if (*pre != *cur) break; } - - // There is a header match. - WordIndex count = reader.ReadCount(); - WordIndex current = reader.ReadWord(); - while (count) { - WordIndex min_continue = std::numeric_limits::max(); - for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) - min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); - } - while (true) { - if (current > min_continue) { - doing_.MiddleBlank(mid_idx, min_continue); - words_[mid_idx + 1] = min_continue; - Middle(mid_idx + 1); - break; - } - ProbBackoff weights; - reader.ReadWeights(weights); - doing_.Middle(mid_idx, words_, current, weights); - --count; - if (current == min_continue) { - words_[mid_idx + 1] = min_continue; - Middle(mid_idx + 1); - if (count) current = reader.ReadWord(); - break; - } - if (!count) break; - current = reader.ReadWord(); - } + if (cur == to + length - 1) { + *pre = *cur; + been_length_ = length; + return; } - // Count is now zero. Finish off remaining blanks. - MiddleAllBlank(mid_idx); - reader.NextHeader(); - } - - void MiddleAllBlank(const unsigned char mid_idx) { - while (true) { - WordIndex min_continue = std::numeric_limits::max(); - for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) - min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); - } - if (min_continue == std::numeric_limits::max()) return; - doing_.MiddleBlank(mid_idx, min_continue); - words_[mid_idx + 1] = min_continue; - Middle(mid_idx + 1); + // There are blanks to insert starting with order blank. + unsigned char blank = cur - to + 1; + UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context."); + const float *lower_basis; + for (lower_basis = basis_ + blank - 2; *lower_basis == kBadProb; --lower_basis) {} + unsigned char based_on = lower_basis - basis_ + 1; + for (; cur != to + length - 1; ++blank, ++cur, ++pre) { + assert(*lower_basis != kBadProb); + doing_.MiddleBlank(blank, to, based_on, *lower_basis); + *pre = *cur; + // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. + basis_[blank - 1] = kBadProb; } + been_length_ = length; } - void Longest() { - SortedFileReader &reader = *(inputs_end_ - 1); - if (reader.Ended() || !HeadMatch(words_, words_ + order_minus_2_ + 1, reader.Header())) return; - WordIndex count = reader.ReadCount(); - for (WordIndex i = 0; i < count; ++i) { - WordIndex word = reader.ReadWord(); - Prob prob; - reader.ReadWeights(prob); - doing_.Longest(word, prob); - } - reader.NextHeader(); - return; - } + private: + const unsigned char total_order_; - Doing doing_; + WordIndex been_[kMaxOrder]; + unsigned char been_length_; - SortedFileReader *inputs_; - SortedFileReader *inputs_end_; + float basis_[kMaxOrder]; + + Doing &doing_; +}; - WordIndex words_[kMaxOrder]; +template void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) { + util::ErsatzProgress progress(progress_out, message, unigram_count + 1); + unsigned int unigram = 0; + std::priority_queue grams; + grams.push(Gram(&unigram, 1)); + for (unsigned char i = 2; i <= total_order; ++i) { + if (input[i-2]) grams.push(Gram(reinterpret_cast(input[i-2].Data()), i)); + } - const unsigned char order_minus_2_; -}; + BlankManager blank(total_order, doing); + + while (true) { + Gram top = grams.top(); + grams.pop(); + unsigned char order = top.end - top.begin; + if (order == 1) { + blank.Visit(&unigram, 1, doing.UnigramProb(unigram)); + doing.Unigram(unigram); + progress.Set(unigram); + if (++unigram == unigram_count + 1) break; + grams.push(top); + } else { + if (order == total_order) { + blank.Visit(top.begin, order, reinterpret_cast(top.end)->prob); + doing.Longest(top.begin); + } else { + blank.Visit(top.begin, order, reinterpret_cast(top.end)->prob); + doing.Middle(order, top.begin); + } + RecordReader &reader = input[order - 2]; + if (++reader) grams.push(top); + } + } + assert(grams.empty()); + doing.Cleanup(); +} void SanityCheckCounts(const std::vector &initial, const std::vector &fixed) { if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]); @@ -778,120 +412,122 @@ void SanityCheckCounts(const std::vector &initial, const std::vector void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { - ProbBackoff weights; - std::vector probs, backoffs; - probs.reserve(count); +template void TrainQuantizer(uint8_t order, uint64_t count, const std::vector &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { + std::vector probs(additional), backoffs; + probs.reserve(count + additional.size()); backoffs.reserve(count); - for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { - uint64_t entries = reader.ReadCount(); - for (uint64_t c = 0; c < entries; ++c) { - reader.ReadWord(); - reader.ReadWeights(weights); - // kBlankProb isn't added yet. - probs.push_back(weights.prob); - if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); - ++progress; - } + for (reader.Rewind(); reader; ++reader) { + const ProbBackoff &weights = *reinterpret_cast(reinterpret_cast(reader.Data()) + sizeof(WordIndex) * order); + probs.push_back(weights.prob); + if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); + ++progress; } quant.Train(order, probs, backoffs); } -template void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { - Prob weights; +template void TrainProbQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { std::vector probs, backoffs; probs.reserve(count); - for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { - uint64_t entries = reader.ReadCount(); - for (uint64_t c = 0; c < entries; ++c) { - reader.ReadWord(); - reader.ReadWeights(weights); - // kBlankProb isn't added yet. - probs.push_back(weights.prob); - ++progress; - } + for (reader.Rewind(); reader; ++reader) { + const Prob &weights = *reinterpret_cast(reinterpret_cast(reader.Data()) + sizeof(WordIndex) * order); + probs.push_back(weights.prob); + ++progress; } quant.TrainProb(order, probs); } +void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { + // Fill unigram probabilities. + try { + rewind(file); + for (WordIndex i = 0; i < unigram_count; ++i) { + ReadOrThrow(file, &unigrams[i].weights, sizeof(ProbBackoff)); + if (contexts && *reinterpret_cast(contexts.Data()) == i) { + SetExtension(unigrams[i].weights.backoff); + ++contexts; + } + } + } catch (util::Exception &e) { + e << " while re-reading unigram probabilities"; + throw; + } +} + } // namespace template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { - std::vector inputs(counts.size() - 1); - std::vector contexts(counts.size() - 1); + RecordReader inputs[kMaxOrder - 1]; + RecordReader contexts[kMaxOrder - 1]; for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; assembled << file_prefix << static_cast(i) << "_merged"; - inputs[i-2].Init(assembled.str(), i); - RemoveOrThrow(assembled.str().c_str()); + inputs[i-2].Init(assembled.str(), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff))); + util::RemoveOrThrow(assembled.str().c_str()); assembled << kContextSuffix; - contexts[i-2].Reset(assembled.str().c_str(), i-1); - RemoveOrThrow(assembled.str().c_str()); + contexts[i-2].Init(assembled.str(), (i-1) * sizeof(WordIndex)); + util::RemoveOrThrow(assembled.str().c_str()); } + SRISucks sri; std::vector fixed_counts(counts.size()); { - RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); - counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); + std::string temp(file_prefix); temp += "unigrams"; + util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str())); + util::scoped_memory unigrams; + MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); + FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast(unigrams.get()), sri); + RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); } - for (std::vector::const_iterator i = inputs.begin(); i != inputs.end(); ++i) { - if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs.begin() + 2) << "-gram table did not complete reading"); + for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) { + if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); } SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; + util::scoped_FILE unigram_file; + { + std::string name(file_prefix + "unigrams"); + unigram_file.reset(OpenOrThrow(name.c_str(), "r")); + util::RemoveOrThrow(name.c_str()); + } + sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); + out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + for (unsigned char i = 2; i <= counts.size(); ++i) { + inputs[i-2].Rewind(); + } if (Quant::kTrain) { util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); for (unsigned char i = 2; i < counts.size(); ++i) { - TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant); + TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); } TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant); quant.FinishedLoading(config); } + UnigramValue *unigrams = out.unigram.Raw(); + PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams); + unigram_file.reset(); + for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); } - UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); - inserter.Apply(config.messages, "Building trie", fixed_counts[0]); - } - - // Fill unigram probabilities. - try { - std::string name(file_prefix + "unigrams"); - util::scoped_FILE file(OpenOrThrow(name.c_str(), "r")); - for (WordIndex i = 0; i < counts[0]; ++i) { - ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); - if (contexts[0] && **contexts[0] == i) { - SetExtension(unigrams[i].weights.backoff); - ++contexts[0]; - } - } - RemoveOrThrow(name.c_str()); - } catch (util::Exception &e) { - e << " while re-reading unigram probabilities"; - throw; + WriteEntries writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri); + RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. for (unsigned char order = 2; order <= counts.size(); ++order) { - const ContextReader &context = contexts[order - 2]; + const RecordReader &context = contexts[order - 2]; if (context) { FormatLoadException e; - e << "An " << static_cast(order) << "-gram has the context (i.e. all but the last word):"; - for (const WordIndex *i = *context; i != *context + order - 1; ++i) { + e << "An " << static_cast(order) << "-gram has context"; + const WordIndex *ctx = reinterpret_cast(context.Data()); + for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) { e << ' ' << *i; } e << " so this context must appear in the model as a " << static_cast(order - 1) << "-gram but it does not"; @@ -945,6 +581,14 @@ template void TrieSearch::LoadedBin longest.LoadedBinary(); } +namespace { +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} +} // namespace + template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 2f39c09f..c3e02a98 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -1,10 +1,16 @@ #ifndef LM_SEARCH_TRIE__ #define LM_SEARCH_TRIE__ -#include "lm/binary_format.hh" +#include "lm/config.hh" +#include "lm/model_type.hh" +#include "lm/return.hh" #include "lm/trie.hh" #include "lm/weights.hh" +#include "util/file_piece.hh" + +#include + #include namespace lm { @@ -30,6 +36,8 @@ template class TrieSearch { static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); + static const unsigned int kVersion = 0; + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); @@ -57,12 +65,16 @@ template class TrieSearch { void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - unigram.Find(word, prob, backoff, node); + void LookupUnigram(WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { + unigram.Find(word, ret.prob, backoff, node); + ret.independent_left = (node.begin == node.end); + ret.extend_left = static_cast(word); } - bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { - return mid.Find(word, prob, backoff, node); + bool LookupMiddle(const Middle &mid, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { + if (!mid.Find(word, ret.prob, backoff, node, ret.extend_left)) return false; + ret.independent_left = (node.begin == node.end); + return true; } bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { @@ -76,14 +88,25 @@ template class TrieSearch { bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { // TODO: don't decode backoff. assert(begin != end); - float ignored_prob, ignored_backoff; - LookupUnigram(*begin, ignored_prob, ignored_backoff, node); + FullScoreReturn ignored; + float ignored_backoff; + LookupUnigram(*begin, ignored_backoff, node, ignored); for (const WordIndex *i = begin + 1; i < end; ++i) { if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; } return true; } + Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { + if (extend_length == 1) { + float ignored; + Node ret; + unigram.Find(static_cast(extend_pointer), prob, ignored, ret); + return ret; + } + return middle_begin_[extend_length - 2].ReadEntry(extend_pointer, prob); + } + private: friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 8c536e66..4e60b184 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -86,7 +86,7 @@ template void BitPackedMiddle::Inse ++insert_index_; } -template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const { uint64_t at_pointer; if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; @@ -94,6 +94,9 @@ template bool BitPackedMiddle::Find uint64_t index = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; + + pointer = at_pointer; + quant_.Read(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 53612064..a9f5e417 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -94,10 +94,18 @@ template class BitPackedMiddle : public BitPacked { void LoadedBinary() { bhiksha_.LoadedBinary(); } - bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; + bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const; bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; + NodeRange ReadEntry(uint64_t pointer, float &prob) { + quant_.ReadProb(base_, pointer, prob); + NodeRange ret; + // pointer/total_bits_ should always round down. + bhiksha_.ReadNext(base_, pointer + quant_.TotalBits(), pointer / total_bits_, total_bits_, ret); + return ret; + } + private: Quant quant_; Bhiksha bhiksha_; diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc new file mode 100644 index 00000000..01c4e490 --- /dev/null +++ b/klm/lm/trie_sort.cc @@ -0,0 +1,261 @@ +#include "lm/trie_sort.hh" + +#include "lm/config.hh" +#include "lm/lm_exception.hh" +#include "lm/read_arpa.hh" +#include "lm/vocab.hh" +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/file_piece.hh" +#include "util/mmap.hh" +#include "util/proxy_iterator.hh" +#include "util/sized_iterator.hh" + +#include +#include +#include +#include +#include +#include + +namespace lm { +namespace ngram { +namespace trie { + +const char *kContextSuffix = "_contexts"; + +FILE *OpenOrThrow(const char *name, const char *mode) { + FILE *ret = fopen(name, mode); + if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode); + return ret; +} + +void WriteOrThrow(FILE *to, const void *data, size_t size) { + assert(size); + if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); +} + +namespace { + +typedef util::SizedIterator NGramIter; + +// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams. +class PartialViewProxy { + public: + PartialViewProxy() : attention_size_(0), inner_() {} + + PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {} + + operator std::string() const { + return std::string(reinterpret_cast(inner_.Data()), attention_size_); + } + + PartialViewProxy &operator=(const PartialViewProxy &from) { + memcpy(inner_.Data(), from.inner_.Data(), attention_size_); + return *this; + } + + PartialViewProxy &operator=(const std::string &from) { + memcpy(inner_.Data(), from.data(), attention_size_); + return *this; + } + + const void *Data() const { return inner_.Data(); } + void *Data() { return inner_.Data(); } + + private: + friend class util::ProxyIterator; + + typedef std::string value_type; + + const std::size_t attention_size_; + + typedef util::SizedInnerIterator InnerIterator; + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +typedef util::ProxyIterator PartialIter; + +std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order) { + std::stringstream assembled; + assembled << file_prefix << static_cast(order) << '_' << batch; + std::string ret(assembled.str()); + util::scoped_fd out(util::CreateOrThrow(ret.c_str())); + util::WriteOrThrow(out.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); + return ret; +} + +void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { + const size_t context_size = sizeof(WordIndex) * (order - 1); + // Sort just the contexts using the same memory. + PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); + PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); + + std::sort(context_begin, context_end, util::SizedCompare(EntryCompare(order - 1))); + + std::string name(ngram_file_name + kContextSuffix); + util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); + + // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. + if (context_begin == context_end) return; + PartialIter i(context_begin); + WriteOrThrow(out.get(), i->Data(), context_size); + const void *previous = i->Data(); + ++i; + for (; i != context_end; ++i) { + if (memcmp(previous, i->Data(), context_size)) { + WriteOrThrow(out.get(), i->Data(), context_size); + previous = i->Data(); + } + } +} + +struct ThrowCombine { + void operator()(std::size_t /*entry_size*/, const void * /*first*/, const void * /*second*/, FILE * /*out*/) const { + UTIL_THROW(FormatLoadException, "Duplicate n-gram detected."); + } +}; + +// Useful for context files that just contain records with no value. +struct FirstCombine { + void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const { + WriteOrThrow(out, first, entry_size); + } +}; + +template void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order, const Combine &combine = ThrowCombine()) { + std::size_t entry_size = sizeof(WordIndex) * order + weights_size; + RecordReader first, second; + first.Init(first_name.c_str(), entry_size); + util::RemoveOrThrow(first_name.c_str()); + second.Init(second_name.c_str(), entry_size); + util::RemoveOrThrow(second_name.c_str()); + util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); + EntryCompare less(order); + while (first && second) { + if (less(first.Data(), second.Data())) { + WriteOrThrow(out_file.get(), first.Data(), entry_size); + ++first; + } else if (less(second.Data(), first.Data())) { + WriteOrThrow(out_file.get(), second.Data(), entry_size); + ++second; + } else { + combine(entry_size, first.Data(), second.Data(), out_file.get()); + ++first; ++second; + } + } + for (RecordReader &remains = (first ? second : first); remains; ++remains) { + WriteOrThrow(out_file.get(), remains.Data(), entry_size); + } +} + +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { + ReadNGramHeader(f, order); + const size_t count = counts[order - 1]; + // Size of weights. Does it include backoff? + const size_t words_size = sizeof(WordIndex) * order; + const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); + const size_t entry_size = words_size + weights_size; + const size_t batch_size = std::min(count, mem.size() / entry_size); + uint8_t *const begin = reinterpret_cast(mem.get()); + std::deque files; + for (std::size_t batch = 0, done = 0; done < count; ++batch) { + uint8_t *out = begin; + uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; + if (order == counts.size()) { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); + } + } else { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); + } + } + // Sort full records by full n-gram. + util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); + // parallel_sort uses too much RAM + std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare(EntryCompare(order))); + files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order)); + WriteContextFile(begin, out_end, files.back(), entry_size, order); + + done += (out_end - begin) / entry_size; + } + + // All individual files created. Merge them. + + std::size_t merge_count = 0; + while (files.size() > 1) { + std::stringstream assembled; + assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); + files.push_back(assembled.str()); + MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); + MergeSortedFiles(files[0], files[1], files.back(), 0, order, FirstCombine()); + files.pop_front(); + files.pop_front(); + } + if (!files.empty()) { + std::stringstream assembled; + assembled << file_prefix << static_cast(order) << "_merged"; + std::string merged_name(assembled.str()); + if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); + std::string context_name = files[0] + kContextSuffix; + merged_name += kContextSuffix; + if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); + } +} + +} // namespace + +void RecordReader::Init(const std::string &name, std::size_t entry_size) { + file_.reset(OpenOrThrow(name.c_str(), "r+")); + data_.reset(malloc(entry_size)); + UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer"); + remains_ = true; + entry_size_ = entry_size; + ++*this; +} + +void RecordReader::Overwrite(const void *start, std::size_t amount) { + long internal = (uint8_t*)start - (uint8_t*)data_.get(); + UTIL_THROW_IF(fseek(file_.get(), internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); + WriteOrThrow(file_.get(), start, amount); + long forward = entry_size_ - internal - amount; + if (forward) UTIL_THROW_IF(fseek(file_.get(), forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision"); +} + +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { + PositiveProbWarn warn(config.positive_log_probability); + { + std::string unigram_name = file_prefix + "unigrams"; + util::scoped_fd unigram_file; + // In case appears. + size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); + Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); + CheckSpecials(config, vocab); + if (!vocab.SawUnk()) ++counts[0]; + } + + // Only use as much buffer as we need. + size_t buffer_use = 0; + for (unsigned int order = 2; order < counts.size(); ++order) { + buffer_use = std::max(buffer_use, static_cast((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); + } + buffer_use = std::max(buffer_use, static_cast((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); + buffer = std::min(buffer, buffer_use); + + util::scoped_memory mem; + mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); + if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); + + for (unsigned char order = 2; order <= counts.size(); ++order) { + ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); + } + ReadEnd(f); +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh new file mode 100644 index 00000000..a6916483 --- /dev/null +++ b/klm/lm/trie_sort.hh @@ -0,0 +1,94 @@ +#ifndef LM_TRIE_SORT__ +#define LM_TRIE_SORT__ + +#include "lm/word_index.hh" + +#include "util/file.hh" +#include "util/scoped.hh" + +#include +#include +#include +#include + +#include + +namespace util { class FilePiece; } + +// Step of trie builder: create sorted files. +namespace lm { +namespace ngram { +class SortedVocabulary; +class Config; + +namespace trie { + +extern const char *kContextSuffix; +FILE *OpenOrThrow(const char *name, const char *mode); +void WriteOrThrow(FILE *to, const void *data, size_t size); + +class EntryCompare : public std::binary_function { + public: + explicit EntryCompare(unsigned char order) : order_(order) {} + + bool operator()(const void *first_void, const void *second_void) const { + const WordIndex *first = static_cast(first_void); + const WordIndex *second = static_cast(second_void); + const WordIndex *end = first + order_; + for (; first != end; ++first, ++second) { + if (*first < *second) return true; + if (*first > *second) return false; + } + return false; + } + private: + unsigned char order_; +}; + +class RecordReader { + public: + RecordReader() : remains_(true) {} + + void Init(const std::string &name, std::size_t entry_size); + + void *Data() { return data_.get(); } + const void *Data() const { return data_.get(); } + + RecordReader &operator++() { + std::size_t ret = fread(data_.get(), entry_size_, 1, file_.get()); + if (!ret) { + UTIL_THROW_IF(!feof(file_.get()), util::ErrnoException, "Error reading temporary file"); + remains_ = false; + } + return *this; + } + + operator bool() const { return remains_; } + + void Rewind() { + rewind(file_.get()); + remains_ = true; + ++*this; + } + + std::size_t EntrySize() const { return entry_size_; } + + void Overwrite(const void *start, std::size_t amount); + + private: + util::scoped_malloc data_; + + bool remains_; + + std::size_t entry_size_; + + util::scoped_FILE file_; +}; + +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab); + +} // namespace trie +} // namespace ngram +} // namespace lm + +#endif // LM_TRIE_SORT__ diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 08627efd..6a5a0196 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -1,37 +1,13 @@ #ifndef LM_VIRTUAL_INTERFACE__ #define LM_VIRTUAL_INTERFACE__ +#include "lm/return.hh" #include "lm/word_index.hh" #include "util/string_piece.hh" #include namespace lm { - -/* Structure returned by scoring routines. */ -struct FullScoreReturn { - // log10 probability - float prob; - - /* The length of n-gram matched. Do not use this for recombination. - * Consider a model containing only the following n-grams: - * -1 foo - * -3.14 bar - * -2.718 baz -5 - * -6 foo bar - * - * If you score ``bar'' then ngram_length is 1 and recombination state is the - * empty string because bar has zero backoff and does not extend to the - * right. - * If you score ``foo'' then ngram_length is 1 and recombination state is - * ``foo''. - * - * Ideally, keep output states around and compare them. Failing that, - * get out_state.ValidLength() and use that length for recombination. - */ - unsigned char ngram_length; -}; - namespace base { template class ModelFacade; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 04979d51..03b0767a 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -1,5 +1,6 @@ #include "lm/vocab.hh" +#include "lm/binary_format.hh" #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/config.hh" @@ -56,16 +57,6 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { } } -void WriteOrThrow(int fd, const void *data_void, std::size_t size) { - const uint8_t *data = static_cast(data_void); - while (size) { - ssize_t ret = write(fd, data, size); - if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); - data += ret; - size -= ret; - } -} - } // namespace WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {} @@ -80,7 +71,7 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { void WriteWordsWrapper::Write(int fd) { if ((off_t)-1 == lseek(fd, 0, SEEK_END)) UTIL_THROW(util::ErrnoException, "Failed to seek in binary to vocab words"); - WriteOrThrow(fd, buffer_.data(), buffer_.size()); + util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); } SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} @@ -146,15 +137,28 @@ void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { SetSpecial(Index(""), Index(""), 0); } +namespace { +const unsigned int kProbingVocabularyVersion = 0; +} // namespace + +namespace detail { +struct ProbingVocabularyHeader { + // Lowest unused vocab id. This is also the number of words, including . + unsigned int version; + WordIndex bound; +}; +} // namespace detail + ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { - return Lookup::Size(entries, config.probing_multiplier); + return Align8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); } void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { - lookup_ = Lookup(start, allocated); - available_ = 1; + header_ = static_cast(start); + lookup_ = Lookup(static_cast(start) + Align8(sizeof(detail::ProbingVocabularyHeader)), allocated); + bound_ = 1; saw_unk_ = false; } @@ -172,20 +176,24 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) { saw_unk_ = true; return 0; } else { - if (enumerate_) enumerate_->Add(available_, str); - lookup_.Insert(Lookup::Packing::Make(hashed, available_)); - return available_++; + if (enumerate_) enumerate_->Add(bound_, str); + lookup_.Insert(Lookup::Packing::Make(hashed, bound_)); + return bound_++; } } void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { lookup_.FinishedInserting(); + header_->bound = bound_; + header_->version = kProbingVocabularyVersion; SetSpecial(Index(""), Index(""), 0); } void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { + UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); lookup_.LoadedBinary(); - available_ = ReadWords(fd, to); + ReadWords(fd, to); + bound_ = header_->bound; SetSpecial(Index(""), Index(""), 0); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 9d218fff..41e97052 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -25,6 +25,7 @@ uint64_t HashForVocab(const char *str, std::size_t len); inline uint64_t HashForVocab(const StringPiece &str) { return HashForVocab(str.data(), str.length()); } +class ProbingVocabularyHeader; } // namespace detail class WriteWordsWrapper : public EnumerateVocab { @@ -113,10 +114,7 @@ class ProbingVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); // Vocab words are [0, Bound()). - // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary. - // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update. - // Specifically, the binary file format does not currently indicate whether is in count or not. - WordIndex Bound() const { return available_; } + WordIndex Bound() const { return bound_; } // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); @@ -141,11 +139,13 @@ class ProbingVocabulary : public base::Vocabulary { Lookup lookup_; - WordIndex available_; + WordIndex bound_; bool saw_unk_; EnumerateVocab *enumerate_; + + detail::ProbingVocabularyHeader *header_; }; void MissingUnknown(const Config &config) throw(SpecialWordMissingException); diff --git a/klm/test.sh b/klm/test.sh index d02a3dc9..fb33300a 100755 --- a/klm/test.sh +++ b/klm/test.sh @@ -2,7 +2,7 @@ #Run tests. Requires Boost. set -e ./compile.sh -for i in util/{bit_packing,file_piece,joint_sort,key_value_packing,probing_hash_table,sorted_uniform}_test lm/model_test; do +for i in util/{bit_packing,file_piece,joint_sort,key_value_packing,probing_hash_table,sorted_uniform}_test lm/{model,left}_test; do g++ -I. -O3 $CXXFLAGS $i.cc {lm,util}/*.o -lboost_test_exec_monitor -lz -o $i pushd $(dirname $i) >/dev/null && ./$(basename $i) || echo "$i failed"; popd >/dev/null done diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 9f47d559..33266b94 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -86,6 +86,20 @@ inline void WriteFloat32(void *base, uint64_t bit_off, float value) { const uint32_t kSignBit = 0x80000000; +inline void SetSign(float &to) { + FloatEnc enc; + enc.f = to; + enc.i |= kSignBit; + to = enc.f; +} + +inline void UnsetSign(float &to) { + FloatEnc enc; + enc.f = to; + enc.i &= ~kSignBit; + to = enc.f; +} + inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) { FloatEnc encoded; encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31); diff --git a/klm/util/exception.cc b/klm/util/exception.cc index 62280970..96951495 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -79,4 +79,9 @@ ErrnoException::ErrnoException() throw() : errno_(errno) { ErrnoException::~ErrnoException() throw() {} +EndOfFileException::EndOfFileException() throw() { + *this << "End of file"; +} +EndOfFileException::~EndOfFileException() throw() {} + } // namespace util diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 81675a57..6d6a37cb 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -105,6 +105,12 @@ class ErrnoException : public Exception { int errno_; }; +class EndOfFileException : public Exception { + public: + EndOfFileException() throw(); + ~EndOfFileException() throw(); +}; + } // namespace util #endif // UTIL_EXCEPTION__ diff --git a/klm/util/file.cc b/klm/util/file.cc new file mode 100644 index 00000000..d707568e --- /dev/null +++ b/klm/util/file.cc @@ -0,0 +1,74 @@ +#include "util/file.hh" + +#include "util/exception.hh" + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace util { + +scoped_fd::~scoped_fd() { + if (fd_ != -1 && close(fd_)) { + std::cerr << "Could not close file " << fd_ << std::endl; + std::abort(); + } +} + +scoped_FILE::~scoped_FILE() { + if (file_ && std::fclose(file_)) { + std::cerr << "Could not close file " << std::endl; + std::abort(); + } +} + +int OpenReadOrThrow(const char *name) { + int ret; + UTIL_THROW_IF(-1 == (ret = open(name, O_RDONLY)), ErrnoException, "while opening " << name); + return ret; +} + +int CreateOrThrow(const char *name) { + int ret; + UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR)), ErrnoException, "while creating " << name); + return ret; +} + +off_t SizeFile(int fd) { + struct stat sb; + if (fstat(fd, &sb) == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize; + return sb.st_size; +} + +void ReadOrThrow(int fd, void *to_void, std::size_t amount) { + uint8_t *to = static_cast(to_void); + while (amount) { + ssize_t ret = read(fd, to, amount); + if (ret == -1) UTIL_THROW(ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); + if (ret == 0) UTIL_THROW(Exception, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); + amount -= ret; + to += ret; + } +} + +void WriteOrThrow(int fd, const void *data_void, std::size_t size) { + const uint8_t *data = static_cast(data_void); + while (size) { + ssize_t ret = write(fd, data, size); + if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); + data += ret; + size -= ret; + } +} + +void RemoveOrThrow(const char *name) { + UTIL_THROW_IF(std::remove(name), util::ErrnoException, "Could not remove " << name); +} + +} // namespace util diff --git a/klm/util/file.hh b/klm/util/file.hh new file mode 100644 index 00000000..d6cca41d --- /dev/null +++ b/klm/util/file.hh @@ -0,0 +1,74 @@ +#ifndef UTIL_FILE__ +#define UTIL_FILE__ + +#include +#include + +namespace util { + +class scoped_fd { + public: + scoped_fd() : fd_(-1) {} + + explicit scoped_fd(int fd) : fd_(fd) {} + + ~scoped_fd(); + + void reset(int to) { + scoped_fd other(fd_); + fd_ = to; + } + + int get() const { return fd_; } + + int operator*() const { return fd_; } + + int release() { + int ret = fd_; + fd_ = -1; + return ret; + } + + operator bool() { return fd_ != -1; } + + private: + int fd_; + + scoped_fd(const scoped_fd &); + scoped_fd &operator=(const scoped_fd &); +}; + +class scoped_FILE { + public: + explicit scoped_FILE(std::FILE *file = NULL) : file_(file) {} + + ~scoped_FILE(); + + std::FILE *get() { return file_; } + const std::FILE *get() const { return file_; } + + void reset(std::FILE *to = NULL) { + scoped_FILE other(file_); + file_ = to; + } + + private: + std::FILE *file_; +}; + +int OpenReadOrThrow(const char *name); + +int CreateOrThrow(const char *name); + +// Return value for SizeFile when it can't size properly. +const off_t kBadSize = -1; +off_t SizeFile(int fd); + +void ReadOrThrow(int fd, void *to, std::size_t size); +void WriteOrThrow(int fd, const void *data_void, std::size_t size); + +void RemoveOrThrow(const char *name); + +} // namespace util + +#endif // UTIL_FILE__ diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index cbe4234f..b57582a0 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -1,6 +1,7 @@ #include "util/file_piece.hh" #include "util/exception.hh" +#include "util/file.hh" #include #include @@ -21,11 +22,6 @@ namespace util { -EndOfFileException::EndOfFileException() throw() { - *this << "End of file"; -} -EndOfFileException::~EndOfFileException() throw() {} - ParseNumberException::ParseNumberException(StringPiece value) throw() { *this << "Could not parse \"" << value << "\" into a number"; } @@ -40,18 +36,6 @@ GZException::GZException(void *file) { // Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale). const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; -int OpenReadOrThrow(const char *name) { - int ret; - UTIL_THROW_IF(-1 == (ret = open(name, O_RDONLY)), ErrnoException, "while opening " << name); - return ret; -} - -off_t SizeFile(int fd) { - struct stat sb; - if (fstat(fd, &sb) == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize; - return sb.st_size; -} - FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) : file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)), progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index a5c00910..a627f38c 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -3,9 +3,9 @@ #include "util/ersatz_progress.hh" #include "util/exception.hh" +#include "util/file.hh" #include "util/have.hh" #include "util/mmap.hh" -#include "util/scoped.hh" #include "util/string_piece.hh" #include @@ -14,12 +14,6 @@ namespace util { -class EndOfFileException : public Exception { - public: - EndOfFileException() throw(); - ~EndOfFileException() throw(); -}; - class ParseNumberException : public Exception { public: explicit ParseNumberException(StringPiece value) throw(); @@ -33,14 +27,8 @@ class GZException : public Exception { ~GZException() throw() {} }; -int OpenReadOrThrow(const char *name); - extern const bool kSpaces[256]; -// Return value for SizeFile when it can't size properly. -const off_t kBadSize = -1; -off_t SizeFile(int fd); - // Memory backing the returned StringPiece may vanish on the next call. class FilePiece { public: diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index e7c0643b..5ce7adc9 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -1,6 +1,6 @@ #include "util/exception.hh" +#include "util/file.hh" #include "util/mmap.hh" -#include "util/scoped.hh" #include @@ -66,20 +66,6 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int return ret; } -namespace { -void ReadAll(int fd, void *to_void, std::size_t amount) { - uint8_t *to = static_cast(to_void); - while (amount) { - ssize_t ret = read(fd, to, amount); - if (ret == -1) UTIL_THROW(ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); - if (ret == 0) UTIL_THROW(Exception, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); - amount -= ret; - to += ret; - } -} - -} // namespace - const int kFileFlags = #ifdef MAP_FILE MAP_FILE | MAP_SHARED @@ -106,7 +92,7 @@ void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_m out.reset(malloc(size), size, scoped_memory::MALLOC_ALLOCATED); if (!out.get()) UTIL_THROW(util::ErrnoException, "Allocating " << size << " bytes with malloc"); if (-1 == lseek(fd, offset, SEEK_SET)) UTIL_THROW(ErrnoException, "lseek to " << offset << " in fd " << fd << " failed."); - ReadAll(fd, out.get(), size); + ReadOrThrow(fd, out.get(), size); break; } } diff --git a/klm/util/mmap.hh b/klm/util/mmap.hh index e4439fa4..b0eb6672 100644 --- a/klm/util/mmap.hh +++ b/klm/util/mmap.hh @@ -2,8 +2,6 @@ #define UTIL_MMAP__ // Utilities for mmaped files. -#include "util/scoped.hh" - #include #include @@ -11,6 +9,8 @@ namespace util { +class scoped_fd; + // (void*)-1 is MAP_FAILED; this is done to avoid including the mmap header here. class scoped_mmap { public: diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index fec47fd9..d58a0727 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -1,129 +1,129 @@ -/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All - * code is released to the public domain. For business purposes, Murmurhash is - * under the MIT license." - * This is modified from the original: - * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. - * length changed to unsigned int. - * placed in namespace util - * add MurmurHashNative - * default option = 0 for seed - */ - -#include "util/murmur_hash.hh" - -namespace util { - -//----------------------------------------------------------------------------- -// MurmurHash2, 64-bit versions, by Austin Appleby - -// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment -// and endian-ness issues if used across multiple platforms. - -// 64-bit hash for 64-bit platforms - -uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) -{ - const uint64_t m = 0xc6a4a7935bd1e995ULL; - const int r = 47; - - uint64_t h = seed ^ (len * m); - - const uint64_t * data = (const uint64_t *)key; - const uint64_t * end = data + (len/8); - - while(data != end) - { - uint64_t k = *data++; - - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - - const unsigned char * data2 = (const unsigned char*)data; - - switch(len & 7) - { - case 7: h ^= uint64_t(data2[6]) << 48; - case 6: h ^= uint64_t(data2[5]) << 40; - case 5: h ^= uint64_t(data2[4]) << 32; - case 4: h ^= uint64_t(data2[3]) << 24; - case 3: h ^= uint64_t(data2[2]) << 16; - case 2: h ^= uint64_t(data2[1]) << 8; - case 1: h ^= uint64_t(data2[0]); - h *= m; - }; - - h ^= h >> r; - h *= m; - h ^= h >> r; - - return h; -} - - -// 64-bit hash for 32-bit platforms - -uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) -{ - const unsigned int m = 0x5bd1e995; - const int r = 24; - - unsigned int h1 = seed ^ len; - unsigned int h2 = 0; - - const unsigned int * data = (const unsigned int *)key; - - while(len >= 8) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - - unsigned int k2 = *data++; - k2 *= m; k2 ^= k2 >> r; k2 *= m; - h2 *= m; h2 ^= k2; - len -= 4; - } - - if(len >= 4) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - } - - switch(len) - { - case 3: h2 ^= ((unsigned char*)data)[2] << 16; - case 2: h2 ^= ((unsigned char*)data)[1] << 8; - case 1: h2 ^= ((unsigned char*)data)[0]; - h2 *= m; - }; - - h1 ^= h2 >> 18; h1 *= m; - h2 ^= h1 >> 22; h2 *= m; - h1 ^= h2 >> 17; h1 *= m; - h2 ^= h1 >> 19; h2 *= m; - - uint64_t h = h1; - - h = (h << 32) | h2; - - return h; -} - -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { - if (sizeof(int) == 4) { - return MurmurHash64B(key, len, seed); - } else { - return MurmurHash64A(key, len, seed); - } -} - -} // namespace util +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All + * code is released to the public domain. For business purposes, Murmurhash is + * under the MIT license." + * This is modified from the original: + * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. + * length changed to unsigned int. + * placed in namespace util + * add MurmurHashNative + * default option = 0 for seed + */ + +#include "util/murmur_hash.hh" + +namespace util { + +//----------------------------------------------------------------------------- +// MurmurHash2, 64-bit versions, by Austin Appleby + +// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment +// and endian-ness issues if used across multiple platforms. + +// 64-bit hash for 64-bit platforms + +uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +{ + const uint64_t m = 0xc6a4a7935bd1e995ULL; + const int r = 47; + + uint64_t h = seed ^ (len * m); + + const uint64_t * data = (const uint64_t *)key; + const uint64_t * end = data + (len/8); + + while(data != end) + { + uint64_t k = *data++; + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + const unsigned char * data2 = (const unsigned char*)data; + + switch(len & 7) + { + case 7: h ^= uint64_t(data2[6]) << 48; + case 6: h ^= uint64_t(data2[5]) << 40; + case 5: h ^= uint64_t(data2[4]) << 32; + case 4: h ^= uint64_t(data2[3]) << 24; + case 3: h ^= uint64_t(data2[2]) << 16; + case 2: h ^= uint64_t(data2[1]) << 8; + case 1: h ^= uint64_t(data2[0]); + h *= m; + }; + + h ^= h >> r; + h *= m; + h ^= h >> r; + + return h; +} + + +// 64-bit hash for 32-bit platforms + +uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +{ + const unsigned int m = 0x5bd1e995; + const int r = 24; + + unsigned int h1 = seed ^ len; + unsigned int h2 = 0; + + const unsigned int * data = (const unsigned int *)key; + + while(len >= 8) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + + unsigned int k2 = *data++; + k2 *= m; k2 ^= k2 >> r; k2 *= m; + h2 *= m; h2 ^= k2; + len -= 4; + } + + if(len >= 4) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + } + + switch(len) + { + case 3: h2 ^= ((unsigned char*)data)[2] << 16; + case 2: h2 ^= ((unsigned char*)data)[1] << 8; + case 1: h2 ^= ((unsigned char*)data)[0]; + h2 *= m; + }; + + h1 ^= h2 >> 18; h1 *= m; + h2 ^= h1 >> 22; h2 *= m; + h1 ^= h2 >> 17; h1 *= m; + h2 ^= h1 >> 19; h2 *= m; + + uint64_t h = h1; + + h = (h << 32) | h2; + + return h; +} + +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { + if (sizeof(int) == 4) { + return MurmurHash64B(key, len, seed); + } else { + return MurmurHash64A(key, len, seed); + } +} + +} // namespace util diff --git a/klm/util/scoped.cc b/klm/util/scoped.cc deleted file mode 100644 index a4cc5016..00000000 --- a/klm/util/scoped.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "util/scoped.hh" - -#include - -#include -#include - -namespace util { - -scoped_fd::~scoped_fd() { - if (fd_ != -1 && close(fd_)) { - std::cerr << "Could not close file " << fd_ << std::endl; - abort(); - } -} - -scoped_FILE::~scoped_FILE() { - if (file_ && fclose(file_)) { - std::cerr << "Could not close file " << std::endl; - abort(); - } -} - -} // namespace util diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index d36a7df3..12e6652b 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -1,10 +1,11 @@ #ifndef UTIL_SCOPED__ #define UTIL_SCOPED__ -/* Other scoped objects in the style of scoped_ptr. */ +#include "util/exception.hh" +/* Other scoped objects in the style of scoped_ptr. */ #include -#include +#include namespace util { @@ -34,52 +35,33 @@ template class scoped_thing { scoped_thing &operator=(const scoped_thing &); }; -class scoped_fd { +class scoped_malloc { public: - scoped_fd() : fd_(-1) {} + scoped_malloc() : p_(NULL) {} - explicit scoped_fd(int fd) : fd_(fd) {} + scoped_malloc(void *p) : p_(p) {} - ~scoped_fd(); + ~scoped_malloc() { std::free(p_); } - void reset(int to) { - scoped_fd other(fd_); - fd_ = to; + void reset(void *p = NULL) { + scoped_malloc other(p_); + p_ = p; } - int get() const { return fd_; } - - int operator*() const { return fd_; } - - int release() { - int ret = fd_; - fd_ = -1; - return ret; + void call_realloc(std::size_t to) { + void *ret; + UTIL_THROW_IF(!(ret = std::realloc(p_, to)), util::ErrnoException, "realloc to " << to << " bytes failed."); + p_ = ret; } - private: - int fd_; - - scoped_fd(const scoped_fd &); - scoped_fd &operator=(const scoped_fd &); -}; - -class scoped_FILE { - public: - explicit scoped_FILE(std::FILE *file = NULL) : file_(file) {} - - ~scoped_FILE(); - - std::FILE *get() { return file_; } - const std::FILE *get() const { return file_; } - - void reset(std::FILE *to = NULL) { - scoped_FILE other(file_); - file_ = to; - } + void *get() { return p_; } + const void *get() const { return p_; } private: - std::FILE *file_; + void *p_; + + scoped_malloc(const scoped_malloc &); + scoped_malloc &operator=(const scoped_malloc &); }; // Hat tip to boost. diff --git a/klm/util/sized_iterator.hh b/klm/util/sized_iterator.hh new file mode 100644 index 00000000..47dfc245 --- /dev/null +++ b/klm/util/sized_iterator.hh @@ -0,0 +1,107 @@ +#ifndef UTIL_SIZED_ITERATOR__ +#define UTIL_SIZED_ITERATOR__ + +#include "util/proxy_iterator.hh" + +#include +#include + +#include +#include + +namespace util { + +class SizedInnerIterator { + public: + SizedInnerIterator() {} + + SizedInnerIterator(void *ptr, std::size_t size) : ptr_(static_cast(ptr)), size_(size) {} + + bool operator==(const SizedInnerIterator &other) const { + return ptr_ == other.ptr_; + } + bool operator<(const SizedInnerIterator &other) const { + return ptr_ < other.ptr_; + } + SizedInnerIterator &operator+=(std::ptrdiff_t amount) { + ptr_ += amount * size_; + return *this; + } + std::ptrdiff_t operator-(const SizedInnerIterator &other) const { + return (ptr_ - other.ptr_) / size_; + } + + const void *Data() const { return ptr_; } + void *Data() { return ptr_; } + std::size_t EntrySize() const { return size_; } + + private: + uint8_t *ptr_; + std::size_t size_; +}; + +class SizedProxy { + public: + SizedProxy() {} + + SizedProxy(void *ptr, std::size_t size) : inner_(ptr, size) {} + + operator std::string() const { + return std::string(reinterpret_cast(inner_.Data()), inner_.EntrySize()); + } + + SizedProxy &operator=(const SizedProxy &from) { + memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize()); + return *this; + } + + SizedProxy &operator=(const std::string &from) { + memcpy(inner_.Data(), from.data(), inner_.EntrySize()); + return *this; + } + + const void *Data() const { return inner_.Data(); } + void *Data() { return inner_.Data(); } + + private: + friend class util::ProxyIterator; + + typedef std::string value_type; + + typedef SizedInnerIterator InnerIterator; + + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +typedef ProxyIterator SizedIterator; + +inline SizedIterator SizedIt(void *ptr, std::size_t size) { return SizedIterator(SizedProxy(ptr, size)); } + +// Useful wrapper for a comparison function i.e. sort. +template class SizedCompare : public std::binary_function { + public: + explicit SizedCompare(const Delegate &delegate = Delegate()) : delegate_(delegate) {} + + bool operator()(const Proxy &first, const Proxy &second) const { + return delegate_(first.Data(), second.Data()); + } + bool operator()(const Proxy &first, const std::string &second) const { + return delegate_(first.Data(), second.data()); + } + bool operator()(const std::string &first, const Proxy &second) const { + return delegate_(first.data(), second.Data()); + } + bool operator()(const std::string &first, const std::string &second) const { + return delegate_(first.data(), second.data()); + } + + const Delegate &GetDelegate() const { return delegate_; } + + private: + const Delegate delegate_; +}; + +} // namespace util +#endif // UTIL_SIZED_ITERATOR__ diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh new file mode 100644 index 00000000..ee1c7ab2 --- /dev/null +++ b/klm/util/tokenize_piece.hh @@ -0,0 +1,69 @@ +#ifndef UTIL_TOKENIZE_PIECE__ +#define UTIL_TOKENIZE_PIECE__ + +#include "util/string_piece.hh" + +#include + +/* Usage: + * + * for (PieceIterator<' '> i(" foo \r\n bar "); i; ++i) { + * std::cout << *i << "\n"; + * } + * + */ + +namespace util { + +// Tokenize a StringPiece using an iterator interface. boost::tokenizer doesn't work with StringPiece. +template class PieceIterator : public boost::iterator_facade, const StringPiece, boost::forward_traversal_tag> { + public: + // Default construct is end, which is also accessed by kEndPieceIterator; + PieceIterator() {} + + explicit PieceIterator(const StringPiece &str) + : after_(str) { + increment(); + } + + bool operator!() const { + return after_.data() == 0; + } + operator bool() const { + return after_.data() != 0; + } + + static PieceIterator end() { + return PieceIterator(); + } + + private: + friend class boost::iterator_core_access; + + void increment() { + const char *start = after_.data(); + for (; (start != after_.data() + after_.size()) && (d == *start); ++start) {} + if (start == after_.data() + after_.size()) { + // End condition. + after_.clear(); + return; + } + const char *finish = start; + for (; (finish != after_.data() + after_.size()) && (d != *finish); ++finish) {} + current_ = StringPiece(start, finish - start); + after_ = StringPiece(finish, after_.data() + after_.size() - finish); + } + + bool equal(const PieceIterator &other) const { + return after_.data() == other.after_.data(); + } + + const StringPiece &dereference() const { return current_; } + + StringPiece current_; + StringPiece after_; +}; + +} // namespace util + +#endif // UTIL_TOKENIZE_PIECE__ -- cgit v1.2.3 From e2d78e3fc1cb414d9d68af4cb4ee397b0c1f8dcc Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 22 Sep 2011 03:18:24 -0400 Subject: Sorry forgot the Makefile.am updates --- klm/lm/Makefile.am | 1 + klm/util/Makefile.am | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index fae6b41a..54fd7f68 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -23,6 +23,7 @@ libklm_a_SOURCES = \ search_hashed.cc \ search_trie.cc \ trie.cc \ + trie_sort.cc \ virtual_interface.cc \ vocab.cc diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index f4f7d158..a8d6299b 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -22,9 +22,9 @@ libklm_util_a_SOURCES = \ ersatz_progress.cc \ bit_packing.cc \ exception.cc \ + file.cc \ file_piece.cc \ mmap.cc \ - murmur_hash.cc \ - scoped.cc + murmur_hash.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -- cgit v1.2.3 From e1b61419329c83709018ca397a29d069e4294bd1 Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Fri, 23 Sep 2011 15:44:35 +0100 Subject: make show_partition work even in absence of feature functions --- decoder/decoder.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 45404c47..c4fe3c4d 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -794,6 +794,11 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl; } + if (conf.count("show_partition")) { + const prob_t z = Inside(forest); + cerr << " Partition log(Z): " << log(z) << endl; + } + SummaryFeature summary_feature_type = kNODE_RISK; if (conf["summary_feature_type"].as() == "edge_risk") summary_feature_type = kEDGE_RISK; -- cgit v1.2.3 From 8ecf63852d730f99e7c1bbacfbffdf518d5a0c3f Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Fri, 23 Sep 2011 20:49:43 +0100 Subject: stub work to talk to new kenlm --- decoder/ff_klm.cc | 349 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 349 insertions(+) diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 24dcb9c3..016aad26 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,6 +12,353 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" +#undef NEW_KENLM +#ifdef NEW_KENLM + +#include "lm/left.hh" + +using namespace std; + +// -x : rules include and +// -n NAME : feature id is NAME +bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) { + vector const& argv=SplitOnWhitespace(in); + *explicit_markers = false; + *featname="LanguageModel"; + *mapfile = ""; +#define LMSPEC_NEXTARG if (i==argv.end()) { \ + cerr << "Missing argument for "<<*last<<". "; goto usage; \ + } else { ++i; } + + for (vector::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { + string const& s=*i; + if (s[0]=='-') { + if (s.size()>2) goto fail; + switch (s[1]) { + case 'x': + *explicit_markers = true; + break; + case 'm': + LMSPEC_NEXTARG; *mapfile=*i; + break; + case 'n': + LMSPEC_NEXTARG; *featname=*i; + break; +#undef LMSPEC_NEXTARG + default: + fail: + cerr<<"Unknown KLanguageModel option "<empty()) + *filename=s; + else { + cerr<<"More than one filename provided. "; + goto usage; + } + } + } + if (!filename->empty()) + return true; +usage: + cerr << "KLanguageModel is incorrect!\n"; + return false; +} + +template +string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { + return "KLanguageModel"; +} + +struct VMapper : public lm::ngram::EnumerateVocab { + VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_->size()) + out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); + (*out_)[cdec_id] = index; + } + vector* out_; + const lm::WordIndex kLM_UNKNOWN_TOKEN; +}; + +template +class KLanguageModelImpl { + + static inline const lm::ngram::ChartState& RemnantLMState(const void* state) { + return *static_cast(state); + } + + inline void SetRemnantLMState(const lm::ngram::ChartState& lmstate, void* state) const { + // if we were clever, we could use the memory pointed to by state to do all + // the work, avoiding this copy + memcpy(state, &lmstate, ngram_->StateSize()); + } + + public: + double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { + double sum = 0.0; + if (oovs) *oovs = 0; + const vector& e = rule.e(); + lm::ngram::ChartState state; + lm::ngram::RuleScore ruleScore(*ngram_, state); + unsigned i = 0; + if (e.size()) { + if (e[i] == kCDEC_SOS) { + ++i; + ruleScore.BeginSentence(); + } else if (e[i] <= 0) { // special case for left-edge NT + const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]); + ruleScore.BeginNonTerminal(prevState, 0.0f); // TODO + ++i; + } + } + for (; i < e.size(); ++i) { + if (e[i] <= 0) { + const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]); + ruleScore.NonTerminal(prevState, 0.0f); // TODO + } else { + const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, + // maybe handle emission + const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id + const bool is_oov = (cur_word == 0); + if (is_oov) (*oovs) += 1.0; + ruleScore.Terminal(cur_word); + } + } + if (remnant) SetRemnantLMState(state, remnant); + return ruleScore.Finish(); + } + + // this assumes no target words on final unary -> goal rule. is that ok? + // for (n-1 left words) and (n-1 right words) + double FinalTraversalCost(const void* state, double* oovs) { + if (add_sos_eos_) { // rules do not produce , so do it here + lm::ngram::ChartState cstate; + lm::ngram::RuleScore ruleScore(*ngram_, cstate); + ruleScore.BeginSentence(); + SetRemnantLMState(cstate, dummy_state_); + dummy_ants_[1] = state; + *oovs = 0; + return LookupWords(*dummy_rule_, dummy_ants_, oovs, NULL); + } else { // rules DO produce ... + double p = 0; + cerr << "not implemented"; abort(); // TODO + //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } + //if (UnscoredSize(state) > 0) { // are there unscored words + // if (kSOS_ != IthUnscoredWord(0, state)) { + // p -= 100 * UnscoredSize(state); + // } + //} + return p; + } + } + + // if this is not a class-based LM, returns w untransformed, + // otherwise returns a word class mapping of w, + // returns TD::Convert("") if there is no mapping for w + WordID ClassifyWordIfNecessary(WordID w) const { + if (word2class_map_.empty()) return w; + if (w >= word2class_map_.size()) + return kCDEC_UNK; + else + return word2class_map_[w]; + } + + // converts to cdec word id's to KenLM's id space, OOVs and end up at 0 + lm::WordIndex MapWord(WordID w) const { + if (w >= cdec2klm_map_.size()) + return 0; + else + return cdec2klm_map_[w]; + } + + public: + KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : + kCDEC_UNK(TD::Convert("")) , + kCDEC_SOS(TD::Convert("")) , + add_sos_eos_(!explicit_markers) { + { + VMapper vm(&cdec2klm_map_); + lm::ngram::Config conf; + conf.enumerate_vocab = &vm; + ngram_ = new Model(filename.c_str(), conf); + } + order_ = ngram_->Order(); + cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; + state_size_ = sizeof(lm::ngram::ChartState); + + // special handling of beginning / ending sentence markers + dummy_state_ = new char[state_size_]; + memset(dummy_state_, 0, state_size_); + dummy_ants_.push_back(dummy_state_); + dummy_ants_.push_back(NULL); + dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); + kSOS_ = MapWord(kCDEC_SOS); + assert(kSOS_ > 0); + kEOS_ = MapWord(TD::Convert("")); + assert(kEOS_ > 0); + assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant + + // handle class-based LMs (unambiguous word->class mapping reqd.) + if (mapfile.size()) + LoadWordClasses(mapfile); + } + + void LoadWordClasses(const string& file) { + ReadFile rf(file); + istream& in = *rf.stream(); + string line; + vector dummy; + int lc = 0; + cerr << " Loading word classes from " << file << " ...\n"; + AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); + AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); + while(in) { + getline(in, line); + if (!in) continue; + dummy.clear(); + TD::ConvertSentence(line, &dummy); + ++lc; + if (dummy.size() != 2) { + cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; + abort(); + } + AddWordToClassMapping_(dummy[0], dummy[1]); + } + } + + void AddWordToClassMapping_(WordID word, WordID cls) { + if (word2class_map_.size() <= word) { + word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); + assert(word2class_map_.size() > word); + } + if(word2class_map_[word] != kCDEC_UNK) { + cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; + abort(); + } + word2class_map_[word] = cls; + } + + ~KLanguageModelImpl() { + delete ngram_; + delete[] dummy_state_; + } + + int ReserveStateSize() const { return state_size_; } + + private: + const WordID kCDEC_UNK; + const WordID kCDEC_SOS; + lm::WordIndex kSOS_; // - requires special handling. + lm::WordIndex kEOS_; // + Model* ngram_; + const bool add_sos_eos_; // flag indicating whether the hypergraph produces and + // if this is true, FinalTransitionFeatures will "add" and + // if false, FinalTransitionFeatures will score anything with the + // markers in the right place (i.e., the beginning and end of + // the sentence) with 0, and anything else with -100 + + int order_; + int state_size_; + char* dummy_state_; + vector dummy_ants_; + vector cdec2klm_map_; + vector word2class_map_; // if this is a class-based LM, this is the word->class mapping + TRulePtr dummy_rule_; +}; + +template +KLanguageModel::KLanguageModel(const string& param) { + string filename, mapfile, featname; + bool explicit_markers; + if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { + abort(); + } + try { + pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); + } + fid_ = FD::Convert(featname); + oov_fid_ = FD::Convert(featname+"_OOV"); + // cerr << "FID: " << oov_fid_ << endl; + SetStateSize(pimpl_->ReserveStateSize()); +} + +template +Features KLanguageModel::features() const { + return single_feature(fid_); +} + +template +KLanguageModel::~KLanguageModel() { + delete pimpl_; +} + +template +void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + double est = 0; + double oovs = 0; + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state)); + if (oovs && oov_fid_) + features->set_value(oov_fid_, oovs); +} + +template +void KLanguageModel::FinalTraversalFeatures(const void* ant_state, + SparseVector* features) const { + double oovs = 0; + double lm = pimpl_->FinalTraversalCost(ant_state, &oovs); + features->set_value(fid_, lm); + if (oov_fid_ && oovs) + features->set_value(oov_fid_, oovs); +} + +template boost::shared_ptr CreateModel(const std::string ¶m) { + KLanguageModel *ret = new KLanguageModel(param); + ret->Init(); + return boost::shared_ptr(ret); +} + +boost::shared_ptr KLanguageModelFactory::Create(std::string param) const { + using namespace lm::ngram; + std::string filename, ignored_map; + bool ignored_markers; + std::string ignored_featname; + ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); + ModelType m; + if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + + switch (m) { + case HASH_PROBING: + return CreateModel(param); + case TRIE_SORTED: + return CreateModel(param); + case ARRAY_TRIE_SORTED: + return CreateModel(param); + case QUANT_TRIE_SORTED: + return CreateModel(param); + case QUANT_ARRAY_TRIE_SORTED: + return CreateModel(param); + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); + } +} + +std::string KLanguageModelFactory::usage(bool params,bool verbose) const { + return KLanguageModel::usage(params, verbose); +} + +#else + using namespace std; static const unsigned char HAS_FULL_CONTEXT = 1; @@ -469,3 +816,5 @@ boost::shared_ptr KLanguageModelFactory::Create(std::string par std::string KLanguageModelFactory::usage(bool params,bool verbose) const { return KLanguageModel::usage(params, verbose); } + +#endif -- cgit v1.2.3 From fcbb924a575df56de53eacce886ebf9ccf3283ed Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 23 Sep 2011 16:09:56 -0400 Subject: Add ZeroRemaining --- klm/lm/left.hh | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/klm/lm/left.hh b/klm/lm/left.hh index df69e97a..837be765 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -26,6 +26,11 @@ struct Left { return 0; } + void ZeroRemaining() { + for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) + *i = 0; + } + uint64_t pointers[kMaxOrder - 1]; unsigned char length; }; @@ -43,6 +48,11 @@ struct ChartState { return (int)full - (int)other.full; } + void ZeroRemaining() { + left.ZeroRemaining(); + right.ZeroRemaining(); + } + Left left; State right; bool full; -- cgit v1.2.3 From d71c74f3924e6c207f3ebfab470b9a30e2551dde Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 23 Sep 2011 16:38:01 -0400 Subject: Go through ff_klm and try to fix it for the new version. --- decoder/ff_klm.cc | 36 +++++++++--------------------------- 1 file changed, 9 insertions(+), 27 deletions(-) diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 016aad26..3b2113ad 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -90,19 +90,12 @@ class KLanguageModelImpl { return *static_cast(state); } - inline void SetRemnantLMState(const lm::ngram::ChartState& lmstate, void* state) const { - // if we were clever, we could use the memory pointed to by state to do all - // the work, avoiding this copy - memcpy(state, &lmstate, ngram_->StateSize()); - } - public: double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { - double sum = 0.0; if (oovs) *oovs = 0; const vector& e = rule.e(); lm::ngram::ChartState state; - lm::ngram::RuleScore ruleScore(*ngram_, state); + lm::ngram::RuleScore ruleScore(*ngram_, remnant ? *static_cast(remnant) : state); unsigned i = 0; if (e.size()) { if (e[i] == kCDEC_SOS) { @@ -123,12 +116,13 @@ class KLanguageModelImpl { // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id const bool is_oov = (cur_word == 0); - if (is_oov) (*oovs) += 1.0; + if (is_oov && oovs) (*oovs) += 1.0; ruleScore.Terminal(cur_word); } } - if (remnant) SetRemnantLMState(state, remnant); - return ruleScore.Finish(); + double ret = ruleScore.Finish(); + state.ZeroRemaining(); + return ret; } // this assumes no target words on final unary -> goal rule. is that ok? @@ -138,10 +132,9 @@ class KLanguageModelImpl { lm::ngram::ChartState cstate; lm::ngram::RuleScore ruleScore(*ngram_, cstate); ruleScore.BeginSentence(); - SetRemnantLMState(cstate, dummy_state_); - dummy_ants_[1] = state; - *oovs = 0; - return LookupWords(*dummy_rule_, dummy_ants_, oovs, NULL); + ruleScore.NonTerminal(RemnantLMState(state), 0.0f); + ruleScore.Terminal(kEOS_); + return ruleScore.Finish(); } else { // rules DO produce ... double p = 0; cerr << "not implemented"; abort(); // TODO @@ -187,14 +180,8 @@ class KLanguageModelImpl { } order_ = ngram_->Order(); cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; - state_size_ = sizeof(lm::ngram::ChartState); // special handling of beginning / ending sentence markers - dummy_state_ = new char[state_size_]; - memset(dummy_state_, 0, state_size_); - dummy_ants_.push_back(dummy_state_); - dummy_ants_.push_back(NULL); - dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); kSOS_ = MapWord(kCDEC_SOS); assert(kSOS_ > 0); kEOS_ = MapWord(TD::Convert("")); @@ -243,10 +230,9 @@ class KLanguageModelImpl { ~KLanguageModelImpl() { delete ngram_; - delete[] dummy_state_; } - int ReserveStateSize() const { return state_size_; } + int ReserveStateSize() const { return sizeof(lm::ngram::ChartState); } private: const WordID kCDEC_UNK; @@ -261,12 +247,8 @@ class KLanguageModelImpl { // the sentence) with 0, and anything else with -100 int order_; - int state_size_; - char* dummy_state_; - vector dummy_ants_; vector cdec2klm_map_; vector word2class_map_; // if this is a class-based LM, this is the word->class mapping - TRulePtr dummy_rule_; }; template -- cgit v1.2.3 From 2e5720a8e7141a75ae549c6be74f50bd18068ef1 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sat, 24 Sep 2011 07:58:58 -0400 Subject: Belated documentation --- klm/lm/left.hh | 70 +++++++++++++++++++++++++++++++++++++++++++++++++++------ klm/lm/model.cc | 5 ----- klm/lm/model.hh | 25 +++++++++++---------- 3 files changed, 76 insertions(+), 24 deletions(-) diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 837be765..effa0560 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -1,3 +1,40 @@ +/* Efficient left and right language model state for sentence fragments. + * Intended usage: + * Store ChartState with every chart entry. + * To do a rule application: + * 1. Make a ChartState object for your new entry. + * 2. Construct RuleScore. + * 3. Going from left to right, call Terminal or NonTerminal. + * For terminals, just pass the vocab id. + * For non-terminals, pass that non-terminal's ChartState. + * If your decoder expects scores inclusive of subtree scores (i.e. you + * label entries with the highest-scoring path), pass the non-terminal's + * score as prob. + * If your decoder expects relative scores and will walk the chart later, + * pass prob = 0.0. + * In other words, the only effect of prob is that it gets added to the + * returned log probability. + * 4. Call Finish. It returns the log probability. + * + * There's a couple more details: + * Do not pass to Terminal as it is formally not a word in the sentence, + * only context. Instead, call BeginSentence. If called, it should be the + * first call after RuleScore is constructed (since is always the + * leftmost). + * + * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal. + * + * Hashing and sorting comparison operators are provided. All state objects + * are POD. If you intend to use memcmp on raw state objects, you must call + * ZeroRemaining first, as the value of array entries beyond length is + * otherwise undefined. + * + * Usage is of course not limited to chart decoding. Anything that generates + * sentence fragments missing left context could benefit. For example, a + * phrase-based decoder could pre-score phrases, storing ChartState with each + * phrase, even if hypotheses are generated left-to-right. + */ + #ifndef LM_LEFT__ #define LM_LEFT__ @@ -5,6 +42,8 @@ #include "lm/model.hh" #include "lm/return.hh" +#include "util/murmur_hash.hh" + #include namespace lm { @@ -18,23 +57,30 @@ struct Left { } int Compare(const Left &other) const { - if (length != other.length) { - return (int)length - (int)other.length; - } + if (length != other.length) return length < other.length ? -1 : 1; if (pointers[length - 1] > other.pointers[length - 1]) return 1; if (pointers[length - 1] < other.pointers[length - 1]) return -1; return 0; } + bool operator<(const Left &other) const { + if (length != other.length) return length < other.length; + return pointers[length - 1] < other.pointers[length - 1]; + } + void ZeroRemaining() { for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) *i = 0; } - uint64_t pointers[kMaxOrder - 1]; unsigned char length; + uint64_t pointers[kMaxOrder - 1]; }; +inline size_t hash_value(const Left &left) { + return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]); +} + struct ChartState { bool operator==(const ChartState &other) { return (left == other.left) && (right == other.right) && (full == other.full); @@ -48,16 +94,27 @@ struct ChartState { return (int)full - (int)other.full; } + bool operator<(const ChartState &other) const { + return Compare(other) == -1; + } + void ZeroRemaining() { left.ZeroRemaining(); right.ZeroRemaining(); } Left left; - State right; bool full; + State right; }; +inline size_t hash_value(const ChartState &state) { + size_t hashes[2]; + hashes[0] = hash_value(state.left); + hashes[1] = hash_value(state.right); + return util::MurmurHashNative(hashes, sizeof(size_t), state.full); +} + template class RuleScore { public: explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { @@ -73,8 +130,7 @@ template class RuleScore { void Terminal(WordIndex word) { State copy(out_.right); - FullScoreReturn ret = model_.FullScore(copy, word, out_.right); - ProcessRet(ret); + ProcessRet(model_.FullScore(copy, word, out_.right)); if (out_.right.length != copy.length + 1) left_done_ = true; } diff --git a/klm/lm/model.cc b/klm/lm/model.cc index ca581d8a..25f1ab7c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -14,11 +14,6 @@ namespace lm { namespace ngram { - -size_t hash_value(const State &state) { - return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); -} - namespace detail { template const ModelType GenericModel::kModelType = Search::kModelType; diff --git a/klm/lm/model.hh b/klm/lm/model.hh index fe91af2e..c278acd6 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -12,6 +12,8 @@ #include "lm/vocab.hh" #include "lm/weights.hh" +#include "util/murmur_hash.hh" + #include #include @@ -28,21 +30,18 @@ class State { public: bool operator==(const State &other) const { if (length != other.length) return false; - const WordIndex *end = words + length; - for (const WordIndex *first = words, *second = other.words; - first != end; ++first, ++second) { - if (*first != *second) return false; - } - // If the histories are equal, so are the backoffs. - return true; + return !memcmp(words, other.words, length * sizeof(WordIndex)); } // Three way comparison function. int Compare(const State &other) const { - if (length == other.length) { - return memcmp(words, other.words, length * sizeof(WordIndex)); - } - return (length < other.length) ? -1 : 1; + if (length != other.length) return length < other.length ? -1 : 1; + return memcmp(words, other.words, length * sizeof(WordIndex)); + } + + bool operator<(const State &other) const { + if (length != other.length) return length < other.length; + return memcmp(words, other.words, length * sizeof(WordIndex)) < 0; } // Call this before using raw memcmp. @@ -62,7 +61,9 @@ class State { unsigned char length; }; -size_t hash_value(const State &state); +inline size_t hash_value(const State &state) { + return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); +} namespace detail { -- cgit v1.2.3 From 747309fcb0e0b1c6d060a68286ba1cf5ed1fbfa4 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sat, 24 Sep 2011 11:33:22 -0400 Subject: Chris says remnant and oovs should not be null, so stop checking. Also, we were not properly doing ZeroRemaining, sorry. --- decoder/ff_klm.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 3b2113ad..6d9aca54 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -92,10 +92,9 @@ class KLanguageModelImpl { public: double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { - if (oovs) *oovs = 0; + *oovs = 0; const vector& e = rule.e(); - lm::ngram::ChartState state; - lm::ngram::RuleScore ruleScore(*ngram_, remnant ? *static_cast(remnant) : state); + lm::ngram::RuleScore ruleScore(*ngram_, *static_cast(remnant)); unsigned i = 0; if (e.size()) { if (e[i] == kCDEC_SOS) { @@ -115,13 +114,12 @@ class KLanguageModelImpl { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id - const bool is_oov = (cur_word == 0); - if (is_oov && oovs) (*oovs) += 1.0; + if (cur_word == 0) (*oovs) += 1.0; ruleScore.Terminal(cur_word); } } double ret = ruleScore.Finish(); - state.ZeroRemaining(); + static_cast(remnant)->ZeroRemaining(); return ret; } -- cgit v1.2.3 From 5ef94f59e08d2f25bee8520c4233829207d1c034 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sun, 25 Sep 2011 19:18:36 -0400 Subject: Fix trie sort merging --- klm/lm/trie_sort.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 01c4e490..86f28493 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -146,7 +146,7 @@ template void MergeSortedFiles(const std::string &first_name, co ++first; ++second; } } - for (RecordReader &remains = (first ? second : first); remains; ++remains) { + for (RecordReader &remains = (first ? first : second); remains; ++remains) { WriteOrThrow(out_file.get(), remains.Data(), entry_size); } } @@ -191,7 +191,7 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); files.push_back(assembled.str()); MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); - MergeSortedFiles(files[0], files[1], files.back(), 0, order, FirstCombine()); + MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order, FirstCombine()); files.pop_front(); files.pop_front(); } -- cgit v1.2.3 From 32288c27a523a1152afa019b9152f4401c3097ce Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 26 Sep 2011 16:54:16 -0400 Subject: Fix trie pointer segfault --- klm/lm/bhiksha.hh | 2 ++ klm/lm/trie.cc | 6 ++---- klm/lm/trie.hh | 7 ++++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index ff7fe452..bc705959 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -11,6 +11,7 @@ */ #include +#include #include "lm/model_type.hh" #include "lm/trie.hh" @@ -78,6 +79,7 @@ class ArrayBhiksha { util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask); out.end = ((end_it - offset_begin_) << next_inline_.bits) | util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask); + //assert(out.end >= out.begin); } void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) { diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 4e60b184..20075bb8 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -91,16 +91,14 @@ template bool BitPackedMiddle::Find if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } - uint64_t index = at_pointer; + pointer = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; - pointer = at_pointer; - quant_.Read(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); - bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); + bhiksha_.ReadNext(base_, at_pointer, pointer, total_bits_, range); return true; } diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index a9f5e417..06cc96ac 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -99,10 +99,11 @@ template class BitPackedMiddle : public BitPacked { bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; NodeRange ReadEntry(uint64_t pointer, float &prob) { - quant_.ReadProb(base_, pointer, prob); + uint64_t addr = pointer * total_bits_; + addr += word_bits_; + quant_.ReadProb(base_, addr, prob); NodeRange ret; - // pointer/total_bits_ should always round down. - bhiksha_.ReadNext(base_, pointer + quant_.TotalBits(), pointer / total_bits_, total_bits_, ret); + bhiksha_.ReadNext(base_, addr + quant_.TotalBits(), pointer, total_bits_, ret); return ret; } -- cgit v1.2.3 From 1706bda5f393808583c6ab21a5d073b204827f52 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 28 Sep 2011 16:04:41 +0100 Subject: fix broken compile on weights test --- utils/weights_test.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/utils/weights_test.cc b/utils/weights_test.cc index 8a4c26ef..938b311f 100644 --- a/utils/weights_test.cc +++ b/utils/weights_test.cc @@ -14,11 +14,10 @@ class WeightsTest : public testing::Test { virtual void TearDown() { } }; - TEST_F(WeightsTest,Load) { - Weights w; - w.InitFromFile("test_data/weights"); - w.WriteToFile("-"); + vector v; + Weights::InitFromFile("test_data/weights", &v); + Weights::WriteToFile("-", v); } int main(int argc, char **argv) { -- cgit v1.2.3 From 5e83a23897b5160fa96788c5828d23ab1912d1a6 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 28 Sep 2011 16:05:50 +0100 Subject: upgrade gtest m4 config --- configure.ac | 4 ++-- m4/gtest.m4 | 65 ++++++++++++++++++++++++++++++++++++------------------------ 2 files changed, 41 insertions(+), 28 deletions(-) diff --git a/configure.ac b/configure.ac index 8e06bffd..2e9cc36d 100644 --- a/configure.ac +++ b/configure.ac @@ -60,7 +60,7 @@ AC_CHECK_HEADER(google/dense_hash_map, [AC_DEFINE([HAVE_SPARSEHASH], [], [flag for google::dense_hash_map])]) AC_PROG_INSTALL -GTEST_LIB_CHECK +GTEST_LIB_CHECK(1.0) AM_CONDITIONAL([RAND_LM], false) AC_ARG_WITH(randlm, @@ -113,4 +113,4 @@ then AM_CONDITIONAL([GLC], true) fi -AC_OUTPUT(Makefile utils/Makefile mteval/Makefile extools/Makefile decoder/Makefile phrasinator/Makefile training/Makefile vest/Makefile pro-train/Makefile klm/util/Makefile klm/lm/Makefile mira/Makefile gi/pyp-topics/src/Makefile gi/clda/src/Makefile) +AC_OUTPUT(Makefile utils/Makefile mteval/Makefile extools/Makefile decoder/Makefile phrasinator/Makefile training/Makefile vest/Makefile pro-train/Makefile klm/util/Makefile klm/lm/Makefile mira/Makefile gi/pyp-topics/src/Makefile gi/clda/src/Makefile gi/cbgi/Makefile gi/ml/Makefile) diff --git a/m4/gtest.m4 b/m4/gtest.m4 index b015ddeb..28ccd2de 100644 --- a/m4/gtest.m4 +++ b/m4/gtest.m4 @@ -12,10 +12,10 @@ AC_DEFUN([GTEST_LIB_CHECK], dnl Provide a flag to enable or disable Google Test usage. AC_ARG_ENABLE([gtest], [AS_HELP_STRING([--enable-gtest], - [Enable tests using the Google C++ Testing Framework.] - [(Default is enabled.)])], + [Enable tests using the Google C++ Testing Framework. + (Default is enabled.)])], [], - [enable_gtest=check]) + [enable_gtest=]) AC_ARG_VAR([GTEST_CONFIG], [The exact path of Google Test's 'gtest-config' script.]) AC_ARG_VAR([GTEST_CPPFLAGS], @@ -29,33 +29,46 @@ AC_ARG_VAR([GTEST_LIBS], AC_ARG_VAR([GTEST_VERSION], [The version of Google Test available.]) HAVE_GTEST="no" -AS_IF([test "x$enable_gtest" != "xno"], - [AC_PATH_PROG([GTEST_CONFIG], [gtest-config]) - AS_IF([test -x "$GTEST_CONFIG"], - [AS_IF([test "x$1" != "x"], - [_min_version="--min-version=$1" +AS_IF([test "x${enable_gtest}" != "xno"], + [AC_MSG_CHECKING([for 'gtest-config']) + AS_IF([test "x${enable_gtest}" != "xyes"], + [AS_IF([test -x "${enable_gtest}/scripts/gtest-config"], + [GTEST_CONFIG="${enable_gtest}/scripts/gtest-config"], + [GTEST_CONFIG="${enable_gtest}/bin/gtest-config"]) + AS_IF([test -x "${GTEST_CONFIG}"], [], + [AC_MSG_RESULT([no]) + AC_MSG_ERROR([dnl +Unable to locate either a built or installed Google Test. +The specific location '${enable_gtest}' was provided for a built or installed +Google Test, but no 'gtest-config' script could be found at this location.]) + ])], + [AC_PATH_PROG([GTEST_CONFIG], [gtest-config])]) + AS_IF([test -x "${GTEST_CONFIG}"], + [AC_MSG_RESULT([${GTEST_CONFIG}]) + m4_ifval([$1], + [_gtest_min_version="--min-version=$1" AC_MSG_CHECKING([for Google Test at least version >= $1])], - [_min_version="--min-version=0" + [_gtest_min_version="--min-version=0" AC_MSG_CHECKING([for Google Test])]) - AS_IF([$GTEST_CONFIG $_min_version], + AS_IF([${GTEST_CONFIG} ${_gtest_min_version}], [AC_MSG_RESULT([yes]) - HAVE_GTEST="yes"], - [AC_MSG_RESULT([no])])]) - AS_IF([test "x$HAVE_GTEST" = "xyes"], - [GTEST_CPPFLAGS=$($GTEST_CONFIG --cppflags) - GTEST_CXXFLAGS=$($GTEST_CONFIG --cxxflags) - GTEST_LDFLAGS=$($GTEST_CONFIG --ldflags) - GTEST_LIBS=$($GTEST_CONFIG --libs | sed 's/la/a/') - GTEST_VERSION=$($GTEST_CONFIG --version) + HAVE_GTEST='yes'], + [AC_MSG_RESULT([no])])], + [AC_MSG_RESULT([no])]) + AS_IF([test "x${HAVE_GTEST}" = "xyes"], + [GTEST_CPPFLAGS=`${GTEST_CONFIG} --cppflags` + GTEST_CXXFLAGS=`${GTEST_CONFIG} --cxxflags` + GTEST_LDFLAGS=`${GTEST_CONFIG} --ldflags` + GTEST_LIBS=`${GTEST_CONFIG} --libs` + GTEST_VERSION=`${GTEST_CONFIG} --version` AC_DEFINE([HAVE_GTEST],[1],[Defined when Google Test is available.])], - [AS_IF([test "x$enable_gtest" = "xyes"], - [AC_MSG_ERROR([ - The Google C++ Testing Framework was explicitly enabled, but a viable version - could not be found on the system. -])])])]) + [AS_IF([test "x${enable_gtest}" = "xyes"], + [AC_MSG_ERROR([dnl +Google Test was enabled, but no viable version could be found.]) + ])])]) AC_SUBST([HAVE_GTEST]) AM_CONDITIONAL([HAVE_GTEST],[test "x$HAVE_GTEST" = "xyes"]) -AS_IF([test "x$HAVE_GTEST" = "xyes"], - [AS_IF([test "x$2" != "x"],[$2],[:])], - [AS_IF([test "x$3" != "x"],[$3],[:])]) +dnl AS_IF([test "x$HAVE_GTEST" = "xyes"], [] []) +dnl [m4_ifval([$2], [$2])], +dnl [m4_ifval([$3], [$3])]) ]) -- cgit v1.2.3 From b77d23a3032f42be3705e88ae1734bae779fb9a3 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 28 Sep 2011 16:19:09 +0100 Subject: test fixes --- decoder/grammar_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/decoder/grammar_test.cc b/decoder/grammar_test.cc index 62b8f958..cde00efa 100644 --- a/decoder/grammar_test.cc +++ b/decoder/grammar_test.cc @@ -15,12 +15,12 @@ using namespace std; class GrammarTest : public testing::Test { public: GrammarTest() { - wts.InitFromFile("test_data/weights.gt"); + Weights::InitFromFile("test_data/weights.gt", &wts); } protected: virtual void SetUp() { } virtual void TearDown() { } - Weights wts; + vector wts; }; TEST_F(GrammarTest,TestTextGrammar) { -- cgit v1.2.3 From 09278065247830c1c9be88c45a60a5a8017f8e9c Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Wed, 28 Sep 2011 16:21:20 +0100 Subject: fix --- configure.ac | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configure.ac b/configure.ac index 2e9cc36d..dd151076 100644 --- a/configure.ac +++ b/configure.ac @@ -113,4 +113,5 @@ then AM_CONDITIONAL([GLC], true) fi -AC_OUTPUT(Makefile utils/Makefile mteval/Makefile extools/Makefile decoder/Makefile phrasinator/Makefile training/Makefile vest/Makefile pro-train/Makefile klm/util/Makefile klm/lm/Makefile mira/Makefile gi/pyp-topics/src/Makefile gi/clda/src/Makefile gi/cbgi/Makefile gi/ml/Makefile) +AC_OUTPUT(Makefile utils/Makefile mteval/Makefile extools/Makefile decoder/Makefile phrasinator/Makefile training/Makefile vest/Makefile pro-train/Makefile klm/util/Makefile klm/lm/Makefile mira/Makefile gi/pyp-topics/src/Makefile gi/clda/src/Makefile) +# gi/cbgi/Makefile gi/ml/Makefile) -- cgit v1.2.3 From 0acc92a0eecf04a2c429f6f7685bfcaa68c7ec3a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 11 Oct 2011 12:06:32 +0100 Subject: check in some experimental particle filtering code, some gitignore fixes --- .gitignore | 24 +- Makefile.am | 2 +- configure.ac | 2 +- gi/markov_al/Makefile.am | 6 + gi/markov_al/README | 2 + gi/markov_al/ml.cc | 470 +++++++++++++++++++++++++++++ gi/pf/Makefile.am | 21 ++ gi/pf/README | 2 + gi/pf/base_measures.cc | 112 +++++++ gi/pf/base_measures.h | 116 +++++++ gi/pf/brat.cc | 554 ++++++++++++++++++++++++++++++++++ gi/pf/cbgi.cc | 340 +++++++++++++++++++++ gi/pf/cfg_wfst_composer.cc | 730 +++++++++++++++++++++++++++++++++++++++++++++ gi/pf/cfg_wfst_composer.h | 46 +++ gi/pf/dpnaive.cc | 349 ++++++++++++++++++++++ gi/pf/itg.cc | 224 ++++++++++++++ gi/pf/pfbrat.cc | 554 ++++++++++++++++++++++++++++++++++ gi/pf/pfdist.cc | 621 ++++++++++++++++++++++++++++++++++++++ gi/pf/pfdist.new.cc | 620 ++++++++++++++++++++++++++++++++++++++ gi/pf/pfnaive.cc | 385 ++++++++++++++++++++++++ gi/pf/reachability.cc | 64 ++++ gi/pf/reachability.h | 28 ++ gi/pf/tpf.cc | 99 ++++++ m4/acx_pthread.m4 | 363 ++++++++++++++++++++++ utils/ccrp_nt.h | 169 +++++++++++ utils/ccrp_onetable.h | 241 +++++++++++++++ 26 files changed, 6141 insertions(+), 3 deletions(-) create mode 100644 gi/markov_al/Makefile.am create mode 100644 gi/markov_al/README create mode 100644 gi/markov_al/ml.cc create mode 100644 gi/pf/Makefile.am create mode 100644 gi/pf/README create mode 100644 gi/pf/base_measures.cc create mode 100644 gi/pf/base_measures.h create mode 100644 gi/pf/brat.cc create mode 100644 gi/pf/cbgi.cc create mode 100644 gi/pf/cfg_wfst_composer.cc create mode 100644 gi/pf/cfg_wfst_composer.h create mode 100644 gi/pf/dpnaive.cc create mode 100644 gi/pf/itg.cc create mode 100644 gi/pf/pfbrat.cc create mode 100644 gi/pf/pfdist.cc create mode 100644 gi/pf/pfdist.new.cc create mode 100644 gi/pf/pfnaive.cc create mode 100644 gi/pf/reachability.cc create mode 100644 gi/pf/reachability.h create mode 100644 gi/pf/tpf.cc create mode 100644 m4/acx_pthread.m4 create mode 100644 utils/ccrp_nt.h create mode 100644 utils/ccrp_onetable.h diff --git a/.gitignore b/.gitignore index 2a287bbc..5efe37b0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,27 @@ +pro-train/.deps +pro-train/mr_pro_map +pro-train/mr_pro_reduce +utils/reconstruct_weights +decoder/.libs +training/augment_grammar +training/mpi_batch_optimize +training/mpi_compute_cllh +training/mpi_em_optimize +training/mpi_extract_features +training/mpi_extract_reachable klm/lm/build_binary extools/extractor_monolingual +gi/pf/.deps +gi/pf/brat +gi/pf/cbgi +gi/pf/dpnaive +gi/pf/itg +gi/pf/libpf.a +gi/pf/pfbrat +gi/pf/pfdist +gi/pf/pfnaive +gi/markov_al/.deps +gi/markov_al/ml gi/posterior-regularisation/prjava/lib/*.jar klm/lm/libklm.a klm/util/.deps @@ -120,4 +142,4 @@ gi/posterior-regularisation/prjava/lib/prjava-20100715.jar *.dvi *.ps *.toc -*~ \ No newline at end of file +*~ diff --git a/Makefile.am b/Makefile.am index 98b4bac7..59c2fc0a 100644 --- a/Makefile.am +++ b/Makefile.am @@ -1,7 +1,7 @@ # warning - the subdirectories in the following list should # be kept in topologically sorted order. Also, DO NOT introduce # cyclic dependencies between these directories! -SUBDIRS = utils mteval klm/util klm/lm decoder phrasinator training mira vest pro-train extools +SUBDIRS = utils mteval klm/util klm/lm decoder phrasinator training mira vest pro-train extools gi/pf gi/markov_al #gi/pyp-topics/src gi/clda/src gi/posterior-regularisation/prjava diff --git a/configure.ac b/configure.ac index 2e9cc36d..131a1705 100644 --- a/configure.ac +++ b/configure.ac @@ -113,4 +113,4 @@ then AM_CONDITIONAL([GLC], true) fi -AC_OUTPUT(Makefile utils/Makefile mteval/Makefile extools/Makefile decoder/Makefile phrasinator/Makefile training/Makefile vest/Makefile pro-train/Makefile klm/util/Makefile klm/lm/Makefile mira/Makefile gi/pyp-topics/src/Makefile gi/clda/src/Makefile gi/cbgi/Makefile gi/ml/Makefile) +AC_OUTPUT(Makefile utils/Makefile mteval/Makefile extools/Makefile decoder/Makefile phrasinator/Makefile training/Makefile vest/Makefile pro-train/Makefile klm/util/Makefile klm/lm/Makefile mira/Makefile gi/pyp-topics/src/Makefile gi/clda/src/Makefile gi/pf/Makefile gi/markov_al/Makefile) diff --git a/gi/markov_al/Makefile.am b/gi/markov_al/Makefile.am new file mode 100644 index 00000000..fe3e3349 --- /dev/null +++ b/gi/markov_al/Makefile.am @@ -0,0 +1,6 @@ +bin_PROGRAMS = ml + +ml_SOURCES = ml.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -funroll-loops -I$(top_srcdir)/utils $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder +AM_LDFLAGS = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz diff --git a/gi/markov_al/README b/gi/markov_al/README new file mode 100644 index 00000000..9c10f7cd --- /dev/null +++ b/gi/markov_al/README @@ -0,0 +1,2 @@ +Experimental translation models with Markovian dependencies. + diff --git a/gi/markov_al/ml.cc b/gi/markov_al/ml.cc new file mode 100644 index 00000000..1e71edd6 --- /dev/null +++ b/gi/markov_al/ml.cc @@ -0,0 +1,470 @@ +#include +#include + +#include +#include +#include +#include + +#include "tdict.h" +#include "filelib.h" +#include "sampler.h" +#include "ccrp_onetable.h" +#include "array2d.h" + +using namespace std; +using namespace std::tr1; +namespace po = boost::program_options; + +void PrintTopCustomers(const CCRP_OneTable& crp) { + for (CCRP_OneTable::const_iterator it = crp.begin(); it != crp.end(); ++it) { + cerr << " " << TD::Convert(it->first) << " = " << it->second << endl; + } +} + +void PrintAlignment(const vector& src, const vector& trg, const vector& a) { + cerr << TD::GetString(src) << endl << TD::GetString(trg) << endl; + Array2D al(src.size(), trg.size()); + for (int i = 0; i < a.size(); ++i) + if (a[i] != 255) al(a[i], i) = true; + cerr << al << endl; +} + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value()->default_value(1000),"Number of samples") + ("input,i",po::value(),"Read parallel data from") + ("random_seed,S",po::value(), "Random seed"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +struct Unigram; +struct Bigram { + Bigram() : trg(), cond() {} + Bigram(WordID prev, WordID cur, WordID t) : trg(t) { cond.first = prev; cond.second = cur; } + const pair& ConditioningPair() const { + return cond; + } + WordID& prev_src() { return cond.first; } + WordID& cur_src() { return cond.second; } + const WordID& prev_src() const { return cond.first; } + const WordID& cur_src() const { return cond.second; } + WordID trg; + private: + pair cond; +}; + +struct Unigram { + Unigram() : cur_src(), trg() {} + Unigram(WordID s, WordID t) : cur_src(s), trg(t) {} + WordID cur_src; + WordID trg; +}; + +ostream& operator<<(ostream& os, const Bigram& b) { + os << "( " << TD::Convert(b.trg) << " | " << TD::Convert(b.prev_src()) << " , " << TD::Convert(b.cur_src()) << " )"; + return os; +} + +ostream& operator<<(ostream& os, const Unigram& u) { + os << "( " << TD::Convert(u.trg) << " | " << TD::Convert(u.cur_src) << " )"; + return os; +} + +bool operator==(const Bigram& a, const Bigram& b) { + return a.trg == b.trg && a.cur_src() == b.cur_src() && a.prev_src() == b.prev_src(); +} + +bool operator==(const Unigram& a, const Unigram& b) { + return a.trg == b.trg && a.cur_src == b.cur_src; +} + +size_t hash_value(const Bigram& b) { + size_t h = boost::hash_value(b.prev_src()); + boost::hash_combine(h, boost::hash_value(b.cur_src())); + boost::hash_combine(h, boost::hash_value(b.trg)); + return h; +} + +size_t hash_value(const Unigram& u) { + size_t h = boost::hash_value(u.cur_src); + boost::hash_combine(h, boost::hash_value(u.trg)); + return h; +} + +void ReadParallelCorpus(const string& filename, + vector >* f, + vector >* e, + set* vocab_f, + set* vocab_e) { + f->clear(); + e->clear(); + vocab_f->clear(); + vocab_e->clear(); + istream* in; + if (filename == "-") + in = &cin; + else + in = new ifstream(filename.c_str()); + assert(*in); + string line; + const WordID kDIV = TD::Convert("|||"); + vector tmp; + while(*in) { + getline(*in, line); + if (line.empty() && !*in) break; + e->push_back(vector()); + f->push_back(vector()); + vector& le = e->back(); + vector& lf = f->back(); + tmp.clear(); + TD::ConvertSentence(line, &tmp); + bool isf = true; + for (unsigned i = 0; i < tmp.size(); ++i) { + const int cur = tmp[i]; + if (isf) { + if (kDIV == cur) { isf = false; } else { + lf.push_back(cur); + vocab_f->insert(cur); + } + } else { + assert(cur != kDIV); + le.push_back(cur); + vocab_e->insert(cur); + } + } + assert(isf == false); + } + if (in != &cin) delete in; +} + +struct UnigramModel { + UnigramModel(size_t src_voc_size, size_t trg_voc_size) : + unigrams(TD::NumWords() + 1, CCRP_OneTable(1,1,1,1)), + p0(1.0 / trg_voc_size) {} + + void increment(const Bigram& b) { + unigrams[b.cur_src()].increment(b.trg); + } + + void decrement(const Bigram& b) { + unigrams[b.cur_src()].decrement(b.trg); + } + + double prob(const Bigram& b) const { + const double q0 = unigrams[b.cur_src()].prob(b.trg, p0); + return q0; + } + + double LogLikelihood() const { + double llh = 0; + for (unsigned i = 0; i < unigrams.size(); ++i) { + const CCRP_OneTable& crp = unigrams[i]; + if (crp.num_customers() > 0) { + llh += crp.log_crp_prob(); + llh += crp.num_tables() * log(p0); + } + } + return llh; + } + + void ResampleHyperparameters(MT19937* rng) { + for (unsigned i = 0; i < unigrams.size(); ++i) + unigrams[i].resample_hyperparameters(rng); + } + + vector > unigrams; // unigrams[src].prob(trg, p0) = p(trg|src) + + const double p0; +}; + +struct BigramModel { + BigramModel(size_t src_voc_size, size_t trg_voc_size) : + unigrams(TD::NumWords() + 1, CCRP_OneTable(1,1,1,1)), + p0(1.0 / trg_voc_size) {} + + void increment(const Bigram& b) { + BigramMap::iterator it = bigrams.find(b.ConditioningPair()); + if (it == bigrams.end()) { + it = bigrams.insert(make_pair(b.ConditioningPair(), CCRP_OneTable(1,1,1,1))).first; + } + if (it->second.increment(b.trg)) + unigrams[b.cur_src()].increment(b.trg); + } + + void decrement(const Bigram& b) { + BigramMap::iterator it = bigrams.find(b.ConditioningPair()); + assert(it != bigrams.end()); + if (it->second.decrement(b.trg)) { + unigrams[b.cur_src()].decrement(b.trg); + if (it->second.num_customers() == 0) + bigrams.erase(it); + } + } + + double prob(const Bigram& b) const { + const double q0 = unigrams[b.cur_src()].prob(b.trg, p0); + const BigramMap::const_iterator it = bigrams.find(b.ConditioningPair()); + if (it == bigrams.end()) return q0; + return it->second.prob(b.trg, q0); + } + + double LogLikelihood() const { + double llh = 0; + for (unsigned i = 0; i < unigrams.size(); ++i) { + const CCRP_OneTable& crp = unigrams[i]; + if (crp.num_customers() > 0) { + llh += crp.log_crp_prob(); + llh += crp.num_tables() * log(p0); + } + } + for (BigramMap::const_iterator it = bigrams.begin(); it != bigrams.end(); ++it) { + const CCRP_OneTable& crp = it->second; + const WordID cur_src = it->first.second; + llh += crp.log_crp_prob(); + for (CCRP_OneTable::const_iterator bit = crp.begin(); bit != crp.end(); ++bit) { + llh += log(unigrams[cur_src].prob(bit->second, p0)); + } + } + return llh; + } + + void ResampleHyperparameters(MT19937* rng) { + for (unsigned i = 0; i < unigrams.size(); ++i) + unigrams[i].resample_hyperparameters(rng); + for (BigramMap::iterator it = bigrams.begin(); it != bigrams.end(); ++it) + it->second.resample_hyperparameters(rng); + } + + typedef unordered_map, CCRP_OneTable, boost::hash > > BigramMap; + BigramMap bigrams; // bigrams[(src-1,src)].prob(trg, q0) = p(trg|src,src-1) + vector > unigrams; // unigrams[src].prob(trg, p0) = p(trg|src) + + const double p0; +}; + +struct BigramAlignmentModel { + BigramAlignmentModel(size_t src_voc_size, size_t trg_voc_size) : bigrams(TD::NumWords() + 1, CCRP_OneTable(1,1,1,1)), p0(1.0 / src_voc_size) {} + void increment(WordID prev, WordID next) { + bigrams[prev].increment(next); // hierarchy? + } + void decrement(WordID prev, WordID next) { + bigrams[prev].decrement(next); // hierarchy? + } + double prob(WordID prev, WordID next) { + return bigrams[prev].prob(next, p0); + } + double LogLikelihood() const { + double llh = 0; + for (unsigned i = 0; i < bigrams.size(); ++i) { + const CCRP_OneTable& crp = bigrams[i]; + if (crp.num_customers() > 0) { + llh += crp.log_crp_prob(); + llh += crp.num_tables() * log(p0); + } + } + return llh; + } + + vector > bigrams; // bigrams[prev].prob(next, p0) = p(next|prev) + const double p0; +}; + +struct Alignment { + vector a; +}; + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const unsigned samples = conf["samples"].as(); + + boost::shared_ptr prng; + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as())); + else + prng.reset(new MT19937); + MT19937& rng = *prng; + + vector > corpuse, corpusf; + set vocabe, vocabf; + cerr << "Reading corpus...\n"; + ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); + cerr << "F-corpus size: " << corpusf.size() << " sentences\t (" << vocabf.size() << " word types)\n"; + cerr << "E-corpus size: " << corpuse.size() << " sentences\t (" << vocabe.size() << " word types)\n"; + assert(corpusf.size() == corpuse.size()); + const size_t corpus_len = corpusf.size(); + const WordID kNULL = TD::Convert(""); + const WordID kBOS = TD::Convert(""); + const WordID kEOS = TD::Convert(""); + Bigram TT(kBOS, TD::Convert("我"), TD::Convert("i")); + Bigram TT2(kBOS, TD::Convert("要"), TD::Convert("i")); + + UnigramModel model(vocabf.size(), vocabe.size()); + vector alignments(corpus_len); + for (unsigned ci = 0; ci < corpus_len; ++ci) { + const vector& src = corpusf[ci]; + const vector& trg = corpuse[ci]; + vector& alg = alignments[ci].a; + alg.resize(trg.size()); + int lenp1 = src.size() + 1; + WordID prev_src = kBOS; + for (int j = 0; j < trg.size(); ++j) { + int samp = lenp1 * rng.next(); + --samp; + if (samp < 0) samp = 255; + alg[j] = samp; + WordID cur_src = (samp == 255 ? kNULL : src[alg[j]]); + Bigram b(prev_src, cur_src, trg[j]); + model.increment(b); + prev_src = cur_src; + } + Bigram b(prev_src, kEOS, kEOS); + model.increment(b); + } + cerr << "Initial LLH: " << model.LogLikelihood() << endl; + + SampleSet ss; + for (unsigned si = 0; si < 50; ++si) { + for (unsigned ci = 0; ci < corpus_len; ++ci) { + const vector& src = corpusf[ci]; + const vector& trg = corpuse[ci]; + vector& alg = alignments[ci].a; + WordID prev_src = kBOS; + for (unsigned j = 0; j < trg.size(); ++j) { + unsigned char& a_j = alg[j]; + WordID cur_e_a_j = (a_j == 255 ? kNULL : src[a_j]); + Bigram b(prev_src, cur_e_a_j, trg[j]); + //cerr << "DEC: " << b << "\t" << nextb << endl; + model.decrement(b); + ss.clear(); + for (unsigned i = 0; i <= src.size(); ++i) { + const WordID cur_src = (i ? src[i-1] : kNULL); + b.cur_src() = cur_src; + ss.add(model.prob(b)); + } + int sampled_a_j = rng.SelectSample(ss); + a_j = (sampled_a_j ? sampled_a_j - 1 : 255); + cur_e_a_j = (a_j == 255 ? kNULL : src[a_j]); + b.cur_src() = cur_e_a_j; + //cerr << "INC: " << b << "\t" << nextb << endl; + model.increment(b); + prev_src = cur_e_a_j; + } + } + cerr << '.' << flush; + if (si % 10 == 9) { + cerr << "[LLH prev=" << model.LogLikelihood(); + //model.ResampleHyperparameters(&rng); + cerr << " new=" << model.LogLikelihood() << "]\n"; + //pair xx = make_pair(kBOS, TD::Convert("我")); + //PrintTopCustomers(model.bigrams.find(xx)->second); + cerr << "p(" << TT << ") = " << model.prob(TT) << endl; + cerr << "p(" << TT2 << ") = " << model.prob(TT2) << endl; + PrintAlignment(corpusf[0], corpuse[0], alignments[0].a); + } + } + { + // MODEL 2 + BigramModel model(vocabf.size(), vocabe.size()); + BigramAlignmentModel amodel(vocabf.size(), vocabe.size()); + for (unsigned ci = 0; ci < corpus_len; ++ci) { + const vector& src = corpusf[ci]; + const vector& trg = corpuse[ci]; + vector& alg = alignments[ci].a; + WordID prev_src = kBOS; + for (int j = 0; j < trg.size(); ++j) { + WordID cur_src = (alg[j] == 255 ? kNULL : src[alg[j]]); + Bigram b(prev_src, cur_src, trg[j]); + model.increment(b); + amodel.increment(prev_src, cur_src); + prev_src = cur_src; + } + amodel.increment(prev_src, kEOS); + Bigram b(prev_src, kEOS, kEOS); + model.increment(b); + } + cerr << "Initial LLH: " << model.LogLikelihood() << " " << amodel.LogLikelihood() << endl; + + SampleSet ss; + for (unsigned si = 0; si < samples; ++si) { + for (unsigned ci = 0; ci < corpus_len; ++ci) { + const vector& src = corpusf[ci]; + const vector& trg = corpuse[ci]; + vector& alg = alignments[ci].a; + WordID prev_src = kBOS; + for (unsigned j = 0; j < trg.size(); ++j) { + unsigned char& a_j = alg[j]; + WordID cur_e_a_j = (a_j == 255 ? kNULL : src[a_j]); + Bigram b(prev_src, cur_e_a_j, trg[j]); + WordID next_src = kEOS; + WordID next_trg = kEOS; + if (j < (trg.size() - 1)) { + next_src = (alg[j+1] == 255 ? kNULL : src[alg[j + 1]]); + next_trg = trg[j + 1]; + } + Bigram nextb(cur_e_a_j, next_src, next_trg); + //cerr << "DEC: " << b << "\t" << nextb << endl; + model.decrement(b); + model.decrement(nextb); + amodel.decrement(prev_src, cur_e_a_j); + amodel.decrement(cur_e_a_j, next_src); + ss.clear(); + for (unsigned i = 0; i <= src.size(); ++i) { + const WordID cur_src = (i ? src[i-1] : kNULL); + b.cur_src() = cur_src; + ss.add(model.prob(b) * model.prob(nextb) * amodel.prob(prev_src, cur_src) * amodel.prob(cur_src, next_src)); + //cerr << log(ss[ss.size() - 1]) << "\t" << b << endl; + } + int sampled_a_j = rng.SelectSample(ss); + a_j = (sampled_a_j ? sampled_a_j - 1 : 255); + cur_e_a_j = (a_j == 255 ? kNULL : src[a_j]); + b.cur_src() = cur_e_a_j; + nextb.prev_src() = cur_e_a_j; + //cerr << "INC: " << b << "\t" << nextb << endl; + //exit(1); + model.increment(b); + model.increment(nextb); + amodel.increment(prev_src, cur_e_a_j); + amodel.increment(cur_e_a_j, next_src); + prev_src = cur_e_a_j; + } + } + cerr << '.' << flush; + if (si % 10 == 9) { + cerr << "[LLH prev=" << (model.LogLikelihood() + amodel.LogLikelihood()); + //model.ResampleHyperparameters(&rng); + cerr << " new=" << model.LogLikelihood() << "]\n"; + pair xx = make_pair(kBOS, TD::Convert("我")); + cerr << "p(" << TT << ") = " << model.prob(TT) << endl; + cerr << "p(" << TT2 << ") = " << model.prob(TT2) << endl; + pair xx2 = make_pair(kBOS, TD::Convert("要")); + PrintTopCustomers(model.bigrams.find(xx)->second); + //PrintTopCustomers(amodel.bigrams[TD::Convert("")]); + //PrintTopCustomers(model.unigrams[TD::Convert("")]); + PrintAlignment(corpusf[0], corpuse[0], alignments[0].a); + } + } + } + return 0; +} + diff --git a/gi/pf/Makefile.am b/gi/pf/Makefile.am new file mode 100644 index 00000000..c9764ad5 --- /dev/null +++ b/gi/pf/Makefile.am @@ -0,0 +1,21 @@ +bin_PROGRAMS = cbgi brat dpnaive pfbrat pfdist itg pfnaive + +noinst_LIBRARIES = libpf.a +libpf_a_SOURCES = base_measures.cc reachability.cc cfg_wfst_composer.cc + +itg_SOURCES = itg.cc + +dpnaive_SOURCES = dpnaive.cc + +pfdist_SOURCES = pfdist.cc + +pfnaive_SOURCES = pfnaive.cc + +cbgi_SOURCES = cbgi.cc + +brat_SOURCES = brat.cc + +pfbrat_SOURCES = pfbrat.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -funroll-loops -I$(top_srcdir)/utils $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder +AM_LDFLAGS = libpf.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz diff --git a/gi/pf/README b/gi/pf/README new file mode 100644 index 00000000..62e47541 --- /dev/null +++ b/gi/pf/README @@ -0,0 +1,2 @@ +Experimental Bayesian alignment tools. Nothing to see here. + diff --git a/gi/pf/base_measures.cc b/gi/pf/base_measures.cc new file mode 100644 index 00000000..f8ddfd32 --- /dev/null +++ b/gi/pf/base_measures.cc @@ -0,0 +1,112 @@ +#include "base_measures.h" + +#include + +#include "filelib.h" + +using namespace std; + +void Model1::LoadModel1(const string& fname) { + cerr << "Loading Model 1 parameters from " << fname << " ..." << endl; + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + unsigned lc = 0; + while(getline(in, line)) { + ++lc; + int cur = 0; + int start = 0; + while(cur < line.size() && line[cur] != ' ') { ++cur; } + assert(cur != line.size()); + line[cur] = 0; + const WordID src = TD::Convert(&line[0]); + ++cur; + start = cur; + while(cur < line.size() && line[cur] != ' ') { ++cur; } + assert(cur != line.size()); + line[cur] = 0; + WordID trg = TD::Convert(&line[start]); + const double logprob = strtod(&line[cur + 1], NULL); + if (src >= ttable.size()) ttable.resize(src + 1); + ttable[src][trg].logeq(logprob); + } + cerr << " read " << lc << " parameters.\n"; +} + +prob_t PhraseConditionalBase::p0(const vector& vsrc, + const vector& vtrg, + int start_src, int start_trg) const { + const int flen = vsrc.size() - start_src; + const int elen = vtrg.size() - start_trg; + prob_t uniform_src_alignment; uniform_src_alignment.logeq(-log(flen + 1)); + prob_t p; + p.logeq(log_poisson(elen, flen + 0.01)); // elen | flen ~Pois(flen + 0.01) + for (int i = 0; i < elen; ++i) { // for each position i in e-RHS + const WordID trg = vtrg[i + start_trg]; + prob_t tp = prob_t::Zero(); + for (int j = -1; j < flen; ++j) { + const WordID src = j < 0 ? 0 : vsrc[j + start_src]; + tp += kM1MIXTURE * model1(src, trg); + tp += kUNIFORM_MIXTURE * kUNIFORM_TARGET; + } + tp *= uniform_src_alignment; // draw a_i ~uniform + p *= tp; // draw e_i ~Model1(f_a_i) / uniform + } + if (p.is_0()) { + cerr << "Zero! " << vsrc << "\nTRG=" << vtrg << endl; + abort(); + } + return p; +} + +prob_t PhraseJointBase::p0(const vector& vsrc, + const vector& vtrg, + int start_src, int start_trg) const { + const int flen = vsrc.size() - start_src; + const int elen = vtrg.size() - start_trg; + prob_t uniform_src_alignment; uniform_src_alignment.logeq(-log(flen + 1)); + prob_t p; + p.logeq(log_poisson(flen, 1.0)); // flen ~Pois(1) + // elen | flen ~Pois(flen + 0.01) + prob_t ptrglen; ptrglen.logeq(log_poisson(elen, flen + 0.01)); + p *= ptrglen; + p *= kUNIFORM_SOURCE.pow(flen); // each f in F ~Uniform + for (int i = 0; i < elen; ++i) { // for each position i in E + const WordID trg = vtrg[i + start_trg]; + prob_t tp = prob_t::Zero(); + for (int j = -1; j < flen; ++j) { + const WordID src = j < 0 ? 0 : vsrc[j + start_src]; + tp += kM1MIXTURE * model1(src, trg); + tp += kUNIFORM_MIXTURE * kUNIFORM_TARGET; + } + tp *= uniform_src_alignment; // draw a_i ~uniform + p *= tp; // draw e_i ~Model1(f_a_i) / uniform + } + if (p.is_0()) { + cerr << "Zero! " << vsrc << "\nTRG=" << vtrg << endl; + abort(); + } + return p; +} + +JumpBase::JumpBase() : p(200) { + for (unsigned src_len = 1; src_len < 200; ++src_len) { + map& cpd = p[src_len]; + int min_jump = 1 - src_len; + int max_jump = src_len; + prob_t z; + for (int j = min_jump; j <= max_jump; ++j) { + prob_t& cp = cpd[j]; + if (j < 0) + cp.logeq(log_poisson(1.5-j, 1)); + else if (j > 0) + cp.logeq(log_poisson(j, 1)); + cp.poweq(0.2); + z += cp; + } + for (int j = min_jump; j <= max_jump; ++j) { + cpd[j] /= z; + } + } +} + diff --git a/gi/pf/base_measures.h b/gi/pf/base_measures.h new file mode 100644 index 00000000..df17aa62 --- /dev/null +++ b/gi/pf/base_measures.h @@ -0,0 +1,116 @@ +#ifndef _BASE_MEASURES_H_ +#define _BASE_MEASURES_H_ + +#include +#include +#include +#include +#include + +#include "trule.h" +#include "prob.h" +#include "tdict.h" + +inline double log_poisson(unsigned x, const double& lambda) { + assert(lambda > 0.0); + return log(lambda) * x - lgamma(x + 1) - lambda; +} + +inline std::ostream& operator<<(std::ostream& os, const std::vector& p) { + os << '['; + for (int i = 0; i < p.size(); ++i) + os << (i==0 ? "" : " ") << TD::Convert(p[i]); + return os << ']'; +} + +struct Model1 { + explicit Model1(const std::string& fname) : + kNULL(TD::Convert("")), + kZERO() { + LoadModel1(fname); + } + + void LoadModel1(const std::string& fname); + + // returns prob 0 if src or trg is not found + const prob_t& operator()(WordID src, WordID trg) const { + if (src == 0) src = kNULL; + if (src < ttable.size()) { + const std::map& cpd = ttable[src]; + const std::map::const_iterator it = cpd.find(trg); + if (it != cpd.end()) + return it->second; + } + return kZERO; + } + + const WordID kNULL; + const prob_t kZERO; + std::vector > ttable; +}; + +struct PhraseConditionalBase { + explicit PhraseConditionalBase(const Model1& m1, const double m1mixture, const unsigned vocab_e_size) : + model1(m1), + kM1MIXTURE(m1mixture), + kUNIFORM_MIXTURE(1.0 - m1mixture), + kUNIFORM_TARGET(1.0 / vocab_e_size) { + assert(m1mixture >= 0.0 && m1mixture <= 1.0); + assert(vocab_e_size > 0); + } + + // return p0 of rule.e_ | rule.f_ + prob_t operator()(const TRule& rule) const { + return p0(rule.f_, rule.e_, 0, 0); + } + + prob_t p0(const std::vector& vsrc, const std::vector& vtrg, int start_src, int start_trg) const; + + const Model1& model1; + const prob_t kM1MIXTURE; // Model 1 mixture component + const prob_t kUNIFORM_MIXTURE; // uniform mixture component + const prob_t kUNIFORM_TARGET; +}; + +struct PhraseJointBase { + explicit PhraseJointBase(const Model1& m1, const double m1mixture, const unsigned vocab_e_size, const unsigned vocab_f_size) : + model1(m1), + kM1MIXTURE(m1mixture), + kUNIFORM_MIXTURE(1.0 - m1mixture), + kUNIFORM_SOURCE(1.0 / vocab_f_size), + kUNIFORM_TARGET(1.0 / vocab_e_size) { + assert(m1mixture >= 0.0 && m1mixture <= 1.0); + assert(vocab_e_size > 0); + } + + // return p0 of rule.e_ | rule.f_ + prob_t operator()(const TRule& rule) const { + return p0(rule.f_, rule.e_, 0, 0); + } + + prob_t p0(const std::vector& vsrc, const std::vector& vtrg, int start_src, int start_trg) const; + + const Model1& model1; + const prob_t kM1MIXTURE; // Model 1 mixture component + const prob_t kUNIFORM_MIXTURE; // uniform mixture component + const prob_t kUNIFORM_SOURCE; + const prob_t kUNIFORM_TARGET; +}; + +// base distribution for jump size multinomials +// basically p(0) = 0 and then, p(1) is max, and then +// you drop as you move to the max jump distance +struct JumpBase { + JumpBase(); + + const prob_t& operator()(int jump, unsigned src_len) const { + assert(jump != 0); + const std::map::const_iterator it = p[src_len].find(jump); + assert(it != p[src_len].end()); + return it->second; + } + std::vector > p; +}; + + +#endif diff --git a/gi/pf/brat.cc b/gi/pf/brat.cc new file mode 100644 index 00000000..4c6ba3ef --- /dev/null +++ b/gi/pf/brat.cc @@ -0,0 +1,554 @@ +#include +#include +#include + +#include +#include +#include +#include + +#include "viterbi.h" +#include "hg.h" +#include "trule.h" +#include "tdict.h" +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "ccrp_nt.h" +#include "cfg_wfst_composer.h" + +using namespace std; +using namespace tr1; +namespace po = boost::program_options; + +static unsigned kMAX_SRC_PHRASE; +static unsigned kMAX_TRG_PHRASE; +struct FSTState; + +size_t hash_value(const TRule& r) { + size_t h = 2 - r.lhs_; + boost::hash_combine(h, boost::hash_value(r.e_)); + boost::hash_combine(h, boost::hash_value(r.f_)); + return h; +} + +bool operator==(const TRule& a, const TRule& b) { + return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); +} + +double log_poisson(unsigned x, const double& lambda) { + assert(lambda > 0.0); + return log(lambda) * x - lgamma(x + 1) - lambda; +} + +struct ConditionalBase { + explicit ConditionalBase(const double m1mixture, const unsigned vocab_e_size, const string& model1fname) : + kM1MIXTURE(m1mixture), + kUNIFORM_MIXTURE(1.0 - m1mixture), + kUNIFORM_TARGET(1.0 / vocab_e_size), + kNULL(TD::Convert("")) { + assert(m1mixture >= 0.0 && m1mixture <= 1.0); + assert(vocab_e_size > 0); + LoadModel1(model1fname); + } + + void LoadModel1(const string& fname) { + cerr << "Loading Model 1 parameters from " << fname << " ..." << endl; + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + unsigned lc = 0; + while(getline(in, line)) { + ++lc; + int cur = 0; + int start = 0; + while(cur < line.size() && line[cur] != ' ') { ++cur; } + assert(cur != line.size()); + line[cur] = 0; + const WordID src = TD::Convert(&line[0]); + ++cur; + start = cur; + while(cur < line.size() && line[cur] != ' ') { ++cur; } + assert(cur != line.size()); + line[cur] = 0; + WordID trg = TD::Convert(&line[start]); + const double logprob = strtod(&line[cur + 1], NULL); + if (src >= ttable.size()) ttable.resize(src + 1); + ttable[src][trg].logeq(logprob); + } + cerr << " read " << lc << " parameters.\n"; + } + + // return logp0 of rule.e_ | rule.f_ + prob_t operator()(const TRule& rule) const { + const int flen = rule.f_.size(); + const int elen = rule.e_.size(); + prob_t uniform_src_alignment; uniform_src_alignment.logeq(-log(flen + 1)); + prob_t p; + p.logeq(log_poisson(elen, flen + 0.01)); // elen | flen ~Pois(flen + 0.01) + for (int i = 0; i < elen; ++i) { // for each position i in e-RHS + const WordID trg = rule.e_[i]; + prob_t tp = prob_t::Zero(); + for (int j = -1; j < flen; ++j) { + const WordID src = j < 0 ? kNULL : rule.f_[j]; + const map::const_iterator it = ttable[src].find(trg); + if (it != ttable[src].end()) { + tp += kM1MIXTURE * it->second; + } + tp += kUNIFORM_MIXTURE * kUNIFORM_TARGET; + } + tp *= uniform_src_alignment; // draw a_i ~uniform + p *= tp; // draw e_i ~Model1(f_a_i) / uniform + } + return p; + } + + const prob_t kM1MIXTURE; // Model 1 mixture component + const prob_t kUNIFORM_MIXTURE; // uniform mixture component + const prob_t kUNIFORM_TARGET; + const WordID kNULL; + vector > ttable; +}; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value()->default_value(1000),"Number of samples") + ("input,i",po::value(),"Read parallel data from") + ("max_src_phrase",po::value()->default_value(3),"Maximum length of source language phrases") + ("max_trg_phrase",po::value()->default_value(3),"Maximum length of target language phrases") + ("model1,m",po::value(),"Model 1 parameters (used in base distribution)") + ("model1_interpolation_weight",po::value()->default_value(0.95),"Mixing proportion of model 1 with uniform target distribution") + ("random_seed,S",po::value(), "Random seed"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void ReadParallelCorpus(const string& filename, + vector >* f, + vector >* e, + set* vocab_f, + set* vocab_e) { + f->clear(); + e->clear(); + vocab_f->clear(); + vocab_e->clear(); + istream* in; + if (filename == "-") + in = &cin; + else + in = new ifstream(filename.c_str()); + assert(*in); + string line; + const WordID kDIV = TD::Convert("|||"); + vector tmp; + while(*in) { + getline(*in, line); + if (line.empty() && !*in) break; + e->push_back(vector()); + f->push_back(vector()); + vector& le = e->back(); + vector& lf = f->back(); + tmp.clear(); + TD::ConvertSentence(line, &tmp); + bool isf = true; + for (unsigned i = 0; i < tmp.size(); ++i) { + const int cur = tmp[i]; + if (isf) { + if (kDIV == cur) { isf = false; } else { + lf.push_back(cur); + vocab_f->insert(cur); + } + } else { + assert(cur != kDIV); + le.push_back(cur); + vocab_e->insert(cur); + } + } + assert(isf == false); + } + if (in != &cin) delete in; +} + +struct UniphraseLM { + UniphraseLM(const vector >& corpus, + const set& vocab, + const po::variables_map& conf) : + phrases_(1,1), + gen_(1,1), + corpus_(corpus), + uniform_word_(1.0 / vocab.size()), + gen_p0_(0.5), + p_end_(0.5), + use_poisson_(conf.count("poisson_length") > 0) {} + + void ResampleHyperparameters(MT19937* rng) { + phrases_.resample_hyperparameters(rng); + gen_.resample_hyperparameters(rng); + cerr << " " << phrases_.concentration(); + } + + CCRP_NoTable > phrases_; + CCRP_NoTable gen_; + vector > z_; // z_[i] is there a phrase boundary after the ith word + const vector >& corpus_; + const double uniform_word_; + const double gen_p0_; + const double p_end_; // in base length distribution, p of the end of a phrase + const bool use_poisson_; +}; + +struct Reachability { + boost::multi_array edges; // edges[src_covered][trg_covered][x][trg_delta] is this edge worth exploring? + boost::multi_array max_src_delta; // msd[src_covered][trg_covered] -- the largest src delta that's valid + + Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) : + edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]), + max_src_delta(boost::extents[srclen][trglen]) { + ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len); + } + + private: + struct SState { + SState() : prev_src_covered(), prev_trg_covered() {} + SState(int i, int j) : prev_src_covered(i), prev_trg_covered(j) {} + int prev_src_covered; + int prev_trg_covered; + }; + + struct NState { + NState() : next_src_covered(), next_trg_covered() {} + NState(int i, int j) : next_src_covered(i), next_trg_covered(j) {} + int next_src_covered; + int next_trg_covered; + }; + + void ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) { + typedef boost::multi_array, 2> array_type; + array_type a(boost::extents[srclen + 1][trglen + 1]); + a[0][0].push_back(SState()); + for (int i = 0; i < srclen; ++i) { + for (int j = 0; j < trglen; ++j) { + if (a[i][j].size() == 0) continue; + const SState prev(i,j); + for (int k = 1; k <= src_max_phrase_len; ++k) { + if ((i + k) > srclen) continue; + for (int l = 1; l <= trg_max_phrase_len; ++l) { + if ((j + l) > trglen) continue; + a[i + k][j + l].push_back(prev); + } + } + } + } + a[0][0].clear(); + cerr << "Final cell contains " << a[srclen][trglen].size() << " back pointers\n"; + assert(a[srclen][trglen].size() > 0); + + typedef boost::multi_array rarray_type; + rarray_type r(boost::extents[srclen + 1][trglen + 1]); +// typedef boost::multi_array, 2> narray_type; +// narray_type b(boost::extents[srclen + 1][trglen + 1]); + r[srclen][trglen] = true; + for (int i = srclen; i >= 0; --i) { + for (int j = trglen; j >= 0; --j) { + vector& prevs = a[i][j]; + if (!r[i][j]) { prevs.clear(); } +// const NState nstate(i,j); + for (int k = 0; k < prevs.size(); ++k) { + r[prevs[k].prev_src_covered][prevs[k].prev_trg_covered] = true; + int src_delta = i - prevs[k].prev_src_covered; + edges[prevs[k].prev_src_covered][prevs[k].prev_trg_covered][src_delta][j - prevs[k].prev_trg_covered] = true; + short &msd = max_src_delta[prevs[k].prev_src_covered][prevs[k].prev_trg_covered]; + if (src_delta > msd) msd = src_delta; +// b[prevs[k].prev_src_covered][prevs[k].prev_trg_covered].push_back(nstate); + } + } + } + assert(!edges[0][0][1][0]); + assert(!edges[0][0][0][1]); + assert(!edges[0][0][0][0]); + cerr << " MAX SRC DELTA[0][0] = " << max_src_delta[0][0] << endl; + assert(max_src_delta[0][0] > 0); + //cerr << "First cell contains " << b[0][0].size() << " forward pointers\n"; + //for (int i = 0; i < b[0][0].size(); ++i) { + // cerr << " -> (" << b[0][0][i].next_src_covered << "," << b[0][0][i].next_trg_covered << ")\n"; + //} + } +}; + +ostream& operator<<(ostream& os, const FSTState& q); +struct FSTState { + explicit FSTState(int src_size) : + trg_covered_(), + src_covered_(), + src_coverage_(src_size) {} + + FSTState(short trg_covered, short src_covered, const vector& src_coverage, const vector& src_prefix) : + trg_covered_(trg_covered), + src_covered_(src_covered), + src_coverage_(src_coverage), + src_prefix_(src_prefix) { + if (src_coverage_.size() == src_covered) { + assert(src_prefix.size() == 0); + } + } + + // if we extend by the word at src_position, what are + // the next states that are reachable and lie on a valid + // path to the final state? + vector Extensions(int src_position, int src_len, int trg_len, const Reachability& r) const { + assert(src_position < src_coverage_.size()); + if (src_coverage_[src_position]) { + cerr << "Trying to extend " << *this << " with position " << src_position << endl; + abort(); + } + vector ncvg = src_coverage_; + ncvg[src_position] = true; + + vector res; + const int trg_remaining = trg_len - trg_covered_; + if (trg_remaining <= 0) { + cerr << "Target appears to have been covered: " << *this << " (trg_len=" << trg_len << ",trg_covered=" << trg_covered_ << ")" << endl; + abort(); + } + const int src_remaining = src_len - src_covered_; + if (src_remaining <= 0) { + cerr << "Source appears to have been covered: " << *this << endl; + abort(); + } + + for (int tc = 1; tc <= kMAX_TRG_PHRASE; ++tc) { + if (r.edges[src_covered_][trg_covered_][src_prefix_.size() + 1][tc]) { + int nc = src_prefix_.size() + 1 + src_covered_; + res.push_back(FSTState(trg_covered_ + tc, nc, ncvg, vector())); + } + } + + if ((src_prefix_.size() + 1) < r.max_src_delta[src_covered_][trg_covered_]) { + vector nsp = src_prefix_; + nsp.push_back(src_position); + res.push_back(FSTState(trg_covered_, src_covered_, ncvg, nsp)); + } + + if (res.size() == 0) { + cerr << *this << " can't be extended!\n"; + abort(); + } + return res; + } + + short trg_covered_, src_covered_; + vector src_coverage_; + vector src_prefix_; +}; +bool operator<(const FSTState& q, const FSTState& r) { + if (q.trg_covered_ != r.trg_covered_) return q.trg_covered_ < r.trg_covered_; + if (q.src_covered_!= r.src_covered_) return q.src_covered_ < r.src_covered_; + if (q.src_coverage_ != r.src_coverage_) return q.src_coverage_ < r.src_coverage_; + return q.src_prefix_ < r.src_prefix_; +} + +ostream& operator<<(ostream& os, const FSTState& q) { + os << "[" << q.trg_covered_ << " : "; + for (int i = 0; i < q.src_coverage_.size(); ++i) + os << q.src_coverage_[i]; + os << " : <"; + for (int i = 0; i < q.src_prefix_.size(); ++i) { + if (i != 0) os << ' '; + os << q.src_prefix_[i]; + } + return os << ">]"; +} + +struct MyModel { + MyModel(ConditionalBase& rcp0) : rp0(rcp0) {} + typedef unordered_map, CCRP_NoTable, boost::hash > > SrcToRuleCRPMap; + + void DecrementRule(const TRule& rule) { + SrcToRuleCRPMap::iterator it = rules.find(rule.f_); + assert(it != rules.end()); + it->second.decrement(rule); + if (it->second.num_customers() == 0) rules.erase(it); + } + + void IncrementRule(const TRule& rule) { + SrcToRuleCRPMap::iterator it = rules.find(rule.f_); + if (it == rules.end()) { + CCRP_NoTable crp(1,1); + it = rules.insert(make_pair(rule.f_, crp)).first; + } + it->second.increment(rule); + } + + // conditioned on rule.f_ + prob_t RuleConditionalProbability(const TRule& rule) const { + const prob_t base = rp0(rule); + SrcToRuleCRPMap::const_iterator it = rules.find(rule.f_); + if (it == rules.end()) { + return base; + } else { + const double lp = it->second.logprob(rule, log(base)); + prob_t q; q.logeq(lp); + return q; + } + } + + const ConditionalBase& rp0; + SrcToRuleCRPMap rules; +}; + +struct MyFST : public WFST { + MyFST(const vector& ssrc, const vector& strg, MyModel* m) : + src(ssrc), trg(strg), + r(src.size(),trg.size(),kMAX_SRC_PHRASE, kMAX_TRG_PHRASE), + model(m) { + FSTState in(src.size()); + cerr << " INIT: " << in << endl; + init = GetNode(in); + for (int i = 0; i < in.src_coverage_.size(); ++i) in.src_coverage_[i] = true; + in.src_covered_ = src.size(); + in.trg_covered_ = trg.size(); + cerr << "FINAL: " << in << endl; + final = GetNode(in); + } + virtual const WFSTNode* Final() const; + virtual const WFSTNode* Initial() const; + + const WFSTNode* GetNode(const FSTState& q); + map > m; + const vector& src; + const vector& trg; + Reachability r; + const WFSTNode* init; + const WFSTNode* final; + MyModel* model; +}; + +struct MyNode : public WFSTNode { + MyNode(const FSTState& q, MyFST* fst) : state(q), container(fst) {} + virtual vector > ExtendInput(unsigned srcindex) const; + const FSTState state; + mutable MyFST* container; +}; + +vector > MyNode::ExtendInput(unsigned srcindex) const { + cerr << "EXTEND " << state << " with " << srcindex << endl; + vector ext = state.Extensions(srcindex, container->src.size(), container->trg.size(), container->r); + vector > res(ext.size()); + for (unsigned i = 0; i < ext.size(); ++i) { + res[i].first = container->GetNode(ext[i]); + if (ext[i].src_prefix_.size() == 0) { + const unsigned trg_from = state.trg_covered_; + const unsigned trg_to = ext[i].trg_covered_; + const unsigned prev_prfx_size = state.src_prefix_.size(); + res[i].second.reset(new TRule); + res[i].second->lhs_ = -TD::Convert("X"); + vector& src = res[i].second->f_; + vector& trg = res[i].second->e_; + src.resize(prev_prfx_size + 1); + for (unsigned j = 0; j < prev_prfx_size; ++j) + src[j] = container->src[state.src_prefix_[j]]; + src[prev_prfx_size] = container->src[srcindex]; + for (unsigned j = trg_from; j < trg_to; ++j) + trg.push_back(container->trg[j]); + res[i].second->scores_.set_value(FD::Convert("Proposal"), log(container->model->RuleConditionalProbability(*res[i].second))); + } + } + return res; +} + +const WFSTNode* MyFST::GetNode(const FSTState& q) { + boost::shared_ptr& res = m[q]; + if (!res) { + res.reset(new MyNode(q, this)); + } + return &*res; +} + +const WFSTNode* MyFST::Final() const { + return final; +} + +const WFSTNode* MyFST::Initial() const { + return init; +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + kMAX_TRG_PHRASE = conf["max_trg_phrase"].as(); + kMAX_SRC_PHRASE = conf["max_src_phrase"].as(); + + if (!conf.count("model1")) { + cerr << argv[0] << "Please use --model1 to specify model 1 parameters\n"; + return 1; + } + shared_ptr prng; + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as())); + else + prng.reset(new MT19937); + MT19937& rng = *prng; + + vector > corpuse, corpusf; + set vocabe, vocabf; + ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); + cerr << "f-Corpus size: " << corpusf.size() << " sentences\n"; + cerr << "f-Vocabulary size: " << vocabf.size() << " types\n"; + cerr << "f-Corpus size: " << corpuse.size() << " sentences\n"; + cerr << "f-Vocabulary size: " << vocabe.size() << " types\n"; + assert(corpusf.size() == corpuse.size()); + + ConditionalBase lp0(conf["model1_interpolation_weight"].as(), + vocabe.size(), + conf["model1"].as()); + MyModel m(lp0); + + TRule x("[X] ||| kAnwntR myN ||| at the convent ||| 0"); + m.IncrementRule(x); + TRule y("[X] ||| nY dyN ||| gave ||| 0"); + m.IncrementRule(y); + + + MyFST fst(corpusf[0], corpuse[0], &m); + ifstream in("./kimura.g"); + assert(in); + CFG_WFSTComposer comp(fst); + Hypergraph hg; + bool succeed = comp.Compose(&in, &hg); + hg.PrintGraphviz(); + if (succeed) { cerr << "SUCCESS.\n"; } else { cerr << "FAILURE REPORTED.\n"; } + +#if 0 + ifstream in2("./amnabooks.g"); + assert(in2); + MyFST fst2(corpusf[1], corpuse[1], &m); + CFG_WFSTComposer comp2(fst2); + Hypergraph hg2; + bool succeed2 = comp2.Compose(&in2, &hg2); + if (succeed2) { cerr << "SUCCESS.\n"; } else { cerr << "FAILURE REPORTED.\n"; } +#endif + + SparseVector w; w.set_value(FD::Convert("Proposal"), 1.0); + hg.Reweight(w); + cerr << ViterbiFTree(hg) << endl; + return 0; +} + diff --git a/gi/pf/cbgi.cc b/gi/pf/cbgi.cc new file mode 100644 index 00000000..20204e8a --- /dev/null +++ b/gi/pf/cbgi.cc @@ -0,0 +1,340 @@ +#include +#include +#include + +#include +#include + +#include "sampler.h" +#include "filelib.h" +#include "hg_io.h" +#include "hg.h" +#include "ccrp_nt.h" +#include "trule.h" +#include "inside_outside.h" + +using namespace std; +using namespace std::tr1; + +double log_poisson(unsigned x, const double& lambda) { + assert(lambda > 0.0); + return log(lambda) * x - lgamma(x + 1) - lambda; +} + +double log_decay(unsigned x, const double& b) { + assert(b > 1.0); + assert(x > 0); + return log(b - 1) - x * log(b); +} + +size_t hash_value(const TRule& r) { + // TODO fix hash function + size_t h = boost::hash_value(r.e_) * boost::hash_value(r.f_) * r.lhs_; + return h; +} + +bool operator==(const TRule& a, const TRule& b) { + return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); +} + +struct SimpleBase { + SimpleBase(unsigned esize, unsigned fsize, unsigned ntsize = 144) : + uniform_e(-log(esize)), + uniform_f(-log(fsize)), + uniform_nt(-log(ntsize)) { + } + + // binomial coefficient + static double choose(unsigned n, unsigned k) { + return exp(lgamma(n + 1) - lgamma(k + 1) - lgamma(n - k + 1)); + } + + // count the number of patterns of terminals and NTs in the rule, given elen and flen + static double log_number_of_patterns(const unsigned flen, const unsigned elen) { + static vector > counts; + if (elen >= counts.size()) counts.resize(elen + 1); + if (flen >= counts[elen].size()) counts[elen].resize(flen + 1); + double& count = counts[elen][flen]; + if (count) return log(count); + const unsigned max_arity = min(elen, flen); + for (unsigned a = 0; a <= max_arity; ++a) + count += choose(elen, a) * choose(flen, a); + return log(count); + } + + // return logp0 of rule | LHS + double operator()(const TRule& rule) const { + const unsigned flen = rule.f_.size(); + const unsigned elen = rule.e_.size(); +#if 0 + double p = 0; + p += log_poisson(flen, 0.5); // flen ~Pois(0.5) + p += log_poisson(elen, flen); // elen | flen ~Pois(flen) + p -= log_number_of_patterns(flen, elen); // pattern | flen,elen ~Uniform + for (unsigned i = 0; i < flen; ++i) { // for each position in f-RHS + if (rule.f_[i] <= 0) // according to pattern + p += uniform_nt; // draw NT ~Uniform + else + p += uniform_f; // draw f terminal ~Uniform + } + p -= lgamma(rule.Arity() + 1); // draw permutation ~Uniform + for (unsigned i = 0; i < elen; ++i) { // for each position in e-RHS + if (rule.e_[i] > 0) // according to pattern + p += uniform_e; // draw e|f term ~Uniform + // TODO this should prob be model 1 + } +#else + double p = 0; + bool is_abstract = rule.f_[0] <= 0; + p += log(0.5); + if (is_abstract) { + if (flen == 2) p += log(0.99); else p += log(0.01); + } else { + p += log_decay(flen, 3); + } + + for (unsigned i = 0; i < flen; ++i) { // for each position in f-RHS + if (rule.f_[i] <= 0) // according to pattern + p += uniform_nt; // draw NT ~Uniform + else + p += uniform_f; // draw f terminal ~Uniform + } +#endif + return p; + } + const double uniform_e; + const double uniform_f; + const double uniform_nt; + vector arities; +}; + +MT19937* rng = NULL; + +template +struct MHSamplerEdgeProb { + MHSamplerEdgeProb(const Hypergraph& hg, + const map >& rdp, + const Base& logp0, + const bool exclude_multiword_terminals) : edge_probs(hg.edges_.size()) { + for (int i = 0; i < edge_probs.size(); ++i) { + const TRule& rule = *hg.edges_[i].rule_; + const map >::const_iterator it = rdp.find(rule.lhs_); + assert(it != rdp.end()); + const CCRP_NoTable& crp = it->second; + edge_probs[i].logeq(crp.logprob(rule, logp0(rule))); + if (exclude_multiword_terminals && rule.f_[0] > 0 && rule.f_.size() > 1) + edge_probs[i] = prob_t::Zero(); + } + } + inline prob_t operator()(const Hypergraph::Edge& e) const { + return edge_probs[e.id_]; + } + prob_t DerivationProb(const vector& d) const { + prob_t p = prob_t::One(); + for (unsigned i = 0; i < d.size(); ++i) + p *= edge_probs[d[i]]; + return p; + } + vector edge_probs; +}; + +template +struct ModelAndData { + ModelAndData() : + base_lh(prob_t::One()), + logp0(10000, 10000), + mh_samples(), + mh_rejects() {} + + void SampleCorpus(const string& hgpath, int i); + void ResampleHyperparameters() { + for (map >::iterator it = rules.begin(); it != rules.end(); ++it) + it->second.resample_hyperparameters(rng); + } + + CCRP_NoTable& RuleCRP(int lhs) { + map >::iterator it = rules.find(lhs); + if (it == rules.end()) { + rules.insert(make_pair(lhs, CCRP_NoTable(1,1))); + it = rules.find(lhs); + } + return it->second; + } + + void IncrementRule(const TRule& rule) { + CCRP_NoTable& crp = RuleCRP(rule.lhs_); + if (crp.increment(rule)) { + prob_t p; p.logeq(logp0(rule)); + base_lh *= p; + } + } + + void DecrementRule(const TRule& rule) { + CCRP_NoTable& crp = RuleCRP(rule.lhs_); + if (crp.decrement(rule)) { + prob_t p; p.logeq(logp0(rule)); + base_lh /= p; + } + } + + void DecrementDerivation(const Hypergraph& hg, const vector& d) { + for (unsigned i = 0; i < d.size(); ++i) { + const TRule& rule = *hg.edges_[d[i]].rule_; + DecrementRule(rule); + } + } + + void IncrementDerivation(const Hypergraph& hg, const vector& d) { + for (unsigned i = 0; i < d.size(); ++i) { + const TRule& rule = *hg.edges_[d[i]].rule_; + IncrementRule(rule); + } + } + + prob_t Likelihood() const { + prob_t p = prob_t::One(); + for (map >::const_iterator it = rules.begin(); it != rules.end(); ++it) { + prob_t q; q.logeq(it->second.log_crp_prob()); + p *= q; + } + p *= base_lh; + return p; + } + + void ResampleDerivation(const Hypergraph& hg, vector* sampled_derivation); + + map > rules; // [lhs] -> distribution over RHSs + prob_t base_lh; + SimpleBase logp0; + vector > samples; // sampled derivations + unsigned int mh_samples; + unsigned int mh_rejects; +}; + +template +void ModelAndData::SampleCorpus(const string& hgpath, int n) { + vector hgs(n); hgs.clear(); + boost::unordered_map acc; + map tot; + for (int i = 0; i < n; ++i) { + ostringstream os; + os << hgpath << '/' << i << ".json.gz"; + if (!FileExists(os.str())) continue; + hgs.push_back(Hypergraph()); + ReadFile rf(os.str()); + HypergraphIO::ReadFromJSON(rf.stream(), &hgs.back()); + } + cerr << "Read " << hgs.size() << " alignment hypergraphs.\n"; + samples.resize(hgs.size()); + const unsigned SAMPLES = 2000; + const unsigned burnin = 3 * SAMPLES / 4; + const unsigned every = 20; + for (unsigned s = 0; s < SAMPLES; ++s) { + if (s % 10 == 0) { + if (s > 0) { cerr << endl; ResampleHyperparameters(); } + cerr << "[" << s << " LLH=" << log(Likelihood()) << " REJECTS=" << ((double)mh_rejects / mh_samples) << " LHS's=" << rules.size() << " base=" << log(base_lh) << "] "; + } + cerr << '.'; + for (unsigned i = 0; i < hgs.size(); ++i) { + ResampleDerivation(hgs[i], &samples[i]); + if (s > burnin && s % every == 0) { + for (unsigned j = 0; j < samples[i].size(); ++j) { + const TRule& rule = *hgs[i].edges_[samples[i][j]].rule_; + ++acc[rule]; + ++tot[rule.lhs_]; + } + } + } + } + cerr << endl; + for (boost::unordered_map::iterator it = acc.begin(); it != acc.end(); ++it) { + cout << it->first << " MyProb=" << log(it->second)-log(tot[it->first.lhs_]) << endl; + } +} + +template +void ModelAndData::ResampleDerivation(const Hypergraph& hg, vector* sampled_deriv) { + vector cur; + cur.swap(*sampled_deriv); + + const prob_t p_cur = Likelihood(); + DecrementDerivation(hg, cur); + if (cur.empty()) { + // first iteration, create restaurants + for (int i = 0; i < hg.edges_.size(); ++i) + RuleCRP(hg.edges_[i].rule_->lhs_); + } + MHSamplerEdgeProb wf(hg, rules, logp0, cur.empty()); +// MHSamplerEdgeProb wf(hg, rules, logp0, false); + const prob_t q_cur = wf.DerivationProb(cur); + vector node_probs; + Inside >(hg, &node_probs, wf); + queue q; + q.push(hg.nodes_.size() - 3); + while(!q.empty()) { + unsigned cur_node_id = q.front(); +// cerr << "NODE=" << cur_node_id << endl; + q.pop(); + const Hypergraph::Node& node = hg.nodes_[cur_node_id]; + const unsigned num_in_edges = node.in_edges_.size(); + unsigned sampled_edge = 0; + if (num_in_edges == 1) { + sampled_edge = node.in_edges_[0]; + } else { + prob_t z; + assert(num_in_edges > 1); + SampleSet ss; + for (unsigned j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; + prob_t p = wf.edge_probs[edge.id_]; // edge proposal prob + for (unsigned k = 0; k < edge.tail_nodes_.size(); ++k) + p *= node_probs[edge.tail_nodes_[k]]; + ss.add(p); +// cerr << log(ss[j]) << " ||| " << edge.rule_->AsString() << endl; + z += p; + } +// for (unsigned j = 0; j < num_in_edges; ++j) { +// const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; +// cerr << exp(log(ss[j] / z)) << " ||| " << edge.rule_->AsString() << endl; +// } +// cerr << " --- \n"; + sampled_edge = node.in_edges_[rng->SelectSample(ss)]; + } + sampled_deriv->push_back(sampled_edge); + const Hypergraph::Edge& edge = hg.edges_[sampled_edge]; + for (unsigned j = 0; j < edge.tail_nodes_.size(); ++j) { + q.push(edge.tail_nodes_[j]); + } + } + IncrementDerivation(hg, *sampled_deriv); + +// cerr << "sampled derivation contains " << sampled_deriv->size() << " edges\n"; +// cerr << "DERIV:\n"; +// for (int i = 0; i < sampled_deriv->size(); ++i) { +// cerr << " " << hg.edges_[(*sampled_deriv)[i]].rule_->AsString() << endl; +// } + + if (cur.empty()) return; // accept first sample + + ++mh_samples; + // only need to do MH if proposal is different to current state + if (cur != *sampled_deriv) { + const prob_t q_prop = wf.DerivationProb(*sampled_deriv); + const prob_t p_prop = Likelihood(); + if (!rng->AcceptMetropolisHastings(p_prop, p_cur, q_prop, q_cur)) { + ++mh_rejects; + DecrementDerivation(hg, *sampled_deriv); + IncrementDerivation(hg, cur); + swap(cur, *sampled_deriv); + } + } +} + +int main(int argc, char** argv) { + rng = new MT19937; + ModelAndData m; + m.SampleCorpus("./hgs", 50); + // m.SampleCorpus("./btec/hgs", 5000); + return 0; +} + diff --git a/gi/pf/cfg_wfst_composer.cc b/gi/pf/cfg_wfst_composer.cc new file mode 100644 index 00000000..a31b5be8 --- /dev/null +++ b/gi/pf/cfg_wfst_composer.cc @@ -0,0 +1,730 @@ +#include "cfg_wfst_composer.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include "fast_lexical_cast.hpp" + +#include "phrasetable_fst.h" +#include "sparse_vector.h" +#include "tdict.h" +#include "hg.h" + +using boost::shared_ptr; +namespace po = boost::program_options; +using namespace std; +using namespace std::tr1; + +WFSTNode::~WFSTNode() {} +WFST::~WFST() {} + +// Define the following macro if you want to see lots of debugging output +// when you run the chart parser +#undef DEBUG_CHART_PARSER + +// A few constants used by the chart parser /////////////// +static const int kMAX_NODES = 2000000; +static const string kPHRASE_STRING = "X"; +static bool constants_need_init = true; +static WordID kUNIQUE_START; +static WordID kPHRASE; +static TRulePtr kX1X2; +static TRulePtr kX1; +static WordID kEPS; +static TRulePtr kEPSRule; + +static void InitializeConstants() { + if (constants_need_init) { + kPHRASE = TD::Convert(kPHRASE_STRING) * -1; + kUNIQUE_START = TD::Convert("S") * -1; + kX1X2.reset(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]")); + kX1.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); + kEPSRule.reset(new TRule("[X] ||| ||| ")); + kEPS = TD::Convert(""); + constants_need_init = false; + } +} +//////////////////////////////////////////////////////////// + +class EGrammarNode { + friend bool CFG_WFSTComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest); + friend void AddGrammarRule(const string& r, map* g); + public: +#ifdef DEBUG_CHART_PARSER + string hint; +#endif + EGrammarNode() : is_some_rule_complete(false), is_root(false) {} + const map& GetTerminals() const { return tptr; } + const map& GetNonTerminals() const { return ntptr; } + bool HasNonTerminals() const { return (!ntptr.empty()); } + bool HasTerminals() const { return (!tptr.empty()); } + bool RuleCompletes() const { + return (is_some_rule_complete || (ntptr.empty() && tptr.empty())); + } + bool GrammarContinues() const { + return !(ntptr.empty() && tptr.empty()); + } + bool IsRoot() const { + return is_root; + } + // these are the features associated with the rule from the start + // node up to this point. If you use these features, you must + // not Extend() this rule. + const SparseVector& GetCFGProductionFeatures() const { + return input_features; + } + + const EGrammarNode* Extend(const WordID& t) const { + if (t < 0) { + map::const_iterator it = ntptr.find(t); + if (it == ntptr.end()) return NULL; + return &it->second; + } else { + map::const_iterator it = tptr.find(t); + if (it == tptr.end()) return NULL; + return &it->second; + } + } + + private: + map tptr; + map ntptr; + SparseVector input_features; + bool is_some_rule_complete; + bool is_root; +}; +typedef map EGrammar; // indexed by the rule LHS + +// edges are immutable once created +struct Edge { +#ifdef DEBUG_CHART_PARSER + static int id_count; + const int id; +#endif + const WordID cat; // lhs side of rule proved/being proved + const EGrammarNode* const dot; // dot position + const WFSTNode* const q; // start of span + const WFSTNode* const r; // end of span + const Edge* const active_parent; // back pointer, NULL for PREDICT items + const Edge* const passive_parent; // back pointer, NULL for SCAN and PREDICT items + TRulePtr tps; // translations + shared_ptr > features; // features from CFG rule + + bool IsPassive() const { + // when a rule is completed, this value will be set + return static_cast(features); + } + bool IsActive() const { return !IsPassive(); } + bool IsInitial() const { + return !(active_parent || passive_parent); + } + bool IsCreatedByScan() const { + return active_parent && !passive_parent && !dot->IsRoot(); + } + bool IsCreatedByPredict() const { + return dot->IsRoot(); + } + bool IsCreatedByComplete() const { + return active_parent && passive_parent; + } + + // constructor for PREDICT + Edge(WordID c, const EGrammarNode* d, const WFSTNode* q_and_r) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(q_and_r), r(q_and_r), active_parent(NULL), passive_parent(NULL), tps() {} + Edge(WordID c, const EGrammarNode* d, const WFSTNode* q_and_r, const Edge* act_parent) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(q_and_r), r(q_and_r), active_parent(act_parent), passive_parent(NULL), tps() {} + + // constructors for SCAN + Edge(WordID c, const EGrammarNode* d, const WFSTNode* i, const WFSTNode* j, + const Edge* act_par, const TRulePtr& translations) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(NULL), tps(translations) {} + + Edge(WordID c, const EGrammarNode* d, const WFSTNode* i, const WFSTNode* j, + const Edge* act_par, const TRulePtr& translations, + const SparseVector& feats) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(NULL), tps(translations), + features(new SparseVector(feats)) {} + + // constructors for COMPLETE + Edge(WordID c, const EGrammarNode* d, const WFSTNode* i, const WFSTNode* j, + const Edge* act_par, const Edge *pas_par) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(pas_par), tps() { + assert(pas_par->IsPassive()); + assert(act_par->IsActive()); + } + + Edge(WordID c, const EGrammarNode* d, const WFSTNode* i, const WFSTNode* j, + const Edge* act_par, const Edge *pas_par, const SparseVector& feats) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(pas_par), tps(), + features(new SparseVector(feats)) { + assert(pas_par->IsPassive()); + assert(act_par->IsActive()); + } + + // constructor for COMPLETE query + Edge(const WFSTNode* _r) : +#ifdef DEBUG_CHART_PARSER + id(0), +#endif + cat(0), dot(NULL), q(NULL), + r(_r), active_parent(NULL), passive_parent(NULL), tps() {} + // constructor for MERGE quere + Edge(const WFSTNode* _q, int) : +#ifdef DEBUG_CHART_PARSER + id(0), +#endif + cat(0), dot(NULL), q(_q), + r(NULL), active_parent(NULL), passive_parent(NULL), tps() {} +}; +#ifdef DEBUG_CHART_PARSER +int Edge::id_count = 0; +#endif + +ostream& operator<<(ostream& os, const Edge& e) { + string type = "PREDICT"; + if (e.IsCreatedByScan()) + type = "SCAN"; + else if (e.IsCreatedByComplete()) + type = "COMPLETE"; + os << "[" +#ifdef DEBUG_CHART_PARSER + << '(' << e.id << ") " +#else + << '(' << &e << ") " +#endif + << "q=" << e.q << ", r=" << e.r + << ", cat="<< TD::Convert(e.cat*-1) << ", dot=" + << e.dot +#ifdef DEBUG_CHART_PARSER + << e.dot->hint +#endif + << (e.IsActive() ? ", Active" : ", Passive") + << ", " << type; +#ifdef DEBUG_CHART_PARSER + if (e.active_parent) { os << ", act.parent=(" << e.active_parent->id << ')'; } + if (e.passive_parent) { os << ", psv.parent=(" << e.passive_parent->id << ')'; } +#endif + if (e.tps) { os << ", tps=" << e.tps->AsString(); } + return os << ']'; +} + +struct Traversal { + const Edge* const edge; // result from the active / passive combination + const Edge* const active; + const Edge* const passive; + Traversal(const Edge* me, const Edge* a, const Edge* p) : edge(me), active(a), passive(p) {} +}; + +struct UniqueTraversalHash { + size_t operator()(const Traversal* t) const { + size_t x = 5381; + x = ((x << 5) + x) ^ reinterpret_cast(t->active); + x = ((x << 5) + x) ^ reinterpret_cast(t->passive); + x = ((x << 5) + x) ^ t->edge->IsActive(); + return x; + } +}; + +struct UniqueTraversalEquals { + size_t operator()(const Traversal* a, const Traversal* b) const { + return (a->passive == b->passive && a->active == b->active && a->edge->IsActive() == b->edge->IsActive()); + } +}; + +struct UniqueEdgeHash { + size_t operator()(const Edge* e) const { + size_t x = 5381; + if (e->IsActive()) { + x = ((x << 5) + x) ^ reinterpret_cast(e->dot); + x = ((x << 5) + x) ^ reinterpret_cast(e->q); + x = ((x << 5) + x) ^ reinterpret_cast(e->r); + x = ((x << 5) + x) ^ static_cast(e->cat); + x += 13; + } else { // with passive edges, we don't care about the dot + x = ((x << 5) + x) ^ reinterpret_cast(e->q); + x = ((x << 5) + x) ^ reinterpret_cast(e->r); + x = ((x << 5) + x) ^ static_cast(e->cat); + } + return x; + } +}; + +struct UniqueEdgeEquals { + bool operator()(const Edge* a, const Edge* b) const { + if (a->IsActive() != b->IsActive()) return false; + if (a->IsActive()) { + return (a->cat == b->cat) && (a->dot == b->dot) && (a->q == b->q) && (a->r == b->r); + } else { + return (a->cat == b->cat) && (a->q == b->q) && (a->r == b->r); + } + } +}; + +struct REdgeHash { + size_t operator()(const Edge* e) const { + size_t x = 5381; + x = ((x << 5) + x) ^ reinterpret_cast(e->r); + return x; + } +}; + +struct REdgeEquals { + bool operator()(const Edge* a, const Edge* b) const { + return (a->r == b->r); + } +}; + +struct QEdgeHash { + size_t operator()(const Edge* e) const { + size_t x = 5381; + x = ((x << 5) + x) ^ reinterpret_cast(e->q); + return x; + } +}; + +struct QEdgeEquals { + bool operator()(const Edge* a, const Edge* b) const { + return (a->q == b->q); + } +}; + +struct EdgeQueue { + queue q; + EdgeQueue() {} + void clear() { while(!q.empty()) q.pop(); } + bool HasWork() const { return !q.empty(); } + const Edge* Next() { const Edge* res = q.front(); q.pop(); return res; } + void AddEdge(const Edge* s) { q.push(s); } +}; + +class CFG_WFSTComposerImpl { + public: + CFG_WFSTComposerImpl(WordID start_cat, + const WFSTNode* q_0, + const WFSTNode* q_final) : start_cat_(start_cat), q_0_(q_0), q_final_(q_final) {} + + // returns false if the intersection is empty + bool Compose(const EGrammar& g, Hypergraph* forest) { + goal_node = NULL; + EGrammar::const_iterator sit = g.find(start_cat_); + forest->ReserveNodes(kMAX_NODES); + assert(sit != g.end()); + Edge* init = new Edge(start_cat_, &sit->second, q_0_); + assert(IncorporateNewEdge(init)); + while (exp_agenda.HasWork() || agenda.HasWork()) { + while(exp_agenda.HasWork()) { + const Edge* edge = exp_agenda.Next(); + FinishEdge(edge, forest); + } + if (agenda.HasWork()) { + const Edge* edge = agenda.Next(); +#ifdef DEBUG_CHART_PARSER + cerr << "processing (" << edge->id << ')' << endl; +#endif + if (edge->IsActive()) { + if (edge->dot->HasTerminals()) + DoScan(edge); + if (edge->dot->HasNonTerminals()) { + DoMergeWithPassives(edge); + DoPredict(edge, g); + } + } else { + DoComplete(edge); + } + } + } + if (goal_node) { + forest->PruneUnreachable(goal_node->id_); + forest->EpsilonRemove(kEPS); + } + FreeAll(); + return goal_node; + } + + void FreeAll() { + for (int i = 0; i < free_list_.size(); ++i) + delete free_list_[i]; + free_list_.clear(); + for (int i = 0; i < traversal_free_list_.size(); ++i) + delete traversal_free_list_[i]; + traversal_free_list_.clear(); + all_traversals.clear(); + exp_agenda.clear(); + agenda.clear(); + tps2node.clear(); + edge2node.clear(); + all_edges.clear(); + passive_edges.clear(); + active_edges.clear(); + } + + ~CFG_WFSTComposerImpl() { + FreeAll(); + } + + // returns the total number of edges created during composition + int EdgesCreated() const { + return free_list_.size(); + } + + private: + void DoScan(const Edge* edge) { + // here, we assume that the FST will potentially have many more outgoing + // edges than the grammar, which will be just a couple. If you want to + // efficiently handle the case where both are relatively large, this code + // will need to change how the intersection is done. The best general + // solution would probably be the Baeza-Yates double binary search. + + const EGrammarNode* dot = edge->dot; + const WFSTNode* r = edge->r; + const map& terms = dot->GetTerminals(); + for (map::const_iterator git = terms.begin(); + git != terms.end(); ++git) { + + if (!(TD::Convert(git->first)[0] >= '0' && TD::Convert(git->first)[0] <= '9')) { + std::cerr << "TERMINAL SYMBOL: " << TD::Convert(git->first) << endl; + abort(); + } + std::vector > extensions = r->ExtendInput(atoi(TD::Convert(git->first))); + for (unsigned nsi = 0; nsi < extensions.size(); ++nsi) { + const WFSTNode* next_r = extensions[nsi].first; + const EGrammarNode* next_dot = &git->second; + const bool grammar_continues = next_dot->GrammarContinues(); + const bool rule_completes = next_dot->RuleCompletes(); + if (extensions[nsi].second) + cerr << "!!! " << extensions[nsi].second->AsString() << endl; + // cerr << " rule completes: " << rule_completes << " after consuming " << TD::Convert(git->first) << endl; + assert(grammar_continues || rule_completes); + const SparseVector& input_features = next_dot->GetCFGProductionFeatures(); + if (rule_completes) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, next_r, edge, extensions[nsi].second, input_features)); + if (grammar_continues) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, next_r, edge, extensions[nsi].second)); + } + } + } + + void DoPredict(const Edge* edge, const EGrammar& g) { + const EGrammarNode* dot = edge->dot; + const map& non_terms = dot->GetNonTerminals(); + for (map::const_iterator git = non_terms.begin(); + git != non_terms.end(); ++git) { + const WordID nt_to_predict = git->first; + //cerr << edge->id << " -- " << TD::Convert(nt_to_predict*-1) << endl; + EGrammar::const_iterator egi = g.find(nt_to_predict); + if (egi == g.end()) { + cerr << "[ERROR] Can't find any grammar rules with a LHS of type " + << TD::Convert(-1*nt_to_predict) << '!' << endl; + continue; + } + assert(edge->IsActive()); + const EGrammarNode* new_dot = &egi->second; + Edge* new_edge = new Edge(nt_to_predict, new_dot, edge->r, edge); + IncorporateNewEdge(new_edge); + } + } + + void DoComplete(const Edge* passive) { +#ifdef DEBUG_CHART_PARSER + cerr << " complete: " << *passive << endl; +#endif + const WordID completed_nt = passive->cat; + const WFSTNode* q = passive->q; + const WFSTNode* next_r = passive->r; + const Edge query(q); + const pair::iterator, + unordered_multiset::iterator > p = + active_edges.equal_range(&query); + for (unordered_multiset::iterator it = p.first; + it != p.second; ++it) { + const Edge* active = *it; +#ifdef DEBUG_CHART_PARSER + cerr << " pos: " << *active << endl; +#endif + const EGrammarNode* next_dot = active->dot->Extend(completed_nt); + if (!next_dot) continue; + const SparseVector& input_features = next_dot->GetCFGProductionFeatures(); + // add up to 2 rules + if (next_dot->RuleCompletes()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive, input_features)); + if (next_dot->GrammarContinues()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive)); + } + } + + void DoMergeWithPassives(const Edge* active) { + // edge is active, has non-terminals, we need to find the passives that can extend it + assert(active->IsActive()); + assert(active->dot->HasNonTerminals()); +#ifdef DEBUG_CHART_PARSER + cerr << " merge active with passives: ACT=" << *active << endl; +#endif + const Edge query(active->r, 1); + const pair::iterator, + unordered_multiset::iterator > p = + passive_edges.equal_range(&query); + for (unordered_multiset::iterator it = p.first; + it != p.second; ++it) { + const Edge* passive = *it; + const EGrammarNode* next_dot = active->dot->Extend(passive->cat); + if (!next_dot) continue; + const WFSTNode* next_r = passive->r; + const SparseVector& input_features = next_dot->GetCFGProductionFeatures(); + if (next_dot->RuleCompletes()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive, input_features)); + if (next_dot->GrammarContinues()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive)); + } + } + + // take ownership of edge memory, add to various indexes, etc + // returns true if this edge is new + bool IncorporateNewEdge(Edge* edge) { + free_list_.push_back(edge); + if (edge->passive_parent && edge->active_parent) { + Traversal* t = new Traversal(edge, edge->active_parent, edge->passive_parent); + traversal_free_list_.push_back(t); + if (all_traversals.find(t) != all_traversals.end()) { + return false; + } else { + all_traversals.insert(t); + } + } + exp_agenda.AddEdge(edge); + return true; + } + + bool FinishEdge(const Edge* edge, Hypergraph* hg) { + bool is_new = false; + if (all_edges.find(edge) == all_edges.end()) { +#ifdef DEBUG_CHART_PARSER + cerr << *edge << " is NEW\n"; +#endif + all_edges.insert(edge); + is_new = true; + if (edge->IsPassive()) passive_edges.insert(edge); + if (edge->IsActive()) active_edges.insert(edge); + agenda.AddEdge(edge); + } else { +#ifdef DEBUG_CHART_PARSER + cerr << *edge << " is NOT NEW.\n"; +#endif + } + AddEdgeToTranslationForest(edge, hg); + return is_new; + } + + // build the translation forest + void AddEdgeToTranslationForest(const Edge* edge, Hypergraph* hg) { + assert(hg->nodes_.size() < kMAX_NODES); + Hypergraph::Node* tps = NULL; + // first add any target language rules + if (edge->tps) { + Hypergraph::Node*& node = tps2node[(size_t)edge->tps.get()]; + if (!node) { + // cerr << "Creating phrases for " << edge->tps << endl; + const TRulePtr& rule = edge->tps; + node = hg->AddNode(kPHRASE); + Hypergraph::Edge* hg_edge = hg->AddEdge(rule, Hypergraph::TailNodeVector()); + hg_edge->feature_values_ += rule->GetFeatureValues(); + hg->ConnectEdgeToHeadNode(hg_edge, node); + } + tps = node; + } + Hypergraph::Node*& head_node = edge2node[edge]; + if (!head_node) + head_node = hg->AddNode(kPHRASE); + if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_final_ && edge->IsPassive()) { + assert(goal_node == NULL || goal_node == head_node); + goal_node = head_node; + } + Hypergraph::TailNodeVector tail; + SparseVector extra; + if (edge->IsCreatedByPredict()) { + // extra.set_value(FD::Convert("predict"), 1); + } else if (edge->IsCreatedByScan()) { + tail.push_back(edge2node[edge->active_parent]->id_); + if (tps) { + tail.push_back(tps->id_); + } + //extra.set_value(FD::Convert("scan"), 1); + } else if (edge->IsCreatedByComplete()) { + tail.push_back(edge2node[edge->active_parent]->id_); + tail.push_back(edge2node[edge->passive_parent]->id_); + //extra.set_value(FD::Convert("complete"), 1); + } else { + assert(!"unexpected edge type!"); + } + //cerr << head_node->id_ << "<--" << *edge << endl; + +#ifdef DEBUG_CHART_PARSER + for (int i = 0; i < tail.size(); ++i) + if (tail[i] == head_node->id_) { + cerr << "ERROR: " << *edge << "\n i=" << i << endl; + if (i == 1) { cerr << "\tP: " << *edge->passive_parent << endl; } + if (i == 0) { cerr << "\tA: " << *edge->active_parent << endl; } + assert(!"self-loop found!"); + } +#endif + Hypergraph::Edge* hg_edge = NULL; + if (tail.size() == 0) { + hg_edge = hg->AddEdge(kEPSRule, tail); + } else if (tail.size() == 1) { + hg_edge = hg->AddEdge(kX1, tail); + } else if (tail.size() == 2) { + hg_edge = hg->AddEdge(kX1X2, tail); + } + if (edge->features) + hg_edge->feature_values_ += *edge->features; + hg_edge->feature_values_ += extra; + hg->ConnectEdgeToHeadNode(hg_edge, head_node); + } + + Hypergraph::Node* goal_node; + EdgeQueue exp_agenda; + EdgeQueue agenda; + unordered_map tps2node; + unordered_map edge2node; + unordered_set all_traversals; + unordered_set all_edges; + unordered_multiset passive_edges; + unordered_multiset active_edges; + vector free_list_; + vector traversal_free_list_; + const WordID start_cat_; + const WFSTNode* const q_0_; + const WFSTNode* const q_final_; +}; + +#ifdef DEBUG_CHART_PARSER +static string TrimRule(const string& r) { + size_t start = r.find(" |||") + 5; + size_t end = r.rfind(" |||"); + return r.substr(start, end - start); +} +#endif + +void AddGrammarRule(const string& r, EGrammar* g) { + const size_t pos = r.find(" ||| "); + if (pos == string::npos || r[0] != '[') { + cerr << "Bad rule: " << r << endl; + return; + } + const size_t rpos = r.rfind(" ||| "); + string feats; + string rs = r; + if (rpos != pos) { + feats = r.substr(rpos + 5); + rs = r.substr(0, rpos); + } + string rhs = rs.substr(pos + 5); + string trule = rs + " ||| " + rhs + " ||| " + feats; + TRule tr(trule); + cerr << "X: " << tr.e_[0] << endl; +#ifdef DEBUG_CHART_PARSER + string hint_last_rule; +#endif + EGrammarNode* cur = &(*g)[tr.GetLHS()]; + cur->is_root = true; + for (int i = 0; i < tr.FLength(); ++i) { + WordID sym = tr.f()[i]; +#ifdef DEBUG_CHART_PARSER + hint_last_rule = TD::Convert(sym < 0 ? -sym : sym); + cur->hint += " <@@> (*" + hint_last_rule + ") " + TrimRule(tr.AsString()); +#endif + if (sym < 0) + cur = &cur->ntptr[sym]; + else + cur = &cur->tptr[sym]; + } +#ifdef DEBUG_CHART_PARSER + cur->hint += " <@@> (" + hint_last_rule + "*) " + TrimRule(tr.AsString()); +#endif + cur->is_some_rule_complete = true; + cur->input_features = tr.GetFeatureValues(); +} + +CFG_WFSTComposer::~CFG_WFSTComposer() { + delete pimpl_; +} + +CFG_WFSTComposer::CFG_WFSTComposer(const WFST& wfst) { + InitializeConstants(); + pimpl_ = new CFG_WFSTComposerImpl(kUNIQUE_START, wfst.Initial(), wfst.Final()); +} + +bool CFG_WFSTComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest) { + // first, convert the src forest into an EGrammar + EGrammar g; + const int nedges = src_forest.edges_.size(); + const int nnodes = src_forest.nodes_.size(); + vector cats(nnodes); + bool assign_cats = false; + for (int i = 0; i < nnodes; ++i) + if (assign_cats) { + cats[i] = TD::Convert("CAT_" + boost::lexical_cast(i)) * -1; + } else { + cats[i] = src_forest.nodes_[i].cat_; + } + // construct the grammar + for (int i = 0; i < nedges; ++i) { + const Hypergraph::Edge& edge = src_forest.edges_[i]; + const vector& src = edge.rule_->f(); + EGrammarNode* cur = &g[cats[edge.head_node_]]; + cur->is_root = true; + int ntc = 0; + for (int j = 0; j < src.size(); ++j) { + WordID sym = src[j]; + if (sym <= 0) { + sym = cats[edge.tail_nodes_[ntc]]; + ++ntc; + cur = &cur->ntptr[sym]; + } else { + cur = &cur->tptr[sym]; + } + } + cur->is_some_rule_complete = true; + cur->input_features = edge.feature_values_; + } + EGrammarNode& goal_rule = g[kUNIQUE_START]; + assert((goal_rule.ntptr.size() == 1 && goal_rule.tptr.size() == 0) || + (goal_rule.ntptr.size() == 0 && goal_rule.tptr.size() == 1)); + + return pimpl_->Compose(g, trg_forest); +} + +bool CFG_WFSTComposer::Compose(istream* in, Hypergraph* trg_forest) { + EGrammar g; + while(*in) { + string line; + getline(*in, line); + if (line.empty()) continue; + AddGrammarRule(line, &g); + } + + return pimpl_->Compose(g, trg_forest); +} diff --git a/gi/pf/cfg_wfst_composer.h b/gi/pf/cfg_wfst_composer.h new file mode 100644 index 00000000..cf47f459 --- /dev/null +++ b/gi/pf/cfg_wfst_composer.h @@ -0,0 +1,46 @@ +#ifndef _CFG_WFST_COMPOSER_H_ +#define _CFG_WFST_COMPOSER_H_ + +#include +#include +#include + +#include "trule.h" +#include "wordid.h" + +class CFG_WFSTComposerImpl; +class Hypergraph; + +struct WFSTNode { + virtual ~WFSTNode(); + // returns the next states reachable by consuming srcindex (which identifies a word) + // paired with the output string generated by taking that transition. + virtual std::vector > ExtendInput(unsigned srcindex) const = 0; +}; + +struct WFST { + virtual ~WFST(); + virtual const WFSTNode* Final() const = 0; + virtual const WFSTNode* Initial() const = 0; +}; + +class CFG_WFSTComposer { + public: + ~CFG_WFSTComposer(); + explicit CFG_WFSTComposer(const WFST& wfst); + bool Compose(const Hypergraph& in_forest, Hypergraph* trg_forest); + + // reads the grammar from a file. There must be a single top-level + // S -> X rule. Anything else is possible. Format is: + // [S] ||| [SS,1] + // [SS] ||| [NP,1] [VP,2] ||| Feature1=0.2 Feature2=-2.3 + // [SS] ||| [VP,1] [NP,2] ||| Feature1=0.8 + // [NP] ||| [DET,1] [N,2] ||| Feature3=2 + // ... + bool Compose(std::istream* grammar_file, Hypergraph* trg_forest); + + private: + CFG_WFSTComposerImpl* pimpl_; +}; + +#endif diff --git a/gi/pf/dpnaive.cc b/gi/pf/dpnaive.cc new file mode 100644 index 00000000..582d1be7 --- /dev/null +++ b/gi/pf/dpnaive.cc @@ -0,0 +1,349 @@ +#include +#include +#include + +#include +#include +#include + +#include "base_measures.h" +#include "trule.h" +#include "tdict.h" +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "ccrp_nt.h" + +using namespace std; +using namespace std::tr1; +namespace po = boost::program_options; + +static unsigned kMAX_SRC_PHRASE; +static unsigned kMAX_TRG_PHRASE; +struct FSTState; + +size_t hash_value(const TRule& r) { + size_t h = 2 - r.lhs_; + boost::hash_combine(h, boost::hash_value(r.e_)); + boost::hash_combine(h, boost::hash_value(r.f_)); + return h; +} + +bool operator==(const TRule& a, const TRule& b) { + return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); +} + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value()->default_value(1000),"Number of samples") + ("input,i",po::value(),"Read parallel data from") + ("max_src_phrase",po::value()->default_value(4),"Maximum length of source language phrases") + ("max_trg_phrase",po::value()->default_value(4),"Maximum length of target language phrases") + ("model1,m",po::value(),"Model 1 parameters (used in base distribution)") + ("model1_interpolation_weight",po::value()->default_value(0.95),"Mixing proportion of model 1 with uniform target distribution") + ("random_seed,S",po::value(), "Random seed"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void ReadParallelCorpus(const string& filename, + vector >* f, + vector >* e, + set* vocab_e, + set* vocab_f) { + f->clear(); + e->clear(); + vocab_f->clear(); + vocab_e->clear(); + istream* in; + if (filename == "-") + in = &cin; + else + in = new ifstream(filename.c_str()); + assert(*in); + string line; + const WordID kDIV = TD::Convert("|||"); + vector tmp; + while(*in) { + getline(*in, line); + if (line.empty() && !*in) break; + e->push_back(vector()); + f->push_back(vector()); + vector& le = e->back(); + vector& lf = f->back(); + tmp.clear(); + TD::ConvertSentence(line, &tmp); + bool isf = true; + for (unsigned i = 0; i < tmp.size(); ++i) { + const int cur = tmp[i]; + if (isf) { + if (kDIV == cur) { isf = false; } else { + lf.push_back(cur); + vocab_f->insert(cur); + } + } else { + assert(cur != kDIV); + le.push_back(cur); + vocab_e->insert(cur); + } + } + assert(isf == false); + } + if (in != &cin) delete in; +} + +shared_ptr prng; + +template +struct ModelAndData { + explicit ModelAndData(const Base& b, const vector >& ce, const vector >& cf, const set& ve, const set& vf) : + rng(&*prng), + p0(b), + baseprob(prob_t::One()), + corpuse(ce), + corpusf(cf), + vocabe(ve), + vocabf(vf), + rules(1,1), + mh_samples(), + mh_rejects(), + kX(-TD::Convert("X")), + derivations(corpuse.size()) {} + + void ResampleHyperparameters() { + rules.resample_hyperparameters(&*prng); + } + + void InstantiateRule(const pair& from, + const pair& to, + const vector& sentf, + const vector& sente, + TRule* rule) const { + rule->f_.clear(); + rule->e_.clear(); + rule->lhs_ = kX; + for (short i = from.first; i < to.first; ++i) + rule->f_.push_back(sentf[i]); + for (short i = from.second; i < to.second; ++i) + rule->e_.push_back(sente[i]); + } + + void DecrementDerivation(const vector >& d, const vector& sentf, const vector& sente) { + if (d.size() < 2) return; + TRule x; + for (int i = 1; i < d.size(); ++i) { + InstantiateRule(d[i], d[i-1], sentf, sente, &x); + //cerr << "REMOVE: " << x.AsString() << endl; + if (rules.decrement(x)) { + baseprob /= p0(x); + //cerr << " (REMOVED ONLY INSTANCE)\n"; + } + } + } + + void PrintDerivation(const vector >& d, const vector& sentf, const vector& sente) { + if (d.size() < 2) return; + TRule x; + for (int i = 1; i < d.size(); ++i) { + InstantiateRule(d[i], d[i-1], sentf, sente, &x); + cerr << i << '/' << (d.size() - 1) << ": " << x << endl; + } + } + + void IncrementDerivation(const vector >& d, const vector& sentf, const vector& sente) { + if (d.size() < 2) return; + TRule x; + for (int i = 1; i < d.size(); ++i) { + InstantiateRule(d[i], d[i-1], sentf, sente, &x); + if (rules.increment(x)) { + baseprob *= p0(x); + } + } + } + + prob_t Likelihood() const { + prob_t p; + p.logeq(rules.log_crp_prob()); + return p * baseprob; + } + + prob_t DerivationProposalProbability(const vector >& d, const vector& sentf, const vector& sente) const { + prob_t p = prob_t::One(); + if (d.size() < 2) return p; + TRule x; + for (int i = 1; i < d.size(); ++i) { + InstantiateRule(d[i], d[i-1], sentf, sente, &x); + prob_t rp; rp.logeq(rules.logprob(x, log(p0(x)))); + p *= rp; + } + return p; + } + + void Sample(); + + MT19937* rng; + const Base& p0; + prob_t baseprob; // cached value of generating the table table labels from p0 + // this can't be used if we go to a hierarchical prior! + const vector >& corpuse, corpusf; + const set& vocabe, vocabf; + CCRP_NoTable rules; + unsigned mh_samples, mh_rejects; + const int kX; + vector > > derivations; +}; + +template +void ModelAndData::Sample() { + unsigned MAXK = 4; + unsigned MAXL = 4; + TRule x; + x.lhs_ = -TD::Convert("X"); + for (int samples = 0; samples < 1000; ++samples) { + if (samples % 1 == 0 && samples > 0) { + //ResampleHyperparameters(); + cerr << " [" << samples << " LLH=" << log(Likelihood()) << " MH=" << ((double)mh_rejects / mh_samples) << "]\n"; + for (int i = 0; i < 10; ++i) { + cerr << "SENTENCE: " << TD::GetString(corpusf[i]) << " ||| " << TD::GetString(corpuse[i]) << endl; + PrintDerivation(derivations[i], corpusf[i], corpuse[i]); + } + } + cerr << '.' << flush; + for (int s = 0; s < corpuse.size(); ++s) { + const vector& sentf = corpusf[s]; + const vector& sente = corpuse[s]; +// cerr << " CUSTOMERS: " << rules.num_customers() << endl; +// cerr << "SENTENCE: " << TD::GetString(sentf) << " ||| " << TD::GetString(sente) << endl; + + vector >& deriv = derivations[s]; + const prob_t p_cur = Likelihood(); + DecrementDerivation(deriv, sentf, sente); + + boost::multi_array a(boost::extents[sentf.size() + 1][sente.size() + 1]); + boost::multi_array trans(boost::extents[sentf.size() + 1][sente.size() + 1][MAXK][MAXL]); + a[0][0] = prob_t::One(); + for (int i = 0; i < sentf.size(); ++i) { + for (int j = 0; j < sente.size(); ++j) { + const prob_t src_a = a[i][j]; + x.f_.clear(); + for (int k = 1; k <= MAXK; ++k) { + if (i + k > sentf.size()) break; + x.f_.push_back(sentf[i + k - 1]); + x.e_.clear(); + for (int l = 1; l <= MAXL; ++l) { + if (j + l > sente.size()) break; + x.e_.push_back(sente[j + l - 1]); + trans[i][j][k - 1][l - 1].logeq(rules.logprob(x, log(p0(x)))); + a[i + k][j + l] += src_a * trans[i][j][k - 1][l - 1]; + } + } + } + } +// cerr << "Inside: " << log(a[sentf.size()][sente.size()]) << endl; + const prob_t q_cur = DerivationProposalProbability(deriv, sentf, sente); + + vector > newderiv; + int cur_i = sentf.size(); + int cur_j = sente.size(); + while(cur_i > 0 && cur_j > 0) { + newderiv.push_back(pair(cur_i, cur_j)); +// cerr << "NODE: (" << cur_i << "," << cur_j << ")\n"; + SampleSet ss; + vector > nexts; + for (int k = 1; k <= MAXK; ++k) { + const int hyp_i = cur_i - k; + if (hyp_i < 0) break; + for (int l = 1; l <= MAXL; ++l) { + const int hyp_j = cur_j - l; + if (hyp_j < 0) break; + const prob_t& inside = a[hyp_i][hyp_j]; + if (inside == prob_t::Zero()) continue; + const prob_t& transp = trans[hyp_i][hyp_j][k - 1][l - 1]; + if (transp == prob_t::Zero()) continue; + const prob_t p = inside * transp; + ss.add(p); + nexts.push_back(pair(hyp_i, hyp_j)); +// cerr << " (" << hyp_i << "," << hyp_j << ") <--- " << log(p) << endl; + } + } +// cerr << " sample set has " << nexts.size() << " elements.\n"; + const int selected = rng->SelectSample(ss); + cur_i = nexts[selected].first; + cur_j = nexts[selected].second; + } + newderiv.push_back(pair(0,0)); + const prob_t q_new = DerivationProposalProbability(newderiv, sentf, sente); + IncrementDerivation(newderiv, sentf, sente); +// cerr << "SANITY: " << q_new << " " <(); + kMAX_SRC_PHRASE = conf["max_src_phrase"].as(); + + if (!conf.count("model1")) { + cerr << argv[0] << "Please use --model1 to specify model 1 parameters\n"; + return 1; + } + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as())); + else + prng.reset(new MT19937); +// MT19937& rng = *prng; + + vector > corpuse, corpusf; + set vocabe, vocabf; + ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); + cerr << "f-Corpus size: " << corpusf.size() << " sentences\n"; + cerr << "f-Vocabulary size: " << vocabf.size() << " types\n"; + cerr << "f-Corpus size: " << corpuse.size() << " sentences\n"; + cerr << "f-Vocabulary size: " << vocabe.size() << " types\n"; + assert(corpusf.size() == corpuse.size()); + + Model1 m1(conf["model1"].as()); + PhraseJointBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size(), vocabf.size()); + + ModelAndData posterior(lp0, corpuse, corpusf, vocabe, vocabf); + posterior.Sample(); + + return 0; +} + diff --git a/gi/pf/itg.cc b/gi/pf/itg.cc new file mode 100644 index 00000000..2c2a86f9 --- /dev/null +++ b/gi/pf/itg.cc @@ -0,0 +1,224 @@ +#include +#include +#include + +#include +#include +#include + +#include "viterbi.h" +#include "hg.h" +#include "trule.h" +#include "tdict.h" +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "ccrp_nt.h" +#include "ccrp_onetable.h" + +using namespace std; +using namespace tr1; +namespace po = boost::program_options; + +ostream& operator<<(ostream& os, const vector& p) { + os << '['; + for (int i = 0; i < p.size(); ++i) + os << (i==0 ? "" : " ") << TD::Convert(p[i]); + return os << ']'; +} + +size_t hash_value(const TRule& r) { + size_t h = boost::hash_value(r.e_); + boost::hash_combine(h, -r.lhs_); + boost::hash_combine(h, boost::hash_value(r.f_)); + return h; +} + +bool operator==(const TRule& a, const TRule& b) { + return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); +} + +double log_poisson(unsigned x, const double& lambda) { + assert(lambda > 0.0); + return log(lambda) * x - lgamma(x + 1) - lambda; +} + +struct Model1 { + explicit Model1(const string& fname) : + kNULL(TD::Convert("")), + kZERO() { + LoadModel1(fname); + } + + void LoadModel1(const string& fname) { + cerr << "Loading Model 1 parameters from " << fname << " ..." << endl; + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + unsigned lc = 0; + while(getline(in, line)) { + ++lc; + int cur = 0; + int start = 0; + while(cur < line.size() && line[cur] != ' ') { ++cur; } + assert(cur != line.size()); + line[cur] = 0; + const WordID src = TD::Convert(&line[0]); + ++cur; + start = cur; + while(cur < line.size() && line[cur] != ' ') { ++cur; } + assert(cur != line.size()); + line[cur] = 0; + WordID trg = TD::Convert(&line[start]); + const double logprob = strtod(&line[cur + 1], NULL); + if (src >= ttable.size()) ttable.resize(src + 1); + ttable[src][trg].logeq(logprob); + } + cerr << " read " << lc << " parameters.\n"; + } + + // returns prob 0 if src or trg is not found! + const prob_t& operator()(WordID src, WordID trg) const { + if (src == 0) src = kNULL; + if (src < ttable.size()) { + const map& cpd = ttable[src]; + const map::const_iterator it = cpd.find(trg); + if (it != cpd.end()) + return it->second; + } + return kZERO; + } + + const WordID kNULL; + const prob_t kZERO; + vector > ttable; +}; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value()->default_value(1000),"Number of samples") + ("particles,p",po::value()->default_value(25),"Number of particles") + ("input,i",po::value(),"Read parallel data from") + ("max_src_phrase",po::value()->default_value(7),"Maximum length of source language phrases") + ("max_trg_phrase",po::value()->default_value(7),"Maximum length of target language phrases") + ("model1,m",po::value(),"Model 1 parameters (used in base distribution)") + ("inverse_model1,M",po::value(),"Inverse Model 1 parameters (used in backward estimate)") + ("model1_interpolation_weight",po::value()->default_value(0.95),"Mixing proportion of model 1 with uniform target distribution") + ("random_seed,S",po::value(), "Random seed"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void ReadParallelCorpus(const string& filename, + vector >* f, + vector >* e, + set* vocab_f, + set* vocab_e) { + f->clear(); + e->clear(); + vocab_f->clear(); + vocab_e->clear(); + istream* in; + if (filename == "-") + in = &cin; + else + in = new ifstream(filename.c_str()); + assert(*in); + string line; + const WordID kDIV = TD::Convert("|||"); + vector tmp; + while(*in) { + getline(*in, line); + if (line.empty() && !*in) break; + e->push_back(vector()); + f->push_back(vector()); + vector& le = e->back(); + vector& lf = f->back(); + tmp.clear(); + TD::ConvertSentence(line, &tmp); + bool isf = true; + for (unsigned i = 0; i < tmp.size(); ++i) { + const int cur = tmp[i]; + if (isf) { + if (kDIV == cur) { isf = false; } else { + lf.push_back(cur); + vocab_f->insert(cur); + } + } else { + assert(cur != kDIV); + le.push_back(cur); + vocab_e->insert(cur); + } + } + assert(isf == false); + } + if (in != &cin) delete in; +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const size_t kMAX_TRG_PHRASE = conf["max_trg_phrase"].as(); + const size_t kMAX_SRC_PHRASE = conf["max_src_phrase"].as(); + const unsigned particles = conf["particles"].as(); + const unsigned samples = conf["samples"].as(); + + if (!conf.count("model1")) { + cerr << argv[0] << "Please use --model1 to specify model 1 parameters\n"; + return 1; + } + shared_ptr prng; + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as())); + else + prng.reset(new MT19937); + MT19937& rng = *prng; + + vector > corpuse, corpusf; + set vocabe, vocabf; + cerr << "Reading corpus...\n"; + ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); + cerr << "F-corpus size: " << corpusf.size() << " sentences\t (" << vocabf.size() << " word types)\n"; + cerr << "E-corpus size: " << corpuse.size() << " sentences\t (" << vocabe.size() << " word types)\n"; + assert(corpusf.size() == corpuse.size()); + + const int kLHS = -TD::Convert("X"); + Model1 m1(conf["model1"].as()); + Model1 invm1(conf["inverse_model1"].as()); + for (int si = 0; si < conf["samples"].as(); ++si) { + cerr << '.' << flush; + for (int ci = 0; ci < corpusf.size(); ++ci) { + const vector& src = corpusf[ci]; + const vector& trg = corpuse[ci]; + for (int i = 0; i < src.size(); ++i) { + for (int j = 0; j < trg.size(); ++j) { + const int eff_max_src = min(src.size() - i, kMAX_SRC_PHRASE); + for (int k = 0; k < eff_max_src; ++k) { + const int eff_max_trg = (k == 0 ? 1 : min(trg.size() - j, kMAX_TRG_PHRASE)); + for (int l = 0; l < eff_max_trg; ++l) { + } + } + } + } + } + } +} + diff --git a/gi/pf/pfbrat.cc b/gi/pf/pfbrat.cc new file mode 100644 index 00000000..4c6ba3ef --- /dev/null +++ b/gi/pf/pfbrat.cc @@ -0,0 +1,554 @@ +#include +#include +#include + +#include +#include +#include +#include + +#include "viterbi.h" +#include "hg.h" +#include "trule.h" +#include "tdict.h" +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "ccrp_nt.h" +#include "cfg_wfst_composer.h" + +using namespace std; +using namespace tr1; +namespace po = boost::program_options; + +static unsigned kMAX_SRC_PHRASE; +static unsigned kMAX_TRG_PHRASE; +struct FSTState; + +size_t hash_value(const TRule& r) { + size_t h = 2 - r.lhs_; + boost::hash_combine(h, boost::hash_value(r.e_)); + boost::hash_combine(h, boost::hash_value(r.f_)); + return h; +} + +bool operator==(const TRule& a, const TRule& b) { + return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); +} + +double log_poisson(unsigned x, const double& lambda) { + assert(lambda > 0.0); + return log(lambda) * x - lgamma(x + 1) - lambda; +} + +struct ConditionalBase { + explicit ConditionalBase(const double m1mixture, const unsigned vocab_e_size, const string& model1fname) : + kM1MIXTURE(m1mixture), + kUNIFORM_MIXTURE(1.0 - m1mixture), + kUNIFORM_TARGET(1.0 / vocab_e_size), + kNULL(TD::Convert("")) { + assert(m1mixture >= 0.0 && m1mixture <= 1.0); + assert(vocab_e_size > 0); + LoadModel1(model1fname); + } + + void LoadModel1(const string& fname) { + cerr << "Loading Model 1 parameters from " << fname << " ..." << endl; + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + unsigned lc = 0; + while(getline(in, line)) { + ++lc; + int cur = 0; + int start = 0; + while(cur < line.size() && line[cur] != ' ') { ++cur; } + assert(cur != line.size()); + line[cur] = 0; + const WordID src = TD::Convert(&line[0]); + ++cur; + start = cur; + while(cur < line.size() && line[cur] != ' ') { ++cur; } + assert(cur != line.size()); + line[cur] = 0; + WordID trg = TD::Convert(&line[start]); + const double logprob = strtod(&line[cur + 1], NULL); + if (src >= ttable.size()) ttable.resize(src + 1); + ttable[src][trg].logeq(logprob); + } + cerr << " read " << lc << " parameters.\n"; + } + + // return logp0 of rule.e_ | rule.f_ + prob_t operator()(const TRule& rule) const { + const int flen = rule.f_.size(); + const int elen = rule.e_.size(); + prob_t uniform_src_alignment; uniform_src_alignment.logeq(-log(flen + 1)); + prob_t p; + p.logeq(log_poisson(elen, flen + 0.01)); // elen | flen ~Pois(flen + 0.01) + for (int i = 0; i < elen; ++i) { // for each position i in e-RHS + const WordID trg = rule.e_[i]; + prob_t tp = prob_t::Zero(); + for (int j = -1; j < flen; ++j) { + const WordID src = j < 0 ? kNULL : rule.f_[j]; + const map::const_iterator it = ttable[src].find(trg); + if (it != ttable[src].end()) { + tp += kM1MIXTURE * it->second; + } + tp += kUNIFORM_MIXTURE * kUNIFORM_TARGET; + } + tp *= uniform_src_alignment; // draw a_i ~uniform + p *= tp; // draw e_i ~Model1(f_a_i) / uniform + } + return p; + } + + const prob_t kM1MIXTURE; // Model 1 mixture component + const prob_t kUNIFORM_MIXTURE; // uniform mixture component + const prob_t kUNIFORM_TARGET; + const WordID kNULL; + vector > ttable; +}; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value()->default_value(1000),"Number of samples") + ("input,i",po::value(),"Read parallel data from") + ("max_src_phrase",po::value()->default_value(3),"Maximum length of source language phrases") + ("max_trg_phrase",po::value()->default_value(3),"Maximum length of target language phrases") + ("model1,m",po::value(),"Model 1 parameters (used in base distribution)") + ("model1_interpolation_weight",po::value()->default_value(0.95),"Mixing proportion of model 1 with uniform target distribution") + ("random_seed,S",po::value(), "Random seed"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void ReadParallelCorpus(const string& filename, + vector >* f, + vector >* e, + set* vocab_f, + set* vocab_e) { + f->clear(); + e->clear(); + vocab_f->clear(); + vocab_e->clear(); + istream* in; + if (filename == "-") + in = &cin; + else + in = new ifstream(filename.c_str()); + assert(*in); + string line; + const WordID kDIV = TD::Convert("|||"); + vector tmp; + while(*in) { + getline(*in, line); + if (line.empty() && !*in) break; + e->push_back(vector()); + f->push_back(vector()); + vector& le = e->back(); + vector& lf = f->back(); + tmp.clear(); + TD::ConvertSentence(line, &tmp); + bool isf = true; + for (unsigned i = 0; i < tmp.size(); ++i) { + const int cur = tmp[i]; + if (isf) { + if (kDIV == cur) { isf = false; } else { + lf.push_back(cur); + vocab_f->insert(cur); + } + } else { + assert(cur != kDIV); + le.push_back(cur); + vocab_e->insert(cur); + } + } + assert(isf == false); + } + if (in != &cin) delete in; +} + +struct UniphraseLM { + UniphraseLM(const vector >& corpus, + const set& vocab, + const po::variables_map& conf) : + phrases_(1,1), + gen_(1,1), + corpus_(corpus), + uniform_word_(1.0 / vocab.size()), + gen_p0_(0.5), + p_end_(0.5), + use_poisson_(conf.count("poisson_length") > 0) {} + + void ResampleHyperparameters(MT19937* rng) { + phrases_.resample_hyperparameters(rng); + gen_.resample_hyperparameters(rng); + cerr << " " << phrases_.concentration(); + } + + CCRP_NoTable > phrases_; + CCRP_NoTable gen_; + vector > z_; // z_[i] is there a phrase boundary after the ith word + const vector >& corpus_; + const double uniform_word_; + const double gen_p0_; + const double p_end_; // in base length distribution, p of the end of a phrase + const bool use_poisson_; +}; + +struct Reachability { + boost::multi_array edges; // edges[src_covered][trg_covered][x][trg_delta] is this edge worth exploring? + boost::multi_array max_src_delta; // msd[src_covered][trg_covered] -- the largest src delta that's valid + + Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) : + edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]), + max_src_delta(boost::extents[srclen][trglen]) { + ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len); + } + + private: + struct SState { + SState() : prev_src_covered(), prev_trg_covered() {} + SState(int i, int j) : prev_src_covered(i), prev_trg_covered(j) {} + int prev_src_covered; + int prev_trg_covered; + }; + + struct NState { + NState() : next_src_covered(), next_trg_covered() {} + NState(int i, int j) : next_src_covered(i), next_trg_covered(j) {} + int next_src_covered; + int next_trg_covered; + }; + + void ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) { + typedef boost::multi_array, 2> array_type; + array_type a(boost::extents[srclen + 1][trglen + 1]); + a[0][0].push_back(SState()); + for (int i = 0; i < srclen; ++i) { + for (int j = 0; j < trglen; ++j) { + if (a[i][j].size() == 0) continue; + const SState prev(i,j); + for (int k = 1; k <= src_max_phrase_len; ++k) { + if ((i + k) > srclen) continue; + for (int l = 1; l <= trg_max_phrase_len; ++l) { + if ((j + l) > trglen) continue; + a[i + k][j + l].push_back(prev); + } + } + } + } + a[0][0].clear(); + cerr << "Final cell contains " << a[srclen][trglen].size() << " back pointers\n"; + assert(a[srclen][trglen].size() > 0); + + typedef boost::multi_array rarray_type; + rarray_type r(boost::extents[srclen + 1][trglen + 1]); +// typedef boost::multi_array, 2> narray_type; +// narray_type b(boost::extents[srclen + 1][trglen + 1]); + r[srclen][trglen] = true; + for (int i = srclen; i >= 0; --i) { + for (int j = trglen; j >= 0; --j) { + vector& prevs = a[i][j]; + if (!r[i][j]) { prevs.clear(); } +// const NState nstate(i,j); + for (int k = 0; k < prevs.size(); ++k) { + r[prevs[k].prev_src_covered][prevs[k].prev_trg_covered] = true; + int src_delta = i - prevs[k].prev_src_covered; + edges[prevs[k].prev_src_covered][prevs[k].prev_trg_covered][src_delta][j - prevs[k].prev_trg_covered] = true; + short &msd = max_src_delta[prevs[k].prev_src_covered][prevs[k].prev_trg_covered]; + if (src_delta > msd) msd = src_delta; +// b[prevs[k].prev_src_covered][prevs[k].prev_trg_covered].push_back(nstate); + } + } + } + assert(!edges[0][0][1][0]); + assert(!edges[0][0][0][1]); + assert(!edges[0][0][0][0]); + cerr << " MAX SRC DELTA[0][0] = " << max_src_delta[0][0] << endl; + assert(max_src_delta[0][0] > 0); + //cerr << "First cell contains " << b[0][0].size() << " forward pointers\n"; + //for (int i = 0; i < b[0][0].size(); ++i) { + // cerr << " -> (" << b[0][0][i].next_src_covered << "," << b[0][0][i].next_trg_covered << ")\n"; + //} + } +}; + +ostream& operator<<(ostream& os, const FSTState& q); +struct FSTState { + explicit FSTState(int src_size) : + trg_covered_(), + src_covered_(), + src_coverage_(src_size) {} + + FSTState(short trg_covered, short src_covered, const vector& src_coverage, const vector& src_prefix) : + trg_covered_(trg_covered), + src_covered_(src_covered), + src_coverage_(src_coverage), + src_prefix_(src_prefix) { + if (src_coverage_.size() == src_covered) { + assert(src_prefix.size() == 0); + } + } + + // if we extend by the word at src_position, what are + // the next states that are reachable and lie on a valid + // path to the final state? + vector Extensions(int src_position, int src_len, int trg_len, const Reachability& r) const { + assert(src_position < src_coverage_.size()); + if (src_coverage_[src_position]) { + cerr << "Trying to extend " << *this << " with position " << src_position << endl; + abort(); + } + vector ncvg = src_coverage_; + ncvg[src_position] = true; + + vector res; + const int trg_remaining = trg_len - trg_covered_; + if (trg_remaining <= 0) { + cerr << "Target appears to have been covered: " << *this << " (trg_len=" << trg_len << ",trg_covered=" << trg_covered_ << ")" << endl; + abort(); + } + const int src_remaining = src_len - src_covered_; + if (src_remaining <= 0) { + cerr << "Source appears to have been covered: " << *this << endl; + abort(); + } + + for (int tc = 1; tc <= kMAX_TRG_PHRASE; ++tc) { + if (r.edges[src_covered_][trg_covered_][src_prefix_.size() + 1][tc]) { + int nc = src_prefix_.size() + 1 + src_covered_; + res.push_back(FSTState(trg_covered_ + tc, nc, ncvg, vector())); + } + } + + if ((src_prefix_.size() + 1) < r.max_src_delta[src_covered_][trg_covered_]) { + vector nsp = src_prefix_; + nsp.push_back(src_position); + res.push_back(FSTState(trg_covered_, src_covered_, ncvg, nsp)); + } + + if (res.size() == 0) { + cerr << *this << " can't be extended!\n"; + abort(); + } + return res; + } + + short trg_covered_, src_covered_; + vector src_coverage_; + vector src_prefix_; +}; +bool operator<(const FSTState& q, const FSTState& r) { + if (q.trg_covered_ != r.trg_covered_) return q.trg_covered_ < r.trg_covered_; + if (q.src_covered_!= r.src_covered_) return q.src_covered_ < r.src_covered_; + if (q.src_coverage_ != r.src_coverage_) return q.src_coverage_ < r.src_coverage_; + return q.src_prefix_ < r.src_prefix_; +} + +ostream& operator<<(ostream& os, const FSTState& q) { + os << "[" << q.trg_covered_ << " : "; + for (int i = 0; i < q.src_coverage_.size(); ++i) + os << q.src_coverage_[i]; + os << " : <"; + for (int i = 0; i < q.src_prefix_.size(); ++i) { + if (i != 0) os << ' '; + os << q.src_prefix_[i]; + } + return os << ">]"; +} + +struct MyModel { + MyModel(ConditionalBase& rcp0) : rp0(rcp0) {} + typedef unordered_map, CCRP_NoTable, boost::hash > > SrcToRuleCRPMap; + + void DecrementRule(const TRule& rule) { + SrcToRuleCRPMap::iterator it = rules.find(rule.f_); + assert(it != rules.end()); + it->second.decrement(rule); + if (it->second.num_customers() == 0) rules.erase(it); + } + + void IncrementRule(const TRule& rule) { + SrcToRuleCRPMap::iterator it = rules.find(rule.f_); + if (it == rules.end()) { + CCRP_NoTable crp(1,1); + it = rules.insert(make_pair(rule.f_, crp)).first; + } + it->second.increment(rule); + } + + // conditioned on rule.f_ + prob_t RuleConditionalProbability(const TRule& rule) const { + const prob_t base = rp0(rule); + SrcToRuleCRPMap::const_iterator it = rules.find(rule.f_); + if (it == rules.end()) { + return base; + } else { + const double lp = it->second.logprob(rule, log(base)); + prob_t q; q.logeq(lp); + return q; + } + } + + const ConditionalBase& rp0; + SrcToRuleCRPMap rules; +}; + +struct MyFST : public WFST { + MyFST(const vector& ssrc, const vector& strg, MyModel* m) : + src(ssrc), trg(strg), + r(src.size(),trg.size(),kMAX_SRC_PHRASE, kMAX_TRG_PHRASE), + model(m) { + FSTState in(src.size()); + cerr << " INIT: " << in << endl; + init = GetNode(in); + for (int i = 0; i < in.src_coverage_.size(); ++i) in.src_coverage_[i] = true; + in.src_covered_ = src.size(); + in.trg_covered_ = trg.size(); + cerr << "FINAL: " << in << endl; + final = GetNode(in); + } + virtual const WFSTNode* Final() const; + virtual const WFSTNode* Initial() const; + + const WFSTNode* GetNode(const FSTState& q); + map > m; + const vector& src; + const vector& trg; + Reachability r; + const WFSTNode* init; + const WFSTNode* final; + MyModel* model; +}; + +struct MyNode : public WFSTNode { + MyNode(const FSTState& q, MyFST* fst) : state(q), container(fst) {} + virtual vector > ExtendInput(unsigned srcindex) const; + const FSTState state; + mutable MyFST* container; +}; + +vector > MyNode::ExtendInput(unsigned srcindex) const { + cerr << "EXTEND " << state << " with " << srcindex << endl; + vector ext = state.Extensions(srcindex, container->src.size(), container->trg.size(), container->r); + vector > res(ext.size()); + for (unsigned i = 0; i < ext.size(); ++i) { + res[i].first = container->GetNode(ext[i]); + if (ext[i].src_prefix_.size() == 0) { + const unsigned trg_from = state.trg_covered_; + const unsigned trg_to = ext[i].trg_covered_; + const unsigned prev_prfx_size = state.src_prefix_.size(); + res[i].second.reset(new TRule); + res[i].second->lhs_ = -TD::Convert("X"); + vector& src = res[i].second->f_; + vector& trg = res[i].second->e_; + src.resize(prev_prfx_size + 1); + for (unsigned j = 0; j < prev_prfx_size; ++j) + src[j] = container->src[state.src_prefix_[j]]; + src[prev_prfx_size] = container->src[srcindex]; + for (unsigned j = trg_from; j < trg_to; ++j) + trg.push_back(container->trg[j]); + res[i].second->scores_.set_value(FD::Convert("Proposal"), log(container->model->RuleConditionalProbability(*res[i].second))); + } + } + return res; +} + +const WFSTNode* MyFST::GetNode(const FSTState& q) { + boost::shared_ptr& res = m[q]; + if (!res) { + res.reset(new MyNode(q, this)); + } + return &*res; +} + +const WFSTNode* MyFST::Final() const { + return final; +} + +const WFSTNode* MyFST::Initial() const { + return init; +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + kMAX_TRG_PHRASE = conf["max_trg_phrase"].as(); + kMAX_SRC_PHRASE = conf["max_src_phrase"].as(); + + if (!conf.count("model1")) { + cerr << argv[0] << "Please use --model1 to specify model 1 parameters\n"; + return 1; + } + shared_ptr prng; + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as())); + else + prng.reset(new MT19937); + MT19937& rng = *prng; + + vector > corpuse, corpusf; + set vocabe, vocabf; + ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); + cerr << "f-Corpus size: " << corpusf.size() << " sentences\n"; + cerr << "f-Vocabulary size: " << vocabf.size() << " types\n"; + cerr << "f-Corpus size: " << corpuse.size() << " sentences\n"; + cerr << "f-Vocabulary size: " << vocabe.size() << " types\n"; + assert(corpusf.size() == corpuse.size()); + + ConditionalBase lp0(conf["model1_interpolation_weight"].as(), + vocabe.size(), + conf["model1"].as()); + MyModel m(lp0); + + TRule x("[X] ||| kAnwntR myN ||| at the convent ||| 0"); + m.IncrementRule(x); + TRule y("[X] ||| nY dyN ||| gave ||| 0"); + m.IncrementRule(y); + + + MyFST fst(corpusf[0], corpuse[0], &m); + ifstream in("./kimura.g"); + assert(in); + CFG_WFSTComposer comp(fst); + Hypergraph hg; + bool succeed = comp.Compose(&in, &hg); + hg.PrintGraphviz(); + if (succeed) { cerr << "SUCCESS.\n"; } else { cerr << "FAILURE REPORTED.\n"; } + +#if 0 + ifstream in2("./amnabooks.g"); + assert(in2); + MyFST fst2(corpusf[1], corpuse[1], &m); + CFG_WFSTComposer comp2(fst2); + Hypergraph hg2; + bool succeed2 = comp2.Compose(&in2, &hg2); + if (succeed2) { cerr << "SUCCESS.\n"; } else { cerr << "FAILURE REPORTED.\n"; } +#endif + + SparseVector w; w.set_value(FD::Convert("Proposal"), 1.0); + hg.Reweight(w); + cerr << ViterbiFTree(hg) << endl; + return 0; +} + diff --git a/gi/pf/pfdist.cc b/gi/pf/pfdist.cc new file mode 100644 index 00000000..18dfd03b --- /dev/null +++ b/gi/pf/pfdist.cc @@ -0,0 +1,621 @@ +#include +#include +#include + +#include +#include +#include + +#include "base_measures.h" +#include "reachability.h" +#include "viterbi.h" +#include "hg.h" +#include "trule.h" +#include "tdict.h" +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "ccrp_nt.h" +#include "ccrp_onetable.h" + +using namespace std; +using namespace tr1; +namespace po = boost::program_options; + +shared_ptr prng; + +size_t hash_value(const TRule& r) { + size_t h = boost::hash_value(r.e_); + boost::hash_combine(h, -r.lhs_); + boost::hash_combine(h, boost::hash_value(r.f_)); + return h; +} + +bool operator==(const TRule& a, const TRule& b) { + return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); +} + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value()->default_value(1000),"Number of samples") + ("particles,p",po::value()->default_value(30),"Number of particles") + ("filter_frequency,f",po::value()->default_value(5),"Number of time steps between filterings") + ("input,i",po::value(),"Read parallel data from") + ("max_src_phrase",po::value()->default_value(5),"Maximum length of source language phrases") + ("max_trg_phrase",po::value()->default_value(5),"Maximum length of target language phrases") + ("model1,m",po::value(),"Model 1 parameters (used in base distribution)") + ("inverse_model1,M",po::value(),"Inverse Model 1 parameters (used in backward estimate)") + ("model1_interpolation_weight",po::value()->default_value(0.95),"Mixing proportion of model 1 with uniform target distribution") + ("random_seed,S",po::value(), "Random seed"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void ReadParallelCorpus(const string& filename, + vector >* f, + vector >* e, + set* vocab_f, + set* vocab_e) { + f->clear(); + e->clear(); + vocab_f->clear(); + vocab_e->clear(); + istream* in; + if (filename == "-") + in = &cin; + else + in = new ifstream(filename.c_str()); + assert(*in); + string line; + const WordID kDIV = TD::Convert("|||"); + vector tmp; + while(*in) { + getline(*in, line); + if (line.empty() && !*in) break; + e->push_back(vector()); + f->push_back(vector()); + vector& le = e->back(); + vector& lf = f->back(); + tmp.clear(); + TD::ConvertSentence(line, &tmp); + bool isf = true; + for (unsigned i = 0; i < tmp.size(); ++i) { + const int cur = tmp[i]; + if (isf) { + if (kDIV == cur) { isf = false; } else { + lf.push_back(cur); + vocab_f->insert(cur); + } + } else { + assert(cur != kDIV); + le.push_back(cur); + vocab_e->insert(cur); + } + } + assert(isf == false); + } + if (in != &cin) delete in; +} + +#if 0 +struct MyConditionalModel { + MyConditionalModel(PhraseConditionalBase& rcp0) : rp0(&rcp0), base(prob_t::One()), src_phrases(1,1), src_jumps(200, CCRP_NoTable(1,1)) {} + + prob_t srcp0(const vector& src) const { + prob_t p(1.0 / 3000.0); + p.poweq(src.size()); + prob_t lenp; lenp.logeq(log_poisson(src.size(), 1.0)); + p *= lenp; + return p; + } + + void DecrementRule(const TRule& rule) { + const RuleCRPMap::iterator it = rules.find(rule.f_); + assert(it != rules.end()); + if (it->second.decrement(rule)) { + base /= (*rp0)(rule); + if (it->second.num_customers() == 0) + rules.erase(it); + } + if (src_phrases.decrement(rule.f_)) + base /= srcp0(rule.f_); + } + + void IncrementRule(const TRule& rule) { + RuleCRPMap::iterator it = rules.find(rule.f_); + if (it == rules.end()) + it = rules.insert(make_pair(rule.f_, CCRP_NoTable(1,1))).first; + if (it->second.increment(rule)) { + base *= (*rp0)(rule); + } + if (src_phrases.increment(rule.f_)) + base *= srcp0(rule.f_); + } + + void IncrementRules(const vector& rules) { + for (int i = 0; i < rules.size(); ++i) + IncrementRule(*rules[i]); + } + + void DecrementRules(const vector& rules) { + for (int i = 0; i < rules.size(); ++i) + DecrementRule(*rules[i]); + } + + void IncrementJump(int dist, unsigned src_len) { + assert(src_len > 0); + if (src_jumps[src_len].increment(dist)) + base *= jp0(dist, src_len); + } + + void DecrementJump(int dist, unsigned src_len) { + assert(src_len > 0); + if (src_jumps[src_len].decrement(dist)) + base /= jp0(dist, src_len); + } + + void IncrementJumps(const vector& js, unsigned src_len) { + for (unsigned i = 0; i < js.size(); ++i) + IncrementJump(js[i], src_len); + } + + void DecrementJumps(const vector& js, unsigned src_len) { + for (unsigned i = 0; i < js.size(); ++i) + DecrementJump(js[i], src_len); + } + + // p(jump = dist | src_len , z) + prob_t JumpProbability(int dist, unsigned src_len) { + const prob_t p0 = jp0(dist, src_len); + const double lp = src_jumps[src_len].logprob(dist, log(p0)); + prob_t q; q.logeq(lp); + return q; + } + + // p(rule.f_ | z) * p(rule.e_ | rule.f_ , z) + prob_t RuleProbability(const TRule& rule) const { + const prob_t p0 = (*rp0)(rule); + prob_t srcp; srcp.logeq(src_phrases.logprob(rule.f_, log(srcp0(rule.f_)))); + const RuleCRPMap::const_iterator it = rules.find(rule.f_); + if (it == rules.end()) return srcp * p0; + const double lp = it->second.logprob(rule, log(p0)); + prob_t q; q.logeq(lp); + return q * srcp; + } + + prob_t Likelihood() const { + prob_t p = base; + for (RuleCRPMap::const_iterator it = rules.begin(); + it != rules.end(); ++it) { + prob_t cl; cl.logeq(it->second.log_crp_prob()); + p *= cl; + } + for (unsigned l = 1; l < src_jumps.size(); ++l) { + if (src_jumps[l].num_customers() > 0) { + prob_t q; + q.logeq(src_jumps[l].log_crp_prob()); + p *= q; + } + } + return p; + } + + JumpBase jp0; + const PhraseConditionalBase* rp0; + prob_t base; + typedef unordered_map, CCRP_NoTable, boost::hash > > RuleCRPMap; + RuleCRPMap rules; + CCRP_NoTable > src_phrases; + vector > src_jumps; +}; + +#endif + +struct MyJointModel { + MyJointModel(PhraseJointBase& rcp0) : + rp0(rcp0), base(prob_t::One()), rules(1,1), src_jumps(200, CCRP_NoTable(1,1)) {} + + void DecrementRule(const TRule& rule) { + if (rules.decrement(rule)) + base /= rp0(rule); + } + + void IncrementRule(const TRule& rule) { + if (rules.increment(rule)) + base *= rp0(rule); + } + + void IncrementRules(const vector& rules) { + for (int i = 0; i < rules.size(); ++i) + IncrementRule(*rules[i]); + } + + void DecrementRules(const vector& rules) { + for (int i = 0; i < rules.size(); ++i) + DecrementRule(*rules[i]); + } + + void IncrementJump(int dist, unsigned src_len) { + assert(src_len > 0); + if (src_jumps[src_len].increment(dist)) + base *= jp0(dist, src_len); + } + + void DecrementJump(int dist, unsigned src_len) { + assert(src_len > 0); + if (src_jumps[src_len].decrement(dist)) + base /= jp0(dist, src_len); + } + + void IncrementJumps(const vector& js, unsigned src_len) { + for (unsigned i = 0; i < js.size(); ++i) + IncrementJump(js[i], src_len); + } + + void DecrementJumps(const vector& js, unsigned src_len) { + for (unsigned i = 0; i < js.size(); ++i) + DecrementJump(js[i], src_len); + } + + // p(jump = dist | src_len , z) + prob_t JumpProbability(int dist, unsigned src_len) { + const prob_t p0 = jp0(dist, src_len); + const double lp = src_jumps[src_len].logprob(dist, log(p0)); + prob_t q; q.logeq(lp); + return q; + } + + // p(rule.f_ | z) * p(rule.e_ | rule.f_ , z) + prob_t RuleProbability(const TRule& rule) const { + prob_t p; p.logeq(rules.logprob(rule, log(rp0(rule)))); + return p; + } + + prob_t Likelihood() const { + prob_t p = base; + prob_t q; q.logeq(rules.log_crp_prob()); + p *= q; + for (unsigned l = 1; l < src_jumps.size(); ++l) { + if (src_jumps[l].num_customers() > 0) { + prob_t q; + q.logeq(src_jumps[l].log_crp_prob()); + p *= q; + } + } + return p; + } + + JumpBase jp0; + const PhraseJointBase& rp0; + prob_t base; + CCRP_NoTable rules; + vector > src_jumps; +}; + +struct BackwardEstimate { + BackwardEstimate(const Model1& m1, const vector& src, const vector& trg) : + model1_(m1), src_(src), trg_(trg) { + } + const prob_t& operator()(const vector& src_cov, unsigned trg_cov) const { + assert(src_.size() == src_cov.size()); + assert(trg_cov <= trg_.size()); + prob_t& e = cache_[src_cov][trg_cov]; + if (e.is_0()) { + if (trg_cov == trg_.size()) { e = prob_t::One(); return e; } + vector r(src_.size() + 1); r.clear(); + r.push_back(0); // NULL word + for (int i = 0; i < src_cov.size(); ++i) + if (!src_cov[i]) r.push_back(src_[i]); + const prob_t uniform_alignment(1.0 / r.size()); + e.logeq(log_poisson(trg_.size() - trg_cov, r.size() - 1)); // p(trg len remaining | src len remaining) + for (unsigned j = trg_cov; j < trg_.size(); ++j) { + prob_t p; + for (unsigned i = 0; i < r.size(); ++i) + p += model1_(r[i], trg_[j]); + if (p.is_0()) { + cerr << "ERROR: p(" << TD::Convert(trg_[j]) << " | " << TD::GetString(r) << ") = 0!\n"; + abort(); + } + p *= uniform_alignment; + e *= p; + } + } + return e; + } + const Model1& model1_; + const vector& src_; + const vector& trg_; + mutable unordered_map, map, boost::hash > > cache_; +}; + +struct BackwardEstimateSym { + BackwardEstimateSym(const Model1& m1, + const Model1& invm1, const vector& src, const vector& trg) : + model1_(m1), invmodel1_(invm1), src_(src), trg_(trg) { + } + const prob_t& operator()(const vector& src_cov, unsigned trg_cov) const { + assert(src_.size() == src_cov.size()); + assert(trg_cov <= trg_.size()); + prob_t& e = cache_[src_cov][trg_cov]; + if (e.is_0()) { + if (trg_cov == trg_.size()) { e = prob_t::One(); return e; } + vector r(src_.size() + 1); r.clear(); + for (int i = 0; i < src_cov.size(); ++i) + if (!src_cov[i]) r.push_back(src_[i]); + r.push_back(0); // NULL word + const prob_t uniform_alignment(1.0 / r.size()); + e.logeq(log_poisson(trg_.size() - trg_cov, r.size() - 1)); // p(trg len remaining | src len remaining) + for (unsigned j = trg_cov; j < trg_.size(); ++j) { + prob_t p; + for (unsigned i = 0; i < r.size(); ++i) + p += model1_(r[i], trg_[j]); + if (p.is_0()) { + cerr << "ERROR: p(" << TD::Convert(trg_[j]) << " | " << TD::GetString(r) << ") = 0!\n"; + abort(); + } + p *= uniform_alignment; + e *= p; + } + r.pop_back(); + const prob_t inv_uniform(1.0 / (trg_.size() - trg_cov + 1.0)); + prob_t inv; + inv.logeq(log_poisson(r.size(), trg_.size() - trg_cov)); + for (unsigned i = 0; i < r.size(); ++i) { + prob_t p; + for (unsigned j = trg_cov - 1; j < trg_.size(); ++j) + p += invmodel1_(j < trg_cov ? 0 : trg_[j], r[i]); + if (p.is_0()) { + cerr << "ERROR: p_inv(" << TD::Convert(r[i]) << " | " << TD::GetString(trg_) << ") = 0!\n"; + abort(); + } + p *= inv_uniform; + inv *= p; + } + prob_t x = pow(e * inv, 0.5); + e = x; + //cerr << "Forward: " << log(e) << "\tBackward: " << log(inv) << "\t prop: " << log(x) << endl; + } + return e; + } + const Model1& model1_; + const Model1& invmodel1_; + const vector& src_; + const vector& trg_; + mutable unordered_map, map, boost::hash > > cache_; +}; + +struct Particle { + Particle() : weight(prob_t::One()), src_cov(), trg_cov(), prev_pos(-1) {} + prob_t weight; + prob_t gamma_last; + vector src_jumps; + vector rules; + vector src_cv; + int src_cov; + int trg_cov; + int prev_pos; +}; + +ostream& operator<<(ostream& o, const vector& v) { + for (int i = 0; i < v.size(); ++i) + o << (v[i] ? '1' : '0'); + return o; +} +ostream& operator<<(ostream& o, const Particle& p) { + o << "[cv=" << p.src_cv << " src_cov=" << p.src_cov << " trg_cov=" << p.trg_cov << " last_pos=" << p.prev_pos << " num_rules=" << p.rules.size() << " w=" << log(p.weight) << ']'; + return o; +} + +void FilterCrapParticlesAndReweight(vector* pps) { + vector& ps = *pps; + SampleSet ss; + for (int i = 0; i < ps.size(); ++i) + ss.add(ps[i].weight); + vector nps; nps.reserve(ps.size()); + const prob_t uniform_weight(1.0 / ps.size()); + for (int i = 0; i < ps.size(); ++i) { + nps.push_back(ps[prng->SelectSample(ss)]); + nps[i].weight = uniform_weight; + } + nps.swap(ps); +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const unsigned kMAX_TRG_PHRASE = conf["max_trg_phrase"].as(); + const unsigned kMAX_SRC_PHRASE = conf["max_src_phrase"].as(); + const unsigned particles = conf["particles"].as(); + const unsigned samples = conf["samples"].as(); + const unsigned rejuv_freq = conf["filter_frequency"].as(); + + if (!conf.count("model1")) { + cerr << argv[0] << "Please use --model1 to specify model 1 parameters\n"; + return 1; + } + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as())); + else + prng.reset(new MT19937); + MT19937& rng = *prng; + + vector > corpuse, corpusf; + set vocabe, vocabf; + cerr << "Reading corpus...\n"; + ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); + cerr << "F-corpus size: " << corpusf.size() << " sentences\t (" << vocabf.size() << " word types)\n"; + cerr << "E-corpus size: " << corpuse.size() << " sentences\t (" << vocabe.size() << " word types)\n"; + assert(corpusf.size() == corpuse.size()); + + const int kLHS = -TD::Convert("X"); + Model1 m1(conf["model1"].as()); + Model1 invm1(conf["inverse_model1"].as()); + +#if 0 + PhraseConditionalBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size()); + MyConditionalModel m(lp0); +#else + PhraseJointBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size(), vocabf.size()); + MyJointModel m(lp0); +#endif + + cerr << "Initializing reachability limits...\n"; + vector ps(corpusf.size()); + vector reaches; reaches.reserve(corpusf.size()); + for (int ci = 0; ci < corpusf.size(); ++ci) + reaches.push_back(Reachability(corpusf[ci].size(), + corpuse[ci].size(), + kMAX_SRC_PHRASE, + kMAX_TRG_PHRASE)); + cerr << "Sampling...\n"; + vector tmp_p(10000); // work space + SampleSet pfss; + for (int SS=0; SS < samples; ++SS) { + for (int ci = 0; ci < corpusf.size(); ++ci) { + vector& src = corpusf[ci]; + vector& trg = corpuse[ci]; + m.DecrementRules(ps[ci].rules); + m.DecrementJumps(ps[ci].src_jumps, src.size()); + + //BackwardEstimate be(m1, src, trg); + BackwardEstimateSym be(m1, invm1, src, trg); + const Reachability& r = reaches[ci]; + vector lps(particles); + + for (int pi = 0; pi < particles; ++pi) { + Particle& p = lps[pi]; + p.src_cv.resize(src.size(), false); + } + + bool all_complete = false; + while(!all_complete) { + SampleSet ss; + + // all particles have now been extended a bit, we will reweight them now + if (lps[0].trg_cov > 0) + FilterCrapParticlesAndReweight(&lps); + + // loop over all particles and extend them + bool done_nothing = true; + for (int pi = 0; pi < particles; ++pi) { + Particle& p = lps[pi]; + int tic = 0; + while(p.trg_cov < trg.size() && tic < rejuv_freq) { + ++tic; + done_nothing = false; + ss.clear(); + TRule x; x.lhs_ = kLHS; + prob_t z; + int first_uncovered = src.size(); + int last_uncovered = -1; + for (int i = 0; i < src.size(); ++i) { + const bool is_uncovered = !p.src_cv[i]; + if (i < first_uncovered && is_uncovered) first_uncovered = i; + if (is_uncovered && i > last_uncovered) last_uncovered = i; + } + assert(last_uncovered > -1); + assert(first_uncovered < src.size()); + + for (int trg_len = 1; trg_len <= kMAX_TRG_PHRASE; ++trg_len) { + x.e_.push_back(trg[trg_len - 1 + p.trg_cov]); + for (int src_len = 1; src_len <= kMAX_SRC_PHRASE; ++src_len) { + if (!r.edges[p.src_cov][p.trg_cov][src_len][trg_len]) continue; + + const int last_possible_start = last_uncovered - src_len + 1; + assert(last_possible_start >= 0); + //cerr << src_len << "," << trg_len << " is allowed. E=" << TD::GetString(x.e_) << endl; + //cerr << " first_uncovered=" << first_uncovered << " last_possible_start=" << last_possible_start << endl; + for (int i = first_uncovered; i <= last_possible_start; ++i) { + if (p.src_cv[i]) continue; + assert(ss.size() < tmp_p.size()); // if fails increase tmp_p size + Particle& np = tmp_p[ss.size()]; + np = p; + x.f_.clear(); + int gap_add = 0; + bool bad = false; + prob_t jp = prob_t::One(); + int prev_pos = p.prev_pos; + for (int j = 0; j < src_len; ++j) { + if ((j + i + gap_add) == src.size()) { bad = true; break; } + while ((i+j+gap_add) < src.size() && p.src_cv[i + j + gap_add]) { ++gap_add; } + if ((j + i + gap_add) == src.size()) { bad = true; break; } + np.src_cv[i + j + gap_add] = true; + x.f_.push_back(src[i + j + gap_add]); + jp *= m.JumpProbability(i + j + gap_add - prev_pos, src.size()); + int jump = i + j + gap_add - prev_pos; + assert(jump != 0); + np.src_jumps.push_back(jump); + prev_pos = i + j + gap_add; + } + if (bad) continue; + np.prev_pos = prev_pos; + np.src_cov += x.f_.size(); + np.trg_cov += x.e_.size(); + if (x.f_.size() != src_len) continue; + prob_t rp = m.RuleProbability(x); + np.gamma_last = rp * jp; + const prob_t u = pow(np.gamma_last * be(np.src_cv, np.trg_cov), 0.2); + //cerr << "**rule=" << x << endl; + //cerr << " u=" << log(u) << " rule=" << rp << " jump=" << jp << endl; + ss.add(u); + np.rules.push_back(TRulePtr(new TRule(x))); + z += u; + + const bool completed = (p.trg_cov == trg.size()); + if (completed) { + int last_jump = src.size() - p.prev_pos; + assert(last_jump > 0); + p.src_jumps.push_back(last_jump); + p.weight *= m.JumpProbability(last_jump, src.size()); + } + } + } + } + cerr << "number of edges to consider: " << ss.size() << endl; + const int sampled = rng.SelectSample(ss); + prob_t q_n = ss[sampled] / z; + p = tmp_p[sampled]; + //m.IncrementRule(*p.rules.back()); + p.weight *= p.gamma_last / q_n; + cerr << "[w=" << log(p.weight) << "]\tsampled rule: " << p.rules.back()->AsString() << endl; + cerr << p << endl; + } + } // loop over particles (pi = 0 .. particles) + if (done_nothing) all_complete = true; + } + pfss.clear(); + for (int i = 0; i < lps.size(); ++i) + pfss.add(lps[i].weight); + const int sampled = rng.SelectSample(pfss); + ps[ci] = lps[sampled]; + m.IncrementRules(lps[sampled].rules); + m.IncrementJumps(lps[sampled].src_jumps, src.size()); + for (int i = 0; i < lps[sampled].rules.size(); ++i) { cerr << "S:\t" << lps[sampled].rules[i]->AsString() << "\n"; } + cerr << "tmp-LLH: " << log(m.Likelihood()) << endl; + } + cerr << "LLH: " << log(m.Likelihood()) << endl; + for (int sni = 0; sni < 5; ++sni) { + for (int i = 0; i < ps[sni].rules.size(); ++i) { cerr << "\t" << ps[sni].rules[i]->AsString() << endl; } + } + } + return 0; +} + diff --git a/gi/pf/pfdist.new.cc b/gi/pf/pfdist.new.cc new file mode 100644 index 00000000..3169eb75 --- /dev/null +++ b/gi/pf/pfdist.new.cc @@ -0,0 +1,620 @@ +#include +#include +#include + +#include +#include +#include + +#include "base_measures.h" +#include "reachability.h" +#include "viterbi.h" +#include "hg.h" +#include "trule.h" +#include "tdict.h" +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "ccrp_nt.h" +#include "ccrp_onetable.h" + +using namespace std; +using namespace tr1; +namespace po = boost::program_options; + +shared_ptr prng; + +size_t hash_value(const TRule& r) { + size_t h = boost::hash_value(r.e_); + boost::hash_combine(h, -r.lhs_); + boost::hash_combine(h, boost::hash_value(r.f_)); + return h; +} + +bool operator==(const TRule& a, const TRule& b) { + return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); +} + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value()->default_value(1000),"Number of samples") + ("particles,p",po::value()->default_value(25),"Number of particles") + ("input,i",po::value(),"Read parallel data from") + ("max_src_phrase",po::value()->default_value(5),"Maximum length of source language phrases") + ("max_trg_phrase",po::value()->default_value(5),"Maximum length of target language phrases") + ("model1,m",po::value(),"Model 1 parameters (used in base distribution)") + ("inverse_model1,M",po::value(),"Inverse Model 1 parameters (used in backward estimate)") + ("model1_interpolation_weight",po::value()->default_value(0.95),"Mixing proportion of model 1 with uniform target distribution") + ("random_seed,S",po::value(), "Random seed"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void ReadParallelCorpus(const string& filename, + vector >* f, + vector >* e, + set* vocab_f, + set* vocab_e) { + f->clear(); + e->clear(); + vocab_f->clear(); + vocab_e->clear(); + istream* in; + if (filename == "-") + in = &cin; + else + in = new ifstream(filename.c_str()); + assert(*in); + string line; + const WordID kDIV = TD::Convert("|||"); + vector tmp; + while(*in) { + getline(*in, line); + if (line.empty() && !*in) break; + e->push_back(vector()); + f->push_back(vector()); + vector& le = e->back(); + vector& lf = f->back(); + tmp.clear(); + TD::ConvertSentence(line, &tmp); + bool isf = true; + for (unsigned i = 0; i < tmp.size(); ++i) { + const int cur = tmp[i]; + if (isf) { + if (kDIV == cur) { isf = false; } else { + lf.push_back(cur); + vocab_f->insert(cur); + } + } else { + assert(cur != kDIV); + le.push_back(cur); + vocab_e->insert(cur); + } + } + assert(isf == false); + } + if (in != &cin) delete in; +} + +#if 0 +struct MyConditionalModel { + MyConditionalModel(PhraseConditionalBase& rcp0) : rp0(&rcp0), base(prob_t::One()), src_phrases(1,1), src_jumps(200, CCRP_NoTable(1,1)) {} + + prob_t srcp0(const vector& src) const { + prob_t p(1.0 / 3000.0); + p.poweq(src.size()); + prob_t lenp; lenp.logeq(log_poisson(src.size(), 1.0)); + p *= lenp; + return p; + } + + void DecrementRule(const TRule& rule) { + const RuleCRPMap::iterator it = rules.find(rule.f_); + assert(it != rules.end()); + if (it->second.decrement(rule)) { + base /= (*rp0)(rule); + if (it->second.num_customers() == 0) + rules.erase(it); + } + if (src_phrases.decrement(rule.f_)) + base /= srcp0(rule.f_); + } + + void IncrementRule(const TRule& rule) { + RuleCRPMap::iterator it = rules.find(rule.f_); + if (it == rules.end()) + it = rules.insert(make_pair(rule.f_, CCRP_NoTable(1,1))).first; + if (it->second.increment(rule)) { + base *= (*rp0)(rule); + } + if (src_phrases.increment(rule.f_)) + base *= srcp0(rule.f_); + } + + void IncrementRules(const vector& rules) { + for (int i = 0; i < rules.size(); ++i) + IncrementRule(*rules[i]); + } + + void DecrementRules(const vector& rules) { + for (int i = 0; i < rules.size(); ++i) + DecrementRule(*rules[i]); + } + + void IncrementJump(int dist, unsigned src_len) { + assert(src_len > 0); + if (src_jumps[src_len].increment(dist)) + base *= jp0(dist, src_len); + } + + void DecrementJump(int dist, unsigned src_len) { + assert(src_len > 0); + if (src_jumps[src_len].decrement(dist)) + base /= jp0(dist, src_len); + } + + void IncrementJumps(const vector& js, unsigned src_len) { + for (unsigned i = 0; i < js.size(); ++i) + IncrementJump(js[i], src_len); + } + + void DecrementJumps(const vector& js, unsigned src_len) { + for (unsigned i = 0; i < js.size(); ++i) + DecrementJump(js[i], src_len); + } + + // p(jump = dist | src_len , z) + prob_t JumpProbability(int dist, unsigned src_len) { + const prob_t p0 = jp0(dist, src_len); + const double lp = src_jumps[src_len].logprob(dist, log(p0)); + prob_t q; q.logeq(lp); + return q; + } + + // p(rule.f_ | z) * p(rule.e_ | rule.f_ , z) + prob_t RuleProbability(const TRule& rule) const { + const prob_t p0 = (*rp0)(rule); + prob_t srcp; srcp.logeq(src_phrases.logprob(rule.f_, log(srcp0(rule.f_)))); + const RuleCRPMap::const_iterator it = rules.find(rule.f_); + if (it == rules.end()) return srcp * p0; + const double lp = it->second.logprob(rule, log(p0)); + prob_t q; q.logeq(lp); + return q * srcp; + } + + prob_t Likelihood() const { + prob_t p = base; + for (RuleCRPMap::const_iterator it = rules.begin(); + it != rules.end(); ++it) { + prob_t cl; cl.logeq(it->second.log_crp_prob()); + p *= cl; + } + for (unsigned l = 1; l < src_jumps.size(); ++l) { + if (src_jumps[l].num_customers() > 0) { + prob_t q; + q.logeq(src_jumps[l].log_crp_prob()); + p *= q; + } + } + return p; + } + + JumpBase jp0; + const PhraseConditionalBase* rp0; + prob_t base; + typedef unordered_map, CCRP_NoTable, boost::hash > > RuleCRPMap; + RuleCRPMap rules; + CCRP_NoTable > src_phrases; + vector > src_jumps; +}; + +#endif + +struct MyJointModel { + MyJointModel(PhraseJointBase& rcp0) : + rp0(rcp0), base(prob_t::One()), rules(1,1), src_jumps(200, CCRP_NoTable(1,1)) {} + + void DecrementRule(const TRule& rule) { + if (rules.decrement(rule)) + base /= rp0(rule); + } + + void IncrementRule(const TRule& rule) { + if (rules.increment(rule)) + base *= rp0(rule); + } + + void IncrementRules(const vector& rules) { + for (int i = 0; i < rules.size(); ++i) + IncrementRule(*rules[i]); + } + + void DecrementRules(const vector& rules) { + for (int i = 0; i < rules.size(); ++i) + DecrementRule(*rules[i]); + } + + void IncrementJump(int dist, unsigned src_len) { + assert(src_len > 0); + if (src_jumps[src_len].increment(dist)) + base *= jp0(dist, src_len); + } + + void DecrementJump(int dist, unsigned src_len) { + assert(src_len > 0); + if (src_jumps[src_len].decrement(dist)) + base /= jp0(dist, src_len); + } + + void IncrementJumps(const vector& js, unsigned src_len) { + for (unsigned i = 0; i < js.size(); ++i) + IncrementJump(js[i], src_len); + } + + void DecrementJumps(const vector& js, unsigned src_len) { + for (unsigned i = 0; i < js.size(); ++i) + DecrementJump(js[i], src_len); + } + + // p(jump = dist | src_len , z) + prob_t JumpProbability(int dist, unsigned src_len) { + const prob_t p0 = jp0(dist, src_len); + const double lp = src_jumps[src_len].logprob(dist, log(p0)); + prob_t q; q.logeq(lp); + return q; + } + + // p(rule.f_ | z) * p(rule.e_ | rule.f_ , z) + prob_t RuleProbability(const TRule& rule) const { + prob_t p; p.logeq(rules.logprob(rule, log(rp0(rule)))); + return p; + } + + prob_t Likelihood() const { + prob_t p = base; + prob_t q; q.logeq(rules.log_crp_prob()); + p *= q; + for (unsigned l = 1; l < src_jumps.size(); ++l) { + if (src_jumps[l].num_customers() > 0) { + prob_t q; + q.logeq(src_jumps[l].log_crp_prob()); + p *= q; + } + } + return p; + } + + JumpBase jp0; + const PhraseJointBase& rp0; + prob_t base; + CCRP_NoTable rules; + vector > src_jumps; +}; + +struct BackwardEstimate { + BackwardEstimate(const Model1& m1, const vector& src, const vector& trg) : + model1_(m1), src_(src), trg_(trg) { + } + const prob_t& operator()(const vector& src_cov, unsigned trg_cov) const { + assert(src_.size() == src_cov.size()); + assert(trg_cov <= trg_.size()); + prob_t& e = cache_[src_cov][trg_cov]; + if (e.is_0()) { + if (trg_cov == trg_.size()) { e = prob_t::One(); return e; } + vector r(src_.size() + 1); r.clear(); + r.push_back(0); // NULL word + for (int i = 0; i < src_cov.size(); ++i) + if (!src_cov[i]) r.push_back(src_[i]); + const prob_t uniform_alignment(1.0 / r.size()); + e.logeq(log_poisson(trg_.size() - trg_cov, r.size() - 1)); // p(trg len remaining | src len remaining) + for (unsigned j = trg_cov; j < trg_.size(); ++j) { + prob_t p; + for (unsigned i = 0; i < r.size(); ++i) + p += model1_(r[i], trg_[j]); + if (p.is_0()) { + cerr << "ERROR: p(" << TD::Convert(trg_[j]) << " | " << TD::GetString(r) << ") = 0!\n"; + abort(); + } + p *= uniform_alignment; + e *= p; + } + } + return e; + } + const Model1& model1_; + const vector& src_; + const vector& trg_; + mutable unordered_map, map, boost::hash > > cache_; +}; + +struct BackwardEstimateSym { + BackwardEstimateSym(const Model1& m1, + const Model1& invm1, const vector& src, const vector& trg) : + model1_(m1), invmodel1_(invm1), src_(src), trg_(trg) { + } + const prob_t& operator()(const vector& src_cov, unsigned trg_cov) const { + assert(src_.size() == src_cov.size()); + assert(trg_cov <= trg_.size()); + prob_t& e = cache_[src_cov][trg_cov]; + if (e.is_0()) { + if (trg_cov == trg_.size()) { e = prob_t::One(); return e; } + vector r(src_.size() + 1); r.clear(); + for (int i = 0; i < src_cov.size(); ++i) + if (!src_cov[i]) r.push_back(src_[i]); + r.push_back(0); // NULL word + const prob_t uniform_alignment(1.0 / r.size()); + e.logeq(log_poisson(trg_.size() - trg_cov, r.size() - 1)); // p(trg len remaining | src len remaining) + for (unsigned j = trg_cov; j < trg_.size(); ++j) { + prob_t p; + for (unsigned i = 0; i < r.size(); ++i) + p += model1_(r[i], trg_[j]); + if (p.is_0()) { + cerr << "ERROR: p(" << TD::Convert(trg_[j]) << " | " << TD::GetString(r) << ") = 0!\n"; + abort(); + } + p *= uniform_alignment; + e *= p; + } + r.pop_back(); + const prob_t inv_uniform(1.0 / (trg_.size() - trg_cov + 1.0)); + prob_t inv; + inv.logeq(log_poisson(r.size(), trg_.size() - trg_cov)); + for (unsigned i = 0; i < r.size(); ++i) { + prob_t p; + for (unsigned j = trg_cov - 1; j < trg_.size(); ++j) + p += invmodel1_(j < trg_cov ? 0 : trg_[j], r[i]); + if (p.is_0()) { + cerr << "ERROR: p_inv(" << TD::Convert(r[i]) << " | " << TD::GetString(trg_) << ") = 0!\n"; + abort(); + } + p *= inv_uniform; + inv *= p; + } + prob_t x = pow(e * inv, 0.5); + e = x; + //cerr << "Forward: " << log(e) << "\tBackward: " << log(inv) << "\t prop: " << log(x) << endl; + } + return e; + } + const Model1& model1_; + const Model1& invmodel1_; + const vector& src_; + const vector& trg_; + mutable unordered_map, map, boost::hash > > cache_; +}; + +struct Particle { + Particle() : weight(prob_t::One()), src_cov(), trg_cov(), prev_pos(-1) {} + prob_t weight; + prob_t gamma_last; + vector src_jumps; + vector rules; + vector src_cv; + int src_cov; + int trg_cov; + int prev_pos; +}; + +ostream& operator<<(ostream& o, const vector& v) { + for (int i = 0; i < v.size(); ++i) + o << (v[i] ? '1' : '0'); + return o; +} +ostream& operator<<(ostream& o, const Particle& p) { + o << "[cv=" << p.src_cv << " src_cov=" << p.src_cov << " trg_cov=" << p.trg_cov << " last_pos=" << p.prev_pos << " num_rules=" << p.rules.size() << " w=" << log(p.weight) << ']'; + return o; +} + +void FilterCrapParticlesAndReweight(vector* pps) { + vector& ps = *pps; + SampleSet ss; + for (int i = 0; i < ps.size(); ++i) + ss.add(ps[i].weight); + vector nps; nps.reserve(ps.size()); + const prob_t uniform_weight(1.0 / ps.size()); + for (int i = 0; i < ps.size(); ++i) { + nps.push_back(ps[prng->SelectSample(ss)]); + nps[i].weight = uniform_weight; + } + nps.swap(ps); +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const unsigned kMAX_TRG_PHRASE = conf["max_trg_phrase"].as(); + const unsigned kMAX_SRC_PHRASE = conf["max_src_phrase"].as(); + const unsigned particles = conf["particles"].as(); + const unsigned samples = conf["samples"].as(); + + if (!conf.count("model1")) { + cerr << argv[0] << "Please use --model1 to specify model 1 parameters\n"; + return 1; + } + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as())); + else + prng.reset(new MT19937); + MT19937& rng = *prng; + + vector > corpuse, corpusf; + set vocabe, vocabf; + cerr << "Reading corpus...\n"; + ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); + cerr << "F-corpus size: " << corpusf.size() << " sentences\t (" << vocabf.size() << " word types)\n"; + cerr << "E-corpus size: " << corpuse.size() << " sentences\t (" << vocabe.size() << " word types)\n"; + assert(corpusf.size() == corpuse.size()); + + const int kLHS = -TD::Convert("X"); + Model1 m1(conf["model1"].as()); + Model1 invm1(conf["inverse_model1"].as()); + +#if 0 + PhraseConditionalBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size()); + MyConditionalModel m(lp0); +#else + PhraseJointBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size(), vocabf.size()); + MyJointModel m(lp0); +#endif + + cerr << "Initializing reachability limits...\n"; + vector ps(corpusf.size()); + vector reaches; reaches.reserve(corpusf.size()); + for (int ci = 0; ci < corpusf.size(); ++ci) + reaches.push_back(Reachability(corpusf[ci].size(), + corpuse[ci].size(), + kMAX_SRC_PHRASE, + kMAX_TRG_PHRASE)); + cerr << "Sampling...\n"; + vector tmp_p(10000); // work space + SampleSet pfss; + for (int SS=0; SS < samples; ++SS) { + for (int ci = 0; ci < corpusf.size(); ++ci) { + vector& src = corpusf[ci]; + vector& trg = corpuse[ci]; + m.DecrementRules(ps[ci].rules); + m.DecrementJumps(ps[ci].src_jumps, src.size()); + + //BackwardEstimate be(m1, src, trg); + BackwardEstimateSym be(m1, invm1, src, trg); + const Reachability& r = reaches[ci]; + vector lps(particles); + + for (int pi = 0; pi < particles; ++pi) { + Particle& p = lps[pi]; + p.src_cv.resize(src.size(), false); + } + + bool all_complete = false; + while(!all_complete) { + SampleSet ss; + + // all particles have now been extended a bit, we will reweight them now + if (lps[0].trg_cov > 0) + FilterCrapParticlesAndReweight(&lps); + + // loop over all particles and extend them + bool done_nothing = true; + for (int pi = 0; pi < particles; ++pi) { + Particle& p = lps[pi]; + int tic = 0; + const int rejuv_freq = 1; + while(p.trg_cov < trg.size() && tic < rejuv_freq) { + ++tic; + done_nothing = false; + ss.clear(); + TRule x; x.lhs_ = kLHS; + prob_t z; + int first_uncovered = src.size(); + int last_uncovered = -1; + for (int i = 0; i < src.size(); ++i) { + const bool is_uncovered = !p.src_cv[i]; + if (i < first_uncovered && is_uncovered) first_uncovered = i; + if (is_uncovered && i > last_uncovered) last_uncovered = i; + } + assert(last_uncovered > -1); + assert(first_uncovered < src.size()); + + for (int trg_len = 1; trg_len <= kMAX_TRG_PHRASE; ++trg_len) { + x.e_.push_back(trg[trg_len - 1 + p.trg_cov]); + for (int src_len = 1; src_len <= kMAX_SRC_PHRASE; ++src_len) { + if (!r.edges[p.src_cov][p.trg_cov][src_len][trg_len]) continue; + + const int last_possible_start = last_uncovered - src_len + 1; + assert(last_possible_start >= 0); + //cerr << src_len << "," << trg_len << " is allowed. E=" << TD::GetString(x.e_) << endl; + //cerr << " first_uncovered=" << first_uncovered << " last_possible_start=" << last_possible_start << endl; + for (int i = first_uncovered; i <= last_possible_start; ++i) { + if (p.src_cv[i]) continue; + assert(ss.size() < tmp_p.size()); // if fails increase tmp_p size + Particle& np = tmp_p[ss.size()]; + np = p; + x.f_.clear(); + int gap_add = 0; + bool bad = false; + prob_t jp = prob_t::One(); + int prev_pos = p.prev_pos; + for (int j = 0; j < src_len; ++j) { + if ((j + i + gap_add) == src.size()) { bad = true; break; } + while ((i+j+gap_add) < src.size() && p.src_cv[i + j + gap_add]) { ++gap_add; } + if ((j + i + gap_add) == src.size()) { bad = true; break; } + np.src_cv[i + j + gap_add] = true; + x.f_.push_back(src[i + j + gap_add]); + jp *= m.JumpProbability(i + j + gap_add - prev_pos, src.size()); + int jump = i + j + gap_add - prev_pos; + assert(jump != 0); + np.src_jumps.push_back(jump); + prev_pos = i + j + gap_add; + } + if (bad) continue; + np.prev_pos = prev_pos; + np.src_cov += x.f_.size(); + np.trg_cov += x.e_.size(); + if (x.f_.size() != src_len) continue; + prob_t rp = m.RuleProbability(x); + np.gamma_last = rp * jp; + const prob_t u = pow(np.gamma_last * be(np.src_cv, np.trg_cov), 0.2); + //cerr << "**rule=" << x << endl; + //cerr << " u=" << log(u) << " rule=" << rp << " jump=" << jp << endl; + ss.add(u); + np.rules.push_back(TRulePtr(new TRule(x))); + z += u; + + const bool completed = (p.trg_cov == trg.size()); + if (completed) { + int last_jump = src.size() - p.prev_pos; + assert(last_jump > 0); + p.src_jumps.push_back(last_jump); + p.weight *= m.JumpProbability(last_jump, src.size()); + } + } + } + } + cerr << "number of edges to consider: " << ss.size() << endl; + const int sampled = rng.SelectSample(ss); + prob_t q_n = ss[sampled] / z; + p = tmp_p[sampled]; + //m.IncrementRule(*p.rules.back()); + p.weight *= p.gamma_last / q_n; + cerr << "[w=" << log(p.weight) << "]\tsampled rule: " << p.rules.back()->AsString() << endl; + cerr << p << endl; + } + } // loop over particles (pi = 0 .. particles) + if (done_nothing) all_complete = true; + } + pfss.clear(); + for (int i = 0; i < lps.size(); ++i) + pfss.add(lps[i].weight); + const int sampled = rng.SelectSample(pfss); + ps[ci] = lps[sampled]; + m.IncrementRules(lps[sampled].rules); + m.IncrementJumps(lps[sampled].src_jumps, src.size()); + for (int i = 0; i < lps[sampled].rules.size(); ++i) { cerr << "S:\t" << lps[sampled].rules[i]->AsString() << "\n"; } + cerr << "tmp-LLH: " << log(m.Likelihood()) << endl; + } + cerr << "LLH: " << log(m.Likelihood()) << endl; + for (int sni = 0; sni < 5; ++sni) { + for (int i = 0; i < ps[sni].rules.size(); ++i) { cerr << "\t" << ps[sni].rules[i]->AsString() << endl; } + } + } + return 0; +} + diff --git a/gi/pf/pfnaive.cc b/gi/pf/pfnaive.cc new file mode 100644 index 00000000..43c604c3 --- /dev/null +++ b/gi/pf/pfnaive.cc @@ -0,0 +1,385 @@ +#include +#include +#include + +#include +#include +#include + +#include "base_measures.h" +#include "reachability.h" +#include "viterbi.h" +#include "hg.h" +#include "trule.h" +#include "tdict.h" +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "ccrp_nt.h" +#include "ccrp_onetable.h" + +using namespace std; +using namespace tr1; +namespace po = boost::program_options; + +shared_ptr prng; + +size_t hash_value(const TRule& r) { + size_t h = boost::hash_value(r.e_); + boost::hash_combine(h, -r.lhs_); + boost::hash_combine(h, boost::hash_value(r.f_)); + return h; +} + +bool operator==(const TRule& a, const TRule& b) { + return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); +} + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value()->default_value(1000),"Number of samples") + ("particles,p",po::value()->default_value(30),"Number of particles") + ("filter_frequency,f",po::value()->default_value(5),"Number of time steps between filterings") + ("input,i",po::value(),"Read parallel data from") + ("max_src_phrase",po::value()->default_value(5),"Maximum length of source language phrases") + ("max_trg_phrase",po::value()->default_value(5),"Maximum length of target language phrases") + ("model1,m",po::value(),"Model 1 parameters (used in base distribution)") + ("inverse_model1,M",po::value(),"Inverse Model 1 parameters (used in backward estimate)") + ("model1_interpolation_weight",po::value()->default_value(0.95),"Mixing proportion of model 1 with uniform target distribution") + ("random_seed,S",po::value(), "Random seed"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void ReadParallelCorpus(const string& filename, + vector >* f, + vector >* e, + set* vocab_f, + set* vocab_e) { + f->clear(); + e->clear(); + vocab_f->clear(); + vocab_e->clear(); + istream* in; + if (filename == "-") + in = &cin; + else + in = new ifstream(filename.c_str()); + assert(*in); + string line; + const WordID kDIV = TD::Convert("|||"); + vector tmp; + while(*in) { + getline(*in, line); + if (line.empty() && !*in) break; + e->push_back(vector()); + f->push_back(vector()); + vector& le = e->back(); + vector& lf = f->back(); + tmp.clear(); + TD::ConvertSentence(line, &tmp); + bool isf = true; + for (unsigned i = 0; i < tmp.size(); ++i) { + const int cur = tmp[i]; + if (isf) { + if (kDIV == cur) { isf = false; } else { + lf.push_back(cur); + vocab_f->insert(cur); + } + } else { + assert(cur != kDIV); + le.push_back(cur); + vocab_e->insert(cur); + } + } + assert(isf == false); + } + if (in != &cin) delete in; +} + +struct MyJointModel { + MyJointModel(PhraseJointBase& rcp0) : + rp0(rcp0), base(prob_t::One()), rules(1,1) {} + + void DecrementRule(const TRule& rule) { + if (rules.decrement(rule)) + base /= rp0(rule); + } + + void IncrementRule(const TRule& rule) { + if (rules.increment(rule)) + base *= rp0(rule); + } + + void IncrementRules(const vector& rules) { + for (int i = 0; i < rules.size(); ++i) + IncrementRule(*rules[i]); + } + + void DecrementRules(const vector& rules) { + for (int i = 0; i < rules.size(); ++i) + DecrementRule(*rules[i]); + } + + prob_t RuleProbability(const TRule& rule) const { + prob_t p; p.logeq(rules.logprob(rule, log(rp0(rule)))); + return p; + } + + prob_t Likelihood() const { + prob_t p = base; + prob_t q; q.logeq(rules.log_crp_prob()); + p *= q; + for (unsigned l = 1; l < src_jumps.size(); ++l) { + if (src_jumps[l].num_customers() > 0) { + prob_t q; + q.logeq(src_jumps[l].log_crp_prob()); + p *= q; + } + } + return p; + } + + const PhraseJointBase& rp0; + prob_t base; + CCRP_NoTable rules; + vector > src_jumps; +}; + +struct BackwardEstimateSym { + BackwardEstimateSym(const Model1& m1, + const Model1& invm1, const vector& src, const vector& trg) : + model1_(m1), invmodel1_(invm1), src_(src), trg_(trg) { + } + const prob_t& operator()(unsigned src_cov, unsigned trg_cov) const { + assert(src_cov <= src_.size()); + assert(trg_cov <= trg_.size()); + prob_t& e = cache_[src_cov][trg_cov]; + if (e.is_0()) { + if (trg_cov == trg_.size()) { e = prob_t::One(); return e; } + vector r(src_.size() + 1); r.clear(); + for (int i = src_cov; i < src_.size(); ++i) + r.push_back(src_[i]); + r.push_back(0); // NULL word + const prob_t uniform_alignment(1.0 / r.size()); + e.logeq(log_poisson(trg_.size() - trg_cov, r.size() - 1)); // p(trg len remaining | src len remaining) + for (unsigned j = trg_cov; j < trg_.size(); ++j) { + prob_t p; + for (unsigned i = 0; i < r.size(); ++i) + p += model1_(r[i], trg_[j]); + if (p.is_0()) { + cerr << "ERROR: p(" << TD::Convert(trg_[j]) << " | " << TD::GetString(r) << ") = 0!\n"; + abort(); + } + p *= uniform_alignment; + e *= p; + } + r.pop_back(); + const prob_t inv_uniform(1.0 / (trg_.size() - trg_cov + 1.0)); + prob_t inv; + inv.logeq(log_poisson(r.size(), trg_.size() - trg_cov)); + for (unsigned i = 0; i < r.size(); ++i) { + prob_t p; + for (unsigned j = trg_cov - 1; j < trg_.size(); ++j) + p += invmodel1_(j < trg_cov ? 0 : trg_[j], r[i]); + if (p.is_0()) { + cerr << "ERROR: p_inv(" << TD::Convert(r[i]) << " | " << TD::GetString(trg_) << ") = 0!\n"; + abort(); + } + p *= inv_uniform; + inv *= p; + } + prob_t x = pow(e * inv, 0.5); + e = x; + //cerr << "Forward: " << log(e) << "\tBackward: " << log(inv) << "\t prop: " << log(x) << endl; + } + return e; + } + const Model1& model1_; + const Model1& invmodel1_; + const vector& src_; + const vector& trg_; + mutable unordered_map > cache_; +}; + +struct Particle { + Particle() : weight(prob_t::One()), src_cov(), trg_cov() {} + prob_t weight; + prob_t gamma_last; + vector rules; + int src_cov; + int trg_cov; +}; + +ostream& operator<<(ostream& o, const vector& v) { + for (int i = 0; i < v.size(); ++i) + o << (v[i] ? '1' : '0'); + return o; +} +ostream& operator<<(ostream& o, const Particle& p) { + o << "[src_cov=" << p.src_cov << " trg_cov=" << p.trg_cov << " num_rules=" << p.rules.size() << " w=" << log(p.weight) << ']'; + return o; +} + +void FilterCrapParticlesAndReweight(vector* pps) { + vector& ps = *pps; + SampleSet ss; + for (int i = 0; i < ps.size(); ++i) + ss.add(ps[i].weight); + vector nps; nps.reserve(ps.size()); + const prob_t uniform_weight(1.0 / ps.size()); + for (int i = 0; i < ps.size(); ++i) { + nps.push_back(ps[prng->SelectSample(ss)]); + nps[i].weight = uniform_weight; + } + nps.swap(ps); +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const unsigned kMAX_TRG_PHRASE = conf["max_trg_phrase"].as(); + const unsigned kMAX_SRC_PHRASE = conf["max_src_phrase"].as(); + const unsigned particles = conf["particles"].as(); + const unsigned samples = conf["samples"].as(); + const unsigned rejuv_freq = conf["filter_frequency"].as(); + + if (!conf.count("model1")) { + cerr << argv[0] << "Please use --model1 to specify model 1 parameters\n"; + return 1; + } + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as())); + else + prng.reset(new MT19937); + MT19937& rng = *prng; + + vector > corpuse, corpusf; + set vocabe, vocabf; + cerr << "Reading corpus...\n"; + ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); + cerr << "F-corpus size: " << corpusf.size() << " sentences\t (" << vocabf.size() << " word types)\n"; + cerr << "E-corpus size: " << corpuse.size() << " sentences\t (" << vocabe.size() << " word types)\n"; + assert(corpusf.size() == corpuse.size()); + + const int kLHS = -TD::Convert("X"); + Model1 m1(conf["model1"].as()); + Model1 invm1(conf["inverse_model1"].as()); + +#if 0 + PhraseConditionalBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size()); + MyConditionalModel m(lp0); +#else + PhraseJointBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size(), vocabf.size()); + MyJointModel m(lp0); +#endif + + cerr << "Initializing reachability limits...\n"; + vector ps(corpusf.size()); + vector reaches; reaches.reserve(corpusf.size()); + for (int ci = 0; ci < corpusf.size(); ++ci) + reaches.push_back(Reachability(corpusf[ci].size(), + corpuse[ci].size(), + kMAX_SRC_PHRASE, + kMAX_TRG_PHRASE)); + cerr << "Sampling...\n"; + vector tmp_p(10000); // work space + SampleSet pfss; + for (int SS=0; SS < samples; ++SS) { + for (int ci = 0; ci < corpusf.size(); ++ci) { + vector& src = corpusf[ci]; + vector& trg = corpuse[ci]; + m.DecrementRules(ps[ci].rules); + + BackwardEstimateSym be(m1, invm1, src, trg); + const Reachability& r = reaches[ci]; + vector lps(particles); + + bool all_complete = false; + while(!all_complete) { + SampleSet ss; + + // all particles have now been extended a bit, we will reweight them now + if (lps[0].trg_cov > 0) + FilterCrapParticlesAndReweight(&lps); + + // loop over all particles and extend them + bool done_nothing = true; + for (int pi = 0; pi < particles; ++pi) { + Particle& p = lps[pi]; + int tic = 0; + while(p.trg_cov < trg.size() && tic < rejuv_freq) { + ++tic; + done_nothing = false; + ss.clear(); + TRule x; x.lhs_ = kLHS; + prob_t z; + + for (int trg_len = 1; trg_len <= kMAX_TRG_PHRASE; ++trg_len) { + x.e_.push_back(trg[trg_len - 1 + p.trg_cov]); + for (int src_len = 1; src_len <= kMAX_SRC_PHRASE; ++src_len) { + if (!r.edges[p.src_cov][p.trg_cov][src_len][trg_len]) continue; + + int i = p.src_cov; + assert(ss.size() < tmp_p.size()); // if fails increase tmp_p size + Particle& np = tmp_p[ss.size()]; + np = p; + x.f_.clear(); + for (int j = 0; j < src_len; ++j) + x.f_.push_back(src[i + j]); + np.src_cov += x.f_.size(); + np.trg_cov += x.e_.size(); + prob_t rp = m.RuleProbability(x); + np.gamma_last = rp; + const prob_t u = pow(np.gamma_last * pow(be(np.src_cov, np.trg_cov), 1.2), 0.1); + //cerr << "**rule=" << x << endl; + //cerr << " u=" << log(u) << " rule=" << rp << endl; + ss.add(u); + np.rules.push_back(TRulePtr(new TRule(x))); + z += u; + } + } + //cerr << "number of edges to consider: " << ss.size() << endl; + const int sampled = rng.SelectSample(ss); + prob_t q_n = ss[sampled] / z; + p = tmp_p[sampled]; + //m.IncrementRule(*p.rules.back()); + p.weight *= p.gamma_last / q_n; + //cerr << "[w=" << log(p.weight) << "]\tsampled rule: " << p.rules.back()->AsString() << endl; + //cerr << p << endl; + } + } // loop over particles (pi = 0 .. particles) + if (done_nothing) all_complete = true; + } + pfss.clear(); + for (int i = 0; i < lps.size(); ++i) + pfss.add(lps[i].weight); + const int sampled = rng.SelectSample(pfss); + ps[ci] = lps[sampled]; + m.IncrementRules(lps[sampled].rules); + for (int i = 0; i < lps[sampled].rules.size(); ++i) { cerr << "S:\t" << lps[sampled].rules[i]->AsString() << "\n"; } + cerr << "tmp-LLH: " << log(m.Likelihood()) << endl; + } + cerr << "LLH: " << log(m.Likelihood()) << endl; + } + return 0; +} + diff --git a/gi/pf/reachability.cc b/gi/pf/reachability.cc new file mode 100644 index 00000000..73dd8d39 --- /dev/null +++ b/gi/pf/reachability.cc @@ -0,0 +1,64 @@ +#include "reachability.h" + +#include +#include + +using namespace std; + +struct SState { + SState() : prev_src_covered(), prev_trg_covered() {} + SState(int i, int j) : prev_src_covered(i), prev_trg_covered(j) {} + int prev_src_covered; + int prev_trg_covered; +}; + +void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) { + typedef boost::multi_array, 2> array_type; + array_type a(boost::extents[srclen + 1][trglen + 1]); + a[0][0].push_back(SState()); + for (int i = 0; i < srclen; ++i) { + for (int j = 0; j < trglen; ++j) { + if (a[i][j].size() == 0) continue; + const SState prev(i,j); + for (int k = 1; k <= src_max_phrase_len; ++k) { + if ((i + k) > srclen) continue; + for (int l = 1; l <= trg_max_phrase_len; ++l) { + if ((j + l) > trglen) continue; + a[i + k][j + l].push_back(prev); + } + } + } + } + a[0][0].clear(); + //cerr << "Final cell contains " << a[srclen][trglen].size() << " back pointers\n"; + if (a[srclen][trglen].size() == 0) { + cerr << "Sentence with length (" << srclen << ',' << trglen << ") violates reachability constraints\n"; + return; + } + + typedef boost::multi_array rarray_type; + rarray_type r(boost::extents[srclen + 1][trglen + 1]); + r[srclen][trglen] = true; + for (int i = srclen; i >= 0; --i) { + for (int j = trglen; j >= 0; --j) { + vector& prevs = a[i][j]; + if (!r[i][j]) { prevs.clear(); } + for (int k = 0; k < prevs.size(); ++k) { + r[prevs[k].prev_src_covered][prevs[k].prev_trg_covered] = true; + int src_delta = i - prevs[k].prev_src_covered; + edges[prevs[k].prev_src_covered][prevs[k].prev_trg_covered][src_delta][j - prevs[k].prev_trg_covered] = true; + short &msd = max_src_delta[prevs[k].prev_src_covered][prevs[k].prev_trg_covered]; + if (src_delta > msd) msd = src_delta; + } + } + } + assert(!edges[0][0][1][0]); + assert(!edges[0][0][0][1]); + assert(!edges[0][0][0][0]); + assert(max_src_delta[0][0] > 0); + //cerr << "First cell contains " << b[0][0].size() << " forward pointers\n"; + //for (int i = 0; i < b[0][0].size(); ++i) { + // cerr << " -> (" << b[0][0][i].next_src_covered << "," << b[0][0][i].next_trg_covered << ")\n"; + //} + } + diff --git a/gi/pf/reachability.h b/gi/pf/reachability.h new file mode 100644 index 00000000..98450ec1 --- /dev/null +++ b/gi/pf/reachability.h @@ -0,0 +1,28 @@ +#ifndef _REACHABILITY_H_ +#define _REACHABILITY_H_ + +#include "boost/multi_array.hpp" + +// determines minimum and maximum lengths of outgoing edges from all +// coverage positions such that the alignment path respects src and +// trg maximum phrase sizes +// +// runs in O(n^2 * src_max * trg_max) time but should be relatively fast +// +// currently forbids 0 -> n and n -> 0 alignments + +struct Reachability { + boost::multi_array edges; // edges[src_covered][trg_covered][x][trg_delta] is this edge worth exploring? + boost::multi_array max_src_delta; // msd[src_covered][trg_covered] -- the largest src delta that's valid + + Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) : + edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]), + max_src_delta(boost::extents[srclen][trglen]) { + ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len); + } + + private: + void ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len); +}; + +#endif diff --git a/gi/pf/tpf.cc b/gi/pf/tpf.cc new file mode 100644 index 00000000..7348d21c --- /dev/null +++ b/gi/pf/tpf.cc @@ -0,0 +1,99 @@ +#include +#include +#include + +#include "sampler.h" + +using namespace std; +using namespace tr1; + +shared_ptr prng; + +struct Particle { + Particle() : weight(prob_t::One()) {} + vector states; + prob_t weight; + prob_t gamma_last; +}; + +ostream& operator<<(ostream& os, const Particle& p) { + os << "["; + for (int i = 0; i < p.states.size(); ++i) os << p.states[i] << ' '; + os << "| w=" << log(p.weight) << ']'; + return os; +} + +void Rejuvenate(vector& pps) { + SampleSet ss; + vector nps(pps.size()); + for (int i = 0; i < pps.size(); ++i) { +// cerr << pps[i] << endl; + ss.add(pps[i].weight); + } +// cerr << "REJUVINATING...\n"; + for (int i = 0; i < pps.size(); ++i) { + nps[i] = pps[prng->SelectSample(ss)]; + nps[i].weight = prob_t(1.0 / pps.size()); +// cerr << nps[i] << endl; + } + nps.swap(pps); +// exit(1); +} + +int main(int argc, char** argv) { + const unsigned particles = 100; + prng.reset(new MT19937); + MT19937& rng = *prng; + + // q(a) = 0.8 + // q(b) = 0.8 + // q(c) = 0.4 + SampleSet ssq; + ssq.add(0.4); + ssq.add(0.6); + ssq.add(0); + double qz = 1; + + // p(a) = 0.2 + // p(b) = 0.8 + vector p(3); + p[0] = 0.2; + p[1] = 0.8; + p[2] = 0; + + vector counts(3); + int tot = 0; + + vector pps(particles); + SampleSet ppss; + int LEN = 12; + int PP = 1; + while (pps[0].states.size() < LEN) { + for (int pi = 0; pi < particles; ++pi) { + Particle& prt = pps[pi]; + + bool redo = true; + const Particle savedp = prt; + while (redo) { + redo = false; + for (int i = 0; i < PP; ++i) { + int s = rng.SelectSample(ssq); + double gamma_last = p[s]; + if (!gamma_last) { redo = true; break; } + double q = ssq[s] / qz; + prt.states.push_back(s); + prt.weight *= prob_t(gamma_last / q); + } + if (redo) { prt = savedp; continue; } + } + } + Rejuvenate(pps); + } + ppss.clear(); + for (int i = 0; i < particles; ++i) { ppss.add(pps[i].weight); } + int sp = rng.SelectSample(ppss); + cerr << pps[sp] << endl; + + return 0; +} + diff --git a/m4/acx_pthread.m4 b/m4/acx_pthread.m4 new file mode 100644 index 00000000..2cf20de1 --- /dev/null +++ b/m4/acx_pthread.m4 @@ -0,0 +1,363 @@ +# This was retrieved from +# http://svn.0pointer.de/viewvc/trunk/common/acx_pthread.m4?revision=1277&root=avahi +# See also (perhaps for new versions?) +# http://svn.0pointer.de/viewvc/trunk/common/acx_pthread.m4?root=avahi +# +# We've rewritten the inconsistency check code (from avahi), to work +# more broadly. In particular, it no longer assumes ld accepts -zdefs. +# This caused a restructing of the code, but the functionality has only +# changed a little. + +dnl @synopsis ACX_PTHREAD([ACTION-IF-FOUND[, ACTION-IF-NOT-FOUND]]) +dnl +dnl @summary figure out how to build C programs using POSIX threads +dnl +dnl This macro figures out how to build C programs using POSIX threads. +dnl It sets the PTHREAD_LIBS output variable to the threads library and +dnl linker flags, and the PTHREAD_CFLAGS output variable to any special +dnl C compiler flags that are needed. (The user can also force certain +dnl compiler flags/libs to be tested by setting these environment +dnl variables.) +dnl +dnl Also sets PTHREAD_CC to any special C compiler that is needed for +dnl multi-threaded programs (defaults to the value of CC otherwise). +dnl (This is necessary on AIX to use the special cc_r compiler alias.) +dnl +dnl NOTE: You are assumed to not only compile your program with these +dnl flags, but also link it with them as well. e.g. you should link +dnl with $PTHREAD_CC $CFLAGS $PTHREAD_CFLAGS $LDFLAGS ... $PTHREAD_LIBS +dnl $LIBS +dnl +dnl If you are only building threads programs, you may wish to use +dnl these variables in your default LIBS, CFLAGS, and CC: +dnl +dnl LIBS="$PTHREAD_LIBS $LIBS" +dnl CFLAGS="$CFLAGS $PTHREAD_CFLAGS" +dnl CC="$PTHREAD_CC" +dnl +dnl In addition, if the PTHREAD_CREATE_JOINABLE thread-attribute +dnl constant has a nonstandard name, defines PTHREAD_CREATE_JOINABLE to +dnl that name (e.g. PTHREAD_CREATE_UNDETACHED on AIX). +dnl +dnl ACTION-IF-FOUND is a list of shell commands to run if a threads +dnl library is found, and ACTION-IF-NOT-FOUND is a list of commands to +dnl run it if it is not found. If ACTION-IF-FOUND is not specified, the +dnl default action will define HAVE_PTHREAD. +dnl +dnl Please let the authors know if this macro fails on any platform, or +dnl if you have any other suggestions or comments. This macro was based +dnl on work by SGJ on autoconf scripts for FFTW (www.fftw.org) (with +dnl help from M. Frigo), as well as ac_pthread and hb_pthread macros +dnl posted by Alejandro Forero Cuervo to the autoconf macro repository. +dnl We are also grateful for the helpful feedback of numerous users. +dnl +dnl @category InstalledPackages +dnl @author Steven G. Johnson +dnl @version 2006-05-29 +dnl @license GPLWithACException +dnl +dnl Checks for GCC shared/pthread inconsistency based on work by +dnl Marcin Owsiany + + +AC_DEFUN([ACX_PTHREAD], [ +AC_REQUIRE([AC_CANONICAL_HOST]) +AC_LANG_SAVE +AC_LANG_C +acx_pthread_ok=no + +# We used to check for pthread.h first, but this fails if pthread.h +# requires special compiler flags (e.g. on True64 or Sequent). +# It gets checked for in the link test anyway. + +# First of all, check if the user has set any of the PTHREAD_LIBS, +# etcetera environment variables, and if threads linking works using +# them: +if test x"$PTHREAD_LIBS$PTHREAD_CFLAGS" != x; then + save_CFLAGS="$CFLAGS" + CFLAGS="$CFLAGS $PTHREAD_CFLAGS" + save_LIBS="$LIBS" + LIBS="$PTHREAD_LIBS $LIBS" + AC_MSG_CHECKING([for pthread_join in LIBS=$PTHREAD_LIBS with CFLAGS=$PTHREAD_CFLAGS]) + AC_TRY_LINK_FUNC(pthread_join, acx_pthread_ok=yes) + AC_MSG_RESULT($acx_pthread_ok) + if test x"$acx_pthread_ok" = xno; then + PTHREAD_LIBS="" + PTHREAD_CFLAGS="" + fi + LIBS="$save_LIBS" + CFLAGS="$save_CFLAGS" +fi + +# We must check for the threads library under a number of different +# names; the ordering is very important because some systems +# (e.g. DEC) have both -lpthread and -lpthreads, where one of the +# libraries is broken (non-POSIX). + +# Create a list of thread flags to try. Items starting with a "-" are +# C compiler flags, and other items are library names, except for "none" +# which indicates that we try without any flags at all, and "pthread-config" +# which is a program returning the flags for the Pth emulation library. + +acx_pthread_flags="pthreads none -Kthread -kthread lthread -pthread -pthreads -mthreads pthread --thread-safe -mt pthread-config" + +# The ordering *is* (sometimes) important. Some notes on the +# individual items follow: + +# pthreads: AIX (must check this before -lpthread) +# none: in case threads are in libc; should be tried before -Kthread and +# other compiler flags to prevent continual compiler warnings +# -Kthread: Sequent (threads in libc, but -Kthread needed for pthread.h) +# -kthread: FreeBSD kernel threads (preferred to -pthread since SMP-able) +# lthread: LinuxThreads port on FreeBSD (also preferred to -pthread) +# -pthread: Linux/gcc (kernel threads), BSD/gcc (userland threads) +# -pthreads: Solaris/gcc +# -mthreads: Mingw32/gcc, Lynx/gcc +# -mt: Sun Workshop C (may only link SunOS threads [-lthread], but it +# doesn't hurt to check since this sometimes defines pthreads too; +# also defines -D_REENTRANT) +# ... -mt is also the pthreads flag for HP/aCC +# pthread: Linux, etcetera +# --thread-safe: KAI C++ +# pthread-config: use pthread-config program (for GNU Pth library) + +case "${host_cpu}-${host_os}" in + *solaris*) + + # On Solaris (at least, for some versions), libc contains stubbed + # (non-functional) versions of the pthreads routines, so link-based + # tests will erroneously succeed. (We need to link with -pthreads/-mt/ + # -lpthread.) (The stubs are missing pthread_cleanup_push, or rather + # a function called by this macro, so we could check for that, but + # who knows whether they'll stub that too in a future libc.) So, + # we'll just look for -pthreads and -lpthread first: + + acx_pthread_flags="-pthreads pthread -mt -pthread $acx_pthread_flags" + ;; +esac + +if test x"$acx_pthread_ok" = xno; then +for flag in $acx_pthread_flags; do + + case $flag in + none) + AC_MSG_CHECKING([whether pthreads work without any flags]) + ;; + + -*) + AC_MSG_CHECKING([whether pthreads work with $flag]) + PTHREAD_CFLAGS="$flag" + ;; + + pthread-config) + AC_CHECK_PROG(acx_pthread_config, pthread-config, yes, no) + if test x"$acx_pthread_config" = xno; then continue; fi + PTHREAD_CFLAGS="`pthread-config --cflags`" + PTHREAD_LIBS="`pthread-config --ldflags` `pthread-config --libs`" + ;; + + *) + AC_MSG_CHECKING([for the pthreads library -l$flag]) + PTHREAD_LIBS="-l$flag" + ;; + esac + + save_LIBS="$LIBS" + save_CFLAGS="$CFLAGS" + LIBS="$PTHREAD_LIBS $LIBS" + CFLAGS="$CFLAGS $PTHREAD_CFLAGS" + + # Check for various functions. We must include pthread.h, + # since some functions may be macros. (On the Sequent, we + # need a special flag -Kthread to make this header compile.) + # We check for pthread_join because it is in -lpthread on IRIX + # while pthread_create is in libc. We check for pthread_attr_init + # due to DEC craziness with -lpthreads. We check for + # pthread_cleanup_push because it is one of the few pthread + # functions on Solaris that doesn't have a non-functional libc stub. + # We try pthread_create on general principles. + AC_TRY_LINK([#include ], + [pthread_t th; pthread_join(th, 0); + pthread_attr_init(0); pthread_cleanup_push(0, 0); + pthread_create(0,0,0,0); pthread_cleanup_pop(0); ], + [acx_pthread_ok=yes]) + + LIBS="$save_LIBS" + CFLAGS="$save_CFLAGS" + + AC_MSG_RESULT($acx_pthread_ok) + if test "x$acx_pthread_ok" = xyes; then + break; + fi + + PTHREAD_LIBS="" + PTHREAD_CFLAGS="" +done +fi + +# Various other checks: +if test "x$acx_pthread_ok" = xyes; then + save_LIBS="$LIBS" + LIBS="$PTHREAD_LIBS $LIBS" + save_CFLAGS="$CFLAGS" + CFLAGS="$CFLAGS $PTHREAD_CFLAGS" + + # Detect AIX lossage: JOINABLE attribute is called UNDETACHED. + AC_MSG_CHECKING([for joinable pthread attribute]) + attr_name=unknown + for attr in PTHREAD_CREATE_JOINABLE PTHREAD_CREATE_UNDETACHED; do + AC_TRY_LINK([#include ], [int attr=$attr; return attr;], + [attr_name=$attr; break]) + done + AC_MSG_RESULT($attr_name) + if test "$attr_name" != PTHREAD_CREATE_JOINABLE; then + AC_DEFINE_UNQUOTED(PTHREAD_CREATE_JOINABLE, $attr_name, + [Define to necessary symbol if this constant + uses a non-standard name on your system.]) + fi + + AC_MSG_CHECKING([if more special flags are required for pthreads]) + flag=no + case "${host_cpu}-${host_os}" in + *-aix* | *-freebsd* | *-darwin*) flag="-D_THREAD_SAFE";; + *solaris* | *-osf* | *-hpux*) flag="-D_REENTRANT";; + esac + AC_MSG_RESULT(${flag}) + if test "x$flag" != xno; then + PTHREAD_CFLAGS="$flag $PTHREAD_CFLAGS" + fi + + LIBS="$save_LIBS" + CFLAGS="$save_CFLAGS" + # More AIX lossage: must compile with xlc_r or cc_r + if test x"$GCC" != xyes; then + AC_CHECK_PROGS(PTHREAD_CC, xlc_r cc_r, ${CC}) + else + PTHREAD_CC=$CC + fi + + # The next part tries to detect GCC inconsistency with -shared on some + # architectures and systems. The problem is that in certain + # configurations, when -shared is specified, GCC "forgets" to + # internally use various flags which are still necessary. + + # + # Prepare the flags + # + save_CFLAGS="$CFLAGS" + save_LIBS="$LIBS" + save_CC="$CC" + + # Try with the flags determined by the earlier checks. + # + # -Wl,-z,defs forces link-time symbol resolution, so that the + # linking checks with -shared actually have any value + # + # FIXME: -fPIC is required for -shared on many architectures, + # so we specify it here, but the right way would probably be to + # properly detect whether it is actually required. + CFLAGS="-shared -fPIC -Wl,-z,defs $CFLAGS $PTHREAD_CFLAGS" + LIBS="$PTHREAD_LIBS $LIBS" + CC="$PTHREAD_CC" + + # In order not to create several levels of indentation, we test + # the value of "$done" until we find the cure or run out of ideas. + done="no" + + # First, make sure the CFLAGS we added are actually accepted by our + # compiler. If not (and OS X's ld, for instance, does not accept -z), + # then we can't do this test. + if test x"$done" = xno; then + AC_MSG_CHECKING([whether to check for GCC pthread/shared inconsistencies]) + AC_TRY_LINK(,, , [done=yes]) + + if test "x$done" = xyes ; then + AC_MSG_RESULT([no]) + else + AC_MSG_RESULT([yes]) + fi + fi + + if test x"$done" = xno; then + AC_MSG_CHECKING([whether -pthread is sufficient with -shared]) + AC_TRY_LINK([#include ], + [pthread_t th; pthread_join(th, 0); + pthread_attr_init(0); pthread_cleanup_push(0, 0); + pthread_create(0,0,0,0); pthread_cleanup_pop(0); ], + [done=yes]) + + if test "x$done" = xyes; then + AC_MSG_RESULT([yes]) + else + AC_MSG_RESULT([no]) + fi + fi + + # + # Linux gcc on some architectures such as mips/mipsel forgets + # about -lpthread + # + if test x"$done" = xno; then + AC_MSG_CHECKING([whether -lpthread fixes that]) + LIBS="-lpthread $PTHREAD_LIBS $save_LIBS" + AC_TRY_LINK([#include ], + [pthread_t th; pthread_join(th, 0); + pthread_attr_init(0); pthread_cleanup_push(0, 0); + pthread_create(0,0,0,0); pthread_cleanup_pop(0); ], + [done=yes]) + + if test "x$done" = xyes; then + AC_MSG_RESULT([yes]) + PTHREAD_LIBS="-lpthread $PTHREAD_LIBS" + else + AC_MSG_RESULT([no]) + fi + fi + # + # FreeBSD 4.10 gcc forgets to use -lc_r instead of -lc + # + if test x"$done" = xno; then + AC_MSG_CHECKING([whether -lc_r fixes that]) + LIBS="-lc_r $PTHREAD_LIBS $save_LIBS" + AC_TRY_LINK([#include ], + [pthread_t th; pthread_join(th, 0); + pthread_attr_init(0); pthread_cleanup_push(0, 0); + pthread_create(0,0,0,0); pthread_cleanup_pop(0); ], + [done=yes]) + + if test "x$done" = xyes; then + AC_MSG_RESULT([yes]) + PTHREAD_LIBS="-lc_r $PTHREAD_LIBS" + else + AC_MSG_RESULT([no]) + fi + fi + if test x"$done" = xno; then + # OK, we have run out of ideas + AC_MSG_WARN([Impossible to determine how to use pthreads with shared libraries]) + + # so it's not safe to assume that we may use pthreads + acx_pthread_ok=no + fi + + CFLAGS="$save_CFLAGS" + LIBS="$save_LIBS" + CC="$save_CC" +else + PTHREAD_CC="$CC" +fi + +AC_SUBST(PTHREAD_LIBS) +AC_SUBST(PTHREAD_CFLAGS) +AC_SUBST(PTHREAD_CC) + +# Finally, execute ACTION-IF-FOUND/ACTION-IF-NOT-FOUND: +if test x"$acx_pthread_ok" = xyes; then + ifelse([$1],,AC_DEFINE(HAVE_PTHREAD,1,[Define if you have POSIX threads libraries and header files.]),[$1]) + : +else + acx_pthread_ok=no + $2 +fi +AC_LANG_RESTORE +])dnl ACX_PTHREAD diff --git a/utils/ccrp_nt.h b/utils/ccrp_nt.h new file mode 100644 index 00000000..63b6f4c2 --- /dev/null +++ b/utils/ccrp_nt.h @@ -0,0 +1,169 @@ +#ifndef _CCRP_NT_H_ +#define _CCRP_NT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "sampler.h" +#include "slice_sampler.h" + +// Chinese restaurant process (1 parameter) +template > +class CCRP_NoTable { + public: + explicit CCRP_NoTable(double conc) : + num_customers_(), + concentration_(conc), + concentration_prior_shape_(std::numeric_limits::quiet_NaN()), + concentration_prior_rate_(std::numeric_limits::quiet_NaN()) {} + + CCRP_NoTable(double c_shape, double c_rate, double c = 10.0) : + num_customers_(), + concentration_(c), + concentration_prior_shape_(c_shape), + concentration_prior_rate_(c_rate) {} + + double concentration() const { return concentration_; } + + bool has_concentration_prior() const { + return !std::isnan(concentration_prior_shape_); + } + + void clear() { + num_customers_ = 0; + custs_.clear(); + } + + unsigned num_customers() const { + return num_customers_; + } + + unsigned num_customers(const Dish& dish) const { + const typename std::tr1::unordered_map::const_iterator it = custs_.find(dish); + if (it == custs_.end()) return 0; + return it->second; + } + + int increment(const Dish& dish) { + int table_diff = 0; + if (++custs_[dish] == 1) + table_diff = 1; + ++num_customers_; + return table_diff; + } + + int decrement(const Dish& dish) { + int table_diff = 0; + int nc = --custs_[dish]; + if (nc == 0) { + custs_.erase(dish); + table_diff = -1; + } else if (nc < 0) { + std::cerr << "Dish counts dropped below zero for: " << dish << std::endl; + abort(); + } + --num_customers_; + return table_diff; + } + + double prob(const Dish& dish, const double& p0) const { + const unsigned at_table = num_customers(dish); + return (at_table + p0 * concentration_) / (num_customers_ + concentration_); + } + + double logprob(const Dish& dish, const double& logp0) const { + const unsigned at_table = num_customers(dish); + return log(at_table + exp(logp0 + log(concentration_))) - log(num_customers_ + concentration_); + } + + double log_crp_prob() const { + return log_crp_prob(concentration_); + } + + static double log_gamma_density(const double& x, const double& shape, const double& rate) { + assert(x >= 0.0); + assert(shape > 0.0); + assert(rate > 0.0); + const double lp = (shape-1)*log(x) - shape*log(rate) - x/rate - lgamma(shape); + return lp; + } + + // taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process + // does not include P_0's + double log_crp_prob(const double& concentration) const { + double lp = 0.0; + if (has_concentration_prior()) + lp += log_gamma_density(concentration, concentration_prior_shape_, concentration_prior_rate_); + assert(lp <= 0.0); + if (num_customers_) { + lp += lgamma(concentration) - lgamma(concentration + num_customers_) + + custs_.size() * log(concentration); + assert(std::isfinite(lp)); + for (typename std::tr1::unordered_map::const_iterator it = custs_.begin(); + it != custs_.end(); ++it) { + lp += lgamma(it->second); + } + } + assert(std::isfinite(lp)); + return lp; + } + + void resample_hyperparameters(MT19937* rng, const unsigned nloop = 5, const unsigned niterations = 10) { + assert(has_concentration_prior()); + ConcentrationResampler cr(*this); + for (int iter = 0; iter < nloop; ++iter) { + concentration_ = slice_sampler1d(cr, concentration_, *rng, 0.0, + std::numeric_limits::infinity(), 0.0, niterations, 100*niterations); + } + } + + struct ConcentrationResampler { + ConcentrationResampler(const CCRP_NoTable& crp) : crp_(crp) {} + const CCRP_NoTable& crp_; + double operator()(const double& proposed_concentration) const { + return crp_.log_crp_prob(proposed_concentration); + } + }; + + void Print(std::ostream* out) const { + (*out) << "DP(alpha=" << concentration_ << ") customers=" << num_customers_ << std::endl; + int cc = 0; + for (typename std::tr1::unordered_map::const_iterator it = custs_.begin(); + it != custs_.end(); ++it) { + (*out) << " " << it->first << "(" << it->second << " eating)"; + ++cc; + if (cc > 10) { (*out) << " ..."; break; } + } + (*out) << std::endl; + } + + unsigned num_customers_; + std::tr1::unordered_map custs_; + + typedef typename std::tr1::unordered_map::const_iterator const_iterator; + const_iterator begin() const { + return custs_.begin(); + } + const_iterator end() const { + return custs_.end(); + } + + double concentration_; + + // optional gamma prior on concentration_ (NaN if no prior) + double concentration_prior_shape_; + double concentration_prior_rate_; +}; + +template +std::ostream& operator<<(std::ostream& o, const CCRP_NoTable& c) { + c.Print(&o); + return o; +} + +#endif diff --git a/utils/ccrp_onetable.h b/utils/ccrp_onetable.h new file mode 100644 index 00000000..a868af9a --- /dev/null +++ b/utils/ccrp_onetable.h @@ -0,0 +1,241 @@ +#ifndef _CCRP_ONETABLE_H_ +#define _CCRP_ONETABLE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "sampler.h" +#include "slice_sampler.h" + +// Chinese restaurant process (Pitman-Yor parameters) with one table approximation + +template > +class CCRP_OneTable { + typedef std::tr1::unordered_map DishMapType; + public: + CCRP_OneTable(double disc, double conc) : + num_tables_(), + num_customers_(), + discount_(disc), + concentration_(conc), + discount_prior_alpha_(std::numeric_limits::quiet_NaN()), + discount_prior_beta_(std::numeric_limits::quiet_NaN()), + concentration_prior_shape_(std::numeric_limits::quiet_NaN()), + concentration_prior_rate_(std::numeric_limits::quiet_NaN()) {} + + CCRP_OneTable(double d_alpha, double d_beta, double c_shape, double c_rate, double d = 0.9, double c = 1.0) : + num_tables_(), + num_customers_(), + discount_(d), + concentration_(c), + discount_prior_alpha_(d_alpha), + discount_prior_beta_(d_beta), + concentration_prior_shape_(c_shape), + concentration_prior_rate_(c_rate) {} + + double discount() const { return discount_; } + double concentration() const { return concentration_; } + void set_concentration(double c) { concentration_ = c; } + void set_discount(double d) { discount_ = d; } + + bool has_discount_prior() const { + return !std::isnan(discount_prior_alpha_); + } + + bool has_concentration_prior() const { + return !std::isnan(concentration_prior_shape_); + } + + void clear() { + num_tables_ = 0; + num_customers_ = 0; + dish_counts_.clear(); + } + + unsigned num_tables() const { + return num_tables_; + } + + unsigned num_tables(const Dish& dish) const { + const typename DishMapType::const_iterator it = dish_counts_.find(dish); + if (it == dish_counts_.end()) return 0; + return 1; + } + + unsigned num_customers() const { + return num_customers_; + } + + unsigned num_customers(const Dish& dish) const { + const typename DishMapType::const_iterator it = dish_counts_.find(dish); + if (it == dish_counts_.end()) return 0; + return it->second; + } + + // returns +1 or 0 indicating whether a new table was opened + int increment(const Dish& dish) { + unsigned& dc = dish_counts_[dish]; + ++dc; + ++num_customers_; + if (dc == 1) { + ++num_tables_; + return 1; + } else { + return 0; + } + } + + // returns -1 or 0, indicating whether a table was closed + int decrement(const Dish& dish) { + unsigned& dc = dish_counts_[dish]; + assert(dc > 0); + if (dc == 1) { + dish_counts_.erase(dish); + --num_tables_; + --num_customers_; + return -1; + } else { + assert(dc > 1); + --dc; + --num_customers_; + return 0; + } + } + + double prob(const Dish& dish, const double& p0) const { + const typename DishMapType::const_iterator it = dish_counts_.find(dish); + const double r = num_tables_ * discount_ + concentration_; + if (it == dish_counts_.end()) { + return r * p0 / (num_customers_ + concentration_); + } else { + return (it->second - discount_ + r * p0) / + (num_customers_ + concentration_); + } + } + + double log_crp_prob() const { + return log_crp_prob(discount_, concentration_); + } + + static double log_beta_density(const double& x, const double& alpha, const double& beta) { + assert(x > 0.0); + assert(x < 1.0); + assert(alpha > 0.0); + assert(beta > 0.0); + const double lp = (alpha-1)*log(x)+(beta-1)*log(1-x)+lgamma(alpha+beta)-lgamma(alpha)-lgamma(beta); + return lp; + } + + static double log_gamma_density(const double& x, const double& shape, const double& rate) { + assert(x >= 0.0); + assert(shape > 0.0); + assert(rate > 0.0); + const double lp = (shape-1)*log(x) - shape*log(rate) - x/rate - lgamma(shape); + return lp; + } + + // taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process + // does not include P_0's + double log_crp_prob(const double& discount, const double& concentration) const { + double lp = 0.0; + if (has_discount_prior()) + lp = log_beta_density(discount, discount_prior_alpha_, discount_prior_beta_); + if (has_concentration_prior()) + lp += log_gamma_density(concentration, concentration_prior_shape_, concentration_prior_rate_); + assert(lp <= 0.0); + if (num_customers_) { + if (discount > 0.0) { + const double r = lgamma(1.0 - discount); + lp += lgamma(concentration) - lgamma(concentration + num_customers_) + + num_tables_ * log(discount) + lgamma(concentration / discount + num_tables_) + - lgamma(concentration / discount); + assert(std::isfinite(lp)); + for (typename DishMapType::const_iterator it = dish_counts_.begin(); + it != dish_counts_.end(); ++it) { + const unsigned& cur = it->second; + lp += lgamma(cur - discount) - r; + } + } else { + assert(!"not implemented yet"); + } + } + assert(std::isfinite(lp)); + return lp; + } + + void resample_hyperparameters(MT19937* rng, const unsigned nloop = 5, const unsigned niterations = 10) { + assert(has_discount_prior() || has_concentration_prior()); + DiscountResampler dr(*this); + ConcentrationResampler cr(*this); + for (int iter = 0; iter < nloop; ++iter) { + if (has_concentration_prior()) { + concentration_ = slice_sampler1d(cr, concentration_, *rng, 0.0, + std::numeric_limits::infinity(), 0.0, niterations, 100*niterations); + } + if (has_discount_prior()) { + discount_ = slice_sampler1d(dr, discount_, *rng, std::numeric_limits::min(), + 1.0, 0.0, niterations, 100*niterations); + } + } + concentration_ = slice_sampler1d(cr, concentration_, *rng, 0.0, + std::numeric_limits::infinity(), 0.0, niterations, 100*niterations); + } + + struct DiscountResampler { + DiscountResampler(const CCRP_OneTable& crp) : crp_(crp) {} + const CCRP_OneTable& crp_; + double operator()(const double& proposed_discount) const { + return crp_.log_crp_prob(proposed_discount, crp_.concentration_); + } + }; + + struct ConcentrationResampler { + ConcentrationResampler(const CCRP_OneTable& crp) : crp_(crp) {} + const CCRP_OneTable& crp_; + double operator()(const double& proposed_concentration) const { + return crp_.log_crp_prob(crp_.discount_, proposed_concentration); + } + }; + + void Print(std::ostream* out) const { + (*out) << "PYP(d=" << discount_ << ",c=" << concentration_ << ") customers=" << num_customers_ << std::endl; + for (typename DishMapType::const_iterator it = dish_counts_.begin(); it != dish_counts_.end(); ++it) { + (*out) << " " << it->first << " = " << it->second << std::endl; + } + } + + typedef typename DishMapType::const_iterator const_iterator; + const_iterator begin() const { + return dish_counts_.begin(); + } + const_iterator end() const { + return dish_counts_.end(); + } + + unsigned num_tables_; + unsigned num_customers_; + DishMapType dish_counts_; + + double discount_; + double concentration_; + + // optional beta prior on discount_ (NaN if no prior) + double discount_prior_alpha_; + double discount_prior_beta_; + + // optional gamma prior on concentration_ (NaN if no prior) + double concentration_prior_shape_; + double concentration_prior_rate_; +}; + +template +std::ostream& operator<<(std::ostream& o, const CCRP_OneTable& c) { + c.Print(&o); + return o; +} + +#endif -- cgit v1.2.3 From 4671d578bd6d97105ac75b02e0144fe0df3abcb0 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 11 Oct 2011 12:56:49 +0100 Subject: missing numwords impl --- utils/tdict.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/utils/tdict.cc b/utils/tdict.cc index c21b2b48..de234323 100644 --- a/utils/tdict.cc +++ b/utils/tdict.cc @@ -13,6 +13,10 @@ using namespace std; Dict TD::dict_; +unsigned int TD::NumWords() { + return dict_.max(); +} + WordID TD::Convert(const std::string& s) { return dict_.Convert(s); } -- cgit v1.2.3 From 0af7d663194beddcde420349bbd91430e0b2e423 Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Tue, 11 Oct 2011 16:16:53 +0100 Subject: remove implicit conversion-to-double operator from LogVal that caused overflow errors, clean up some pf code --- decoder/aligner.cc | 2 +- decoder/cfg.cc | 2 +- decoder/cfg_format.h | 2 +- decoder/decoder.cc | 10 ++++---- decoder/hg.cc | 4 ++-- decoder/rule_lexer.l | 2 ++ decoder/trule.h | 15 +++++++++++- gi/pf/brat.cc | 11 --------- gi/pf/cbgi.cc | 10 -------- gi/pf/dpnaive.cc | 12 ---------- gi/pf/itg.cc | 11 --------- gi/pf/pfbrat.cc | 11 --------- gi/pf/pfdist.cc | 11 --------- gi/pf/pfnaive.cc | 11 --------- mteval/mbr_kbest.cc | 4 ++-- phrasinator/ccrp_nt.h | 24 +++++++++++++++---- training/mpi_batch_optimize.cc | 2 +- training/mpi_compute_cllh.cc | 51 +++++++++++++++++++---------------------- training/mpi_online_optimize.cc | 4 ++-- utils/logval.h | 10 ++++---- 20 files changed, 78 insertions(+), 131 deletions(-) diff --git a/decoder/aligner.cc b/decoder/aligner.cc index 292ee123..53e059fb 100644 --- a/decoder/aligner.cc +++ b/decoder/aligner.cc @@ -165,7 +165,7 @@ inline void WriteProbGrid(const Array2D& m, ostream* pos) { if (m(i,j) == prob_t::Zero()) { os << "\t---X---"; } else { - snprintf(b, 1024, "%0.5f", static_cast(m(i,j))); + snprintf(b, 1024, "%0.5f", m(i,j).as_float()); os << '\t' << b; } } diff --git a/decoder/cfg.cc b/decoder/cfg.cc index 651978d2..cd7e66e9 100755 --- a/decoder/cfg.cc +++ b/decoder/cfg.cc @@ -639,7 +639,7 @@ void CFG::Print(std::ostream &o,CFGFormat const& f) const { o << '['<& src, SparseVector* trg) { for (SparseVector::const_iterator it = src.begin(); it != src.end(); ++it) - trg->set_value(it->first, it->second); + trg->set_value(it->first, it->second.as_float()); } }; @@ -788,10 +788,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { const bool show_tree_structure=conf.count("show_tree_structure"); if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,oracle.show_derivation); if (conf.count("show_expected_length")) { - const PRPair res = - Inside, - PRWeightFunction >(forest); - cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl; + const PRPair res = + Inside, + PRWeightFunction >(forest); + cerr << " Expected length (words): " << (res.r / res.p).as_float() << "\t" << res << endl; } if (conf.count("show_partition")) { diff --git a/decoder/hg.cc b/decoder/hg.cc index 3ad17f1a..180986d7 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -157,14 +157,14 @@ prob_t Hypergraph::ComputeEdgePosteriors(double scale, vector* posts) co const ScaledEdgeProb weight(scale); const ScaledTransitionEventWeightFunction w2(scale); SparseVector pv; - const double inside = InsideOutside, ScaledTransitionEventWeightFunction>(*this, &pv, weight, w2); posts->resize(edges_.size()); for (int i = 0; i < edges_.size(); ++i) (*posts)[i] = prob_t(pv.value(i)); - return prob_t(inside); + return inside; } prob_t Hypergraph::ComputeBestPathThroughEdges(vector* post) const { diff --git a/decoder/rule_lexer.l b/decoder/rule_lexer.l index 9331d8ed..083a5bb1 100644 --- a/decoder/rule_lexer.l +++ b/decoder/rule_lexer.l @@ -220,6 +220,8 @@ NT [^\t \[\],]+ std::cerr << "Line " << lex_line << ": LHS and RHS arity mismatch!\n"; abort(); } + // const bool ignore_grammar_features = false; + // if (ignore_grammar_features) scfglex_num_feats = 0; TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity, scfglex_als, scfglex_num_als)); check_and_update_ctf_stack(rp); TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top()); diff --git a/decoder/trule.h b/decoder/trule.h index 4df4ec90..8eb2a059 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -5,7 +5,9 @@ #include #include #include -#include + +#include "boost/shared_ptr.hpp" +#include "boost/functional/hash.hpp" #include "sparse_vector.h" #include "wordid.h" @@ -162,4 +164,15 @@ class TRule { bool SanityCheck() const; }; +inline size_t hash_value(const TRule& r) { + size_t h = boost::hash_value(r.e_); + boost::hash_combine(h, -r.lhs_); + boost::hash_combine(h, boost::hash_value(r.f_)); + return h; +} + +inline bool operator==(const TRule& a, const TRule& b) { + return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); +} + #endif diff --git a/gi/pf/brat.cc b/gi/pf/brat.cc index 4c6ba3ef..7b60ef23 100644 --- a/gi/pf/brat.cc +++ b/gi/pf/brat.cc @@ -25,17 +25,6 @@ static unsigned kMAX_SRC_PHRASE; static unsigned kMAX_TRG_PHRASE; struct FSTState; -size_t hash_value(const TRule& r) { - size_t h = 2 - r.lhs_; - boost::hash_combine(h, boost::hash_value(r.e_)); - boost::hash_combine(h, boost::hash_value(r.f_)); - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} - double log_poisson(unsigned x, const double& lambda) { assert(lambda > 0.0); return log(lambda) * x - lgamma(x + 1) - lambda; diff --git a/gi/pf/cbgi.cc b/gi/pf/cbgi.cc index 20204e8a..97f1ba34 100644 --- a/gi/pf/cbgi.cc +++ b/gi/pf/cbgi.cc @@ -27,16 +27,6 @@ double log_decay(unsigned x, const double& b) { return log(b - 1) - x * log(b); } -size_t hash_value(const TRule& r) { - // TODO fix hash function - size_t h = boost::hash_value(r.e_) * boost::hash_value(r.f_) * r.lhs_; - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} - struct SimpleBase { SimpleBase(unsigned esize, unsigned fsize, unsigned ntsize = 144) : uniform_e(-log(esize)), diff --git a/gi/pf/dpnaive.cc b/gi/pf/dpnaive.cc index 582d1be7..608f73d5 100644 --- a/gi/pf/dpnaive.cc +++ b/gi/pf/dpnaive.cc @@ -20,18 +20,6 @@ namespace po = boost::program_options; static unsigned kMAX_SRC_PHRASE; static unsigned kMAX_TRG_PHRASE; -struct FSTState; - -size_t hash_value(const TRule& r) { - size_t h = 2 - r.lhs_; - boost::hash_combine(h, boost::hash_value(r.e_)); - boost::hash_combine(h, boost::hash_value(r.f_)); - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); diff --git a/gi/pf/itg.cc b/gi/pf/itg.cc index 2c2a86f9..ac3c16a3 100644 --- a/gi/pf/itg.cc +++ b/gi/pf/itg.cc @@ -27,17 +27,6 @@ ostream& operator<<(ostream& os, const vector& p) { return os << ']'; } -size_t hash_value(const TRule& r) { - size_t h = boost::hash_value(r.e_); - boost::hash_combine(h, -r.lhs_); - boost::hash_combine(h, boost::hash_value(r.f_)); - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} - double log_poisson(unsigned x, const double& lambda) { assert(lambda > 0.0); return log(lambda) * x - lgamma(x + 1) - lambda; diff --git a/gi/pf/pfbrat.cc b/gi/pf/pfbrat.cc index 4c6ba3ef..7b60ef23 100644 --- a/gi/pf/pfbrat.cc +++ b/gi/pf/pfbrat.cc @@ -25,17 +25,6 @@ static unsigned kMAX_SRC_PHRASE; static unsigned kMAX_TRG_PHRASE; struct FSTState; -size_t hash_value(const TRule& r) { - size_t h = 2 - r.lhs_; - boost::hash_combine(h, boost::hash_value(r.e_)); - boost::hash_combine(h, boost::hash_value(r.f_)); - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} - double log_poisson(unsigned x, const double& lambda) { assert(lambda > 0.0); return log(lambda) * x - lgamma(x + 1) - lambda; diff --git a/gi/pf/pfdist.cc b/gi/pf/pfdist.cc index 18dfd03b..81abd61b 100644 --- a/gi/pf/pfdist.cc +++ b/gi/pf/pfdist.cc @@ -24,17 +24,6 @@ namespace po = boost::program_options; shared_ptr prng; -size_t hash_value(const TRule& r) { - size_t h = boost::hash_value(r.e_); - boost::hash_combine(h, -r.lhs_); - boost::hash_combine(h, boost::hash_value(r.f_)); - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} - void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() diff --git a/gi/pf/pfnaive.cc b/gi/pf/pfnaive.cc index 43c604c3..c30e7c4f 100644 --- a/gi/pf/pfnaive.cc +++ b/gi/pf/pfnaive.cc @@ -24,17 +24,6 @@ namespace po = boost::program_options; shared_ptr prng; -size_t hash_value(const TRule& r) { - size_t h = boost::hash_value(r.e_); - boost::hash_combine(h, -r.lhs_); - boost::hash_combine(h, boost::hash_value(r.f_)); - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} - void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc index 2867b36b..64a6a8bf 100644 --- a/mteval/mbr_kbest.cc +++ b/mteval/mbr_kbest.cc @@ -32,7 +32,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } struct LossComparer { - bool operator()(const pair, double>& a, const pair, double>& b) const { + bool operator()(const pair, prob_t>& a, const pair, prob_t>& b) const { return a.second < b.second; } }; @@ -108,7 +108,7 @@ int main(int argc, char** argv) { ScoreP s = scorer->ScoreCandidate(list[j].first); double loss = 1.0 - s->ComputeScore(); if (type == TER || type == AER) loss = 1.0 - loss; - double weighted_loss = loss * (joints[j] / marginal); + double weighted_loss = loss * (joints[j] / marginal).as_float(); wl_acc += weighted_loss; if ((!output_list) && wl_acc > mbr_loss) break; } diff --git a/phrasinator/ccrp_nt.h b/phrasinator/ccrp_nt.h index 163b643a..811bce73 100644 --- a/phrasinator/ccrp_nt.h +++ b/phrasinator/ccrp_nt.h @@ -50,15 +50,26 @@ class CCRP_NoTable { return it->second; } - void increment(const Dish& dish) { - ++custs_[dish]; + int increment(const Dish& dish) { + int table_diff = 0; + if (++custs_[dish] == 1) + table_diff = 1; ++num_customers_; + return table_diff; } - void decrement(const Dish& dish) { - if ((--custs_[dish]) == 0) + int decrement(const Dish& dish) { + int table_diff = 0; + int nc = --custs_[dish]; + if (nc == 0) { custs_.erase(dish); + table_diff = -1; + } else if (nc < 0) { + std::cerr << "Dish counts dropped below zero for: " << dish << std::endl; + abort(); + } --num_customers_; + return table_diff; } double prob(const Dish& dish, const double& p0) const { @@ -66,6 +77,11 @@ class CCRP_NoTable { return (at_table + p0 * concentration_) / (num_customers_ + concentration_); } + double logprob(const Dish& dish, const double& logp0) const { + const unsigned at_table = num_customers(dish); + return log(at_table + exp(logp0 + log(concentration_))) - log(num_customers_ + concentration_); + } + double log_crp_prob() const { return log_crp_prob(concentration_); } diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index 0ba8c530..046e921c 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -92,7 +92,7 @@ struct TrainingObserver : public DecoderObserver { void SetLocalGradientAndObjective(vector* g, double* o) const { *o = acc_obj; for (SparseVector::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) - (*g)[it->first] = it->second; + (*g)[it->first] = it->second.as_float(); } virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { diff --git a/training/mpi_compute_cllh.cc b/training/mpi_compute_cllh.cc index b496d196..d5caa745 100644 --- a/training/mpi_compute_cllh.cc +++ b/training/mpi_compute_cllh.cc @@ -1,6 +1,4 @@ -#include #include -#include #include #include #include @@ -12,6 +10,7 @@ #include #include +#include "sentence_metadata.h" #include "verbose.h" #include "hg.h" #include "prob.h" @@ -52,7 +51,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { return true; } -void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* ids) { +void ReadInstances(const string& fname, int rank, int size, vector* c) { + assert(fname != "-"); ReadFile rf(fname); istream& in = *rf.stream(); string line; @@ -60,20 +60,16 @@ void ReadTrainingCorpus(const string& fname, int rank, int size, vector* while(in) { getline(in, line); if (!in) break; - if (lc % size == rank) { - c->push_back(line); - ids->push_back(lc); - } + if (lc % size == rank) c->push_back(line); ++lc; } } static const double kMINUS_EPSILON = -1e-6; -struct TrainingObserver : public DecoderObserver { - void Reset() { - acc_obj = 0; - } +struct ConditionalLikelihoodObserver : public DecoderObserver { + + ConditionalLikelihoodObserver() : trg_words(), acc_obj(), cur_obj() {} virtual void NotifyDecodingStart(const SentenceMetadata&) { cur_obj = 0; @@ -120,8 +116,10 @@ struct TrainingObserver : public DecoderObserver { } assert(!isnan(log_ref_z)); acc_obj += (cur_obj - log_ref_z); + trg_words += smeta.GetReference().size(); } + unsigned trg_words; double acc_obj; double cur_obj; int state; @@ -161,35 +159,32 @@ int main(int argc, char** argv) { if (conf.count("weights")) Weights::InitFromFile(conf["weights"].as(), &weights); - // freeze feature set - //const bool freeze_feature_set = conf.count("freeze_feature_set"); - //if (freeze_feature_set) FD::Freeze(); - - vector corpus; vector ids; - ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); + vector corpus; + ReadInstances(conf["training_data"].as(), rank, size, &corpus); assert(corpus.size() > 0); - assert(corpus.size() == ids.size()); - - TrainingObserver observer; - double objective = 0; - observer.Reset(); if (rank == 0) - cerr << "Each processor is decoding " << corpus.size() << " training examples...\n"; + cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; - for (int i = 0; i < corpus.size(); ++i) { - decoder.SetId(ids[i]); + ConditionalLikelihoodObserver observer; + for (int i = 0; i < corpus.size(); ++i) decoder.Decode(corpus[i], &observer); - } + double objective = 0; + unsigned total_words = 0; #ifdef HAVE_MPI reduce(world, observer.acc_obj, objective, std::plus(), 0); + reduce(world, observer.trg_words, total_words, std::plus(), 0); #else objective = observer.acc_obj; #endif - if (rank == 0) - cout << "OBJECTIVE: " << objective << endl; + if (rank == 0) { + cout << "CONDITIONAL LOG_e LIKELIHOOD: " << objective << endl; + cout << "CONDITIONAL LOG_2 LIKELIHOOD: " << (objective/log(2)) << endl; + cout << " CONDITIONAL ENTROPY: " << (objective/log(2) / total_words) << endl; + cout << " PERPLEXITY: " << pow(2, (objective/log(2) / total_words)) << endl; + } return 0; } diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 2ef4a2e7..f87b7274 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -94,7 +94,7 @@ struct TrainingObserver : public DecoderObserver { void SetLocalGradientAndObjective(vector* g, double* o) const { *o = acc_obj; for (SparseVector::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) - (*g)[it->first] = it->second; + (*g)[it->first] = it->second.as_float(); } virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { @@ -158,7 +158,7 @@ struct TrainingObserver : public DecoderObserver { void GetGradient(SparseVector* g) const { g->clear(); for (SparseVector::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) - g->set_value(it->first, it->second); + g->set_value(it->first, it->second.as_float()); } int total_complete; diff --git a/utils/logval.h b/utils/logval.h index 6fdc2c42..8a59d0b1 100644 --- a/utils/logval.h +++ b/utils/logval.h @@ -25,12 +25,13 @@ class LogVal { typedef LogVal Self; LogVal() : s_(), v_(LOGVAL_LOG0) {} - explicit LogVal(double x) : s_(std::signbit(x)), v_(s_ ? std::log(-x) : std::log(x)) {} + LogVal(double x) : s_(std::signbit(x)), v_(s_ ? std::log(-x) : std::log(x)) {} + const Self& operator=(double x) { s_ = std::signbit(x); v_ = s_ ? std::log(-x) : std::log(x); return *this; } LogVal(init_minus_1) : s_(true),v_(0) { } LogVal(init_1) : s_(),v_(0) { } LogVal(init_0) : s_(),v_(LOGVAL_LOG0) { } - LogVal(int x) : s_(x<0), v_(s_ ? std::log(-x) : std::log(x)) {} - LogVal(unsigned x) : s_(0), v_(std::log(x)) { } + explicit LogVal(int x) : s_(x<0), v_(s_ ? std::log(-x) : std::log(x)) {} + explicit LogVal(unsigned x) : s_(0), v_(std::log(x)) { } LogVal(double lnx,bool sign) : s_(sign),v_(lnx) {} LogVal(double lnx,init_lnx) : s_(),v_(lnx) {} static Self exp(T lnx) { return Self(lnx,false); } @@ -141,9 +142,6 @@ class LogVal { return pow(1/root); } - operator T() const { - if (s_) return -std::exp(v_); else return std::exp(v_); - } T as_float() const { if (s_) return -std::exp(v_); else return std::exp(v_); } -- cgit v1.2.3 From 0e1ffb6c1528e44f63ae8bac466bd5163e973974 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 11 Oct 2011 14:58:52 -0400 Subject: Trie fixes for SRI --- klm/lm/left.hh | 26 +++++++++++++++----------- klm/lm/search_trie.cc | 7 ++++++- klm/lm/search_trie.hh | 2 +- klm/lm/trie_sort.cc | 2 +- klm/util/scoped.hh | 2 +- 5 files changed, 24 insertions(+), 15 deletions(-) diff --git a/klm/lm/left.hh b/klm/lm/left.hh index effa0560..bb3f5539 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -117,7 +117,7 @@ inline size_t hash_value(const ChartState &state) { template class RuleScore { public: - explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { out.left.length = 0; out.right.length = 0; } @@ -130,15 +130,22 @@ template class RuleScore { void Terminal(WordIndex word) { State copy(out_.right); - ProcessRet(model_.FullScore(copy, word, out_.right)); - if (out_.right.length != copy.length + 1) left_done_ = true; + FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); + prob_ += ret.prob; + if (left_done_) return; + if (ret.independent_left) { + left_done_ = true; + return; + } + out_.left.pointers[out_.left.length++] = ret.extend_left; + if (out_.right.length != copy.length + 1) + left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. void BeginNonTerminal(const ChartState &in, float prob) { prob_ = prob; out_ = in; - left_write_ = out_.left.pointers + out_.left.length; left_done_ = in.full; } @@ -157,11 +164,10 @@ template class RuleScore { if (!out_.right.length) { out_.right = in.right; if (left_done_) return; - if (left_write_ != out_.left.pointers) { + if (out_.left.length) { left_done_ = true; } else { out_.left = in.left; - left_write_ = out_.left.pointers + in.left.length; left_done_ = in.full; } return; @@ -214,8 +220,8 @@ template class RuleScore { } float Finish() { - out_.left.length = left_write_ - out_.left.pointers; - out_.full = left_done_; + // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. + out_.full = left_done_ || (out_.left.length == model_.Order() - 1); return prob_; } @@ -227,7 +233,7 @@ template class RuleScore { left_done_ = true; return; } - *(left_write_++) = ret.extend_left; + out_.left.pointers[out_.left.length++] = ret.extend_left; } const M &model_; @@ -236,8 +242,6 @@ template class RuleScore { bool left_done_; - uint64_t *left_write_; - float prob_; }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 6479813b..5d8c70db 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -151,6 +151,11 @@ class BackoffMessages { private: void FinishedAdding() { Resize(current_ - (uint8_t*)backing_.get()); + // Sort requests in same order as files. + std::sort( + util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)), + util::SizedIterator(util::SizedProxy(current_, entry_size_)), + util::SizedCompare(EntryCompare((entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex)))); current_ = (uint8_t*)backing_.get(); } @@ -525,7 +530,7 @@ template void BuildTrie(const std::string &file_pre const RecordReader &context = contexts[order - 2]; if (context) { FormatLoadException e; - e << "An " << static_cast(order) << "-gram has context"; + e << "A " << static_cast(order) << "-gram has context"; const WordIndex *ctx = reinterpret_cast(context.Data()); for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) { e << ' ' << *i; diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index c3e02a98..33ae8cff 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -36,7 +36,7 @@ template class TrieSearch { static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); - static const unsigned int kVersion = 0; + static const unsigned int kVersion = 1; static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 86f28493..bb126f18 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -191,7 +191,7 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); files.push_back(assembled.str()); MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); - MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order, FirstCombine()); + MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order - 1, FirstCombine()); files.pop_front(); files.pop_front(); } diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index 12e6652b..93e2e817 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -50,7 +50,7 @@ class scoped_malloc { void call_realloc(std::size_t to) { void *ret; - UTIL_THROW_IF(!(ret = std::realloc(p_, to)), util::ErrnoException, "realloc to " << to << " bytes failed."); + UTIL_THROW_IF(!(ret = std::realloc(p_, to)) && to, util::ErrnoException, "realloc to " << to << " bytes failed."); p_ = ret; } -- cgit v1.2.3 From a32cd0131c6325e364c82e5f6bbefc03b61e437f Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 11 Oct 2011 14:59:12 -0400 Subject: util/murmur_hash.cc --- klm/util/murmur_hash.cc | 258 ++++++++++++++++++++++++------------------------ 1 file changed, 129 insertions(+), 129 deletions(-) diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index d58a0727..fec47fd9 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -1,129 +1,129 @@ -/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All - * code is released to the public domain. For business purposes, Murmurhash is - * under the MIT license." - * This is modified from the original: - * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. - * length changed to unsigned int. - * placed in namespace util - * add MurmurHashNative - * default option = 0 for seed - */ - -#include "util/murmur_hash.hh" - -namespace util { - -//----------------------------------------------------------------------------- -// MurmurHash2, 64-bit versions, by Austin Appleby - -// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment -// and endian-ness issues if used across multiple platforms. - -// 64-bit hash for 64-bit platforms - -uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) -{ - const uint64_t m = 0xc6a4a7935bd1e995ULL; - const int r = 47; - - uint64_t h = seed ^ (len * m); - - const uint64_t * data = (const uint64_t *)key; - const uint64_t * end = data + (len/8); - - while(data != end) - { - uint64_t k = *data++; - - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - - const unsigned char * data2 = (const unsigned char*)data; - - switch(len & 7) - { - case 7: h ^= uint64_t(data2[6]) << 48; - case 6: h ^= uint64_t(data2[5]) << 40; - case 5: h ^= uint64_t(data2[4]) << 32; - case 4: h ^= uint64_t(data2[3]) << 24; - case 3: h ^= uint64_t(data2[2]) << 16; - case 2: h ^= uint64_t(data2[1]) << 8; - case 1: h ^= uint64_t(data2[0]); - h *= m; - }; - - h ^= h >> r; - h *= m; - h ^= h >> r; - - return h; -} - - -// 64-bit hash for 32-bit platforms - -uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) -{ - const unsigned int m = 0x5bd1e995; - const int r = 24; - - unsigned int h1 = seed ^ len; - unsigned int h2 = 0; - - const unsigned int * data = (const unsigned int *)key; - - while(len >= 8) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - - unsigned int k2 = *data++; - k2 *= m; k2 ^= k2 >> r; k2 *= m; - h2 *= m; h2 ^= k2; - len -= 4; - } - - if(len >= 4) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - } - - switch(len) - { - case 3: h2 ^= ((unsigned char*)data)[2] << 16; - case 2: h2 ^= ((unsigned char*)data)[1] << 8; - case 1: h2 ^= ((unsigned char*)data)[0]; - h2 *= m; - }; - - h1 ^= h2 >> 18; h1 *= m; - h2 ^= h1 >> 22; h2 *= m; - h1 ^= h2 >> 17; h1 *= m; - h2 ^= h1 >> 19; h2 *= m; - - uint64_t h = h1; - - h = (h << 32) | h2; - - return h; -} - -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { - if (sizeof(int) == 4) { - return MurmurHash64B(key, len, seed); - } else { - return MurmurHash64A(key, len, seed); - } -} - -} // namespace util +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All + * code is released to the public domain. For business purposes, Murmurhash is + * under the MIT license." + * This is modified from the original: + * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. + * length changed to unsigned int. + * placed in namespace util + * add MurmurHashNative + * default option = 0 for seed + */ + +#include "util/murmur_hash.hh" + +namespace util { + +//----------------------------------------------------------------------------- +// MurmurHash2, 64-bit versions, by Austin Appleby + +// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment +// and endian-ness issues if used across multiple platforms. + +// 64-bit hash for 64-bit platforms + +uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +{ + const uint64_t m = 0xc6a4a7935bd1e995ULL; + const int r = 47; + + uint64_t h = seed ^ (len * m); + + const uint64_t * data = (const uint64_t *)key; + const uint64_t * end = data + (len/8); + + while(data != end) + { + uint64_t k = *data++; + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + const unsigned char * data2 = (const unsigned char*)data; + + switch(len & 7) + { + case 7: h ^= uint64_t(data2[6]) << 48; + case 6: h ^= uint64_t(data2[5]) << 40; + case 5: h ^= uint64_t(data2[4]) << 32; + case 4: h ^= uint64_t(data2[3]) << 24; + case 3: h ^= uint64_t(data2[2]) << 16; + case 2: h ^= uint64_t(data2[1]) << 8; + case 1: h ^= uint64_t(data2[0]); + h *= m; + }; + + h ^= h >> r; + h *= m; + h ^= h >> r; + + return h; +} + + +// 64-bit hash for 32-bit platforms + +uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +{ + const unsigned int m = 0x5bd1e995; + const int r = 24; + + unsigned int h1 = seed ^ len; + unsigned int h2 = 0; + + const unsigned int * data = (const unsigned int *)key; + + while(len >= 8) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + + unsigned int k2 = *data++; + k2 *= m; k2 ^= k2 >> r; k2 *= m; + h2 *= m; h2 ^= k2; + len -= 4; + } + + if(len >= 4) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + } + + switch(len) + { + case 3: h2 ^= ((unsigned char*)data)[2] << 16; + case 2: h2 ^= ((unsigned char*)data)[1] << 8; + case 1: h2 ^= ((unsigned char*)data)[0]; + h2 *= m; + }; + + h1 ^= h2 >> 18; h1 *= m; + h2 ^= h1 >> 22; h2 *= m; + h1 ^= h2 >> 17; h1 *= m; + h2 ^= h1 >> 19; h2 *= m; + + uint64_t h = h1; + + h = (h << 32) | h2; + + return h; +} + +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { + if (sizeof(int) == 4) { + return MurmurHash64B(key, len, seed); + } else { + return MurmurHash64A(key, len, seed); + } +} + +} // namespace util -- cgit v1.2.3 From ee84ab027c0be54800cac0c9bff62dd097354f6d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 12 Oct 2011 14:57:15 +0100 Subject: model lenght properly, clean up --- gi/pf/Makefile.am | 2 +- gi/pf/corpus.cc | 57 ++++++++++++++++++++++++ gi/pf/corpus.h | 19 ++++++++ gi/pf/dpnaive.cc | 95 +++++++++++----------------------------- gi/pf/monotonic_pseg.h | 88 +++++++++++++++++++++++++++++++++++++ gi/pf/pfnaive.cc | 116 +++++-------------------------------------------- utils/logval_test.cc | 14 +++--- 7 files changed, 209 insertions(+), 182 deletions(-) create mode 100644 gi/pf/corpus.cc create mode 100644 gi/pf/corpus.h create mode 100644 gi/pf/monotonic_pseg.h diff --git a/gi/pf/Makefile.am b/gi/pf/Makefile.am index c9764ad5..42758939 100644 --- a/gi/pf/Makefile.am +++ b/gi/pf/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = cbgi brat dpnaive pfbrat pfdist itg pfnaive noinst_LIBRARIES = libpf.a -libpf_a_SOURCES = base_measures.cc reachability.cc cfg_wfst_composer.cc +libpf_a_SOURCES = base_measures.cc reachability.cc cfg_wfst_composer.cc corpus.cc itg_SOURCES = itg.cc diff --git a/gi/pf/corpus.cc b/gi/pf/corpus.cc new file mode 100644 index 00000000..a408e7cf --- /dev/null +++ b/gi/pf/corpus.cc @@ -0,0 +1,57 @@ +#include "corpus.h" + +#include +#include +#include + +#include "tdict.h" +#include "filelib.h" + +using namespace std; + +namespace corpus { + +void ReadParallelCorpus(const string& filename, + vector >* f, + vector >* e, + set* vocab_f, + set* vocab_e) { + f->clear(); + e->clear(); + vocab_f->clear(); + vocab_e->clear(); + ReadFile rf(filename); + istream* in = rf.stream(); + assert(*in); + string line; + const WordID kDIV = TD::Convert("|||"); + vector tmp; + while(*in) { + getline(*in, line); + if (line.empty() && !*in) break; + e->push_back(vector()); + f->push_back(vector()); + vector& le = e->back(); + vector& lf = f->back(); + tmp.clear(); + TD::ConvertSentence(line, &tmp); + bool isf = true; + for (unsigned i = 0; i < tmp.size(); ++i) { + const int cur = tmp[i]; + if (isf) { + if (kDIV == cur) { isf = false; } else { + lf.push_back(cur); + vocab_f->insert(cur); + } + } else { + assert(cur != kDIV); + le.push_back(cur); + vocab_e->insert(cur); + } + } + assert(isf == false); + } +} + +} + diff --git a/gi/pf/corpus.h b/gi/pf/corpus.h new file mode 100644 index 00000000..e7febdb7 --- /dev/null +++ b/gi/pf/corpus.h @@ -0,0 +1,19 @@ +#ifndef _CORPUS_H_ +#define _CORPUS_H_ + +#include +#include +#include +#include "wordid.h" + +namespace corpus { + +void ReadParallelCorpus(const std::string& filename, + std::vector >* f, + std::vector >* e, + std::set* vocab_f, + std::set* vocab_e); + +} + +#endif diff --git a/gi/pf/dpnaive.cc b/gi/pf/dpnaive.cc index 608f73d5..c926487b 100644 --- a/gi/pf/dpnaive.cc +++ b/gi/pf/dpnaive.cc @@ -7,12 +7,14 @@ #include #include "base_measures.h" +#include "monotonic_pseg.h" #include "trule.h" #include "tdict.h" #include "filelib.h" #include "dict.h" #include "sampler.h" #include "ccrp_nt.h" +#include "corpus.h" using namespace std; using namespace std::tr1; @@ -52,57 +54,12 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } -void ReadParallelCorpus(const string& filename, - vector >* f, - vector >* e, - set* vocab_e, - set* vocab_f) { - f->clear(); - e->clear(); - vocab_f->clear(); - vocab_e->clear(); - istream* in; - if (filename == "-") - in = &cin; - else - in = new ifstream(filename.c_str()); - assert(*in); - string line; - const WordID kDIV = TD::Convert("|||"); - vector tmp; - while(*in) { - getline(*in, line); - if (line.empty() && !*in) break; - e->push_back(vector()); - f->push_back(vector()); - vector& le = e->back(); - vector& lf = f->back(); - tmp.clear(); - TD::ConvertSentence(line, &tmp); - bool isf = true; - for (unsigned i = 0; i < tmp.size(); ++i) { - const int cur = tmp[i]; - if (isf) { - if (kDIV == cur) { isf = false; } else { - lf.push_back(cur); - vocab_f->insert(cur); - } - } else { - assert(cur != kDIV); - le.push_back(cur); - vocab_e->insert(cur); - } - } - assert(isf == false); - } - if (in != &cin) delete in; -} - shared_ptr prng; template struct ModelAndData { - explicit ModelAndData(const Base& b, const vector >& ce, const vector >& cf, const set& ve, const set& vf) : + explicit ModelAndData(MonotonicParallelSegementationModel& m, const Base& b, const vector >& ce, const vector >& cf, const set& ve, const set& vf) : + model(m), rng(&*prng), p0(b), baseprob(prob_t::One()), @@ -110,14 +67,12 @@ struct ModelAndData { corpusf(cf), vocabe(ve), vocabf(vf), - rules(1,1), mh_samples(), mh_rejects(), kX(-TD::Convert("X")), derivations(corpuse.size()) {} void ResampleHyperparameters() { - rules.resample_hyperparameters(&*prng); } void InstantiateRule(const pair& from, @@ -139,12 +94,10 @@ struct ModelAndData { TRule x; for (int i = 1; i < d.size(); ++i) { InstantiateRule(d[i], d[i-1], sentf, sente, &x); - //cerr << "REMOVE: " << x.AsString() << endl; - if (rules.decrement(x)) { - baseprob /= p0(x); - //cerr << " (REMOVED ONLY INSTANCE)\n"; - } + model.DecrementRule(x); + model.DecrementContinue(); } + model.DecrementStop(); } void PrintDerivation(const vector >& d, const vector& sentf, const vector& sente) { @@ -161,39 +114,38 @@ struct ModelAndData { TRule x; for (int i = 1; i < d.size(); ++i) { InstantiateRule(d[i], d[i-1], sentf, sente, &x); - if (rules.increment(x)) { - baseprob *= p0(x); - } + model.IncrementRule(x); + model.IncrementContinue(); } + model.IncrementStop(); } prob_t Likelihood() const { - prob_t p; - p.logeq(rules.log_crp_prob()); - return p * baseprob; + return model.Likelihood(); } prob_t DerivationProposalProbability(const vector >& d, const vector& sentf, const vector& sente) const { - prob_t p = prob_t::One(); + prob_t p = model.StopProbability(); if (d.size() < 2) return p; TRule x; + const prob_t p_cont = model.ContinueProbability(); for (int i = 1; i < d.size(); ++i) { InstantiateRule(d[i], d[i-1], sentf, sente, &x); - prob_t rp; rp.logeq(rules.logprob(x, log(p0(x)))); - p *= rp; + p *= p_cont; + p *= model.RuleProbability(x); } return p; } void Sample(); + MonotonicParallelSegementationModel& model; MT19937* rng; const Base& p0; prob_t baseprob; // cached value of generating the table table labels from p0 // this can't be used if we go to a hierarchical prior! const vector >& corpuse, corpusf; const set& vocabe, vocabf; - CCRP_NoTable rules; unsigned mh_samples, mh_rejects; const int kX; vector > > derivations; @@ -201,8 +153,8 @@ struct ModelAndData { template void ModelAndData::Sample() { - unsigned MAXK = 4; - unsigned MAXL = 4; + unsigned MAXK = kMAX_SRC_PHRASE; + unsigned MAXL = kMAX_TRG_PHRASE; TRule x; x.lhs_ = -TD::Convert("X"); for (int samples = 0; samples < 1000; ++samples) { @@ -228,6 +180,8 @@ void ModelAndData::Sample() { boost::multi_array a(boost::extents[sentf.size() + 1][sente.size() + 1]); boost::multi_array trans(boost::extents[sentf.size() + 1][sente.size() + 1][MAXK][MAXL]); a[0][0] = prob_t::One(); + const prob_t q_stop = model.StopProbability(); + const prob_t q_cont = model.ContinueProbability(); for (int i = 0; i < sentf.size(); ++i) { for (int j = 0; j < sente.size(); ++j) { const prob_t src_a = a[i][j]; @@ -239,7 +193,9 @@ void ModelAndData::Sample() { for (int l = 1; l <= MAXL; ++l) { if (j + l > sente.size()) break; x.e_.push_back(sente[j + l - 1]); - trans[i][j][k - 1][l - 1].logeq(rules.logprob(x, log(p0(x)))); + const bool stop_now = ((j + l) == sente.size()) && ((i + k) == sentf.size()); + const prob_t& cp = stop_now ? q_stop : q_cont; + trans[i][j][k - 1][l - 1] = model.RuleProbability(x) * cp; a[i + k][j + l] += src_a * trans[i][j][k - 1][l - 1]; } } @@ -319,7 +275,7 @@ int main(int argc, char** argv) { vector > corpuse, corpusf; set vocabe, vocabf; - ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); + corpus::ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); cerr << "f-Corpus size: " << corpusf.size() << " sentences\n"; cerr << "f-Vocabulary size: " << vocabf.size() << " types\n"; cerr << "f-Corpus size: " << corpuse.size() << " sentences\n"; @@ -328,8 +284,9 @@ int main(int argc, char** argv) { Model1 m1(conf["model1"].as()); PhraseJointBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size(), vocabf.size()); + MonotonicParallelSegementationModel m(lp0); - ModelAndData posterior(lp0, corpuse, corpusf, vocabe, vocabf); + ModelAndData posterior(m, lp0, corpuse, corpusf, vocabe, vocabf); posterior.Sample(); return 0; diff --git a/gi/pf/monotonic_pseg.h b/gi/pf/monotonic_pseg.h new file mode 100644 index 00000000..7e6af3fc --- /dev/null +++ b/gi/pf/monotonic_pseg.h @@ -0,0 +1,88 @@ +#ifndef _MONOTONIC_PSEG_H_ +#define _MONOTONIC_PSEG_H_ + +#include + +#include "prob.h" +#include "ccrp_nt.h" +#include "trule.h" +#include "base_measures.h" + +struct MonotonicParallelSegementationModel { + explicit MonotonicParallelSegementationModel(PhraseJointBase& rcp0) : + rp0(rcp0), base(prob_t::One()), rules(1,1), stop(1.0) {} + + void DecrementRule(const TRule& rule) { + if (rules.decrement(rule)) + base /= rp0(rule); + } + + void IncrementRule(const TRule& rule) { + if (rules.increment(rule)) + base *= rp0(rule); + } + + void IncrementRulesAndStops(const std::vector& rules) { + for (int i = 0; i < rules.size(); ++i) + IncrementRule(*rules[i]); + if (rules.size()) IncrementContinue(rules.size() - 1); + IncrementStop(); + } + + void DecrementRulesAndStops(const std::vector& rules) { + for (int i = 0; i < rules.size(); ++i) + DecrementRule(*rules[i]); + if (rules.size()) { + DecrementContinue(rules.size() - 1); + DecrementStop(); + } + } + + prob_t RuleProbability(const TRule& rule) const { + prob_t p; p.logeq(rules.logprob(rule, log(rp0(rule)))); + return p; + } + + prob_t Likelihood() const { + prob_t p = base; + prob_t q; q.logeq(rules.log_crp_prob()); + p *= q; + q.logeq(stop.log_crp_prob()); + p *= q; + return p; + } + + void IncrementStop() { + stop.increment(true); + } + + void IncrementContinue(int n = 1) { + for (int i = 0; i < n; ++i) + stop.increment(false); + } + + void DecrementStop() { + stop.decrement(true); + } + + void DecrementContinue(int n = 1) { + for (int i = 0; i < n; ++i) + stop.decrement(false); + } + + prob_t StopProbability() const { + return prob_t(stop.prob(true, 0.5)); + } + + prob_t ContinueProbability() const { + return prob_t(stop.prob(false, 0.5)); + } + + const PhraseJointBase& rp0; + prob_t base; + CCRP_NoTable rules; + CCRP_NoTable stop; +}; + +#endif + diff --git a/gi/pf/pfnaive.cc b/gi/pf/pfnaive.cc index c30e7c4f..33dc08c3 100644 --- a/gi/pf/pfnaive.cc +++ b/gi/pf/pfnaive.cc @@ -7,6 +7,7 @@ #include #include "base_measures.h" +#include "monotonic_pseg.h" #include "reachability.h" #include "viterbi.h" #include "hg.h" @@ -17,6 +18,7 @@ #include "sampler.h" #include "ccrp_nt.h" #include "ccrp_onetable.h" +#include "corpus.h" using namespace std; using namespace tr1; @@ -58,101 +60,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } -void ReadParallelCorpus(const string& filename, - vector >* f, - vector >* e, - set* vocab_f, - set* vocab_e) { - f->clear(); - e->clear(); - vocab_f->clear(); - vocab_e->clear(); - istream* in; - if (filename == "-") - in = &cin; - else - in = new ifstream(filename.c_str()); - assert(*in); - string line; - const WordID kDIV = TD::Convert("|||"); - vector tmp; - while(*in) { - getline(*in, line); - if (line.empty() && !*in) break; - e->push_back(vector()); - f->push_back(vector()); - vector& le = e->back(); - vector& lf = f->back(); - tmp.clear(); - TD::ConvertSentence(line, &tmp); - bool isf = true; - for (unsigned i = 0; i < tmp.size(); ++i) { - const int cur = tmp[i]; - if (isf) { - if (kDIV == cur) { isf = false; } else { - lf.push_back(cur); - vocab_f->insert(cur); - } - } else { - assert(cur != kDIV); - le.push_back(cur); - vocab_e->insert(cur); - } - } - assert(isf == false); - } - if (in != &cin) delete in; -} - -struct MyJointModel { - MyJointModel(PhraseJointBase& rcp0) : - rp0(rcp0), base(prob_t::One()), rules(1,1) {} - - void DecrementRule(const TRule& rule) { - if (rules.decrement(rule)) - base /= rp0(rule); - } - - void IncrementRule(const TRule& rule) { - if (rules.increment(rule)) - base *= rp0(rule); - } - - void IncrementRules(const vector& rules) { - for (int i = 0; i < rules.size(); ++i) - IncrementRule(*rules[i]); - } - - void DecrementRules(const vector& rules) { - for (int i = 0; i < rules.size(); ++i) - DecrementRule(*rules[i]); - } - - prob_t RuleProbability(const TRule& rule) const { - prob_t p; p.logeq(rules.logprob(rule, log(rp0(rule)))); - return p; - } - - prob_t Likelihood() const { - prob_t p = base; - prob_t q; q.logeq(rules.log_crp_prob()); - p *= q; - for (unsigned l = 1; l < src_jumps.size(); ++l) { - if (src_jumps[l].num_customers() > 0) { - prob_t q; - q.logeq(src_jumps[l].log_crp_prob()); - p *= q; - } - } - return p; - } - - const PhraseJointBase& rp0; - prob_t base; - CCRP_NoTable rules; - vector > src_jumps; -}; - struct BackwardEstimateSym { BackwardEstimateSym(const Model1& m1, const Model1& invm1, const vector& src, const vector& trg) : @@ -264,7 +171,7 @@ int main(int argc, char** argv) { vector > corpuse, corpusf; set vocabe, vocabf; cerr << "Reading corpus...\n"; - ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); + corpus::ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); cerr << "F-corpus size: " << corpusf.size() << " sentences\t (" << vocabf.size() << " word types)\n"; cerr << "E-corpus size: " << corpuse.size() << " sentences\t (" << vocabe.size() << " word types)\n"; assert(corpusf.size() == corpuse.size()); @@ -273,13 +180,8 @@ int main(int argc, char** argv) { Model1 m1(conf["model1"].as()); Model1 invm1(conf["inverse_model1"].as()); -#if 0 - PhraseConditionalBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size()); - MyConditionalModel m(lp0); -#else PhraseJointBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size(), vocabf.size()); - MyJointModel m(lp0); -#endif + MonotonicParallelSegementationModel m(lp0); cerr << "Initializing reachability limits...\n"; vector ps(corpusf.size()); @@ -296,7 +198,10 @@ int main(int argc, char** argv) { for (int ci = 0; ci < corpusf.size(); ++ci) { vector& src = corpusf[ci]; vector& trg = corpuse[ci]; - m.DecrementRules(ps[ci].rules); + m.DecrementRulesAndStops(ps[ci].rules); + const prob_t q_stop = m.StopProbability(); + const prob_t q_cont = m.ContinueProbability(); + cerr << "P(stop)=" << q_stop << "\tP(continue)=" <AsString() << "\n"; } cerr << "tmp-LLH: " << log(m.Likelihood()) << endl; } diff --git a/utils/logval_test.cc b/utils/logval_test.cc index 4aa452f2..6133f5ce 100644 --- a/utils/logval_test.cc +++ b/utils/logval_test.cc @@ -30,13 +30,13 @@ TEST_F(LogValTest,Negate) { LogVal x(-2.4); LogVal y(2.4); y.negate(); - EXPECT_FLOAT_EQ(x,y); + EXPECT_FLOAT_EQ(x.as_float(),y.as_float()); } TEST_F(LogValTest,Inverse) { LogVal x(1/2.4); LogVal y(2.4); - EXPECT_FLOAT_EQ(x,y.inverse()); + EXPECT_FLOAT_EQ(x.as_float(),y.inverse().as_float()); } TEST_F(LogValTest,Minus) { @@ -45,9 +45,9 @@ TEST_F(LogValTest,Minus) { LogVal z1 = x - y; LogVal z2 = x; z2 -= y; - EXPECT_FLOAT_EQ(z1, z2); - EXPECT_FLOAT_EQ(z1, 10.0); - EXPECT_FLOAT_EQ(y - x, -10.0); + EXPECT_FLOAT_EQ(z1.as_float(), z2.as_float()); + EXPECT_FLOAT_EQ(z1.as_float(), 10.0); + EXPECT_FLOAT_EQ((y - x).as_float(), -10.0); } TEST_F(LogValTest,TestOps) { @@ -62,8 +62,8 @@ TEST_F(LogValTest,TestOps) { LogVal bb(-0.3); cerr << (aa + bb) << endl; cerr << (bb + aa) << endl; - EXPECT_FLOAT_EQ((aa + bb), (bb + aa)); - EXPECT_FLOAT_EQ((aa + bb), -0.1); + EXPECT_FLOAT_EQ((aa + bb).as_float(), (bb + aa).as_float()); + EXPECT_FLOAT_EQ((aa + bb).as_float(), -0.1); } TEST_F(LogValTest,TestSizes) { -- cgit v1.2.3 From 171027795ba3a01ba2ed82d7036610ac397e1fe8 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 14 Oct 2011 11:51:12 +0100 Subject: remove FSA integration code. will have to be resurrected another day --- decoder/Makefile.am | 1 - decoder/apply_fsa_models.cc | 798 ---------------------------------------- decoder/cdec_ff.cc | 13 - decoder/feature_accum.h | 129 ------- decoder/ff_factory.h | 2 - decoder/ff_from_fsa.h | 304 --------------- decoder/ff_fsa.h | 401 -------------------- decoder/ff_fsa_data.h | 131 ------- decoder/ff_fsa_dynamic.h | 208 ----------- decoder/ff_lm.cc | 48 --- decoder/ff_lm_fsa.h | 140 ------- decoder/ff_register.h | 38 -- decoder/hg_test.cc | 16 +- training/mpi_online_optimize.cc | 2 + 14 files changed, 10 insertions(+), 2221 deletions(-) delete mode 100755 decoder/apply_fsa_models.cc delete mode 100755 decoder/feature_accum.h delete mode 100755 decoder/ff_from_fsa.h delete mode 100755 decoder/ff_fsa.h delete mode 100755 decoder/ff_fsa_data.h delete mode 100755 decoder/ff_fsa_dynamic.h delete mode 100755 decoder/ff_lm_fsa.h diff --git a/decoder/Makefile.am b/decoder/Makefile.am index ede1cff0..6b9360d8 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -42,7 +42,6 @@ libcdec_a_SOURCES = \ cfg.cc \ dwarf.cc \ ff_dwarf.cc \ - apply_fsa_models.cc \ rule_lexer.cc \ fst_translator.cc \ csplit.cc \ diff --git a/decoder/apply_fsa_models.cc b/decoder/apply_fsa_models.cc deleted file mode 100755 index 3e93cadd..00000000 --- a/decoder/apply_fsa_models.cc +++ /dev/null @@ -1,798 +0,0 @@ -//see apply_fsa_models.README for notes on the l2r earley fsa+cfg intersection -//implementation in this file (also some comments in this file) -#define SAFE_VALGRIND 1 - -#include "apply_fsa_models.h" -#include -#include -#include -#include - -#include "writer.h" -#include "hg.h" -#include "ff_fsa_dynamic.h" -#include "ff_from_fsa.h" -#include "feature_vector.h" -#include "stringlib.h" -#include "apply_models.h" -#include "cfg.h" -#include "hg_cfg.h" -#include "utoa.h" -#include "hash.h" -#include "value_array.h" -#include "d_ary_heap.h" -#include "agenda.h" -#include "show.h" -#include "string_to.h" - - -#define DFSA(x) x -//fsa earley chart - -#define DPFSA(x) x -//prefix trie - -#define DBUILDTRIE(x) - -#define PRINT_PREFIX 1 -#if PRINT_PREFIX -# define IF_PRINT_PREFIX(x) x -#else -# define IF_PRINT_PREFIX(x) -#endif -// keep backpointers in prefix trie so you can print a meaningful node id - -static const unsigned FSA_AGENDA_RESERVE=10; // TODO: increase to 1<<24 (16M) - -using namespace std; - -//impl details (not exported). flat namespace for my ease. - -typedef CFG::RHS RHS; -typedef CFG::BinRhs BinRhs; -typedef CFG::NTs NTs; -typedef CFG::NT NT; -typedef CFG::NTHandle NTHandle; -typedef CFG::Rules Rules; -typedef CFG::Rule Rule; -typedef CFG::RuleHandle RuleHandle; - -namespace { - -/* - -1) A -> x . * (trie) - -this is somewhat nice. cost pushed for best first, of course. similar benefit as left-branching binarization without the explicit predict/complete steps? - -vs. just - -2) * -> x . y - -here you have to potentially list out all A -> . x y as items * -> . x y immediately, and shared rhs seqs won't be shared except at the usual single-NT predict/complete. of course, the prediction of items -> . x y can occur lazy best-first. - -vs. - -3) * -> x . * - -with 3, we predict all sorts of useless items - that won't give us our goal A and may not partcipate in any parse. this is not a good option at all. - -I'm using option 1. -*/ - -// if we don't greedy-binarize, we want to encode recognized prefixes p (X -> p . rest) efficiently. if we're doing this, we may as well also push costs so we can best-first select rules in a lazy fashion. this is effectively left-branching binarization, of course. - -template -struct fsa_map_type { - typedef std::map type; // change to HASH_MAP ? -}; -//template typedef - and macro to make it less painful -#define FSA_MAP(k,v) fsa_map_type >::type - -struct PrefixTrieNode; -typedef PrefixTrieNode *NodeP; -typedef PrefixTrieNode const *NodePc; - -// for debugging prints only -struct TrieBackP { - WordID w; - NodePc from; - TrieBackP(WordID w=0,NodePc from=0) : w(w),from(from) { } -}; - -FsaFeatureFunction const* print_fsa=0; -CFG const* print_cfg=0; -inline ostream& print_cfg_rhs(std::ostream &o,WordID w,CFG const*pcfg=print_cfg) { - if (pcfg) - pcfg->print_rhs_name(o,w); - else - CFG::static_print_rhs_name(o,w); - return o; -} - -inline std::string nt_name(WordID n,CFG const*pcfg=print_cfg) { - if (pcfg) return pcfg->nt_name(n); - return CFG::static_nt_name(n); -} - -template -ostream& print_by_nt(std::ostream &o,V const& v,CFG const*pcfg=print_cfg,char const* header="\nNT -> X\n") { - o< "< -ostream& print_map_by_nt(std::ostream &o,V const& v,CFG const*pcfg=print_cfg,char const* header="\nNT -> X\n") { - o<first,pcfg) << " -> "<second<<"\n"; - } - return o; -} - -struct PrefixTrieEdge { - PrefixTrieEdge() - // : dest(0),w(TD::max_wordid) - {} - PrefixTrieEdge(WordID w,NodeP dest) - : dest(dest),w(w) - {} -// explicit PrefixTrieEdge(best_t p) : p(p),dest(0) { } - - best_t p;// viterbi additional prob, i.e. product over path incl. p_final = total rule prob. note: for final edge, set this. - //DPFSA() - // we can probably just store deltas, but for debugging remember the full p - // best_t delta; // - NodeP dest; - bool is_final() const { return dest==0; } - best_t p_dest() const; - WordID w; // for root and and is_final(), this will be (negated) NTHandle. - - // for sorting most probable first in adj; actually >(p) - inline bool operator <(PrefixTrieEdge const& o) const { - return o.p"< BPs; - void back_vec(BPs &ns) const { - IF_PRINT_PREFIX(if(backp.from) { ns.push_back(backp); backp.from->back_vec(ns); }) - } - - BPs back_vec() const { - BPs ret; - back_vec(ret); - return ret; - } - - unsigned size() const { - unsigned a=adj.size(); - unsigned e=edge_for.size(); - return a>e?a:e; - } - - void print_back_str(std::ostream &o) const { - BPs back=back_vec(); - unsigned i=back.size(); - if (!i) { - o<<"PrefixTrieNode@"<<(uintptr_t)this; - return; - } - bool first=true; - while (i--<=0) { - if (!first) o<<','; - first=false; - WordID w=back[i].w; - print_cfg_rhs(o,w); - } - } - std::string back_str() const { - std::ostringstream o; - print_back_str(o); - return o.str(); - } - -// best_t p_final; // additional prob beyond what we already paid. while building, this is the total prob -// instead of storing final, we'll say that an edge with a NULL dest is a final edge. this way it gets sorted into the list of adj. - - // instead of completed map, we have trie start w/ lhs. - NTHandle lhs; // nonneg. - instead of storing this in Item. - IF_PRINT_PREFIX(BP backp;) - - enum { ROOT=-1 }; - explicit PrefixTrieNode(NTHandle lhs=ROOT,best_t p=1) : p(p),lhs(lhs),IF_PRINT_PREFIX(backp()) { - //final=false; - } - bool is_root() const { return lhs==ROOT; } // means adj are the nonneg lhs indices, and we have the index edge_for still available - - // outgoing edges will be ordered highest p to worst p - - typedef FSA_MAP(WordID,PrefixTrieEdge) PrefixTrieEdgeFor; -public: - PrefixTrieEdgeFor edge_for; //TODO: move builder elsewhere? then need 2nd hash or edge include pointer to builder. just clear this later - bool have_adj() const { - return adj.size()>=edge_for.size(); - } - bool no_adj() const { - return adj.empty(); - } - - void index_adj() { - index_adj(edge_for); - } - template - void index_adj(M &m) { - assert(have_adj()); - m.clear(); - for (int i=0;i - void index_lhs(PV &v) { - for (int i=0,e=adj.size();i!=e;++i) { - PrefixTrieEdge const& edge=adj[i]; - // assert(edge.p.is_1()); // actually, after done_building, e will have telescoped dest->p/p. - NTHandle n=-edge.w; - assert(n>=0); -// SHOWM3(DPFSA,"index_lhs",i,edge,n); - v[n]=edge.dest; - } - } - - template - void done_root(PV &v) { - assert(is_root()); - SHOWM1(DBUILDTRIE,"done_root",OSTRF1(print_map_by_nt,edge_for)); - done_building_r(); //sets adj - SHOWM1(DBUILDTRIE,"done_root",OSTRF1(print_by_nt,adj)); -// SHOWM1(DBUILDTRIE,done_root,adj); -// index_adj(); // we want an index for the root node?. don't think so - index_lhs handles it. also we stopped clearing edge_for. - index_lhs(v); // uses adj - } - - // call only once. - void done_building_r() { - done_building(); - for (int i=0;idone_building_r(); - } - - // for done_building; compute incremental (telescoped) edge p - PrefixTrieEdge /*const&*/ operator()(PrefixTrieEdgeFor::value_type & pair) const { - PrefixTrieEdge &e=pair.second;//const_cast(pair.second); - e.p=e.p_dest()/p; - return e; - } - - // call only once. - void done_building() { - SHOWM3(DBUILDTRIE,"done_building",edge_for.size(),adj.size(),1); -#if 1 - adj.reinit_map(edge_for,*this); -#else - adj.reinit(edge_for.size()); - SHOWM3(DBUILDTRIE,"done_building_reinit",edge_for.size(),adj.size(),2); - Adj::iterator o=adj.begin(); - for (PrefixTrieEdgeFor::iterator i=edge_for.begin(),e=edge_for.end();i!=e;++i) { - SHOWM3(DBUILDTRIE,"edge_for",o-adj.begin(),i->first,i->second); - PrefixTrieEdge &edge=i->second; - edge.p=(edge.dest->p)/p; - *o++=edge; -// (*this)(*i); - } -#endif - SHOWM1(DBUILDTRIE,"done building adj",prange(adj.begin(),adj.end(),true)); - assert(adj.size()==edge_for.size()); -// if (final) p_final/=p; - std::sort(adj.begin(),adj.end()); - //TODO: store adjacent differences on edges (compared to - } - - typedef ValueArray Adj; -// typedef vector Adj; - Adj adj; - - typedef WordID W; - - // let's compute p_min so that every rule reachable from the created node has p at least this low. - NodeP improve_edge(PrefixTrieEdge const& e,best_t rulep) { - NodeP d=e.dest; - maybe_improve(d->p,rulep); - return d; - } - - inline NodeP build(W w,best_t rulep) { - return build(lhs,w,rulep); - } - inline NodeP build_lhs(NTHandle n,best_t rulep) { - return build(n,-n,rulep); - } - - NodeP build(NTHandle lhs_,W w,best_t rulep) { - PrefixTrieEdgeFor::iterator i=edge_for.find(w); - if (i!=edge_for.end()) - return improve_edge(i->second,rulep); - NodeP r=new PrefixTrieNode(lhs_,rulep); - IF_PRINT_PREFIX(r->backp=BP(w,this)); -// edge_for.insert(i,PrefixTrieEdgeFor::value_type(w,PrefixTrieEdge(w,r))); - add(edge_for,w,PrefixTrieEdge(w,r)); - SHOWM4(DBUILDTRIE,"built node",this,w,*r,r); - return r; - } - - void set_final(NTHandle lhs_,best_t pf) { - assert(no_adj()); -// final=true; - PrefixTrieEdge &e=edge_for[null_wordid]; - e.p=pf; - e.dest=0; - e.w=lhs_; - maybe_improve(p,pf); - } - -private: - void destroy_children() { - assert(adj.size()>=edge_for.size()); - for (int i=0,e=adj.size();i" << p; - o << ',' << size() << ','; - print_back_str(o); - } - PRINT_SELF(PrefixTrieNode) -}; - -inline best_t PrefixTrieEdge::p_dest() const { - return dest ? dest->p : p; // for final edge, p was set (no sentinel node) -} - - -//Trie starts with lhs (nonneg index), then continues w/ rhs (mixed >0 word, else NT) -// trie ends with final edge, which points to a per-lhs prefix node -struct PrefixTrie { - void print(std::ostream &o) const { - o << cfgp << ' ' << root; - } - PRINT_SELF(PrefixTrie); - CFG *cfgp; - Rules const* rulesp; - Rules const& rules() const { return *rulesp; } - CFG const& cfg() const { return *cfgp; } - PrefixTrieNode root; - typedef std::vector LhsToTrie; // will have to check lhs2[lhs].p for best cost of some rule with that lhs, then use edge deltas after? they're just caching a very cheap computation, really - LhsToTrie lhs2; // no reason to use a map or hash table; every NT in the CFG will have some rule rhses. lhs_to_trie[i]=root.edge_for[i], i.e. we still have a root trie node conceptually, we just access through this since it's faster. - typedef LhsToTrie LhsToComplete; - LhsToComplete lhs2complete; // the sentinel "we're completing" node (dot at end) for that lhs. special case of suffix-set=same trie minimization (aka right branching binarization) // these will be used to track kbest completions, along with a l state (r state will be in the list) - PrefixTrie(CFG &cfg) : cfgp(&cfg),rulesp(&cfg.rules),lhs2(cfg.nts.size(),0),lhs2complete(cfg.nts.size()) { -// cfg.SortLocalBestFirst(); // instead we'll sort in done_building_r - print_cfg=cfgp; - SHOWM2(DBUILDTRIE,"PrefixTrie()",rulesp->size(),lhs2.size()); - cfg.VisitRuleIds(*this); - root.done_root(lhs2); - SHOWM3(DBUILDTRIE,"done w/ PrefixTrie: ",root,root.adj.size(),lhs2.size()); - DBUILDTRIE(print_by_nt(cerr,lhs2,cfgp)); - SHOWM1(DBUILDTRIE,"lhs2",OSTRF2(print_by_nt,lhs2,cfgp)); - } - - void operator()(int ri) { - Rule const& r=rules()[ri]; - NTHandle lhs=r.lhs; - best_t p=r.p; -// NodeP n=const_cast(root).build_lhs(lhs,p); - NodeP n=root.build_lhs(lhs,p); - SHOWM4(DBUILDTRIE,"Prefixtrie rule id, root",ri,root,p,*n); - for (RHS::const_iterator i=r.rhs.begin(),e=r.rhs.end();;++i) { - SHOWM2(DBUILDTRIE,"PrefixTrie build or final",i-r.rhs.begin(),*n); - if (i==e) { - n->set_final(lhs,p); - break; - } - n=n->build(*i,p); - SHOWM2(DBUILDTRIE,"PrefixTrie built",*i,*n); - } -// root.build(lhs,r.p)->build(r.rhs,r.p); - } - inline NodeP lhs2_ex(NTHandle n) const { - NodeP r=lhs2[n]; - if (!r) throw std::runtime_error("PrefixTrie: no CFG rule w/ lhs "+cfgp->nt_name(n)); - return r; - } -private: - PrefixTrie(PrefixTrie const& o); -}; - - - -typedef std::size_t ItemHash; - - -struct ItemKey { - explicit ItemKey(NodeP start,Bytes const& start_state) : dot(start),q(start_state),r(start_state) { } - explicit ItemKey(NodeP dot) : dot(dot) { } - NodeP dot; // dot is a function of the stuff already recognized, and gives a set of suffixes y to complete to finish a rhs for lhs() -> dot y. for a lhs A -> . *, this will point to lh2[A] - Bytes q,r; // (q->r are the fsa states; if r is empty it means - bool operator==(ItemKey const& o) const { - return dot==o.dot && q==o.q && r==o.r; - } - inline ItemHash hash() const { - ItemHash h=GOLDEN_MEAN_FRACTION*(ItemHash)(dot-NULL); // i.e. lower order bits of ptr are nonrandom - using namespace boost; - hash_combine(h,q); - hash_combine(h,r); - return h; - } - template - void print(O &o) const { - o<<"lhs="<print_back_str(o); - if (print_fsa) { - o<<'/'; - print_fsa->print_state(o,&q[0]); - o<<"->"; - print_fsa->print_state(o,&r[0]); - } - } - NTHandle lhs() const { return dot->lhs; } - PRINT_SELF(ItemKey) -}; -inline ItemHash hash_value(ItemKey const& x) { - return x.hash(); -} -ItemKey null_item((PrefixTrieNode*)0); - -struct Item; -typedef Item *ItemP; - -/* we use a single type of item so it can live in a single best-first queue. we hold them by pointer so they can have mutable state, e.g. priority/location, but also lists of predictions and kbest completions (i.e. completions[L,r] = L -> * (r,s), by 1best for each possible s. we may discover more s later. we could use different subtypes since we hold by pointer, but for now everything will be packed as variants of Item */ -#undef INIT_LOCATION -#if D_ARY_TRACK_OUT_OF_HEAP -# define INIT_LOCATION , location(D_ARY_HEAP_NULL_INDEX) -#elif !defined(NDEBUG) || SAFE_VALGRIND - // avoid spurious valgrind warning - FIXME: still complains??? -# define INIT_LOCATION , location() -#else -# define INIT_LOCATION -#endif - -// these should go in a global best-first queue -struct ItemPrio { - // NOTE: sum = viterbi (max) - ItemPrio() : priority(init_0()),inner(init_0()) { } - explicit ItemPrio(best_t priority) : priority(priority),inner(init_0()) { } - best_t priority; // includes inner prob. (forward) - /* The forward probability alpha_i(X[k]->x.y) is the sum of the probabilities of all - constrained paths of length i that end in state X[k]->x.y*/ - best_t inner; - /* The inner probability beta_i(X[k]->x.y) is the sum of the probabilities of all - paths of length i-k that start in state X[k,k]->.xy and end in X[k,i]->x.y, and generate the input symbols x[k,...,i-1] */ - template - void print(O &o) const { - o<=0; - } - explicit Item(FFState const& state,NodeP dot,best_t prio,int next=0) : ItemPrio(prio),ItemKey(dot,state),trienext(next),from(0) - INIT_LOCATION - { -// t=ADJ; -// if (dot->adj.size()) - dot->p_delta(next,priority); -// SHOWM1(DFSA,"Item(state,dot,prio)",prio); - } - typedef std::queue Predicted; -// Predicted predicted; // this is empty, unless this is a predicted L -> .asdf item, or a to-complete L -> asdf . - int trienext; // index of dot->adj to complete (if dest==0), or predict (if NT), or scan (if word). note: we could store pointer inside adj since it and trie are @ fixed addrs. less pointer arith, more space. - ItemP from; //backpointer - 0 for L -> . asdf for the rest; L -> a .sdf, it's the L -> .asdf item. - ItemP predicted_from() const { - ItemP p=(ItemP)this; - while(p->from) p=p->from; - return p; - } - template - void print(O &o) const { - o<< '['; - o< -struct ApplyFsa { - ApplyFsa(HgCFG &i, - const SentenceMetadata& smeta, - const FsaFeatureFunction& fsa, - DenseWeightVector const& weights, - ApplyFsaBy const& by, - Hypergraph* oh - ) - :hgcfg(i),smeta(smeta),fsa(fsa),weights(weights),by(by),oh(oh) - { - stateless=!fsa.state_bytes(); - } - void Compute() { - if (by.IsBottomUp() || stateless) - ApplyBottomUp(); - else - ApplyEarley(); - } - void ApplyBottomUp(); - void ApplyEarley(); - CFG const& GetCFG(); -private: - CFG cfg; - HgCFG &hgcfg; - SentenceMetadata const& smeta; - FsaFF const& fsa; -// WeightVector weight_vector; - DenseWeightVector weights; - ApplyFsaBy by; - Hypergraph* oh; - std::string cfg_out; - bool stateless; -}; - -template -void ApplyFsa::ApplyBottomUp() -{ - assert(by.IsBottomUp()); - FeatureFunctionFromFsa buff(&fsa); - buff.Init(); // mandatory to call this (normally factory would do it) - vector ffs(1,&buff); - ModelSet models(weights, ffs); - IntersectionConfiguration i(stateless ? BU_FULL : by.BottomUpAlgorithm(),by.pop_limit); - ApplyModelSet(hgcfg.ih,smeta,models,i,oh); -} - -template -void ApplyFsa::ApplyEarley() -{ - hgcfg.GiveCFG(cfg); - print_cfg=&cfg; - print_fsa=&fsa; - Chart chart(cfg,smeta,fsa); - // don't need to uniq - option to do that already exists in cfg_options - //TODO: - chart.best_first(); - *oh=hgcfg.ih; -} - - -void ApplyFsaModels(HgCFG &i, - const SentenceMetadata& smeta, - const FsaFeatureFunction& fsa, - DenseWeightVector const& weight_vector, - ApplyFsaBy const& by, - Hypergraph* oh) -{ - ApplyFsa a(i,smeta,fsa,weight_vector,by,oh); - a.Compute(); -} - -/* -namespace { -char const* anames[]={ - "BU_CUBE", - "BU_FULL", - "EARLEY", - 0 -}; -} -*/ - -//TODO: named enum type in boost? - -std::string ApplyFsaBy::name() const { -// return anames[algorithm]; - return GetName(algorithm); -} - -std::string ApplyFsaBy::all_names() { - return FsaByNames(" "); - /* - std::ostringstream o; - for (int i=0;i=N_ALGORITHMS) - throw std::runtime_error("Unknown ApplyFsaBy type id: "+itos(i)+" - legal types: "+all_names()); -*/ - GetName(i); // checks validity - algorithm=i; -} - -int ApplyFsaBy::BottomUpAlgorithm() const { - assert(IsBottomUp()); - return algorithm==BU_CUBE ? - IntersectionConfiguration::CUBE - :IntersectionConfiguration::FULL; -} - -void ApplyFsaModels(Hypergraph const& ih, - const SentenceMetadata& smeta, - const FsaFeatureFunction& fsa, - DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this) - ApplyFsaBy const& cfg, - Hypergraph* out) -{ - HgCFG i(ih); - ApplyFsaModels(i,smeta,fsa,weights,cfg,out); -} diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 69f40c93..4ce5749e 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -12,8 +12,6 @@ #include "ff_rules.h" #include "ff_ruleshape.h" #include "ff_bleu.h" -#include "ff_lm_fsa.h" -#include "ff_sample_fsa.h" #include "ff_source_syntax.h" #include "ff_register.h" #include "ff_charset.h" @@ -31,15 +29,6 @@ void register_feature_functions() { } registered = true; - //TODO: these are worthless example target FSA ffs. remove later - RegisterFsaImpl(true); - RegisterFsaImpl(true); - RegisterFsaImpl(true); -// ff_registry.Register("LanguageModelFsaDynamic",new FFFactory > >); // to test correctness of FsaFeatureFunctionDynamic erasure - RegisterFsaDynToFF(); - RegisterFsaImpl(true); // same as LM but using fsa wrapper - RegisterFsaDynToFF(); - RegisterFF(); RegisterFF(); @@ -47,8 +36,6 @@ void register_feature_functions() { RegisterFF(); RegisterFF(); - ff_registry.Register(new FFFactory); // same as WordPenalty, but implemented using ff_fsa - //TODO: use for all features the new Register which requires static FF::usage(false,false) give name #ifdef HAVE_RANDLM ff_registry.Register("RandLM", new FFFactory); diff --git a/decoder/feature_accum.h b/decoder/feature_accum.h deleted file mode 100755 index 4b8338eb..00000000 --- a/decoder/feature_accum.h +++ /dev/null @@ -1,129 +0,0 @@ -#ifndef FEATURE_ACCUM_H -#define FEATURE_ACCUM_H - -#include "ff.h" -#include "sparse_vector.h" -#include "value_array.h" - -struct SparseFeatureAccumulator : public FeatureVector { - typedef FeatureVector State; - SparseFeatureAccumulator() { assert(!"this code is disabled"); } - template - FeatureVector const& describe(FF const& ) { return *this; } - void Store(FeatureVector *fv) const { -//NO fv->set_from(*this); - } - template - void Store(FF const& /* ff */,FeatureVector *fv) const { -//NO fv->set_from(*this); - } - template - void Add(FF const& /* ff */,FeatureVector const& fv) { - (*this)+=fv; - } - void Add(FeatureVector const& fv) { - (*this)+=fv; - } - /* - SparseFeatureAccumulator(FeatureVector const& fv) : State(fv) {} - FeatureAccumulator(Features const& fids) {} - FeatureAccumulator(Features const& fids,FeatureVector const& fv) : State(fv) {} - void Add(Features const& fids,FeatureVector const& fv) { - *this += fv; - } - */ - void Add(int i,Featval v) { -//NO (*this)[i]+=v; - } - void Add(Features const& fids,int i,Featval v) { -//NO (*this)[i]+=v; - } -}; - -struct SingleFeatureAccumulator { - typedef Featval State; - typedef SingleFeatureAccumulator Self; - State v; - /* - void operator +=(State const& o) { - v+=o; - } - */ - void operator +=(Self const& s) { - v+=s.v; - } - SingleFeatureAccumulator() : v() {} - template - State const& describe(FF const& ) const { return v; } - - template - void Store(FF const& ff,FeatureVector *fv) const { - fv->set_value(ff.fid_,v); - } - void Store(Features const& fids,FeatureVector *fv) const { - assert(fids.size()==1); - fv->set_value(fids[0],v); - } - /* - SingleFeatureAccumulator(Features const& fids) { assert(fids.size()==1); } - SingleFeatureAccumulator(Features const& fids,FeatureVector const& fv) - { - assert(fids.size()==1); - v=fv.get_singleton(); - } - */ - - template - void Add(FF const& ff,FeatureVector const& fv) { - v+=fv.get(ff.fid_); - } - void Add(FeatureVector const& fv) { - v+=fv.get_singleton(); - } - - void Add(Features const& fids,FeatureVector const& fv) { - v += fv.get(fids[0]); - } - void Add(Featval dv) { - v+=dv; - } - void Add(int,Featval dv) { - v+=dv; - } - void Add(FeatureVector const& fids,int i,Featval dv) { - assert(fids.size()==1 && i==0); - v+=dv; - } -}; - - -#if 0 -// omitting this so we can default construct an accum. might be worth resurrecting in the future -struct ArrayFeatureAccumulator : public ValueArray { - typedef ValueArray State; - template - ArrayFeatureAccumulator(Fsa const& fsa) : State(fsa.features_.size()) { } - ArrayFeatureAccumulator(Features const& fids) : State(fids.size()) { } - ArrayFeatureAccumulator(Features const& fids) : State(fids.size()) { } - ArrayFeatureAccumulator(Features const& fids,FeatureVector const& fv) : State(fids.size()) { - for (int i=0,e=iset_value(fids[i],(*this)[i]); - } - void Add(Features const& fids,FeatureVector const& fv) { - for (int i=0,e=i -#include "ff_fsa_dynamic.h" - class FeatureFunction; class FsaFeatureFunction; diff --git a/decoder/ff_from_fsa.h b/decoder/ff_from_fsa.h deleted file mode 100755 index f8d79e03..00000000 --- a/decoder/ff_from_fsa.h +++ /dev/null @@ -1,304 +0,0 @@ -#ifndef FF_FROM_FSA_H -#define FF_FROM_FSA_H - -#include "ff_fsa.h" - -#ifndef TD__none -// replacing dependency on SRILM -#define TD__none -1 -#endif - -#ifndef FSA_FF_DEBUG -# define FSA_FF_DEBUG 0 -#endif -#if FSA_FF_DEBUG -# define FSAFFDBG(e,x) FSADBGif(debug(),e,x) -# define FSAFFDBGnl(e) FSADBGif_nl(debug(),e) -#else -# define FSAFFDBG(e,x) -# define FSAFFDBGnl(e) -#endif - -/* regular bottom up scorer from Fsa feature - uses guarantee about markov order=N to score ASAP - encoding of state: if less than N-1 (ctxlen) words - - usage: - typedef FeatureFunctionFromFsa LanguageModelFromFsa; -*/ - -template -class FeatureFunctionFromFsa : public FeatureFunction { - typedef void const* SP; - typedef WordID *W; - typedef WordID const* WP; -public: - template - FeatureFunctionFromFsa(I const& param) : ff(param) { - debug_=true; // because factory won't set until after we construct. - } - template - FeatureFunctionFromFsa(I & param) : ff(param) { - debug_=true; // because factory won't set until after we construct. - } - - static std::string usage(bool args,bool verbose) { - return Impl::usage(args,verbose); - } - void init_name_debug(std::string const& n,bool debug) { - FeatureFunction::init_name_debug(n,debug); - ff.init_name_debug(n,debug); - } - - // this should override - Features features() const { - DBGINIT("FeatureFunctionFromFsa features() name="<=1) - for (int j=0,ee=e.size();;++j) { // items in target side of rule - for(;;++j) { - if (j>=ee) goto rhs_done; // j may go 1 past ee due to k possibly getting to end - if (RHS_WORD(j)) break; - } - // word @j - int k=j; - while(k{"<") - FSAFFDBG(edge," end="<{"< -# define FSADBG(e,x) FSADBGif(d().debug(),e,x) -# define FSADBGnl(e) FSADBGif_nl(d().debug(),e,x) -#else -# define FSADBG(e,x) -# define FSADBGnl(e) -#endif - -#include "fast_lexical_cast.hpp" -#include -#include -#include "ff.h" -#include "sparse_vector.h" -#include "tdict.h" -#include "hg.h" -#include "ff_fsa_data.h" - -/* -usage: see ff_sample_fsa.h or ff_lm_fsa.h - - then, to decode, see ff_from_fsa.h (or TODO: left->right target-earley style rescoring) - - */ - - -template -struct FsaFeatureFunctionBase : public FsaFeatureFunctionData { - Impl const& d() const { return static_cast(*this); } - Impl & d() { return static_cast(*this); } - - // this will get called by factory - override if you have multiple or dynamically named features. note: may be called repeatedly - void Init() { - Init(name()); - DBGINIT("base (single feature) FsaFeatureFunctionBase::Init name="<set_value(fid,val) possibly with duplicates. state and next_state will never be the same memory. - //TODO: decide if we want to require you to support dest same as src, since that's how we use it most often in ff_from_fsa bottom-up wrapper (in l->r scoring, however, distinct copies will be the rule), and it probably wouldn't be too hard for most people to support. however, it's good to hide the complexity here, once (see overly clever FsaScan loop that swaps src/dest addresses repeatedly to scan a sequence by effectively swapping) - -protected: - // overrides have different name because of inheritance method hiding; - - // simple/common case; 1 fid. these need not be overriden if you have multiple feature ids - Featval Scan1(WordID w,void const* state,void *next_state) const { - assert(0); - return 0; - } - Featval Scan1Meta(SentenceMetadata const& /* smeta */,Hypergraph::Edge const& /* edge */, - WordID w,void const* state,void *next_state) const { - return d().Scan1(w,state,next_state); - } -public: - - // must override this or Scan1Meta or Scan1 - template - inline void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID w,void const* state,void *next_state,Accum *a) const { - Add(d().Scan1Meta(smeta,edge,w,state,next_state),a); - } - - // bounce back and forth between two state vars starting at cs, returning end state location. if we required src=dest addr safe state updating, this concept wouldn't need to exist. - // required that you override this if you score phrases differently than word-by-word, however, you can just use the SCAN_PHRASE_ACCUM_OVERRIDE macro to do that in terms of ScanPhraseAccum - template - void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *accum) const { - // extra code - IT'S FOR EFFICIENCY, MAN! IT'S OK! definitely no bugs here. - if (!ssz) { - for (;io - odd: - d().ScanAccum(smeta,edge,i[0],os,es,accum); // o->e - } - return es; - } - - - static const bool simple_phrase_score=true; // if d().simple_phrase_score_, then you should expect different Phrase scores for phrase length > M. so, set this false if you provide ScanPhraseAccum (SCAN_PHRASE_ACCUM_OVERRIDE macro does this) - - // override this (and use SCAN_PHRASE_ACCUM_OVERRIDE ) if you want e.g. maximum possible order ngram scores with markov_order < n-1. in the future SparseFeatureAccumulator will probably be the only option for type-erased FSA ffs. - // note you'll still have to override ScanAccum - template - void ScanPhraseAccum(SentenceMetadata const& smeta,Hypergraph::Edge const & edge, - WordID const* i, WordID const* end, - void const* state,void *next_state,Accum *accum) const { - if (!ssz) { - for (;i \ - void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *accum) const { \ - ScanPhraseAccum(smeta,edge,i,end,cs,ns,accum); \ - return ns; \ - } \ - template \ - void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, \ - WordID const* i, WordID const* end, \ - void const* state,Accum *accum) const { \ - char s2[ssz]; ScanPhraseAccum(smeta,edge,i,end,state,(void*)s2,accum); \ - } - - // override this or bounce along with above. note: you can just call ScanPhraseAccum - // doesn't set state (for heuristic in ff_from_fsa) - template - void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID const* i, WordID const* end, - void const* state,Accum *accum) const { - char s1[ssz]; - char s2[ssz]; - state_copy(s1,state); - d().ScanPhraseAccumBounce(smeta,edge,i,end,(void*)s1,(void*)s2,accum); - } - - // for single-feat only. but will work for different accums - template - inline void Add(Featval v,Accum *a) const { - a->Add(fid_,v); - } - inline void set_feat(FeatureVector *features,Featval v) const { - features->set_value(fid_,v); - } - - // don't set state-bytes etc. in ctor because it may depend on parsing param string - FsaFeatureFunctionBase(int statesz=0,Sentence const& end_sentence_phrase=Sentence()) - : FsaFeatureFunctionData(statesz,end_sentence_phrase) - { - name_=name(); // should allow FsaDynamic wrapper to get name copied to it with sync - } - -}; - -template -struct MultipleFeatureFsa : public FsaFeatureFunctionBase { - typedef SparseFeatureAccumulator Accum; -}; - - - - -// if State is pod. sets state size and allocs start, h_start -// usage: -// struct ShorterThanPrev : public FsaTypedBase -// i.e. Impl is a CRTP -template -struct FsaTypedBase : public FsaFeatureFunctionBase { - Impl const& d() const { return static_cast(*this); } - Impl & d() { return static_cast(*this); } -protected: - typedef FsaFeatureFunctionBase Base; - typedef St State; - static inline State & state(void *state) { - return *(State*)state; - } - static inline State const& state(void const* state) { - return *(State const*)state; - } - void set_starts(State const& s,State const& heuristic_s) { - if (0) { // already in ctor - Base::start.resize(sizeof(State)); - Base::h_start.resize(sizeof(State)); - } - assert(Base::start.size()==sizeof(State)); - assert(Base::h_start.size()==sizeof(State)); - state(Base::start.begin())=s; - state(Base::h_start.begin())=heuristic_s; - } - FsaTypedBase(St const& start_st=St() - ,St const& h_start_st=St() - ,Sentence const& end_sentence_phrase=Sentence()) - : Base(sizeof(State),end_sentence_phrase) { - set_starts(start_st,h_start_st); - } -public: - void print_state(std::ostream &o,void const*st) const { - o< - inline void ScanT(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID w,St const& prev_st,St &new_st,Accum *a) const { - Add(d().ScanT1(smeta,edge,w,prev_st,new_st),a); - } - - // note: you're on your own when it comes to Phrase overrides. see FsaFeatureFunctionBase. sorry. - - template - inline void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID w,void const* st,void *next_state,Accum *a) const { - Impl const& im=d(); - FSADBG(edge,"Scan "<describe(im)<<" "<"< -struct FsaScanner { -// enum {ALIGN=8}; - static const int ALIGN=8; - FF const& ff; - SentenceMetadata const& smeta; - int ssz; - Bytes states; // first is at begin, second is at (char*)begin+stride - void *st0; // states - void *st1; // states+stride - void *cs; // initially st0, alternates between st0 and st1 - inline void *nexts() const { - return (cs==st0)?st1:st0; - } - Hypergraph::Edge const& edge; - FsaScanner(FF const& ff,SentenceMetadata const& smeta,Hypergraph::Edge const& edge) : ff(ff),smeta(smeta),edge(edge) - { - ssz=ff.state_bytes(); - int stride=((ssz+ALIGN-1)/ALIGN)*ALIGN; // round up to multiple of ALIGN - states.resize(stride+ssz); - st0=states.begin(); - st1=(char*)st0+stride; -// for (int i=0;i<2;++i) st[i]=cs+(i*stride); - } - void reset(void const* state) { - cs=st0; - std::memcpy(st0,state,ssz); - } - template - void scan(WordID w,Accum *a) { - void *ns=nexts(); - ff.ScanAccum(smeta,edge,w,cs,ns,a); - cs=ns; - } - template - void scan(WordID const* i,WordID const* end,Accum *a) { - // faster. and allows greater-order excursions - cs=ff.ScanPhraseAccumBounce(smeta,edge,i,end,cs,nexts(),a); - } -}; - - -//TODO: combine 2 FsaFeatures typelist style (can recurse for more) - - - - -#endif diff --git a/decoder/ff_fsa_data.h b/decoder/ff_fsa_data.h deleted file mode 100755 index d215e940..00000000 --- a/decoder/ff_fsa_data.h +++ /dev/null @@ -1,131 +0,0 @@ -#ifndef FF_FSA_DATA_H -#define FF_FSA_DATA_H - -#include //C99 -#include -#include "sentences.h" -#include "feature_accum.h" -#include "value_array.h" -#include "ff.h" //debug -typedef ValueArray Bytes; - -// stuff I see no reason to have virtual. but because it's impossible (w/o virtual inheritance to have dynamic fsa ff know where the impl's data starts, implemented a sync (copy) method that needs to be called. init_name_debug was already necessary to keep state in sync between ff and ff_from_fsa, so no sync should be needed after it. supposing all modifications were through setters, then no explicit sync call would ever be needed; updates could be mirrored. -struct FsaFeatureFunctionData -{ - void init_name_debug(std::string const& n,bool debug) { - name_=n; - debug_=debug; - } - //HACK for diamond inheritance (w/o costing performance) - FsaFeatureFunctionData *sync_to_; - - void sync() const { // call this if you modify any fields after your constructor is done - if (sync_to_) { - DBGINIT("sync to "<<*sync_to_); - *sync_to_=*this; - DBGINIT("synced result="<<*sync_to_<< " from this="<<*this); - } else { - DBGINIT("nobody to sync to - from FeatureFunctionData this="<<*this); - } - } - - friend std::ostream &operator<<(std::ostream &o,FsaFeatureFunctionData const& d) { - o << "[FSA "< - static inline T* state_as(void *p) { return (T*)p; } - template - static inline T const* state_as(void const* p) { return (T*)p; } - std::string describe_features(FeatureVector const& feats) { - std::ostringstream o; - o<" for lm. -protected: - int ssz; // don't forget to set this. default 0 (it may depend on params of course) - // this can be called instead or after constructor (also set bytes and end_phrase_) - void set_state_bytes(int sb=0) { - if (start.size()!=sb) start.resize(sb); - if (h_start.size()!=sb) h_start.resize(sb); - ssz=sb; - } - void set_end_phrase(WordID single) { - end_phrase_=singleton_sentence(single); - } - - inline void static to_state(void *state,char const* begin,char const* end) { - std::memcpy(state,begin,end-begin); - } - inline void static to_state(void *state,char const* begin,int n) { - std::memcpy(state,begin,n); - } - template - inline void static to_state(void *state,T const* begin,int n=1) { - to_state(state,(char const*)begin,n*sizeof(T)); - } - template - inline void static to_state(void *state,T const* begin,T const* end) { - to_state(state,(char const*)begin,(char const*)end); - } - inline static char hexdigit(int i) { - int j=i-10; - return j>=0?'a'+j:'0'+i; - } - inline static void print_hex_byte(std::ostream &o,unsigned c) { - o<>4); - o<Add(v); - } - -}; - -#endif diff --git a/decoder/ff_fsa_dynamic.h b/decoder/ff_fsa_dynamic.h deleted file mode 100755 index 6f75bbe5..00000000 --- a/decoder/ff_fsa_dynamic.h +++ /dev/null @@ -1,208 +0,0 @@ -#ifndef FF_FSA_DYNAMIC_H -#define FF_FSA_DYNAMIC_H - -struct SentenceMetadata; - -#include "ff_fsa_data.h" -#include "hg.h" // can't forward declare nested Hypergraph::Edge class -#include - -// the type-erased interface - -//FIXME: diamond inheritance problem. make a copy of the fixed data? or else make the dynamic version not wrap but rather be templated CRTP base (yuck) -struct FsaFeatureFunction : public FsaFeatureFunctionData { - static const bool simple_phrase_score=false; - virtual int markov_order() const = 0; - - // see ff_fsa.h - FsaFeatureFunctionBase gives you reasonable impls of these if you override just ScanAccum - virtual void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID w,void const* state,void *next_state,Accum *a) const = 0; - virtual void ScanPhraseAccum(SentenceMetadata const& smeta,Hypergraph::Edge const & edge, - WordID const* i, WordID const* end, - void const* state,void *next_state,Accum *accum) const = 0; - virtual void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID const* i, WordID const* end, - void const* state,Accum *accum) const = 0; - virtual void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *accum) const = 0; - - virtual int early_score_words(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,Accum *accum) const { return 0; } - // called after constructor, before use - virtual void Init() = 0; - virtual std::string usage_v(bool param,bool verbose) const { - return FeatureFunction::usage_helper("unnamed_dynamic_fsa_feature","","",param,verbose); - } - virtual void init_name_debug(std::string const& n,bool debug) { - FsaFeatureFunctionData::init_name_debug(n,debug); - } - - virtual void print_state(std::ostream &o,void const*state) const { - FsaFeatureFunctionData::print_state(o,state); - } - virtual std::string describe() const { return "[FSA unnamed_dynamic_fsa_feature]"; } - - //end_phrase() - virtual ~FsaFeatureFunction() {} - - // no need to override: - std::string describe_state(void const* state) const { - std::ostringstream o; - print_state(o,state); - return o.str(); - } -}; - -// conforming to above interface, type erases FsaImpl -// you might be wondering: why do this? answer: it's cool, and it means that the bottom-up ff over ff_fsa wrapper doesn't go through multiple layers of dynamic dispatch -// usage: typedef FsaFeatureFunctionDynamic MyFsaDyn; -template -struct FsaFeatureFunctionDynamic : public FsaFeatureFunction { - static const bool simple_phrase_score=Impl::simple_phrase_score; - Impl& d() { return impl;//static_cast(*this); - } - Impl const& d() const { return impl; - //static_cast(*this); - } - int markov_order() const { return d().markov_order(); } - - std::string describe() const { - return d().describe(); - } - - virtual void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID w,void const* state,void *next_state,Accum *a) const { - return d().ScanAccum(smeta,edge,w,state,next_state,a); - } - - virtual void ScanPhraseAccum(SentenceMetadata const& smeta,Hypergraph::Edge const & edge, - WordID const* i, WordID const* end, - void const* state,void *next_state,Accum *a) const { - return d().ScanPhraseAccum(smeta,edge,i,end,state,next_state,a); - } - - virtual void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID const* i, WordID const* end, - void const* state,Accum *a) const { - return d().ScanPhraseAccumOnly(smeta,edge,i,end,state,a); - } - - virtual void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *a) const { - return d().ScanPhraseAccumBounce(smeta,edge,i,end,cs,ns,a); - } - - virtual int early_score_words(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,Accum *accum) const { - return d().early_score_words(smeta,edge,i,end,accum); - } - - static std::string usage(bool param,bool verbose) { - return Impl::usage(param,verbose); - } - - std::string usage_v(bool param,bool verbose) const { - return Impl::usage(param,verbose); - } - - virtual void print_state(std::ostream &o,void const*state) const { - return d().print_state(o,state); - } - - void init_name_debug(std::string const& n,bool debug) { - FsaFeatureFunction::init_name_debug(n,debug); - d().init_name_debug(n,debug); - } - - virtual void Init() { - d().sync_to_=(FsaFeatureFunctionData*)this; - d().Init(); - d().sync(); - } - - template - FsaFeatureFunctionDynamic(I const& param) : impl(param) { - Init(); - } -private: - Impl impl; -}; - -// constructor takes ptr or shared_ptr to Impl, otherwise same as above - note: not virtual -template -struct FsaFeatureFunctionPimpl : public FsaFeatureFunctionData { - typedef boost::shared_ptr Pimpl; - static const bool simple_phrase_score=Impl::simple_phrase_score; - Impl const& d() const { return *p_; } - int markov_order() const { return d().markov_order(); } - - std::string describe() const { - return d().describe(); - } - - void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID w,void const* state,void *next_state,Accum *a) const { - return d().ScanAccum(smeta,edge,w,state,next_state,a); - } - - void ScanPhraseAccum(SentenceMetadata const& smeta,Hypergraph::Edge const & edge, - WordID const* i, WordID const* end, - void const* state,void *next_state,Accum *a) const { - return d().ScanPhraseAccum(smeta,edge,i,end,state,next_state,a); - } - - void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID const* i, WordID const* end, - void const* state,Accum *a) const { - return d().ScanPhraseAccumOnly(smeta,edge,i,end,state,a); - } - - void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *a) const { - return d().ScanPhraseAccumBounce(smeta,edge,i,end,cs,ns,a); - } - - int early_score_words(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,Accum *accum) const { - return d().early_score_words(smeta,edge,i,end,accum); - } - - static std::string usage(bool param,bool verbose) { - return Impl::usage(param,verbose); - } - - std::string usage_v(bool param,bool verbose) const { - return Impl::usage(param,verbose); - } - - void print_state(std::ostream &o,void const*state) const { - return d().print_state(o,state); - } - -#if 0 - // this and Init() don't touch p_ because we want to leave the original alone. - void init_name_debug(std::string const& n,bool debug) { - FsaFeatureFunctionData::init_name_debug(n,debug); - } -#endif - void Init() { - p_=hold_pimpl_.get(); -#if 0 - d().sync_to_=static_cast(this); - d().Init(); -#endif - *static_cast(this)=d(); - } - - FsaFeatureFunctionPimpl(Impl const* const p) : hold_pimpl_(p,null_deleter()) { - Init(); - } - FsaFeatureFunctionPimpl(Pimpl const& p) : hold_pimpl_(p) { - Init(); - } -private: - Impl const* p_; - Pimpl hold_pimpl_; -}; - -typedef FsaFeatureFunctionPimpl FsaFeatureFunctionFwd; // allow ff_from_fsa for an existing dynamic-type ff (as opposed to usual register a wrapped known-type FSA in ff_register, which is more efficient) -//typedef FsaFeatureFunctionDynamic DynamicFsaFeatureFunctionFwd; //if you really need to have a dynamic fsa facade that's also a dynamic fsa - -//TODO: combine 2 (or N) FsaFeatureFunction (type erased) - - -#endif diff --git a/decoder/ff_lm.cc b/decoder/ff_lm.cc index afa36b96..5e16d4e3 100644 --- a/decoder/ff_lm.cc +++ b/decoder/ff_lm.cc @@ -46,7 +46,6 @@ char const* usage_verbose="-n determines the name of the feature (and its weight #endif #include "ff_lm.h" -#include "ff_lm_fsa.h" #include #include @@ -69,10 +68,6 @@ char const* usage_verbose="-n determines the name of the feature (and its weight using namespace std; -string LanguageModelFsa::usage(bool param,bool verbose) { - return FeatureFunction::usage_helper("LanguageModelFsa",usage_short,usage_verbose,param,verbose); -} - string LanguageModel::usage(bool param,bool verbose) { return FeatureFunction::usage_helper(usage_name,usage_short,usage_verbose,param,verbose); } @@ -524,49 +519,6 @@ LanguageModel::LanguageModel(const string& param) { SetStateSize(LanguageModelImpl::OrderToStateSize(order)); } -//TODO: decide whether to waste a word of space so states are always none-terminated for SRILM. otherwise we have to copy -void LanguageModelFsa::set_ngram_order(int i) { - assert(i>0); - ngram_order_=i; - ctxlen_=i-1; - set_state_bytes(ctxlen_*sizeof(WordID)); - WordID *ss=(WordID*)start.begin(); - WordID *hs=(WordID*)h_start.begin(); - if (ctxlen_) { // avoid segfault in case of unigram lm (0 state) - set_end_phrase(TD::Convert("")); -// se is pretty boring in unigram case, just adds constant prob. check that this is what we want - ss[0]=TD::Convert(""); // start-sentence context (length 1) - hs[0]=0; // empty context - for (int i=1;ifloor_; - set_ngram_order(lmorder); -} - -void LanguageModelFsa::print_state(ostream &o,void const* st) const { - WordID const *wst=(WordID const*)st; - o<<'['; - bool sp=false; - for (int i=ctxlen_;i>0;sp=true) { - --i; - WordID w=wst[i]; - if (w==0) continue; - if (sp) o<<' '; - o << TD::Convert(w); - } - o<<']'; -} - Features LanguageModel::features() const { return single_feature(fid_); } diff --git a/decoder/ff_lm_fsa.h b/decoder/ff_lm_fsa.h deleted file mode 100755 index 85b7ef44..00000000 --- a/decoder/ff_lm_fsa.h +++ /dev/null @@ -1,140 +0,0 @@ -#ifndef FF_LM_FSA_H -#define FF_LM_FSA_H - -//FIXME: when FSA_LM_PHRASE 1, 3gram fsa has differences, especially with unk words, in about the 4th decimal digit (about .05%), compared to regular ff_lm. this is USUALLY a bug (there's way more actual precision in there). this was with #define LM_FSA_SHORTEN_CONTEXT 1 and 0 (so it's not that). also, LM_FSA_SHORTEN_CONTEXT gives identical scores with FSA_LM_PHRASE 0 - -// enabling for now - retest unigram+ more, solve above puzzle - -// some impls in ff_lm.cc - -#define FSA_LM_PHRASE 1 - -#define FSA_LM_DEBUG 0 -#if FSA_LM_DEBUG -# define FSALMDBG(e,x) FSADBGif(debug(),e,x) -# define FSALMDBGnl(e) FSADBGif_nl(debug(),e) -#else -# define FSALMDBG(e,x) -# define FSALMDBGnl(e) -#endif - -#include "ff_fsa.h" -#include "ff_lm.h" - -#ifndef TD__none -// replacing dependency on SRILM -#define TD__none -1 -#endif - -namespace { -WordID empty_context=TD__none; -} - -struct LanguageModelFsa : public FsaFeatureFunctionBase { - typedef WordID * W; - typedef WordID const* WP; - - // overrides; implementations in ff_lm.cc - typedef SingleFeatureAccumulator Accum; - static std::string usage(bool,bool); - LanguageModelFsa(std::string const& param); - int markov_order() const { return ctxlen_; } - void print_state(std::ostream &,void const *) const; - inline Featval floored(Featval p) const { - return pleft;--e) - if (e[-1]!=TD__none) break; - //post: [left,e] are the seen left words - return e; - } - - template - void ScanAccum(SentenceMetadata const& /* smeta */,Hypergraph::Edge const& edge,WordID w,void const* old_st,void *new_st,Accum *a) const { -#if USE_INFO_EDGE - Hypergraph::Edge &de=(Hypergraph::Edge &)edge; -#endif - if (!ctxlen_) { - Add(floored(pimpl_->WordProb(w,&empty_context)),a); - } else { - WordID ctx[ngram_order_]; //alloca if you don't have C99 - state_copy(ctx,old_st); - ctx[ctxlen_]=TD__none; - Featval p=floored(pimpl_->WordProb(w,ctx)); - FSALMDBG(de,"p("<ShortenContext(nst,ctxlen_); -#endif - Add(p,a); - } - } - -#if FSA_LM_PHRASE - //FIXME: there is a bug in here somewhere, or else the 3gram LM we use gives different scores for phrases (impossible? BOW nonzero when shortening context past what LM has?) - template - void ScanPhraseAccum(SentenceMetadata const& /* smeta */,const Hypergraph::Edge&edge,WordID const* begin,WordID const* end,void const* old_st,void *new_st,Accum *a) const { - Hypergraph::Edge &de=(Hypergraph::Edge &)edge;(void)de; - if (begin==end) return; // otherwise w/ shortening it's possible to end up with no words at all. - /* // this is forcing unigram prob always. we will instead build the phrase - if (!ctxlen_) { - Featval p=0; - for (;iWordProb(*i,e&mpty_context)); - Add(p,a); - return; - } */ - int nw=end-begin; - WP st=(WP)old_st; - WP st_end=st+ctxlen_; // may include some null already (or none if full) - int nboth=nw+ctxlen_; - WordID ctx[nboth+1]; - ctx[nboth]=TD__none; - // reverse order - state at very end of context, then [i,end) in rev order ending at ctx[0] - W ctx_score_end=wordcpy_reverse(ctx,begin,end); - wordcpy(ctx_score_end,st,st_end); // st already reversed. - assert(ctx_score_end==ctx+nw); - // we could just copy the filled state words, but it probably doesn't save much time (and might cost some to scan to find the nones. most contexts are full except for the shortest source spans. - FSALMDBG(de," scan.r->l("<ctx;--ctx_score_end) - p+=floored(pimpl_->WordProb(ctx_score_end[-1],ctx_score_end)); - //TODO: look for score discrepancy - - // i had some idea that maybe shortencontext would return a different prob if the length provided was > ctxlen_; however, since the same disagreement happens with LM_FSA_SHORTEN_CONTEXT 0 anyway, it's not that. perhaps look to SCAN_PHRASE_ACCUM_OVERRIDE - make sure they do the right thing. -#if LM_FSA_SHORTEN_CONTEXT - p+=pimpl_->ShortenContext(ctx,nboth - need to use factory rather than ctor. -#if 0 -template -inline void RegisterFsa(bool ff_also=true,bool fsa_prefix_ff=true) { - assert(!ff_also); -// global_fsa_ff_registry->RegisterFsa(); -//if (ff_also) ff_registry.RegisterFF >(prefix_fsa(DynFsa::usage(false,false)),fsa_prefix_ff); -} -#endif - -//TODO: ff from fsa that uses pointer to fsa impl? e.g. in LanguageModel we share underlying lm file by recognizing same param, but without that effort, otherwise stateful ff may duplicate state if we enable both fsa and ff_from_fsa -template -inline void RegisterFsaImpl(bool ff_also=true,bool fsa_prefix_ff=false) { - typedef FsaFeatureFunctionDynamic DynFsa; - typedef FeatureFunctionFromFsa FFFrom; - std::string name=FsaImpl::usage(false,false); - fsa_ff_registry.Register(new FsaFactory); - if (ff_also) - ff_registry.Register(prefix_fsa(name,fsa_prefix_ff),new FFFactory); -} template inline void RegisterFF() { ff_registry.Register(new FFFactory); } -template -inline void RegisterFsaDynToFF(std::string name,bool prefix=true) { - typedef FsaFeatureFunctionDynamic DynFsa; - ff_registry.Register(prefix?"DynamicFsa"+name:name,new FFFactory >); -} - -template -inline void RegisterFsaDynToFF(bool prefix=true) { - RegisterFsaDynToFF(FsaImpl::usage(false,false),prefix); -} - void register_feature_functions(); #endif diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 3be5b82d..5d1910fb 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -57,7 +57,7 @@ TEST_F(HGTest,Union) { c3 = ViterbiESentence(hg1, &t3); int l3 = ViterbiPathLength(hg1); cerr << c3 << "\t" << TD::GetString(t3) << endl; - EXPECT_FLOAT_EQ(c2, c3); + EXPECT_FLOAT_EQ(c2.as_float(), c3.as_float()); EXPECT_EQ(TD::GetString(t2), TD::GetString(t3)); EXPECT_EQ(l2, l3); @@ -117,7 +117,7 @@ TEST_F(HGTest,InsideScore) { cerr << "cost: " << cost << "\n"; hg.PrintGraphviz(); prob_t inside = Inside(hg); - EXPECT_FLOAT_EQ(1.7934048, inside); // computed by hand + EXPECT_FLOAT_EQ(1.7934048, inside.as_float()); // computed by hand vector post; inside = hg.ComputeBestPathThroughEdges(&post); EXPECT_FLOAT_EQ(-0.3, log(inside)); // computed by hand @@ -282,13 +282,13 @@ TEST_F(HGTest, TestGenericInside) { hg.Reweight(wts); vector inside; prob_t ins = Inside(hg, &inside); - EXPECT_FLOAT_EQ(1.7934048, ins); // computed by hand + EXPECT_FLOAT_EQ(1.7934048, ins.as_float()); // computed by hand vector outside; Outside(hg, inside, &outside); EXPECT_EQ(3, outside.size()); - EXPECT_FLOAT_EQ(1.7934048, outside[0]); - EXPECT_FLOAT_EQ(1.3114071, outside[1]); - EXPECT_FLOAT_EQ(1.0, outside[2]); + EXPECT_FLOAT_EQ(1.7934048, outside[0].as_float()); + EXPECT_FLOAT_EQ(1.3114071, outside[1].as_float()); + EXPECT_FLOAT_EQ(1.0, outside[2].as_float()); } TEST_F(HGTest,TestGenericInside2) { @@ -327,8 +327,8 @@ TEST_F(HGTest,TestAddExpectations) { SparseVector feat_exps; prob_t z = InsideOutside, EdgeFeaturesAndProbWeightFunction>(hg, &feat_exps); - EXPECT_FLOAT_EQ(-2.5439765, feat_exps.value(FD::Convert("f1")) / z); - EXPECT_FLOAT_EQ(-2.6357865, feat_exps.value(FD::Convert("f2")) / z); + EXPECT_FLOAT_EQ(-2.5439765, (feat_exps.value(FD::Convert("f1")) / z).as_float()); + EXPECT_FLOAT_EQ(-2.6357865, (feat_exps.value(FD::Convert("f2")) / z).as_float()); cerr << feat_exps << endl; cerr << "Z=" << z << endl; } diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index f87b7274..993627f0 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -9,6 +9,7 @@ #include #include +#include "stringlib.h" #include "verbose.h" #include "hg.h" #include "prob.h" @@ -204,6 +205,7 @@ bool LoadAgenda(const string& file, vector >* a) { } int main(int argc, char** argv) { + cerr << "THIS SOFTWARE IS DEPRECATED YOU SHOULD USE mpi_flex_optimize\n"; #ifdef HAVE_MPI mpi::environment env(argc, argv); mpi::communicator world; -- cgit v1.2.3 From eb4b8a6ca070794db1a01b04570e9aaf346881ae Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 14 Oct 2011 11:53:48 +0100 Subject: one more to remove --- decoder/cdec-fsa.ini | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100755 decoder/cdec-fsa.ini diff --git a/decoder/cdec-fsa.ini b/decoder/cdec-fsa.ini deleted file mode 100755 index 05aaefd4..00000000 --- a/decoder/cdec-fsa.ini +++ /dev/null @@ -1,10 +0,0 @@ -cubepruning_pop_limit=200 -feature_function=WordPenalty -feature_function=ArityPenalty -feature_function=WordPenaltyFsa -#feature_function=LongerThanPrev -feature_function=ShorterThanPrev debug -add_pass_through_rules=true -formalism=scfg -grammar=mt09.grammar.gz -weights=weights-fsa -- cgit v1.2.3 From f036d4ec5c79db95df3470adb7cd317ff258ab7d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 14 Oct 2011 22:39:37 +0100 Subject: le optimizer --- training/Makefile.am | 4 + training/mpi_flex_optimize.cc | 346 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 350 insertions(+) create mode 100644 training/mpi_flex_optimize.cc diff --git a/training/Makefile.am b/training/Makefile.am index 0b598fd5..2a11ae52 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -12,6 +12,7 @@ bin_PROGRAMS = \ mpi_extract_reachable \ mpi_extract_features \ mpi_online_optimize \ + mpi_flex_optimize \ mpi_batch_optimize \ mpi_compute_cllh \ augment_grammar @@ -25,6 +26,9 @@ TESTS = lbfgs_test optimize_test mpi_online_optimize_SOURCES = mpi_online_optimize.cc online_optimizer.cc mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_flex_optimize_SOURCES = mpi_flex_optimize.cc online_optimizer.cc optimize.cc +mpi_flex_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz diff --git a/training/mpi_flex_optimize.cc b/training/mpi_flex_optimize.cc new file mode 100644 index 00000000..87c5f331 --- /dev/null +++ b/training/mpi_flex_optimize.cc @@ -0,0 +1,346 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "stringlib.h" +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "optimize.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" +#include "sampler.h" + +#ifdef HAVE_MPI +#include +#include +namespace mpi = boost::mpi; +#endif + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("cdec_config,c",po::value(),"Decoder configuration file") + ("weights,w",po::value(),"Initial feature weights") + ("training_data,d",po::value(),"Training data") + ("minibatch_size_per_proc,s", po::value()->default_value(6), "Number of training instances evaluated per processor in each minibatch") + ("optimization_method,m", po::value()->default_value("lbfgs"), "Optimization method (options: lbfgs, sgd, rprop)") + ("minibatch_iterations,i", po::value()->default_value(10), "Number of optimization iterations per minibatch (1 = standard SGD)") + ("iterations,I", po::value()->default_value(50), "Number of passes through the training data before termination") + ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") + ("lbfgs_memory_buffers,M", po::value()->default_value(10), "Number of memory buffers for LBFGS history") + ("eta_0,e", po::value()->default_value(0.1), "Initial learning rate for SGD") + ("L1,1","Use L1 regularization") + ("L2,2","Use L2 regularization") + ("regularization_strength,C", po::value()->default_value(1.0), "Regularization strength (C)"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("cdec_config")) { + cerr << "General-purpose minibatch online optimizer (MPI support " +#if HAVE_MPI + << "enabled" +#else + << "not enabled" +#endif + << ")\n" << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* order) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int id = 0; + while(in) { + getline(in, line); + if (!in) break; + if (id % size == rank) { + c->push_back(line); + order->push_back(id); + } + ++id; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct CopyHGsObserver : public DecoderObserver { + Hypergraph* hg_; + Hypergraph* gold_hg_; + + // this can free up some memory + void RemoveRules(Hypergraph* h) { + for (unsigned i = 0; i < h->edges_.size(); ++i) + h->edges_[i].rule_.reset(); + } + + void SetCurrentHypergraphs(Hypergraph* h, Hypergraph* gold_h) { + hg_ = h; + gold_hg_ = gold_h; + } + + virtual void NotifyDecodingStart(const SentenceMetadata&) { + state = 1; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + *hg_ = *hg; + RemoveRules(hg_); + assert(state == 1); + state = 2; + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata&, Hypergraph* hg) { + assert(state == 2); + state = 3; + *gold_hg_ = *hg; + RemoveRules(gold_hg_); + } + + virtual void NotifyDecodingComplete(const SentenceMetadata&) { + if (state == 3) { + } else { + hg_->clear(); + gold_hg_->clear(); + } + } + + int state; +}; + +void ReadConfig(const string& ini, istringstream* out) { + ReadFile rf(ini); + istream& in = *rf.stream(); + ostringstream os; + while(in) { + string line; + getline(in, line); + if (!in) continue; + os << line << endl; + } + out->str(os.str()); +} + +#ifdef HAVE_MPI +namespace boost { namespace mpi { + template<> + struct is_commutative >, SparseVector > + : mpl::true_ { }; +} } // end namespace boost::mpi +#endif + +void AddGrad(const SparseVector x, double s, SparseVector* acc) { + for (SparseVector::const_iterator it = x.begin(); it != x.end(); ++it) + acc->add_value(it->first, it->second.as_float() * s); +} + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + MT19937* rng = NULL; + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return 1; + + boost::shared_ptr o; + const unsigned lbfgs_memory_buffers = conf["lbfgs_memory_buffers"].as(); + + istringstream ins; + ReadConfig(conf["cdec_config"].as(), &ins); + Decoder decoder(&ins); + + // load initial weights + vector init_weights; + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &init_weights); + + vector corpus; + vector ids; + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); + assert(corpus.size() > 0); + + const unsigned size_per_proc = conf["minibatch_size_per_proc"].as(); + if (size_per_proc > corpus.size()) { + cerr << "Minibatch size must be smaller than corpus size!\n"; + return 1; + } + + size_t total_corpus_size = 0; +#ifdef HAVE_MPI + reduce(world, corpus.size(), total_corpus_size, std::plus(), 0); +#else + total_corpus_size = corpus.size(); +#endif + + if (conf.count("random_seed")) + rng = new MT19937(conf["random_seed"].as()); + else + rng = new MT19937; + + const unsigned minibatch_iterations = conf["minibatch_iterations"].as(); + + if (rank == 0) { + cerr << "Total corpus size: " << total_corpus_size << endl; + const unsigned batch_size = size_per_proc * size; + } + + SparseVector x; + Weights::InitSparseVector(init_weights, &x); + CopyHGsObserver observer; + + int write_weights_every_ith = 100; // TODO configure + int titer = -1; + + vector& lambdas = decoder.CurrentWeightVector(); + lambdas.swap(init_weights); + init_weights.clear(); + + int iter = -1; + bool converged = false; + while (!converged) { +#ifdef HAVE_MPI + mpi::timer timer; +#endif + x.init_vector(&lambdas); + ++iter; ++titer; +#if 0 + if (rank == 0) { + converged = (iter == max_iteration); + Weights::SanityCheck(lambdas); + Weights::ShowLargestFeatures(lambdas); + string fname = "weights.cur.gz"; + if (iter % write_weights_every_ith == 0) { + ostringstream o; o << "weights.epoch_" << (ai+1) << '.' << iter << ".gz"; + fname = o.str(); + } + if (converged && ((ai+1)==agenda.size())) { fname = "weights.final.gz"; } + ostringstream vv; + vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast(corpus.size())) << " eta=" << lr->eta(titer); + const string svv = vv.str(); + cerr << svv << endl; + Weights::WriteToFile(fname, lambdas, true, &svv); + } +#endif + + vector hgs(size_per_proc); + vector gold_hgs(size_per_proc); + for (int i = 0; i < size_per_proc; ++i) { + int ei = corpus.size() * rng->next(); + int id = ids[ei]; + observer.SetCurrentHypergraphs(&hgs[i], &gold_hgs[i]); + decoder.SetId(id); + decoder.Decode(corpus[ei], &observer); + } + + SparseVector local_grad, g; + double local_obj = 0; + o.reset(); + for (unsigned mi = 0; mi < minibatch_iterations; ++mi) { + local_grad.clear(); + g.clear(); + local_obj = 0; + + for (unsigned i = 0; i < size_per_proc; ++i) { + Hypergraph& hg = hgs[i]; + Hypergraph& hg_gold = gold_hgs[i]; + if (hg.edges_.size() < 2) continue; + + hg.Reweight(lambdas); + hg_gold.Reweight(lambdas); + SparseVector model_exp, gold_exp; + const prob_t z = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(hg, &model_exp); + local_obj += log(z); + model_exp /= z; + AddGrad(model_exp, 1.0, &local_grad); + model_exp.clear(); + + const prob_t goldz = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(hg_gold, &gold_exp); + local_obj -= log(goldz); + + if (log(z) - log(goldz) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_model_z < log_gold_z: " << log(z) << " " << log(goldz) << endl; + return 1; + } + + gold_exp /= goldz; + AddGrad(gold_exp, -1.0, &local_grad); + } + + double obj = 0; +#ifdef HAVE_MPI + // TODO obj + reduce(world, local_grad, g, std::plus >(), 0); +#else + obj = local_obj; + g.swap(local_grad); +#endif + local_grad.clear(); + if (rank == 0) { + g /= (size_per_proc * size); + if (!o) + o.reset(new LBFGSOptimizer(FD::NumFeats(), lbfgs_memory_buffers)); + vector gg(FD::NumFeats()); + if (gg.size() != lambdas.size()) { lambdas.resize(gg.size()); } + for (SparseVector::const_iterator it = g.begin(); it != g.end(); ++it) + if (it->first) { gg[it->first] = it->second; } + cerr << "OBJ: " << obj << endl; + o->Optimize(obj, gg, &lambdas); + } +#ifdef HAVE_MPI + broadcast(world, x, 0); + broadcast(world, converged, 0); + world.barrier(); + if (rank == 0) { cerr << " ELAPSED TIME THIS ITERATION=" << timer.elapsed() << endl; } +#endif + } + } + return 0; +} -- cgit v1.2.3 From 957d90991b4ec80b9877126c736bd60768b094aa Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 17 Oct 2011 16:58:26 +0100 Subject: Chris, I'd like you to review this for use with your rules that contain and . --- decoder/ff_klm.cc | 72 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 23 deletions(-) diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 6d9aca54..658aef80 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -71,6 +71,8 @@ string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; } +namespace { + struct VMapper : public lm::ngram::EnumerateVocab { VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } void Add(lm::WordIndex index, const StringPiece &str) { @@ -83,66 +85,90 @@ struct VMapper : public lm::ngram::EnumerateVocab { const lm::WordIndex kLM_UNKNOWN_TOKEN; }; -template -class KLanguageModelImpl { +#pragma pack(push) +#pragma pack(1) - static inline const lm::ngram::ChartState& RemnantLMState(const void* state) { - return *static_cast(state); +struct BoundaryAnnotatedState { + lm::ngram::ChartState state; + bool seen_bos, seen_eos; +}; + +#pragma pack(pop) + +void BoundaryCheck(bool &annotated, bool sub, double &ret) { + if (!sub) return; + if (annotated) { + ret -= 100.0; + } else { + annotated = true; } +} +} // namespace + +template +class KLanguageModelImpl { public: double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { *oovs = 0; const vector& e = rule.e(); - lm::ngram::RuleScore ruleScore(*ngram_, *static_cast(remnant)); + BoundaryAnnotatedState &annotated = *static_cast(remnant); + lm::ngram::RuleScore ruleScore(*ngram_, annotated.state); + annotated.seen_bos = false; + annotated.seen_eos = false; unsigned i = 0; + double ret = 0.0; if (e.size()) { if (e[i] == kCDEC_SOS) { ++i; ruleScore.BeginSentence(); + annotated.seen_bos = true; } else if (e[i] <= 0) { // special case for left-edge NT - const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]); - ruleScore.BeginNonTerminal(prevState, 0.0f); // TODO + const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[0]]); + ruleScore.BeginNonTerminal(sub.state, 0.0f); + annotated.seen_bos = sub.seen_bos; + annotated.seen_eos = sub.seen_eos; ++i; } } for (; i < e.size(); ++i) { if (e[i] <= 0) { - const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]); - ruleScore.NonTerminal(prevState, 0.0f); // TODO + const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[i]]); + ruleScore.NonTerminal(sub.state, 0.0f); + BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret); + BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret); } else { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id if (cur_word == 0) (*oovs) += 1.0; + BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret); ruleScore.Terminal(cur_word); } } - double ret = ruleScore.Finish(); - static_cast(remnant)->ZeroRemaining(); + ret += ruleScore.Finish(); + annotated.state.ZeroRemaining(); return ret; } // this assumes no target words on final unary -> goal rule. is that ok? // for (n-1 left words) and (n-1 right words) - double FinalTraversalCost(const void* state, double* oovs) { + double FinalTraversalCost(const void* state_void, double* oovs) { + const BoundaryAnnotatedState &annotated = *static_cast(state_void); if (add_sos_eos_) { // rules do not produce , so do it here + assert(!annotated.seen_bos); + assert(!annotated.seen_eos); lm::ngram::ChartState cstate; lm::ngram::RuleScore ruleScore(*ngram_, cstate); ruleScore.BeginSentence(); - ruleScore.NonTerminal(RemnantLMState(state), 0.0f); + ruleScore.NonTerminal(annotated.state, 0.0f); ruleScore.Terminal(kEOS_); return ruleScore.Finish(); } else { // rules DO produce ... - double p = 0; - cerr << "not implemented"; abort(); // TODO - //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } - //if (UnscoredSize(state) > 0) { // are there unscored words - // if (kSOS_ != IthUnscoredWord(0, state)) { - // p -= 100 * UnscoredSize(state); - // } - //} - return p; + double ret = 0.0; + if (!annotated.seen_bos) ret -= 100.0; + if (!annotated.seen_eos) ret -= 100.0; + return ret; } } @@ -230,7 +256,7 @@ class KLanguageModelImpl { delete ngram_; } - int ReserveStateSize() const { return sizeof(lm::ngram::ChartState); } + int ReserveStateSize() const { return sizeof(BoundaryAnnotatedState); } private: const WordID kCDEC_UNK; -- cgit v1.2.3 From 3d1ed02a4e5d81aace80b0e004e96351d116630f Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 18 Oct 2011 10:25:56 +0100 Subject: Revised and handling --- decoder/ff_klm.cc | 84 ++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 26 deletions(-) diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 658aef80..3c941fbf 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,8 +12,8 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" +#define NEW_KENLM #undef NEW_KENLM -#ifdef NEW_KENLM #include "lm/left.hh" @@ -95,14 +95,58 @@ struct BoundaryAnnotatedState { #pragma pack(pop) -void BoundaryCheck(bool &annotated, bool sub, double &ret) { - if (!sub) return; - if (annotated) { - ret -= 100.0; - } else { - annotated = true; - } -} +template class BoundaryRuleScore { + public: + BoundaryRuleScore(const Model &m, BoundaryAnnotatedState &state) : + back_(m, state.state), + bos_(state.seen_bos), + eos_(state.seen_eos), + penalty_(0.0), + end_sentence_(m.GetVocabulary().EndSentence()) { + bos_ = false; + eos_ = false; + } + + void BeginSentence() { + back_.BeginSentence(); + bos_ = true; + } + + void BeginNonTerminal(const BoundaryAnnotatedState &sub) { + back_.BeginNonTerminal(sub.state, 0.0f); + bos_ = sub.seen_bos; + eos_ = sub.seen_eos; + } + + void NonTerminal(const BoundaryAnnotatedState &sub) { + back_.NonTerminal(sub.state, 0.0f); + // cdec only calls this if there's content. + if (sub.seen_bos) { + bos_ = true; + penalty_ -= 100.0f; + } + if (eos_) penalty_ -= 100.0f; + eos_ |= sub.seen_eos; + } + + void Terminal(lm::WordIndex word) { + back_.Terminal(word); + if (eos_) penalty_ -= 100.0f; + if (word == end_sentence_) eos_ = true; + } + + float Finish() { + return penalty_ + back_.Finish(); + } + + private: + lm::ngram::RuleScore back_; + bool &bos_, &eos_; + + float penalty_; + + lm::WordIndex end_sentence_; +}; } // namespace @@ -112,42 +156,30 @@ class KLanguageModelImpl { double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { *oovs = 0; const vector& e = rule.e(); - BoundaryAnnotatedState &annotated = *static_cast(remnant); - lm::ngram::RuleScore ruleScore(*ngram_, annotated.state); - annotated.seen_bos = false; - annotated.seen_eos = false; + BoundaryRuleScore ruleScore(*ngram_, *static_cast(remnant)); unsigned i = 0; - double ret = 0.0; if (e.size()) { if (e[i] == kCDEC_SOS) { ++i; ruleScore.BeginSentence(); - annotated.seen_bos = true; } else if (e[i] <= 0) { // special case for left-edge NT - const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[0]]); - ruleScore.BeginNonTerminal(sub.state, 0.0f); - annotated.seen_bos = sub.seen_bos; - annotated.seen_eos = sub.seen_eos; + ruleScore.BeginNonTerminal(*static_cast(ant_states[-e[0]])); ++i; } } for (; i < e.size(); ++i) { if (e[i] <= 0) { - const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[i]]); - ruleScore.NonTerminal(sub.state, 0.0f); - BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret); - BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret); + ruleScore.NonTerminal(*static_cast(ant_states[-e[i]])); } else { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id if (cur_word == 0) (*oovs) += 1.0; - BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret); ruleScore.Terminal(cur_word); } } - ret += ruleScore.Finish(); - annotated.state.ZeroRemaining(); + double ret = ruleScore.Finish(); + static_cast(remnant)->state.ZeroRemaining(); return ret; } -- cgit v1.2.3 From 04e38a57b19ea012895ac2efb39382c2e77833a9 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 18 Oct 2011 14:19:09 +0100 Subject: incorporate kenneth's fixes --- decoder/ff_klm.cc | 464 ------------------------------------------------------ 1 file changed, 464 deletions(-) diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 3c941fbf..ed6f731e 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,9 +12,6 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" -#define NEW_KENLM -#undef NEW_KENLM - #include "lm/left.hh" using namespace std; @@ -395,464 +392,3 @@ std::string KLanguageModelFactory::usage(bool params,bool verbose) const { return KLanguageModel::usage(params, verbose); } -#else - -using namespace std; - -static const unsigned char HAS_FULL_CONTEXT = 1; -static const unsigned char HAS_EOS_ON_RIGHT = 2; -static const unsigned char MASK = 7; - -// -x : rules include and -// -n NAME : feature id is NAME -bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) { - vector const& argv=SplitOnWhitespace(in); - *explicit_markers = false; - *featname="LanguageModel"; - *mapfile = ""; -#define LMSPEC_NEXTARG if (i==argv.end()) { \ - cerr << "Missing argument for "<<*last<<". "; goto usage; \ - } else { ++i; } - - for (vector::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { - string const& s=*i; - if (s[0]=='-') { - if (s.size()>2) goto fail; - switch (s[1]) { - case 'x': - *explicit_markers = true; - break; - case 'm': - LMSPEC_NEXTARG; *mapfile=*i; - break; - case 'n': - LMSPEC_NEXTARG; *featname=*i; - break; -#undef LMSPEC_NEXTARG - default: - fail: - cerr<<"Unknown KLanguageModel option "<empty()) - *filename=s; - else { - cerr<<"More than one filename provided. "; - goto usage; - } - } - } - if (!filename->empty()) - return true; -usage: - cerr << "KLanguageModel is incorrect!\n"; - return false; -} - -template -string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { - return "KLanguageModel"; -} - -struct VMapper : public lm::ngram::EnumerateVocab { - VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } - void Add(lm::WordIndex index, const StringPiece &str) { - const WordID cdec_id = TD::Convert(str.as_string()); - if (cdec_id >= out_->size()) - out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); - (*out_)[cdec_id] = index; - } - vector* out_; - const lm::WordIndex kLM_UNKNOWN_TOKEN; -}; - -template -class KLanguageModelImpl { - - // returns the number of unscored words at the left edge of a span - inline int UnscoredSize(const void* state) const { - return *(static_cast(state) + unscored_size_offset_); - } - - inline void SetUnscoredSize(int size, void* state) const { - *(static_cast(state) + unscored_size_offset_) = size; - } - - static inline const lm::ngram::State& RemnantLMState(const void* state) { - return *static_cast(state); - } - - inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const { - // if we were clever, we could use the memory pointed to by state to do all - // the work, avoiding this copy - memcpy(state, &lmstate, ngram_->StateSize()); - } - - lm::WordIndex IthUnscoredWord(int i, const void* state) const { - const lm::WordIndex* const mem = reinterpret_cast(static_cast(state) + unscored_words_offset_); - return mem[i]; - } - - void SetIthUnscoredWord(int i, lm::WordIndex index, void *state) const { - lm::WordIndex* mem = reinterpret_cast(static_cast(state) + unscored_words_offset_); - mem[i] = index; - } - - inline bool GetFlag(const void *state, unsigned char flag) const { - return (*(static_cast(state) + is_complete_offset_) & flag); - } - - inline void SetFlag(bool on, unsigned char flag, void *state) const { - if (on) { - *(static_cast(state) + is_complete_offset_) |= flag; - } else { - *(static_cast(state) + is_complete_offset_) &= (MASK ^ flag); - } - } - - inline bool HasFullContext(const void *state) const { - return GetFlag(state, HAS_FULL_CONTEXT); - } - - inline void SetHasFullContext(bool flag, void *state) const { - SetFlag(flag, HAS_FULL_CONTEXT, state); - } - - public: - double LookupWords(const TRule& rule, const vector& ant_states, double* pest_sum, double* oovs, double* est_oovs, void* remnant) { - double sum = 0.0; - double est_sum = 0.0; - int num_scored = 0; - int num_estimated = 0; - if (oovs) *oovs = 0; - if (est_oovs) *est_oovs = 0; - bool saw_eos = false; - bool has_some_history = false; - lm::ngram::State state = ngram_->NullContextState(); - const vector& e = rule.e(); - bool context_complete = false; - for (int j = 0; j < e.size(); ++j) { - if (e[j] < 1) { // handle non-terminal substitution - const void* astate = (ant_states[-e[j]]); - int unscored_ant_len = UnscoredSize(astate); - for (int k = 0; k < unscored_ant_len; ++k) { - const lm::WordIndex cur_word = IthUnscoredWord(k, astate); - const bool is_oov = (cur_word == 0); - double p = 0; - if (cur_word == kSOS_) { - state = ngram_->BeginSentenceState(); - if (has_some_history) { // this is immediately fully scored, and bad - p = -100; - context_complete = true; - } else { // this might be a real - num_scored = max(0, order_ - 2); - } - } else { - const lm::ngram::State scopy(state); - p = ngram_->Score(scopy, cur_word, state); - if (saw_eos) { p = -100; } - saw_eos = (cur_word == kEOS_); - } - has_some_history = true; - ++num_scored; - if (!context_complete) { - if (num_scored >= order_) context_complete = true; - } - if (context_complete) { - sum += p; - if (oovs && is_oov) (*oovs)++; - } else { - if (remnant) - SetIthUnscoredWord(num_estimated, cur_word, remnant); - ++num_estimated; - est_sum += p; - if (est_oovs && is_oov) (*est_oovs)++; - } - } - saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); - if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 - state = RemnantLMState(astate); - context_complete = true; - } - } else { // handle terminal - const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[j]); // in future, - // maybe handle emission - const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id - double p = 0; - const bool is_oov = (cur_word == 0); - if (cur_word == kSOS_) { - state = ngram_->BeginSentenceState(); - if (has_some_history) { // this is immediately fully scored, and bad - p = -100; - context_complete = true; - } else { // this might be a real - num_scored = max(0, order_ - 2); - } - } else { - const lm::ngram::State scopy(state); - p = ngram_->Score(scopy, cur_word, state); - if (saw_eos) { p = -100; } - saw_eos = (cur_word == kEOS_); - } - has_some_history = true; - ++num_scored; - if (!context_complete) { - if (num_scored >= order_) context_complete = true; - } - if (context_complete) { - sum += p; - if (oovs && is_oov) (*oovs)++; - } else { - if (remnant) - SetIthUnscoredWord(num_estimated, cur_word, remnant); - ++num_estimated; - est_sum += p; - if (est_oovs && is_oov) (*est_oovs)++; - } - } - } - if (pest_sum) *pest_sum = est_sum; - if (remnant) { - state.ZeroRemaining(); - SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); - SetRemnantLMState(state, remnant); - SetUnscoredSize(num_estimated, remnant); - SetHasFullContext(context_complete || (num_scored >= order_), remnant); - } - return sum; - } - - // this assumes no target words on final unary -> goal rule. is that ok? - // for (n-1 left words) and (n-1 right words) - double FinalTraversalCost(const void* state, double* oovs) { - if (add_sos_eos_) { // rules do not produce , so do it here - SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); - SetHasFullContext(1, dummy_state_); - SetUnscoredSize(0, dummy_state_); - dummy_ants_[1] = state; - *oovs = 0; - return LookupWords(*dummy_rule_, dummy_ants_, NULL, oovs, NULL, NULL); - } else { // rules DO produce ... - double p = 0; - if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } - if (UnscoredSize(state) > 0) { // are there unscored words - if (kSOS_ != IthUnscoredWord(0, state)) { - p -= 100 * UnscoredSize(state); - } - } - return p; - } - } - - // if this is not a class-based LM, returns w untransformed, - // otherwise returns a word class mapping of w, - // returns TD::Convert("") if there is no mapping for w - WordID ClassifyWordIfNecessary(WordID w) const { - if (word2class_map_.empty()) return w; - if (w >= word2class_map_.size()) - return kCDEC_UNK; - else - return word2class_map_[w]; - } - - // converts to cdec word id's to KenLM's id space, OOVs and end up at 0 - lm::WordIndex MapWord(WordID w) const { - if (w >= cdec2klm_map_.size()) - return 0; - else - return cdec2klm_map_[w]; - } - - public: - KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : - kCDEC_UNK(TD::Convert("")) , - add_sos_eos_(!explicit_markers) { - { - VMapper vm(&cdec2klm_map_); - lm::ngram::Config conf; - conf.enumerate_vocab = &vm; - ngram_ = new Model(filename.c_str(), conf); - } - order_ = ngram_->Order(); - cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; - state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); - unscored_size_offset_ = ngram_->StateSize(); - is_complete_offset_ = unscored_size_offset_ + 1; - unscored_words_offset_ = is_complete_offset_ + 1; - - // special handling of beginning / ending sentence markers - dummy_state_ = new char[state_size_]; - memset(dummy_state_, 0, state_size_); - dummy_ants_.push_back(dummy_state_); - dummy_ants_.push_back(NULL); - dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); - kSOS_ = MapWord(TD::Convert("")); - assert(kSOS_ > 0); - kEOS_ = MapWord(TD::Convert("")); - assert(kEOS_ > 0); - assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant - - // handle class-based LMs (unambiguous word->class mapping reqd.) - if (mapfile.size()) - LoadWordClasses(mapfile); - } - - void LoadWordClasses(const string& file) { - ReadFile rf(file); - istream& in = *rf.stream(); - string line; - vector dummy; - int lc = 0; - cerr << " Loading word classes from " << file << " ...\n"; - AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); - AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); - while(in) { - getline(in, line); - if (!in) continue; - dummy.clear(); - TD::ConvertSentence(line, &dummy); - ++lc; - if (dummy.size() != 2) { - cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; - abort(); - } - AddWordToClassMapping_(dummy[0], dummy[1]); - } - } - - void AddWordToClassMapping_(WordID word, WordID cls) { - if (word2class_map_.size() <= word) { - word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); - assert(word2class_map_.size() > word); - } - if(word2class_map_[word] != kCDEC_UNK) { - cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; - abort(); - } - word2class_map_[word] = cls; - } - - ~KLanguageModelImpl() { - delete ngram_; - delete[] dummy_state_; - } - - int ReserveStateSize() const { return state_size_; } - - private: - const WordID kCDEC_UNK; - lm::WordIndex kSOS_; // - requires special handling. - lm::WordIndex kEOS_; // - Model* ngram_; - const bool add_sos_eos_; // flag indicating whether the hypergraph produces and - // if this is true, FinalTransitionFeatures will "add" and - // if false, FinalTransitionFeatures will score anything with the - // markers in the right place (i.e., the beginning and end of - // the sentence) with 0, and anything else with -100 - - int order_; - int state_size_; - int unscored_size_offset_; - int is_complete_offset_; - int unscored_words_offset_; - char* dummy_state_; - vector dummy_ants_; - vector cdec2klm_map_; - vector word2class_map_; // if this is a class-based LM, this is the word->class mapping - TRulePtr dummy_rule_; -}; - -template -KLanguageModel::KLanguageModel(const string& param) { - string filename, mapfile, featname; - bool explicit_markers; - if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { - abort(); - } - try { - pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); - } catch (std::exception &e) { - std::cerr << e.what() << std::endl; - abort(); - } - fid_ = FD::Convert(featname); - oov_fid_ = FD::Convert(featname+"_OOV"); - cerr << "FID: " << oov_fid_ << endl; - SetStateSize(pimpl_->ReserveStateSize()); -} - -template -Features KLanguageModel::features() const { - return single_feature(fid_); -} - -template -KLanguageModel::~KLanguageModel() { - delete pimpl_; -} - -template -void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, - const Hypergraph::Edge& edge, - const vector& ant_states, - SparseVector* features, - SparseVector* estimated_features, - void* state) const { - double est = 0; - double oovs = 0; - double est_oovs = 0; - features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, &oovs, &est_oovs, state)); - estimated_features->set_value(fid_, est); - if (oov_fid_) { - if (oovs) features->set_value(oov_fid_, oovs); - if (est_oovs) estimated_features->set_value(oov_fid_, est_oovs); - } -} - -template -void KLanguageModel::FinalTraversalFeatures(const void* ant_state, - SparseVector* features) const { - double oovs = 0; - double lm = pimpl_->FinalTraversalCost(ant_state, &oovs); - features->set_value(fid_, lm); - if (oov_fid_ && oovs) - features->set_value(oov_fid_, oovs); -} - -template boost::shared_ptr CreateModel(const std::string ¶m) { - KLanguageModel *ret = new KLanguageModel(param); - ret->Init(); - return boost::shared_ptr(ret); -} - -boost::shared_ptr KLanguageModelFactory::Create(std::string param) const { - using namespace lm::ngram; - std::string filename, ignored_map; - bool ignored_markers; - std::string ignored_featname; - ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); - ModelType m; - if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; - - switch (m) { - case HASH_PROBING: - return CreateModel(param); - case TRIE_SORTED: - return CreateModel(param); - case ARRAY_TRIE_SORTED: - return CreateModel(param); - case QUANT_TRIE_SORTED: - return CreateModel(param); - case QUANT_ARRAY_TRIE_SORTED: - return CreateModel(param); - default: - UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); - } -} - -std::string KLanguageModelFactory::usage(bool params,bool verbose) const { - return KLanguageModel::usage(params, verbose); -} - -#endif -- cgit v1.2.3