From 3106cf8eca76df8b46d139b8f5ce5002200d660d Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 24 Oct 2011 18:17:24 +0100 Subject: KenLM update. EnumerateVocab moved up a namespace. Fix trie building when bigrams are pruned. Make Chris feel better about MurmurHashNative. --- BUILDING | 2 +- decoder/ff_csplit.cc | 2 +- decoder/ff_klm.cc | 2 +- klm/lm/config.hh | 6 ++- klm/lm/enumerate_vocab.hh | 2 - klm/lm/left.hh | 8 +-- klm/lm/model.cc | 4 +- klm/lm/read_arpa.cc | 2 + klm/lm/search_hashed.cc | 16 +++--- klm/lm/search_trie.cc | 2 +- klm/lm/sri.cc | 108 ---------------------------------------- klm/lm/sri.hh | 102 ------------------------------------- klm/lm/vocab.hh | 2 +- klm/util/mmap.cc | 2 +- klm/util/murmur_hash.cc | 15 ++++-- klm/util/tokenize_piece.hh | 75 ++++++++++++++++++++++++++++ klm/util/tokenize_piece_test.cc | 94 ++++++++++++++++++++++++++++++++++ training/augment_grammar.cc | 2 +- training/test_ngram.cc | 2 +- 19 files changed, 208 insertions(+), 240 deletions(-) delete mode 100644 klm/lm/sri.cc delete mode 100644 klm/lm/sri.hh create mode 100644 klm/util/tokenize_piece_test.cc diff --git a/BUILDING b/BUILDING index b7535d70..c7b954c7 100644 --- a/BUILDING +++ b/BUILDING @@ -33,7 +33,7 @@ Instructions for building If you're building on cygwin, their libtool is buggy; this make command works for now: - make LIBS+="-loolm -ldstruct -lmisc -lz -lboost_program_options" \ + make LIBS+="-lz -lboost_program_options" \ CFLAGS+="-Wno-sign-compare" 5) Test diff --git a/decoder/ff_csplit.cc b/decoder/ff_csplit.cc index dee6f4f9..3991d38f 100644 --- a/decoder/ff_csplit.cc +++ b/decoder/ff_csplit.cc @@ -155,7 +155,7 @@ void BasicCSplitFeatures::TraversalFeaturesImpl( } namespace { -struct CSVMapper : public lm::ngram::EnumerateVocab { +struct CSVMapper : public lm::EnumerateVocab { CSVMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } void Add(lm::WordIndex index, const StringPiece &str) { const WordID cdec_id = TD::Convert(str.as_string()); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index ed6f731e..a4b26f7c 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -70,7 +70,7 @@ string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { namespace { -struct VMapper : public lm::ngram::EnumerateVocab { +struct VMapper : public lm::EnumerateVocab { VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } void Add(lm::WordIndex index, const StringPiece &str) { const WordID cdec_id = TD::Convert(str.as_string()); diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 227b8512..8564661b 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -8,10 +8,12 @@ /* Configuration for ngram model. Separate header to reduce pollution. */ -namespace lm { namespace ngram { - +namespace lm { + class EnumerateVocab; +namespace ngram { + struct Config { // EFFECTIVE FOR BOTH ARPA AND BINARY READS diff --git a/klm/lm/enumerate_vocab.hh b/klm/lm/enumerate_vocab.hh index e734316b..27263621 100644 --- a/klm/lm/enumerate_vocab.hh +++ b/klm/lm/enumerate_vocab.hh @@ -5,7 +5,6 @@ #include "util/string_piece.hh" namespace lm { -namespace ngram { /* If you need the actual strings in the vocabulary, inherit from this class * and implement Add. Then put a pointer in Config.enumerate_vocab; it does @@ -23,7 +22,6 @@ class EnumerateVocab { EnumerateVocab() {} }; -} // namespace ngram } // namespace lm #endif // LM_ENUMERATE_VOCAB__ diff --git a/klm/lm/left.hh b/klm/lm/left.hh index bb3f5539..15464c82 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -176,16 +176,18 @@ template class RuleScore { float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1]; float *back = backoffs, *back2 = backoffs2; unsigned char next_use; - FullScoreReturn ret; - ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use)); + + // First word + ProcessRet(model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use)); if (!next_use) { left_done_ = true; out_.right = in.right; return; } + // Words after the first, so extending a bigram to begin with unsigned char extend_length = 2; for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) { - ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use)); + ProcessRet(model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use)); if (!next_use) { left_done_ = true; out_.right = in.right; diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 25f1ab7c..e4c1ec1d 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -91,8 +91,8 @@ template void GenericModel FullScoreReturn GenericModel::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); - if (ret.ngram_length - 1 < in_state.length) { - ret.prob = std::accumulate(in_state.backoff + ret.ngram_length - 1, in_state.backoff + in_state.length, ret.prob); + for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) { + ret.prob += *i; } return ret; } diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 455bc4ba..dce73f77 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -38,6 +38,8 @@ void ReadARPACounts(util::FilePiece &in, std::vector &number) { } if (static_cast(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic) UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?"); + UTIL_THROW_IF(line.size() >= 4 && StringPiece(line.data(), 4) == "blmt", FormatLoadException, "This looks like an IRSTLM binary file. Did you forget to pass --text yes to compile-lm?"); + UTIL_THROW_IF(line == "iARPA", FormatLoadException, "This looks like an IRSTLM iARPA file. You need an ARPA file. Run\n compile-lm --text yes " << in.FileName() << " " << in.FileName() << ".arpa\nfirst."); UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\."); } while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 334adf12..247832b0 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -87,14 +87,14 @@ template void ReadNGrams( ReadNGramHeader(f, n); // vocab ids of words in reverse order - WordIndex vocab_ids[n]; - uint64_t keys[n - 1]; + std::vector vocab_ids(n); + std::vector keys(n-1); typename Store::Packing::Value value; typename Middle::MutableIterator found; for (size_t i = 0; i < count; ++i) { - ReadNGram(f, n, vocab, vocab_ids, value, warn); + ReadNGram(f, n, vocab, &*vocab_ids.begin(), value, warn); - keys[0] = detail::CombineWordHash(static_cast(*vocab_ids), vocab_ids[1]); + keys[0] = detail::CombineWordHash(static_cast(vocab_ids.front()), vocab_ids[1]); for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); } @@ -106,9 +106,9 @@ template void ReadNGrams( util::FloatEnc fix_prob; for (lower = n - 3; ; --lower) { if (lower == -1) { - fix_prob.f = unigrams[vocab_ids[0]].prob; + fix_prob.f = unigrams[vocab_ids.front()].prob; fix_prob.i &= ~util::kSignBit; - unigrams[vocab_ids[0]].prob = fix_prob.f; + unigrams[vocab_ids.front()].prob = fix_prob.f; break; } if (middle[lower].UnsafeMutableFind(keys[lower], found)) { @@ -120,8 +120,8 @@ template void ReadNGrams( break; } } - if (lower != static_cast(n) - 3) FixSRI(lower, fix_prob.f, n, keys, vocab_ids, unigrams, middle); - activate(vocab_ids, n); + if (lower != static_cast(n) - 3) FixSRI(lower, fix_prob.f, n, &*keys.begin(), &*vocab_ids.begin(), unigrams, middle); + activate(&*vocab_ids.begin(), n); } store.FinishedInserting(); diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 5d8c70db..e3cf9547 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -493,7 +493,7 @@ template void BuildTrie(const std::string &file_pre util::scoped_FILE unigram_file; { std::string name(file_prefix + "unigrams"); - unigram_file.reset(OpenOrThrow(name.c_str(), "r")); + unigram_file.reset(OpenOrThrow(name.c_str(), "r+")); util::RemoveOrThrow(name.c_str()); } sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); diff --git a/klm/lm/sri.cc b/klm/lm/sri.cc deleted file mode 100644 index 825f699b..00000000 --- a/klm/lm/sri.cc +++ /dev/null @@ -1,108 +0,0 @@ -#include "lm/lm_exception.hh" -#include "lm/sri.hh" - -#include -#include - -#include - -namespace lm { -namespace sri { - -Vocabulary::Vocabulary() : sri_(new Vocab) {} - -Vocabulary::~Vocabulary() {} - -WordIndex Vocabulary::Index(const char *str) const { - WordIndex ret = sri_->getIndex(str); - // NGram wants the index of Vocab_Unknown for unknown words, but for some reason SRI returns Vocab_None here :-(. - if (ret == Vocab_None) { - return not_found_; - } else { - return ret; - } -} - -const char *Vocabulary::Word(WordIndex index) const { - return sri_->getWord(index); -} - -void Vocabulary::FinishedLoading() { - SetSpecial( - sri_->ssIndex(), - sri_->seIndex(), - sri_->unkIndex()); -} - -namespace { -Ngram *MakeSRIModel(const char *file_name, unsigned int ngram_length, Vocab &sri_vocab) { - sri_vocab.unkIsWord() = true; - std::auto_ptr ret(new Ngram(sri_vocab, ngram_length)); - File file(file_name, "r"); - errno = 0; - if (!ret->read(file)) { - UTIL_THROW(FormatLoadException, "reading file " << file_name << " with SRI failed."); - } - return ret.release(); -} -} // namespace - -Model::Model(const char *file_name, unsigned int ngram_length) : sri_(MakeSRIModel(file_name, ngram_length, *vocab_.sri_)) { - if (!sri_->setorder()) { - UTIL_THROW(FormatLoadException, "Can't have an SRI model with order 0."); - } - vocab_.FinishedLoading(); - State begin_state = State(); - begin_state.valid_length_ = 1; - if (kMaxOrder > 1) { - begin_state.history_[0] = vocab_.BeginSentence(); - if (kMaxOrder > 2) begin_state.history_[1] = Vocab_None; - } - State null_state = State(); - null_state.valid_length_ = 0; - if (kMaxOrder > 1) null_state.history_[0] = Vocab_None; - Init(begin_state, null_state, vocab_, sri_->setorder()); - not_found_ = vocab_.NotFound(); -} - -Model::~Model() {} - -namespace { - -/* Argh SRI's wordProb knows the ngram length but doesn't return it. One more - * reason you should use my model. */ -// TODO(stolcke): fix SRILM so I don't have to do this. -unsigned int MatchedLength(Ngram &model, const WordIndex new_word, const SRIVocabIndex *const_history) { - unsigned int out_length = 0; - // This gets the length of context used, which is ngram_length - 1 unless new_word is OOV in which case it is 0. - model.contextID(new_word, const_history, out_length); - return out_length + 1; -} - -} // namespace - -FullScoreReturn Model::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { - // If you get a compiler in this function, change SRIVocabIndex in sri.hh to match the one found in SRI's Vocab.h. - const SRIVocabIndex *const_history; - SRIVocabIndex local_history[Order()]; - if (in_state.valid_length_ < kMaxOrder - 1) { - const_history = in_state.history_; - } else { - std::copy(in_state.history_, in_state.history_ + in_state.valid_length_, local_history); - local_history[in_state.valid_length_] = Vocab_None; - const_history = local_history; - } - FullScoreReturn ret; - ret.ngram_length = MatchedLength(*sri_, new_word, const_history); - out_state.history_[0] = new_word; - out_state.valid_length_ = std::min(ret.ngram_length, Order() - 1); - std::copy(const_history, const_history + out_state.valid_length_ - 1, out_state.history_ + 1); - if (out_state.valid_length_ < kMaxOrder - 1) { - out_state.history_[out_state.valid_length_] = Vocab_None; - } - ret.prob = sri_->wordProb(new_word, const_history); - return ret; -} - -} // namespace sri -} // namespace lm diff --git a/klm/lm/sri.hh b/klm/lm/sri.hh deleted file mode 100644 index b57e9b73..00000000 --- a/klm/lm/sri.hh +++ /dev/null @@ -1,102 +0,0 @@ -#ifndef LM_SRI__ -#define LM_SRI__ - -#include "lm/facade.hh" -#include "util/murmur_hash.hh" - -#include -#include -#include - -class Ngram; -class Vocab; - -/* The ngram length reported uses some random API I found and may be wrong. - * - * See ngram, which should return equivalent results. - */ - -namespace lm { -namespace sri { - -static const unsigned int kMaxOrder = 6; - -/* This should match VocabIndex found in SRI's Vocab.h - * The reason I define this here independently is that SRI's headers - * pollute and increase compile time. - * It's difficult to extract this from their header and anyway would - * break packaging. - * If these differ there will be a compiler error in ActuallyCall. - */ -typedef unsigned int SRIVocabIndex; - -class State { - public: - // You shouldn't need to touch these, but they're public so State will be a POD. - // If valid_length_ < kMaxOrder - 1 then history_[valid_length_] == Vocab_None. - SRIVocabIndex history_[kMaxOrder - 1]; - unsigned char valid_length_; -}; - -inline bool operator==(const State &left, const State &right) { - if (left.valid_length_ != right.valid_length_) { - return false; - } - for (const SRIVocabIndex *l = left.history_, *r = right.history_; - l != left.history_ + left.valid_length_; - ++l, ++r) { - if (*l != *r) return false; - } - return true; -} - -inline size_t hash_value(const State &state) { - return util::MurmurHashNative(&state.history_, sizeof(SRIVocabIndex) * state.valid_length_); -} - -class Vocabulary : public base::Vocabulary { - public: - Vocabulary(); - - ~Vocabulary(); - - WordIndex Index(const StringPiece &str) const { - std::string temp(str.data(), str.length()); - return Index(temp.c_str()); - } - WordIndex Index(const std::string &str) const { - return Index(str.c_str()); - } - WordIndex Index(const char *str) const; - - const char *Word(WordIndex index) const; - - private: - friend class Model; - void FinishedLoading(); - - // The parent class isn't copyable so auto_ptr is the same as scoped_ptr - // but without the boost dependence. - mutable std::auto_ptr sri_; -}; - -class Model : public base::ModelFacade { - public: - Model(const char *file_name, unsigned int ngram_length); - - ~Model(); - - FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; - - private: - Vocabulary vocab_; - - mutable std::auto_ptr sri_; - - WordIndex not_found_; -}; - -} // namespace sri -} // namespace lm - -#endif // LM_SRI__ diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 41e97052..4cf68196 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -15,10 +15,10 @@ namespace lm { class ProbBackoff; +class EnumerateVocab; namespace ngram { class Config; -class EnumerateVocab; namespace detail { uint64_t HashForVocab(const char *str, std::size_t len); diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index 5ce7adc9..279bafa8 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -15,7 +15,7 @@ namespace util { scoped_mmap::~scoped_mmap() { if (data_ != (void*)-1) { - // Thanks Denis Filimonov for pointing on NFS likes msync first. + // Thanks Denis Filimonov for pointing out NFS likes msync first. if (msync(data_, size_, MS_SYNC) || munmap(data_, size_)) { std::cerr << "msync or mmap failed for " << size_ << " bytes." << std::endl; abort(); diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index fec47fd9..ef5783fe 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -117,13 +117,18 @@ uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) return h; } +// Trick to test for 64-bit architecture at compile time. +namespace { +template uint64_t MurmurHashNativeBackend(const void * key, std::size_t len, unsigned int seed) { + return MurmurHash64A(key, len, seed); +} +template <> uint64_t MurmurHashNativeBackend<4>(const void * key, std::size_t len, unsigned int seed) { + return MurmurHash64B(key, len, seed); +} +} // namespace uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { - if (sizeof(int) == 4) { - return MurmurHash64B(key, len, seed); - } else { - return MurmurHash64A(key, len, seed); - } + return MurmurHashNativeBackend(key, len, seed); } } // namespace util diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index ee1c7ab2..413bda0b 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -5,6 +5,9 @@ #include +#include +#include + /* Usage: * * for (PieceIterator<' '> i(" foo \r\n bar "); i; ++i) { @@ -64,6 +67,78 @@ template class PieceIterator : public boost::iterator_facade class TokenIter : public boost::iterator_facade, const StringPiece, boost::forward_traversal_tag> { + public: + TokenIter() {} + + TokenIter(const StringPiece &str, const Find &finder) : after_(str), finder_(finder) { + increment(); + } + + bool operator!() const { + return current_.data() == 0; + } + operator bool() const { + return current_.data() != 0; + } + + static TokenIter end() { + return TokenIter(); + } + + private: + friend class boost::iterator_core_access; + + void increment() { + do { + StringPiece found(finder_.Find(after_)); + current_ = StringPiece(after_.data(), found.data() - after_.data()); + if (found.data() == after_.data() + after_.size()) { + after_ = StringPiece(NULL, 0); + } else { + after_ = StringPiece(found.data() + found.size(), after_.data() - found.data() + after_.size() - found.size()); + } + } while (SkipEmpty && current_.data() && current_.empty()); // Compiler should optimize this away if SkipEmpty is false. + } + + bool equal(const TokenIter &other) const { + return after_.data() == other.after_.data(); + } + + const StringPiece &dereference() const { + return current_; + } + + StringPiece current_; + StringPiece after_; + + Find finder_; +}; + } // namespace util #endif // UTIL_TOKENIZE_PIECE__ diff --git a/klm/util/tokenize_piece_test.cc b/klm/util/tokenize_piece_test.cc new file mode 100644 index 00000000..e07ebcf5 --- /dev/null +++ b/klm/util/tokenize_piece_test.cc @@ -0,0 +1,94 @@ +#include "util/tokenize_piece.hh" +#include "util/string_piece.hh" + +#define BOOST_TEST_MODULE TokenIteratorTest +#include + +#include + +namespace util { +namespace { + +BOOST_AUTO_TEST_CASE(simple) { + PieceIterator<' '> it("single spaced words."); + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("single"), *it); + ++it; + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("spaced"), *it); + ++it; + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("words."), *it); + ++it; + BOOST_CHECK(!it); +} + +BOOST_AUTO_TEST_CASE(null_delimiter) { + const char str[] = "\0first\0\0second\0\0\0third\0fourth\0\0\0"; + PieceIterator<'\0'> it(StringPiece(str, sizeof(str) - 1)); + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("first"), *it); + ++it; + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("second"), *it); + ++it; + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("third"), *it); + ++it; + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("fourth"), *it); + ++it; + BOOST_CHECK(!it); +} + +BOOST_AUTO_TEST_CASE(null_entries) { + const char str[] = "\0split\0\0 \0me\0 "; + PieceIterator<' '> it(StringPiece(str, sizeof(str) - 1)); + BOOST_REQUIRE(it); + const char first[] = "\0split\0\0"; + BOOST_CHECK_EQUAL(StringPiece(first, sizeof(first) - 1), *it); + ++it; + BOOST_REQUIRE(it); + const char second[] = "\0me\0"; + BOOST_CHECK_EQUAL(StringPiece(second, sizeof(second) - 1), *it); + ++it; + BOOST_CHECK(!it); +} + +/*BOOST_AUTO_TEST_CASE(pipe_pipe_none) { + const char str[] = "nodelimit at all"; + TokenIter it(str, MultiCharacter("|||")); + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece(str), *it); + ++it; + BOOST_CHECK(!it); +} +BOOST_AUTO_TEST_CASE(pipe_pipe_two) { + const char str[] = "|||"; + TokenIter it(str, MultiCharacter("|||")); + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece(), *it); + ++it; + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece(), *it); + ++it; + BOOST_CHECK(!it); +} + +BOOST_AUTO_TEST_CASE(remove_empty) { + const char str[] = "|||"; + TokenIter it(str, MultiCharacter("|||")); + BOOST_CHECK(!it); +}*/ + +BOOST_AUTO_TEST_CASE(remove_empty_keep) { + const char str[] = " |||"; + TokenIter it(str, MultiCharacter("|||")); + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece(" "), *it); + ++it; + BOOST_CHECK(!it); +} + +} // namespace +} // namespace util diff --git a/training/augment_grammar.cc b/training/augment_grammar.cc index e89a92d5..1e5af9a1 100644 --- a/training/augment_grammar.cc +++ b/training/augment_grammar.cc @@ -18,7 +18,7 @@ using namespace std; vector word_map; lm::ngram::ProbingModel* ngram; -struct VMapper : public lm::ngram::EnumerateVocab { +struct VMapper : public lm::EnumerateVocab { VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } void Add(lm::WordIndex index, const StringPiece &str) { const WordID cdec_id = TD::Convert(str.as_string()); diff --git a/training/test_ngram.cc b/training/test_ngram.cc index c481b564..4597cc01 100644 --- a/training/test_ngram.cc +++ b/training/test_ngram.cc @@ -12,7 +12,7 @@ namespace po = boost::program_options; using namespace std; lm::ngram::ProbingModel* ngram; -struct GetVocab : public lm::ngram::EnumerateVocab { +struct GetVocab : public lm::EnumerateVocab { GetVocab(vector* out) : out_(out) { } void Add(lm::WordIndex index, const StringPiece &str) { out_->push_back(index); -- cgit v1.2.3 From b036e03e9db226fde7e6b0e69d86bdb5741f8006 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 3 Nov 2011 19:53:53 +0000 Subject: Bugfix trie building --- klm/lm/search_trie.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index e3cf9547..633bcdf4 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -24,10 +24,8 @@ #include #include #include +#include "util/portability.hh" -#include -#include -#include namespace lm { namespace ngram { @@ -271,7 +269,7 @@ template class WriteEntries { contexts_(contexts), unigrams_(unigrams), middle_(middle), - longest_(longest), + longest_(longest), bigram_pack_((order == 2) ? static_cast(longest_) : static_cast(*middle_)), order_(order), sri_(sri) {} @@ -334,6 +332,7 @@ template class BlankManager { void Visit(const WordIndex *to, unsigned char length, float prob) { basis_[length - 1] = prob; + // Try to match everything except the last word, which is expected to be different. unsigned char overlap = std::min(length - 1, been_length_); const WordIndex *cur; WordIndex *pre; @@ -350,14 +349,15 @@ template class BlankManager { UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context."); const float *lower_basis; for (lower_basis = basis_ + blank - 2; *lower_basis == kBadProb; --lower_basis) {} + assert(*lower_basis != kBadProb); unsigned char based_on = lower_basis - basis_ + 1; for (; cur != to + length - 1; ++blank, ++cur, ++pre) { - assert(*lower_basis != kBadProb); doing_.MiddleBlank(blank, to, based_on, *lower_basis); *pre = *cur; // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. basis_[blank - 1] = kBadProb; } + *pre = *cur; been_length_ = length; } -- cgit v1.2.3 From 635a8d31de50b5514cb471cb79bbe2cd3f23b0b5 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 3 Nov 2011 19:58:37 +0000 Subject: Oops introduced some of Hieu's windows stuff --- klm/lm/search_trie.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 633bcdf4..4bd3f4ee 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -24,8 +24,10 @@ #include #include #include -#include "util/portability.hh" +#include +#include +#include namespace lm { namespace ngram { @@ -269,7 +271,7 @@ template class WriteEntries { contexts_(contexts), unigrams_(unigrams), middle_(middle), - longest_(longest), + longest_(longest), bigram_pack_((order == 2) ? static_cast(longest_) : static_cast(*middle_)), order_(order), sri_(sri) {} @@ -332,7 +334,6 @@ template class BlankManager { void Visit(const WordIndex *to, unsigned char length, float prob) { basis_[length - 1] = prob; - // Try to match everything except the last word, which is expected to be different. unsigned char overlap = std::min(length - 1, been_length_); const WordIndex *cur; WordIndex *pre; @@ -349,9 +350,9 @@ template class BlankManager { UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context."); const float *lower_basis; for (lower_basis = basis_ + blank - 2; *lower_basis == kBadProb; --lower_basis) {} - assert(*lower_basis != kBadProb); unsigned char based_on = lower_basis - basis_ + 1; for (; cur != to + length - 1; ++blank, ++cur, ++pre) { + assert(*lower_basis != kBadProb); doing_.MiddleBlank(blank, to, based_on, *lower_basis); *pre = *cur; // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. -- cgit v1.2.3 From bdd7fe7b513ade0b979fc050766e375044e84e86 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 3 Nov 2011 20:08:43 +0000 Subject: Mostly minor changes like a missing header guard and bad documentation --- klm/lm/bhiksha.hh | 5 +++++ klm/lm/build_binary.cc | 2 +- klm/lm/left.hh | 39 ++++++++++++++++++++++++--------------- klm/lm/vocab.cc | 1 + klm/lm/vocab.hh | 1 - klm/util/probing_hash_table.hh | 4 ++-- 6 files changed, 33 insertions(+), 19 deletions(-) diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index bc705959..3df43dda 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -10,6 +10,9 @@ * Currently only used for next pointers. */ +#ifndef LM_BHIKSHA__ +#define LM_BHIKSHA__ + #include #include @@ -108,3 +111,5 @@ class ArrayBhiksha { } // namespace trie } // namespace ngram } // namespace lm + +#endif // LM_BHIKSHA__ diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index b7aee4de..fdb62a71 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,7 +15,7 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" "-u sets the log10 probability for if the ARPA file does not have one.\n" " Default is -100. The ARPA file will always take precedence.\n" "-s allows models to be built even if they do not have and .\n" diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 15464c82..41f71f84 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -175,24 +175,14 @@ template class RuleScore { float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1]; float *back = backoffs, *back2 = backoffs2; - unsigned char next_use; + unsigned char next_use = out_.right.length; // First word - ProcessRet(model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use)); - if (!next_use) { - left_done_ = true; - out_.right = in.right; - return; - } + if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return; + // Words after the first, so extending a bigram to begin with - unsigned char extend_length = 2; - for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) { - ProcessRet(model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use)); - if (!next_use) { - left_done_ = true; - out_.right = in.right; - return; - } + for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { + if (ExtendLeft(in, next_use, extend_length, back, back2)) return; std::swap(back, back2); } @@ -228,6 +218,25 @@ template class RuleScore { } private: + bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { + ProcessRet(model_.ExtendLeft( + out_.right.words, out_.right.words + next_use, // Words to extend into + back_in, // Backoffs to use + in.left.pointers[extend_length - 1], extend_length, // Words to be extended + back_out, // Backoffs for the next score + next_use)); // Length of n-gram to use in next scoring. + if (next_use != out_.right.length) { + left_done_ = true; + if (!next_use) { + out_.right = in.right; + // Early exit. + return true; + } + } + // Continue scoring. + return false; + } + void ProcessRet(const FullScoreReturn &ret) { prob_ += ret.prob; if (left_done_) return; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 03b0767a..ffec41ca 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -135,6 +135,7 @@ void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { end_ = begin_ + *(reinterpret_cast(begin_) - 1); ReadWords(fd, to); SetSpecial(Index(""), Index(""), 0); + bound_ = end_ - begin_ + 1; } namespace { diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 4cf68196..3c3414fb 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -66,7 +66,6 @@ class SortedVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. - // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases. WordIndex Bound() const { return bound_; } // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 2ec342a6..8122d69c 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -61,14 +61,14 @@ template void Insert(const T &t) { + template MutableIterator Insert(const T &t) { if (++entries_ >= buckets_) UTIL_THROW(ProbingSizeException, "Hash table with " << buckets_ << " buckets is full."); #ifdef DEBUG assert(initialized_); #endif for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) { - if (equal_(i->GetKey(), invalid_)) { *i = t; return; } + if (equal_(i->GetKey(), invalid_)) { *i = t; return i; } if (++i == end_) { i = begin_; } } } -- cgit v1.2.3 From bcda3258ab35cba2f71e28e1c93863958f5aca8b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 7 Nov 2011 18:09:47 -0500 Subject: updates to pro to support regularization to previous weight vectors, regualarization normalization, disable broken regularization tuning --- pro-train/dist-pro.pl | 22 +++++++++++-- pro-train/mr_pro_reduce.cc | 82 ++++++++++++++++++++++++++-------------------- 2 files changed, 66 insertions(+), 38 deletions(-) diff --git a/pro-train/dist-pro.pl b/pro-train/dist-pro.pl index dbfa329a..4bc9cfe3 100755 --- a/pro-train/dist-pro.pl +++ b/pro-train/dist-pro.pl @@ -41,6 +41,7 @@ my $lines_per_mapper = 30; my $iteration = 1; my $run_local = 0; my $best_weights; +my $psi = 1; my $max_iterations = 30; my $decode_nodes = 15; # number of decode nodes my $pmem = "4g"; @@ -62,7 +63,8 @@ my $cpbin=1; # regularization strength my $tune_regularizer = 0; -my $reg = 1e-2; +my $reg = 10; +my $reg_previous = 0; # Process command-line options Getopt::Long::Configure("no_auto_abbrev"); @@ -73,10 +75,12 @@ if (GetOptions( "use-fork" => \$usefork, "dry-run" => \$dryrun, "epsilon=s" => \$epsilon, + "interpolate-with-weights=f" => \$psi, "help" => \$help, "weights=s" => \$initial_weights, "tune-regularizer" => \$tune_regularizer, "reg=f" => \$reg, + "reg-previous=f" => \$reg_previous, "local" => \$run_local, "use-make=i" => \$use_make, "max-iterations=i" => \$max_iterations, @@ -91,6 +95,8 @@ if (GetOptions( exit; } +die "--tune-regularizer is no longer supported with --reg-previous and --reg. Please tune manually.\n" if $tune_regularizer; + if ($usefork) { $usefork = "--use-fork"; } else { $usefork = ''; } if ($metric =~ /^(combi|ter)$/i) { @@ -411,7 +417,7 @@ while (1){ } print STDERR "\nRUNNING CLASSIFIER (REDUCER)\n"; print STDERR unchecked_output("date"); - $cmd="cat @dev_outs | $REDUCER -w $dir/weights.$im1 -s $reg"; + $cmd="cat @dev_outs | $REDUCER -w $dir/weights.$im1 -C $reg -y $reg_previous --interpolate_with_weights $psi"; if ($tune_regularizer) { $cmd .= " -T -t $dev_test_file"; } @@ -605,11 +611,21 @@ General options: Regularization options: + --interpolate-with-weights + [deprecated] At each iteration the resulting weights are + interpolated with the weights from the previous iteration, with + this factor. + --tune-regularizer Hold out one third of the tuning data and used this to tune the - regularization parameter. + regularization parameter. [this doesn't work well] --reg + l2 regularization strength + + --reg-previous + l2 penalty for moving away from the weights from the previous + iteration. Help } diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index aff410a0..98cddba2 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -23,11 +23,12 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("weights,w", po::value(), "Weights from previous iteration (used as initialization and interpolation") - ("interpolation,p",po::value()->default_value(0.9), "Output weights are p*w + (1-p)*w_prev") + ("regularize_to_weights,y",po::value()->default_value(0.0), "Differences in learned weights to previous weights are penalized with an l2 penalty with this strength; 0.0 = no effect") + ("interpolate_with_weights,p",po::value()->default_value(1.0), "Output weights are p*w + (1-p)*w_prev; 1.0 = no effect") ("memory_buffers,m",po::value()->default_value(200), "Number of memory buffers (LBFGS)") - ("sigma_squared,s",po::value()->default_value(0.1), "Sigma squared for Gaussian prior") - ("min_reg,r",po::value()->default_value(1e-8), "When tuning (-T) regularization strength, minimum regularization strenght") - ("max_reg,R",po::value()->default_value(10.0), "When tuning (-T) regularization strength, maximum regularization strenght") + ("regularization_strength,C",po::value()->default_value(1.0), "l2 regularization strength") + ("min_reg,r",po::value()->default_value(0.01), "When tuning (-T) regularization strength, minimum regularization strenght") + ("max_reg,R",po::value()->default_value(1e6), "When tuning (-T) regularization strength, maximum regularization strenght") ("testset,t",po::value(), "Optional held-out test set") ("tune_regularizer,T", "Use the held out test set (-t) to tune the regularization strength") ("help,h", "Help"); @@ -95,6 +96,27 @@ void GradAdd(const SparseVector& v, const double scale, vector& weights, + const vector& prev_weights, + vector* g) { + assert(weights.size() == g->size()); + double reg = 0; + for (size_t i = 0; i < weights.size(); ++i) { + const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0); + const double& w_i = weights[i]; + double& g_i = (*g)[i]; + reg += C * w_i * w_i; + g_i += 2 * C * w_i; + + const double diff_i = w_i - prev_w_i; + reg += T * diff_i * diff_i; + g_i += 2 * T * diff_i; + } + return reg; +} + double TrainingInference(const vector& x, const vector > >& corpus, vector* g = NULL) { @@ -134,8 +156,10 @@ double TrainingInference(const vector& x, // return held-out log likelihood double LearnParameters(const vector > >& training, const vector > >& testing, - const double sigsq, + const double C, + const double T, const unsigned memory_buffers, + const vector& prev_x, vector* px) { vector& x = *px; vector vg(FD::NumFeats(), 0.0); @@ -157,26 +181,12 @@ double LearnParameters(const vector > >& train } // handle regularizer -#if 1 - double norm = 0; - for (int i = 1; i < x.size(); ++i) { - const double mean_i = 0.0; - const double param = (x[i] - mean_i); - norm += param * param; - vg[i] += param / sigsq; - } - const double reg = norm / (2.0 * sigsq); -#else - double reg = 0; -#endif + double reg = ApplyRegularizationTerms(C, T, x, prev_x, &vg); cll += reg; - cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t TEST_PPL=" << tppl << "\t"; + cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t TEST_PPL=" << tppl << "\t" << endl; try { - vector old_x = x; - do { - opt.Optimize(cll, vg, &x); - converged = opt.HasConverged(); - } while (!converged && x == old_x); + opt.Optimize(cll, vg, &x); + converged = opt.HasConverged(); } catch (...) { cerr << "Exception caught, assuming convergence is close enough...\n"; converged = true; @@ -201,13 +211,14 @@ int main(int argc, char** argv) { } const double min_reg = conf["min_reg"].as(); const double max_reg = conf["max_reg"].as(); - double sigsq = conf["sigma_squared"].as(); // will be overridden if parameter is tuned - assert(sigsq > 0.0); + double C = conf["regularization_strength"].as(); // will be overridden if parameter is tuned + const double T = conf["regularize_to_weights"].as(); + assert(C > 0.0); assert(min_reg > 0.0); assert(max_reg > 0.0); assert(max_reg > min_reg); - const double psi = conf["interpolation"].as(); - if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; } + const double psi = conf["interpolate_with_weights"].as(); + if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; return 1; } ReadCorpus(&cin, &training); if (conf.count("testset")) { ReadFile rf(conf["testset"].as()); @@ -231,14 +242,15 @@ int main(int argc, char** argv) { vector > sp; vector smoothed; if (tune_regularizer) { - sigsq = min_reg; + C = min_reg; const double steps = 18; double sweep_factor = exp((log(max_reg) - log(min_reg)) / steps); cerr << "SWEEP FACTOR: " << sweep_factor << endl; - while(sigsq < max_reg) { - tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as(), &x); - sp.push_back(make_pair(sigsq, tppl)); - sigsq *= sweep_factor; + while(C < max_reg) { + cerr << "C=" << C << "\tT=" <(), prev_x, &x); + sp.push_back(make_pair(C, tppl)); + C *= sweep_factor; } smoothed.resize(sp.size(), 0); smoothed[0] = sp[0].second; @@ -257,16 +269,16 @@ int main(int argc, char** argv) { best_i = i; } } - sigsq = sp[best_i].first; + C = sp[best_i].first; } // tune regularizer - tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as(), &x); + tppl = LearnParameters(training, testing, C, T, conf["memory_buffers"].as(), prev_x, &x); if (conf.count("weights")) { for (int i = 1; i < x.size(); ++i) { x[i] = (x[i] * psi) + prev_x[i] * (1.0 - psi); } } cout.precision(15); - cout << "# sigma^2=" << sigsq << "\theld out perplexity="; + cout << "# C=" << C << "\theld out perplexity="; if (tppl) { cout << tppl << endl; } else { cout << "N/A\n"; } if (sp.size()) { cout << "# Parameter sweep:\n"; -- cgit v1.2.3 From 8b13335bbd877d337320efa0aa549f97cf303ae5 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 10 Nov 2011 20:19:09 +0000 Subject: better defaults for pro --- pro-train/dist-pro.pl | 45 +++++++++++++++++++++------------------------ pro-train/mr_pro_reduce.cc | 8 ++++---- vest/parallelize.pl | 2 +- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/pro-train/dist-pro.pl b/pro-train/dist-pro.pl index 4bc9cfe3..6332563f 100755 --- a/pro-train/dist-pro.pl +++ b/pro-train/dist-pro.pl @@ -63,8 +63,8 @@ my $cpbin=1; # regularization strength my $tune_regularizer = 0; -my $reg = 10; -my $reg_previous = 0; +my $reg = 500; +my $reg_previous = 5000; # Process command-line options Getopt::Long::Configure("no_auto_abbrev"); @@ -547,16 +547,12 @@ sub enseg { sub print_help { my $executable = check_output("basename $0"); chomp $executable; - print << "Help"; + print << "Help"; Usage: $executable [options] $executable [options] - Runs a complete MERT optimization and test set decoding, using - the decoder configuration in ini file. Note that many of the - options have default values that are inferred automatically - based on certain conventions. For details, refer to descriptions - of the options --decoder, --weights, and --workdir. + Runs a complete PRO optimization using the ini file specified. Required: @@ -576,6 +572,10 @@ General options: --local Run the decoder and optimizer locally with a single thread. + --use-make + Use make -j to run the optimizer commands (useful on large + shared-memory machines where qsub is unavailable). + --decode-nodes Number of decoder processes to run in parallel. [default=15] @@ -584,7 +584,7 @@ General options: --max-iterations Maximum number of iterations to run. If not specified, defaults - to 10. + to 30. --metric Metric to optimize. @@ -597,10 +597,6 @@ General options: --pmem Amount of physical memory requested for parallel decoding jobs. - --use-make - Use make -j to run the optimizer commands (useful on large - shared-memory machines where qsub is unavailable). - --workdir Directory for intermediate and output files. If not specified, the name is derived from the ini filename. Assuming that the ini @@ -611,21 +607,22 @@ General options: Regularization options: - --interpolate-with-weights - [deprecated] At each iteration the resulting weights are - interpolated with the weights from the previous iteration, with - this factor. - - --tune-regularizer - Hold out one third of the tuning data and used this to tune the - regularization parameter. [this doesn't work well] - --reg - l2 regularization strength + l2 regularization strength [default=500]. The greater this value, + the closer to zero the weights will be. --reg-previous l2 penalty for moving away from the weights from the previous - iteration. + iteration. [default=5000]. The greater this value, the closer + to the previous iteration's weights the next iteration's weights + will be. + +Deprecated options: + + --interpolate-with-weights + [deprecated] At each iteration the resulting weights are + interpolated with the weights from the previous iteration, with + this factor. [default=1.0, i.e., no effect] Help } diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 98cddba2..6362ce47 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -23,14 +23,14 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("weights,w", po::value(), "Weights from previous iteration (used as initialization and interpolation") - ("regularize_to_weights,y",po::value()->default_value(0.0), "Differences in learned weights to previous weights are penalized with an l2 penalty with this strength; 0.0 = no effect") - ("interpolate_with_weights,p",po::value()->default_value(1.0), "Output weights are p*w + (1-p)*w_prev; 1.0 = no effect") - ("memory_buffers,m",po::value()->default_value(200), "Number of memory buffers (LBFGS)") - ("regularization_strength,C",po::value()->default_value(1.0), "l2 regularization strength") + ("regularization_strength,C",po::value()->default_value(500.0), "l2 regularization strength") + ("regularize_to_weights,y",po::value()->default_value(5000.0), "Differences in learned weights to previous weights are penalized with an l2 penalty with this strength; 0.0 = no effect") + ("memory_buffers,m",po::value()->default_value(100), "Number of memory buffers (LBFGS)") ("min_reg,r",po::value()->default_value(0.01), "When tuning (-T) regularization strength, minimum regularization strenght") ("max_reg,R",po::value()->default_value(1e6), "When tuning (-T) regularization strength, maximum regularization strenght") ("testset,t",po::value(), "Optional held-out test set") ("tune_regularizer,T", "Use the held out test set (-t) to tune the regularization strength") + ("interpolate_with_weights,p",po::value()->default_value(1.0), "[deprecated] Output weights are p*w + (1-p)*w_prev; 1.0 = no effect") ("help,h", "Help"); po::options_description dcmdline_options; dcmdline_options.add(opts); diff --git a/vest/parallelize.pl b/vest/parallelize.pl index 869f430b..7d0365cc 100755 --- a/vest/parallelize.pl +++ b/vest/parallelize.pl @@ -396,7 +396,7 @@ usage: $name [options] options: - --fork + --use-fork Instead of using qsub, use fork. -e, --error-dir -- cgit v1.2.3 From b4fd470d2cb80b0c88d4210f7e5bb10d2aa4531d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 11 Nov 2011 14:13:26 -0500 Subject: new improved distributed operation for PRO, MERT --- environment/LocalConfig.pm | 43 ++++++++++++++++----- pro-train/dist-pro.pl | 77 +++++++++++++++++++------------------ vest/dist-vest.pl | 96 +++++++++++++++++++--------------------------- 3 files changed, 112 insertions(+), 104 deletions(-) diff --git a/environment/LocalConfig.pm b/environment/LocalConfig.pm index 4e5e0d5f..831a3a43 100644 --- a/environment/LocalConfig.pm +++ b/environment/LocalConfig.pm @@ -4,7 +4,7 @@ use strict; use warnings; use base 'Exporter'; -our @EXPORT = qw( qsub_args mert_memory environment_name ); +our @EXPORT = qw( qsub_args mert_memory environment_name env_default_jobs has_qsub ); use Net::Domain qw(hostname hostfqdn hostdomain domainname); @@ -14,47 +14,62 @@ my $host = domainname; my $CCONFIG = { 'StarCluster' => { 'HOST_REGEXP' => qr/compute-\d+\.internal$/, + 'JobControl' => 'qsub', 'QSubMemFlag' => '-l mem', + 'DefaultJobs' => 20, }, 'LTICluster' => { 'HOST_REGEXP' => qr/^cluster\d+\.lti\.cs\.cmu\.edu$/, + 'JobControl' => 'qsub', 'QSubMemFlag' => '-l h_vmem=', 'QSubExtraFlags' => '-l walltime=0:45:00', + 'DefaultJobs' => 15, #'QSubQueue' => '-q long', }, 'UMIACS' => { 'HOST_REGEXP' => qr/^d.*\.umiacs\.umd\.edu$/, + 'JobControl' => 'qsub', 'QSubMemFlag' => '-l pmem=', 'QSubQueue' => '-q batch', 'QSubExtraFlags' => '-l walltime=144:00:00', + 'DefaultJobs' => 15, }, 'CLSP' => { 'HOST_REGEXP' => qr/\.clsp\.jhu\.edu$/, + 'JobControl' => 'qsub', 'QSubMemFlag' => '-l mem_free=', 'MERTMem' => '9G', + 'DefaultJobs' => 15, }, 'Valhalla' => { 'HOST_REGEXP' => qr/^(thor|tyr)\.inf\.ed\.ac\.uk$/, + 'JobControl' => 'fork', + 'DefaultJobs' => 8, }, 'Blacklight' => { 'HOST_REGEXP' => qr/^(tg-login1.blacklight.psc.teragrid.org|blacklight.psc.edu|bl1.psc.teragrid.org|bl0.psc.teragrid.org)$/, - 'QSubMemFlag' => '-l pmem=', + 'JobControl' => 'fork', + 'DefaultJobs' => 32, }, 'Barrow/Chicago' => { 'HOST_REGEXP' => qr/^(barrow|chicago).lti.cs.cmu.edu$/, - 'QSubMemFlag' => '-l pmem=', + 'JobControl' => 'fork', + 'DefaultJobs' => 8, }, 'OxfordDeathSnakes' => { 'HOST_REGEXP' => qr/^(taipan|tiger).cs.ox.ac.uk$/, - 'QSubMemFlag' => '-l pmem=', + 'JobControl' => 'fork', + 'DefaultJobs' => 12, }, - 'LOCAL' => { - 'HOST_REGEXP' => qr/local\./, + 'LOCAL' => { # LOCAL must be last in the list!!! + 'HOST_REGEXP' => qr//, 'QSubMemFlag' => ' ', + 'JobControl' => 'fork', + 'DefaultJobs' => 2, }, }; -our $senvironment_name; +our $senvironment_name = 'LOCAL'; for my $config_key (keys %$CCONFIG) { my $re = $CCONFIG->{$config_key}->{'HOST_REGEXP'}; die "Can't find HOST_REGEXP for $config_key" unless $re; @@ -63,15 +78,23 @@ for my $config_key (keys %$CCONFIG) { } } -die "NO ENVIRONMENT INFO FOR HOST: $host\nPLEASE EDIT LocalConfig.pm\n" unless $senvironment_name; - our %CONFIG = %{$CCONFIG->{$senvironment_name}}; -print STDERR "**Environment: $senvironment_name\n"; +print STDERR "**Environment: $senvironment_name"; +print STDERR " (has qsub)" if has_qsub(); +print STDERR "\n"; + +sub has_qsub { + return ($CONFIG{'JobControl'} eq 'qsub'); +} sub environment_name { return $senvironment_name; } +sub env_default_jobs { + return 1 * $CONFIG{'DefaultJobs'}; +} + sub qsub_args { my $mem = shift @_; die "qsub_args requires a memory amount as a parameter, e.g. 4G" unless $mem; diff --git a/pro-train/dist-pro.pl b/pro-train/dist-pro.pl index 6332563f..5db053de 100755 --- a/pro-train/dist-pro.pl +++ b/pro-train/dist-pro.pl @@ -10,6 +10,7 @@ use Getopt::Long; use IPC::Open2; use POSIX ":sys_wait_h"; my $QSUB_CMD = qsub_args(mert_memory()); +my $default_jobs = env_default_jobs(); my $VEST_DIR="$SCRIPT_DIR/../vest"; require "$VEST_DIR/libcall.pl"; @@ -39,11 +40,11 @@ die "Can't find $libcall" unless -e $libcall; my $decoder = $cdec; my $lines_per_mapper = 30; my $iteration = 1; -my $run_local = 0; my $best_weights; my $psi = 1; -my $max_iterations = 30; -my $decode_nodes = 15; # number of decode nodes +my $default_max_iter = 30; +my $max_iterations = $default_max_iter; +my $jobs = $default_jobs; # number of decode nodes my $pmem = "4g"; my $disable_clean = 0; my %seen_weights; @@ -55,8 +56,8 @@ my $metric = "ibm_bleu"; my $dir; my $iniFile; my $weights; -my $use_make; # use make to parallelize -my $usefork; +my $use_make = 1; # use make to parallelize +my $useqsub = 0; my $initial_weights; my $pass_suffix = ''; my $cpbin=1; @@ -69,10 +70,10 @@ my $reg_previous = 5000; # Process command-line options Getopt::Long::Configure("no_auto_abbrev"); if (GetOptions( - "decode-nodes=i" => \$decode_nodes, + "jobs=i" => \$jobs, "dont-clean" => \$disable_clean, "pass-suffix=s" => \$pass_suffix, - "use-fork" => \$usefork, + "qsub" => \$useqsub, "dry-run" => \$dryrun, "epsilon=s" => \$epsilon, "interpolate-with-weights=f" => \$psi, @@ -81,7 +82,6 @@ if (GetOptions( "tune-regularizer" => \$tune_regularizer, "reg=f" => \$reg, "reg-previous=f" => \$reg_previous, - "local" => \$run_local, "use-make=i" => \$use_make, "max-iterations=i" => \$max_iterations, "pmem=s" => \$pmem, @@ -97,7 +97,16 @@ if (GetOptions( die "--tune-regularizer is no longer supported with --reg-previous and --reg. Please tune manually.\n" if $tune_regularizer; -if ($usefork) { $usefork = "--use-fork"; } else { $usefork = ''; } +if ($useqsub) { + $use_make = 0; + die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); +} + +my @missing_args = (); +if (!defined $srcFile) { push @missing_args, "--source-file"; } +if (!defined $refFiles) { push @missing_args, "--ref-files"; } +if (!defined $initial_weights) { push @missing_args, "--weights"; } +die "Please specify missing arguments: " . join (', ', @missing_args) . "\n" if (@missing_args); if ($metric =~ /^(combi|ter)$/i) { $lines_per_mapper = 5; @@ -254,13 +263,10 @@ while (1){ `rm -f $dir/hgs/*.gz`; my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; my $pcmd; - if ($run_local) { - $pcmd = "cat $srcFile |"; - } elsif ($use_make) { - # TODO: Throw error when decode_nodes is specified along with use_make - $pcmd = "cat $srcFile | $parallelize --use-fork -p $pmem -e $logdir -j $use_make --"; + if ($use_make) { + $pcmd = "cat $srcFile | $parallelize --use-fork -p $pmem -e $logdir -j $jobs --"; } else { - $pcmd = "cat $srcFile | $parallelize $usefork -p $pmem -e $logdir -j $decode_nodes --"; + $pcmd = "cat $srcFile | $parallelize -p $pmem -e $logdir -j $jobs --"; } my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $runFile"; print STDERR "COMMAND:\n$cmd\n"; @@ -333,10 +339,7 @@ while (1){ push @mapoutputs, "$dir/splag.$im1/$mapoutput"; $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard"; my $script = "$MAPPER -s $srcFile -l $metric $refs_comma_sep -w $inweights -K $dir/kbest < $dir/splag.$im1/$shard > $dir/splag.$im1/$mapoutput"; - if ($run_local) { - print STDERR "COMMAND:\n$script\n"; - check_bash_call($script); - } elsif ($use_make) { + if ($use_make) { my $script_file = "$dir/scripts/map.$shard"; open F, ">$script_file" or die "Can't write $script_file: $!"; print F "#!/bin/bash\n"; @@ -382,12 +385,10 @@ while (1){ } else { @dev_outs = @mapoutputs; } - if ($run_local) { - print STDERR "\nCompleted extraction of training exemplars.\n"; - } elsif ($use_make) { + if ($use_make) { print $mkfile "$dir/splag.$im1/map.done: @mkouts\n\ttouch $dir/splag.$im1/map.done\n\n"; close $mkfile; - my $mcmd = "make -j $use_make -f $mkfilename"; + my $mcmd = "make -j $jobs -f $mkfilename"; print STDERR "\nExecuting: $mcmd\n"; check_call($mcmd); } else { @@ -498,7 +499,7 @@ sub write_config { print $fh "REFS (DEV): $refFiles\n"; print $fh "EVAL METRIC: $metric\n"; print $fh "MAX ITERATIONS: $max_iterations\n"; - print $fh "DECODE NODES: $decode_nodes\n"; + print $fh "JOBS: $jobs\n"; print $fh "HEAD NODE: $host\n"; print $fh "PMEM (DECODING): $pmem\n"; print $fh "CLEANUP: $cleanup\n"; @@ -569,22 +570,12 @@ Required: General options: - --local - Run the decoder and optimizer locally with a single thread. - - --use-make - Use make -j to run the optimizer commands (useful on large - shared-memory machines where qsub is unavailable). - - --decode-nodes - Number of decoder processes to run in parallel. [default=15] - --help Print this message and exit. --max-iterations Maximum number of iterations to run. If not specified, defaults - to 30. + to $default_max_iter. --metric Metric to optimize. @@ -594,9 +585,6 @@ General options: If the decoder is doing multi-pass decoding, the pass suffix "2", "3", etc., is used to control what iteration of weights is set. - --pmem - Amount of physical memory requested for parallel decoding jobs. - --workdir Directory for intermediate and output files. If not specified, the name is derived from the ini filename. Assuming that the ini @@ -617,6 +605,19 @@ Regularization options: to the previous iteration's weights the next iteration's weights will be. +Job control options: + + --jobs + Number of decoder processes to run in parallel. [default=$default_jobs] + + --qsub + Use qsub to run jobs in parallel (qsub must be configured in + environment/LocalEnvironment.pm) + + --pmem + Amount of physical memory requested for parallel decoding jobs + (used with qsub requests only) + Deprecated options: --interpolate-with-weights diff --git a/vest/dist-vest.pl b/vest/dist-vest.pl index b7a862c4..11e791c1 100755 --- a/vest/dist-vest.pl +++ b/vest/dist-vest.pl @@ -16,6 +16,7 @@ require "libcall.pl"; # Default settings my $srcFile; my $refFiles; +my $default_jobs = env_default_jobs(); my $bin_dir = $SCRIPT_DIR; die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; my $FAST_SCORE="$bin_dir/../mteval/fast_score"; @@ -39,11 +40,10 @@ my $decoder = $cdec; my $lines_per_mapper = 400; my $rand_directions = 15; my $iteration = 1; -my $run_local = 0; my $best_weights; my $max_iterations = 15; my $optimization_iters = 6; -my $decode_nodes = 15; # number of decode nodes +my $jobs = $default_jobs; # number of decode nodes my $pmem = "9g"; my $disable_clean = 0; my %seen_weights; @@ -64,28 +64,25 @@ my $maxsim=0; my $oraclen=0; my $oracleb=20; my $bleu_weight=1; -my $use_make; # use make to parallelize line search +my $use_make = 1; # use make to parallelize line search my $dirargs=''; my $density_prune; -my $usefork; +my $useqsub; my $pass_suffix = ''; my $cpbin=1; # Process command-line options Getopt::Long::Configure("no_auto_abbrev"); if (GetOptions( "decoder=s" => \$decoderOpt, - "decode-nodes=i" => \$decode_nodes, + "jobs=i" => \$jobs, "density-prune=f" => \$density_prune, "dont-clean" => \$disable_clean, "pass-suffix=s" => \$pass_suffix, - "use-fork" => \$usefork, "dry-run" => \$dryrun, "epsilon=s" => \$epsilon, "help" => \$help, "interval" => \$interval, - "iteration=i" => \$iteration, - "local" => \$run_local, - "use-make=i" => \$use_make, + "qsub" => \$useqsub, "max-iterations=i" => \$max_iterations, "normalize=s" => \$normalize, "pmem=s" => \$pmem, @@ -114,7 +111,16 @@ if (defined $density_prune) { die "--density_prune n: n must be greater than 1.0\n" unless $density_prune > 1.0; } -if ($usefork) { $usefork = "--use-fork"; } else { $usefork = ''; } +if ($useqsub) { + $use_make = 0; + die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); +} + +my @missing_args = (); +if (!defined $srcFile) { push @missing_args, "--source-file"; } +if (!defined $refFiles) { push @missing_args, "--ref-files"; } +if (!defined $initialWeights) { push @missing_args, "--weights"; } +die "Please specify missing arguments: " . join (', ', @missing_args) . "\n" if (@missing_args); if ($metric =~ /^(combi|ter)$/i) { $lines_per_mapper = 40; @@ -276,17 +282,11 @@ while (1){ my $im1 = $iteration - 1; my $weightsFile="$dir/weights.$im1"; my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; - if ($density_prune) { - $decoder_cmd .= " --density_prune $density_prune"; - } my $pcmd; - if ($run_local) { - $pcmd = "cat $srcFile |"; - } elsif ($use_make) { - # TODO: Throw error when decode_nodes is specified along with use_make - $pcmd = "cat $srcFile | $parallelize --use-fork -p $pmem -e $logdir -j $use_make --"; + if ($use_make) { + $pcmd = "cat $srcFile | $parallelize --use-fork -p $pmem -e $logdir -j $jobs --"; } else { - $pcmd = "cat $srcFile | $parallelize $usefork -p $pmem -e $logdir -j $decode_nodes --"; + $pcmd = "cat $srcFile | $parallelize -p $pmem -e $logdir -j $jobs --"; } my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $runFile"; print STDERR "COMMAND:\n$cmd\n"; @@ -365,10 +365,7 @@ while (1){ push @mapoutputs, "$dir/splag.$im1/$mapoutput"; $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard"; my $script = "$MAPPER -s $srcFile -l $metric $refs_comma_sep < $dir/splag.$im1/$shard | sort -t \$'\\t' -k 1 > $dir/splag.$im1/$mapoutput"; - if ($run_local) { - print STDERR "COMMAND:\n$script\n"; - check_bash_call($script); - } elsif ($use_make) { + if ($use_make) { my $script_file = "$dir/scripts/map.$shard"; open F, ">$script_file" or die "Can't write $script_file: $!"; print F "#!/bin/bash\n"; @@ -398,12 +395,10 @@ while (1){ else {$joblist = $joblist . "\|" . $jobid; } } } - if ($run_local) { - print STDERR "\nProcessing line search complete.\n"; - } elsif ($use_make) { + if ($use_make) { print $mkfile "$dir/splag.$im1/map.done: @mkouts\n\ttouch $dir/splag.$im1/map.done\n\n"; close $mkfile; - my $mcmd = "make -j $use_make -f $mkfilename"; + my $mcmd = "make -j $jobs -f $mkfilename"; print STDERR "\nExecuting: $mcmd\n"; check_call($mcmd); } else { @@ -558,7 +553,7 @@ sub write_config { print $fh "EVAL METRIC: $metric\n"; print $fh "START ITERATION: $iteration\n"; print $fh "MAX ITERATIONS: $max_iterations\n"; - print $fh "DECODE NODES: $decode_nodes\n"; + print $fh "PARALLEL JOBS: $jobs\n"; print $fh "HEAD NODE: $host\n"; print $fh "PMEM (DECODING): $pmem\n"; print $fh "CLEANUP: $cleanup\n"; @@ -612,37 +607,15 @@ sub print_help { Usage: $executable [options] $executable [options] - Runs a complete MERT optimization and test set decoding, using - the decoder configuration in ini file. Note that many of the - options have default values that are inferred automatically - based on certain conventions. For details, refer to descriptions - of the options --decoder, --weights, and --workdir. + Runs a complete MERT optimization using the decoder configuration + in . Required options are --weights, --source-file, and + --ref-files. Options: - --local - Run the decoder and optimizer locally with a single thread. - - --use-make - Use make -j to run the optimizer commands (useful on large - shared-memory machines where qsub is unavailable). - - --decode-nodes - Number of decoder processes to run in parallel. [default=15] - - --decoder - Decoder binary to use. - - --density-prune - Limit the density of the hypergraph on each iteration to N times - the number of edges on the Viterbi path. - --help Print this message and exit. - --iteration - Starting iteration number. If not specified, defaults to 1. - --max-iterations Maximum number of iterations to run. If not specified, defaults to 10. @@ -651,9 +624,6 @@ Options: If the decoder is doing multi-pass decoding, the pass suffix "2", "3", etc., is used to control what iteration of weights is set. - --pmem - Amount of physical memory requested for parallel decoding jobs. - --ref-files Dev set ref files. This option takes only a single string argument. To use multiple files (including file globbing), this argument should @@ -678,6 +648,7 @@ Options: A file specifying initial feature weights. The format is FeatureName_1 value1 FeatureName_2 value2 + **All and only the weights listed in will be optimized!** --workdir Directory for intermediate and output files. If not specified, the @@ -687,6 +658,19 @@ Options: the filename. E.g. an ini file named decoder.foo.ini would have a default working directory name foo. +Job control options: + + --jobs + Number of decoder processes to run in parallel. [default=$default_jobs] + + --qsub + Use qsub to run jobs in parallel (qsub must be configured in + environment/LocalEnvironment.pm) + + --pmem + Amount of physical memory requested for parallel decoding jobs + (used with qsub requests only) + Help } -- cgit v1.2.3 From 105a52a8d37497fe69a01a7de771ef9b9300cd71 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 11 Nov 2011 17:12:39 -0500 Subject: optionally sample from forest to get training instances, rather than k-best it --- decoder/Makefile.am | 1 + decoder/hg_sampler.cc | 73 +++++++++++++++++++++++++++++++++++++++++++++++ decoder/hg_sampler.h | 27 ++++++++++++++++++ mira/kbest_mira.cc | 79 ++++++++++++++++++++++++++++++++++++++------------- 4 files changed, 161 insertions(+), 19 deletions(-) create mode 100644 decoder/hg_sampler.cc create mode 100644 decoder/hg_sampler.h diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 6b9360d8..30eaf04d 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -51,6 +51,7 @@ libcdec_a_SOURCES = \ hg_io.cc \ decoder.cc \ hg_intersect.cc \ + hg_sampler.cc \ factored_lexicon_helper.cc \ viterbi.cc \ lattice.cc \ diff --git a/decoder/hg_sampler.cc b/decoder/hg_sampler.cc new file mode 100644 index 00000000..cdf0ec3c --- /dev/null +++ b/decoder/hg_sampler.cc @@ -0,0 +1,73 @@ +#include "hg_sampler.h" + +#include + +#include "viterbi.h" +#include "inside_outside.h" + +using namespace std; + +struct SampledDerivationWeightFunction { + typedef double Weight; + explicit SampledDerivationWeightFunction(const vector& sampled) : sampled_edges(sampled) {} + double operator()(const Hypergraph::Edge& e) const { + return static_cast(sampled_edges[e.id_]); + } + const vector& sampled_edges; +}; + +void HypergraphSampler::sample_hypotheses(const Hypergraph& hg, + unsigned n, + MT19937* rng, + vector* hypos) { + hypos->clear(); + hypos->resize(n); + + // compute inside probabilities + vector node_probs; + Inside(hg, &node_probs, EdgeProb()); + + vector sampled_edges(hg.edges_.size()); + queue q; + SampleSet ss; + for (unsigned i = 0; i < n; ++i) { + fill(sampled_edges.begin(), sampled_edges.end(), false); + // sample derivation top down + assert(q.empty()); + Hypothesis& hyp = (*hypos)[i]; + SparseVector& deriv_features = hyp.fmap; + q.push(hg.nodes_.size() - 1); + prob_t& model_score = hyp.model_score; + model_score = prob_t::One(); + while(!q.empty()) { + unsigned cur_node_id = q.front(); + q.pop(); + const Hypergraph::Node& node = hg.nodes_[cur_node_id]; + const unsigned num_in_edges = node.in_edges_.size(); + unsigned sampled_edge_idx = 0; + if (num_in_edges == 1) { + sampled_edge_idx = node.in_edges_[0]; + } else { + assert(num_in_edges > 1); + ss.clear(); + for (unsigned j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; + prob_t p = edge.edge_prob_; // edge weight + for (unsigned k = 0; k < edge.tail_nodes_.size(); ++k) + p *= node_probs[edge.tail_nodes_[k]]; // tail node inside weight + ss.add(p); + } + sampled_edge_idx = node.in_edges_[rng->SelectSample(ss)]; + } + sampled_edges[sampled_edge_idx] = true; + const Hypergraph::Edge& sampled_edge = hg.edges_[sampled_edge_idx]; + deriv_features += sampled_edge.feature_values_; + model_score *= sampled_edge.edge_prob_; + //sampled_deriv->push_back(sampled_edge_idx); + for (unsigned j = 0; j < sampled_edge.tail_nodes_.size(); ++j) { + q.push(sampled_edge.tail_nodes_[j]); + } + } + Viterbi(hg, &hyp.words, ESentenceTraversal(), SampledDerivationWeightFunction(sampled_edges)); + } +} diff --git a/decoder/hg_sampler.h b/decoder/hg_sampler.h new file mode 100644 index 00000000..bf4e1eb0 --- /dev/null +++ b/decoder/hg_sampler.h @@ -0,0 +1,27 @@ +#ifndef _HG_SAMPLER_H_ +#define _HG_SAMPLER_H_ + + +#include +#include "sparse_vector.h" +#include "sampler.h" +#include "wordid.h" + +class Hypergraph; + +struct HypergraphSampler { + + struct Hypothesis { + std::vector words; + SparseVector fmap; + prob_t model_score; // log unnormalized probability + }; + + static void + sample_hypotheses(const Hypergraph& hg, + unsigned n, // how many samples to draw + MT19937* rng, + std::vector* hypos); +}; + +#endif diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc index 459a5e6f..904eba74 100644 --- a/mira/kbest_mira.cc +++ b/mira/kbest_mira.cc @@ -10,6 +10,7 @@ #include #include +#include "hg_sampler.h" #include "sentence_metadata.h" #include "scorer.h" #include "verbose.h" @@ -54,6 +55,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("max_step_size,C", po::value()->default_value(0.01), "regularization strength (C)") ("mt_metric_scale,s", po::value()->default_value(1.0), "Amount to scale MT loss function by") ("k_best_size,k", po::value()->default_value(250), "Size of hypothesis list to search for oracles") + ("sample_forest,f", "Instead of a k-best list, sample k hypotheses from the decoder's forest") + ("sample_forest_unit_weight_vector,x", "Before sampling (must use -f option), rescale the weight vector used so it has unit length; this may improve the quality of the samples") ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") ("decoder_config,c",po::value(),"Decoder configuration file"); po::options_description clo("Command line options"); @@ -91,11 +94,12 @@ struct GoodBadOracle { }; struct TrainingObserver : public DecoderObserver { - TrainingObserver(const int k, const DocScorer& d, vector* o) : ds(d), oracles(*o), kbest_size(k) {} + TrainingObserver(const int k, const DocScorer& d, bool sf, vector* o) : ds(d), oracles(*o), kbest_size(k), sample_forest(sf) {} const DocScorer& ds; vector& oracles; shared_ptr cur_best; const int kbest_size; + const bool sample_forest; const HypothesisInfo& GetCurrentBestHypothesis() const { return *cur_best; @@ -116,24 +120,43 @@ struct TrainingObserver : public DecoderObserver { shared_ptr& cur_good = oracles[sent_id].good; shared_ptr& cur_bad = oracles[sent_id].bad; cur_bad.reset(); // TODO get rid of?? - KBest::KBestDerivations, ESentenceTraversal> kbest(forest, kbest_size); - for (int i = 0; i < kbest_size; ++i) { - const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); - if (!d) break; - float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); - if (invert_score) sentscore *= -1.0; - // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; - if (i == 0) - cur_best = MakeHypothesisInfo(d->feature_values, sentscore); - if (!cur_good || sentscore > cur_good->mt_metric) - cur_good = MakeHypothesisInfo(d->feature_values, sentscore); - if (!cur_bad || sentscore < cur_bad->mt_metric) - cur_bad = MakeHypothesisInfo(d->feature_values, sentscore); + + if (sample_forest) { + vector cur_prediction; + ViterbiESentence(forest, &cur_prediction); + float sentscore = ds[sent_id]->ScoreCandidate(cur_prediction)->ComputeScore(); + cur_best = MakeHypothesisInfo(ViterbiFeatures(forest), sentscore); + + vector samples; + HypergraphSampler::sample_hypotheses(forest, kbest_size, &*rng, &samples); + for (unsigned i = 0; i < samples.size(); ++i) { + sentscore = ds[sent_id]->ScoreCandidate(samples[i].words)->ComputeScore(); + if (invert_score) sentscore *= -1.0; + if (!cur_good || sentscore > cur_good->mt_metric) + cur_good = MakeHypothesisInfo(samples[i].fmap, sentscore); + if (!cur_bad || sentscore < cur_bad->mt_metric) + cur_bad = MakeHypothesisInfo(samples[i].fmap, sentscore); + } + } else { + KBest::KBestDerivations, ESentenceTraversal> kbest(forest, kbest_size); + for (int i = 0; i < kbest_size; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); + if (invert_score) sentscore *= -1.0; + // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; + if (i == 0) + cur_best = MakeHypothesisInfo(d->feature_values, sentscore); + if (!cur_good || sentscore > cur_good->mt_metric) + cur_good = MakeHypothesisInfo(d->feature_values, sentscore); + if (!cur_bad || sentscore < cur_bad->mt_metric) + cur_bad = MakeHypothesisInfo(d->feature_values, sentscore); + } + //cerr << "GOOD: " << cur_good->mt_metric << endl; + //cerr << " CUR: " << cur_best->mt_metric << endl; + //cerr << " BAD: " << cur_bad->mt_metric << endl; } - //cerr << "GOOD: " << cur_good->mt_metric << endl; - //cerr << " CUR: " << cur_best->mt_metric << endl; - //cerr << " BAD: " << cur_bad->mt_metric << endl; } }; @@ -164,6 +187,12 @@ int main(int argc, char** argv) { rng.reset(new MT19937(conf["random_seed"].as())); else rng.reset(new MT19937); + const bool sample_forest = conf.count("sample_forest") > 0; + const bool sample_forest_unit_weight_vector = conf.count("sample_forest_unit_weight_vector") > 0; + if (sample_forest_unit_weight_vector && !sample_forest) { + cerr << "Cannot --sample_forest_unit_weight_vector without --sample_forest" << endl; + return 1; + } vector corpus; ReadTrainingCorpus(conf["source"].as(), &corpus); const string metric_name = conf["mt_metric"].as(); @@ -195,7 +224,7 @@ int main(int argc, char** argv) { assert(corpus.size() > 0); vector oracles(corpus.size()); - TrainingObserver observer(conf["k_best_size"].as(), ds, &oracles); + TrainingObserver observer(conf["k_best_size"].as(), ds, sample_forest, &oracles); int cur_sent = 0; int lcount = 0; int normalizer = 0; @@ -234,7 +263,19 @@ int main(int argc, char** argv) { cerr << "PASS " << (lcount / corpus.size() + 1) << endl; } decoder.SetId(order[cur_sent]); + double sc = 1.0; + if (sample_forest_unit_weight_vector) { + sc = lambdas.l2norm(); + if (sc > 0) { + for (unsigned i = 0; i < dense_weights.size(); ++i) + dense_weights[i] /= sc; + } + } decoder.Decode(corpus[order[cur_sent]], &observer); // update oracles + if (sc && sc != 1.0) { + for (unsigned i = 0; i < dense_weights.size(); ++i) + dense_weights[i] *= sc; + } const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis(); const HypothesisInfo& cur_good = *oracles[order[cur_sent]].good; const HypothesisInfo& cur_bad = *oracles[order[cur_sent]].bad; -- cgit v1.2.3