From dc078281c2fd83db9fc4860e8b24979ab2a1e407 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 16 Aug 2012 17:02:56 -0400 Subject: KenLM update. Remove a couple of segfaults for weird input. Other oddball stuff. --- klm/lm/vocab.hh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'klm/lm/vocab.hh') diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index c3efcb4a..a25432f9 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -13,11 +13,11 @@ #include namespace lm { -class ProbBackoff; +struct ProbBackoff; class EnumerateVocab; namespace ngram { -class Config; +struct Config; namespace detail { uint64_t HashForVocab(const char *str, std::size_t len); -- cgit v1.2.3 From 8882e9ebe158aef382bb5544559ef7f2a553db62 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 11 Sep 2012 14:23:39 +0100 Subject: Update kenlm and build system --- bjam | 4 +- jam-files/fail/Jamroot | 4 + jam-files/sanity.jam | 9 ++ klm/lm/Jamfile | 11 +-- klm/lm/bhiksha.cc | 2 +- klm/lm/bhiksha.hh | 4 +- klm/lm/binary_format.cc | 4 +- klm/lm/binary_format.hh | 4 +- klm/lm/build_binary.cc | 9 +- klm/lm/max_order.hh | 2 +- klm/lm/model.cc | 21 +++-- klm/lm/model.hh | 2 +- klm/lm/partial.hh | 167 ++++++++++++++++++++++++++++++++++ klm/lm/partial_test.cc | 199 +++++++++++++++++++++++++++++++++++++++++ klm/lm/quantize.hh | 8 +- klm/lm/read_arpa.cc | 23 +++-- klm/lm/search_hashed.hh | 6 +- klm/lm/search_trie.hh | 4 +- klm/lm/state.hh | 2 + klm/lm/trie.cc | 4 +- klm/lm/trie.hh | 8 +- klm/lm/vocab.cc | 4 +- klm/lm/vocab.hh | 4 +- klm/util/Jamfile | 14 +-- klm/util/ersatz_progress.cc | 10 +-- klm/util/ersatz_progress.hh | 10 ++- klm/util/exception.cc | 3 + klm/util/exception.hh | 22 +++++ klm/util/file.cc | 7 +- klm/util/file_piece.cc | 1 - klm/util/probing_hash_table.hh | 5 +- 31 files changed, 502 insertions(+), 75 deletions(-) create mode 100644 jam-files/fail/Jamroot create mode 100644 klm/lm/partial.hh create mode 100644 klm/lm/partial_test.cc (limited to 'klm/lm/vocab.hh') diff --git a/bjam b/bjam index d1ac8a55..2b0232c8 100755 --- a/bjam +++ b/bjam @@ -4,8 +4,8 @@ if bjam="$(which bjam 2>/dev/null)" && #exists [ ${#bjam} != 0 ] && #paranoia about which printing nothing then returning true ! grep UFIHGUFIHBDJKNCFZXAEVA "${bjam}" /dev/null && #bjam in path isn't this script - "${bjam}" --help >/dev/null 2>/dev/null && #bjam in path isn't broken (i.e. has boost-build) - "${bjam}" --version |grep "Boost.Build 201" >/dev/null 2>/dev/null #It's recent enough. + "${bjam}" --sanity-test 2>/dev/null |grep Sane >/dev/null && #The test in jam-files/sanity.jam passes + (cd jam-files/fail && ! "${bjam}") >/dev/null #Returns non-zero on failure then #Delegate to system bjam exec "${bjam}" "$@" diff --git a/jam-files/fail/Jamroot b/jam-files/fail/Jamroot new file mode 100644 index 00000000..c3584d89 --- /dev/null +++ b/jam-files/fail/Jamroot @@ -0,0 +1,4 @@ +actions fail { + false +} +make fail : : fail ; diff --git a/jam-files/sanity.jam b/jam-files/sanity.jam index 8ccfc65d..086f20ae 100644 --- a/jam-files/sanity.jam +++ b/jam-files/sanity.jam @@ -15,6 +15,13 @@ rule _shell ( cmd : extras * ) { return [ trim-nl [ SHELL $(cmd) : $(extras) ] ] ; } +rule shell_or_fail ( cmd ) { + local ret = [ SHELL $(cmd) : exit-status ] ; + if $(ret[2]) != 0 { + exit $(cmd) failed : 1 ; + } +} + cxxflags = [ os.environ "CXXFLAGS" ] ; cflags = [ os.environ "CFLAGS" ] ; ldflags = [ os.environ "LDFLAGS" ] ; @@ -275,3 +282,5 @@ if [ option.get "sanity-test" : : "yes" ] { EXIT "Bad" : 1 ; } } + +use-project /top : . ; diff --git a/klm/lm/Jamfile b/klm/lm/Jamfile index b1971d88..dd620068 100644 --- a/klm/lm/Jamfile +++ b/klm/lm/Jamfile @@ -2,13 +2,14 @@ lib kenlm : bhiksha.cc binary_format.cc config.cc lm_exception.cc model.cc quant import testing ; -run left_test.cc ../util//kenutil kenlm ../..//boost_unit_test_framework : : test.arpa ; -run model_test.cc ../util//kenutil kenlm ../..//boost_unit_test_framework : : test.arpa test_nounk.arpa ; +run left_test.cc ../util//kenutil kenlm /top//boost_unit_test_framework : : test.arpa ; +run model_test.cc ../util//kenutil kenlm /top//boost_unit_test_framework : : test.arpa test_nounk.arpa ; exe query : ngram_query.cc kenlm ../util//kenutil ; exe build_binary : build_binary.cc kenlm ../util//kenutil ; +exe kenlm_max_order : max_order.cc : .. ; -install legacy : build_binary query - : $(TOP)/klm/lm EXE on shared:$(TOP)/klm/lm shared:LIB ; +alias programs : query build_binary kenlm_max_order ; -alias programs : build_binary query ; +install legacy : build_binary query kenlm_max_order + : $(TOP)/lm EXE on shared:$(TOP)/lm shared:LIB ; diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc index 870a4eee..088ea98d 100644 --- a/klm/lm/bhiksha.cc +++ b/klm/lm/bhiksha.cc @@ -50,7 +50,7 @@ std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &con } } // namespace -std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { +uint64_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */; } diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index 9734f3ab..8ff88654 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -33,7 +33,7 @@ class DontBhiksha { static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} - static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } + static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) { return util::RequiredBits(max_next); @@ -67,7 +67,7 @@ class ArrayBhiksha { static void UpdateConfigFromBinary(int fd, Config &config); - static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); + static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config); diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index a56e998e..fd841e59 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -200,10 +200,10 @@ void SeekPastHeader(int fd, const Parameters ¶ms) { util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); } -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { +uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) { const uint64_t file_size = util::SizeFile(backing.file.get()); // The header is smaller than a page, so we have to map the whole header as well. - std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size; + std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size); if (file_size != util::kBadSize && static_cast(file_size) < total_map) UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index dd795f62..bf699d5f 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -70,7 +70,7 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet void SeekPastHeader(int fd, const Parameters ¶ms); -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing); +uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing); void ComplainAboutARPA(const Config &config, ModelType model_type); @@ -90,7 +90,7 @@ template void LoadLM(const char *file, const Config &config, To &to) new_config.probing_multiplier = params.fixed.probing_multiplier; detail::SeekPastHeader(backing.file.get(), params); To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config); - std::size_t memory_size = To::Size(params.counts, new_config); + uint64_t memory_size = To::Size(params.counts, new_config); uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); to.InitializeFromBinary(start, params, new_config, backing.file.get()); } else { diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index c2ca1101..efe99899 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -8,7 +8,6 @@ #include #include -#include #ifdef WIN32 #include "util/getopt.hh" @@ -86,16 +85,16 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t sizes[6]; + uint64_t sizes[6]; sizes[0] = ProbingModel::Size(counts, config); sizes[1] = RestProbingModel::Size(counts, config); sizes[2] = TrieModel::Size(counts, config); sizes[3] = QuantTrieModel::Size(counts, config); sizes[4] = ArrayTrieModel::Size(counts, config); sizes[5] = QuantArrayTrieModel::Size(counts, config); - std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); - std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); - std::size_t divide; + uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); + uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); + uint64_t divide; char prefix; if (min_length < (1 << 10) * 10) { prefix = ' '; diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index bc8687cd..989f8324 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -8,5 +8,5 @@ #define KENLM_MAX_ORDER 6 #endif #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "Edit klm/lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif diff --git a/klm/lm/model.cc b/klm/lm/model.cc index b46333a4..40af8a63 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -12,6 +12,7 @@ #include #include #include +#include namespace lm { namespace ngram { @@ -19,17 +20,18 @@ namespace detail { template const ModelType GenericModel::kModelType = Search::kModelType; -template size_t GenericModel::Size(const std::vector &counts, const Config &config) { +template uint64_t GenericModel::Size(const std::vector &counts, const Config &config) { return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); } template void GenericModel::SetupMemory(void *base, const std::vector &counts, const Config &config) { + size_t goal_size = util::CheckOverflow(Size(counts, config)); uint8_t *start = static_cast(base); size_t allocated = VocabularyT::Size(counts[0], config); vocab_.SetupMemory(start, allocated, counts[0], config); start += allocated; start = search_.SetupMemory(start, counts, config); - if (static_cast(start - static_cast(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast(base)) << " but Size says they should take " << Size(counts, config)); + if (static_cast(start - static_cast(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast(base)) << " but Size says they should take " << goal_size); } template GenericModel::GenericModel(const char *file, const Config &config) { @@ -49,13 +51,18 @@ template GenericModel::Ge } namespace { -void CheckMaxOrder(size_t order) { - UTIL_THROW_IF(order > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << order << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE); +void CheckCounts(const std::vector &counts) { + UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE); + if (sizeof(uint64_t) > sizeof(std::size_t)) { + for (std::vector::const_iterator i = counts.begin(); i != counts.end(); ++i) { + UTIL_THROW_IF(*i > static_cast(std::numeric_limits::max()), util::OverflowException, "This model has " << *i << " " << (i - counts.begin() + 1) << "-grams which is too many for 32-bit machines."); + } + } } } // namespace template void GenericModel::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { - CheckMaxOrder(params.counts.size()); + CheckCounts(params.counts); SetupMemory(start, params.counts, config); vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab); search_.LoadedBinary(); @@ -68,11 +75,11 @@ template void GenericModel counts; // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. ReadARPACounts(f, counts); - CheckMaxOrder(counts.size()); + CheckCounts(counts); if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); - std::size_t vocab_size = VocabularyT::Size(counts[0], config); + std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 6dee9419..13ff864e 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -41,7 +41,7 @@ template class GenericModel : public base::Mod * does not include small non-mapped control structures, such as this class * itself. */ - static size_t Size(const std::vector &counts, const Config &config = Config()); + static uint64_t Size(const std::vector &counts, const Config &config = Config()); /* Load the model from a file. It may be an ARPA or binary file. Binary * files must have the format expected by this class or you'll get an diff --git a/klm/lm/partial.hh b/klm/lm/partial.hh new file mode 100644 index 00000000..1dede359 --- /dev/null +++ b/klm/lm/partial.hh @@ -0,0 +1,167 @@ +#ifndef LM_PARTIAL__ +#define LM_PARTIAL__ + +#include "lm/return.hh" +#include "lm/state.hh" + +#include + +#include + +namespace lm { +namespace ngram { + +struct ExtendReturn { + float adjust; + bool make_full; + unsigned char next_use; +}; + +template ExtendReturn ExtendLoop( + const Model &model, + unsigned char seen, const WordIndex *add_rbegin, const WordIndex *add_rend, const float *backoff_start, + const uint64_t *pointers, const uint64_t *pointers_end, + uint64_t *&pointers_write, + float *backoff_write) { + unsigned char add_length = add_rend - add_rbegin; + + float backoff_buf[2][KENLM_MAX_ORDER - 1]; + float *backoff_in = backoff_buf[0], *backoff_out = backoff_buf[1]; + std::copy(backoff_start, backoff_start + add_length, backoff_in); + + ExtendReturn value; + value.make_full = false; + value.adjust = 0.0; + value.next_use = add_length; + + unsigned char i = 0; + unsigned char length = pointers_end - pointers; + // pointers_write is NULL means that the existing left state is full, so we should use completed probabilities. + if (pointers_write) { + // Using full context, writing to new left state. + for (; i < length; ++i) { + FullScoreReturn ret(model.ExtendLeft( + add_rbegin, add_rbegin + value.next_use, + backoff_in, + pointers[i], i + seen + 1, + backoff_out, + value.next_use)); + std::swap(backoff_in, backoff_out); + if (ret.independent_left) { + value.adjust += ret.prob; + value.make_full = true; + ++i; + break; + } + value.adjust += ret.rest; + *pointers_write++ = ret.extend_left; + if (value.next_use != add_length) { + value.make_full = true; + ++i; + break; + } + } + } + // Using some of the new context. + for (; i < length && value.next_use; ++i) { + FullScoreReturn ret(model.ExtendLeft( + add_rbegin, add_rbegin + value.next_use, + backoff_in, + pointers[i], i + seen + 1, + backoff_out, + value.next_use)); + std::swap(backoff_in, backoff_out); + value.adjust += ret.prob; + } + float unrest = model.UnRest(pointers + i, pointers_end, i + seen + 1); + // Using none of the new context. + value.adjust += unrest; + + std::copy(backoff_in, backoff_in + value.next_use, backoff_write); + return value; +} + +template float RevealBefore(const Model &model, const Right &reveal, const unsigned char seen, bool reveal_full, Left &left, Right &right) { + assert(seen < reveal.length || reveal_full); + uint64_t *pointers_write = reveal_full ? NULL : left.pointers; + float backoff_buffer[KENLM_MAX_ORDER - 1]; + ExtendReturn value(ExtendLoop( + model, + seen, reveal.words + seen, reveal.words + reveal.length, reveal.backoff + seen, + left.pointers, left.pointers + left.length, + pointers_write, + left.full ? backoff_buffer : (right.backoff + right.length))); + if (reveal_full) { + left.length = 0; + value.make_full = true; + } else { + left.length = pointers_write - left.pointers; + value.make_full |= (left.length == model.Order() - 1); + } + if (left.full) { + for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i]; + } else { + // If left wasn't full when it came in, put words into right state. + std::copy(reveal.words + seen, reveal.words + seen + value.next_use, right.words + right.length); + right.length += value.next_use; + left.full = value.make_full || (right.length == model.Order() - 1); + } + return value.adjust; +} + +template float RevealAfter(const Model &model, Left &left, Right &right, const Left &reveal, unsigned char seen) { + assert(seen < reveal.length || reveal.full); + uint64_t *pointers_write = left.full ? NULL : (left.pointers + left.length); + ExtendReturn value(ExtendLoop( + model, + seen, right.words, right.words + right.length, right.backoff, + reveal.pointers + seen, reveal.pointers + reveal.length, + pointers_write, + right.backoff)); + if (reveal.full) { + for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += right.backoff[i]; + right.length = 0; + value.make_full = true; + } else { + right.length = value.next_use; + value.make_full |= (right.length == model.Order() - 1); + } + if (!left.full) { + left.length = pointers_write - left.pointers; + left.full = value.make_full || (left.length == model.Order() - 1); + } + return value.adjust; +} + +template float Subsume(const Model &model, Left &first_left, const Right &first_right, const Left &second_left, Right &second_right, const unsigned int between_length) { + assert(first_right.length < KENLM_MAX_ORDER); + assert(second_left.length < KENLM_MAX_ORDER); + assert(between_length < KENLM_MAX_ORDER - 1); + uint64_t *pointers_write = first_left.full ? NULL : (first_left.pointers + first_left.length); + float backoff_buffer[KENLM_MAX_ORDER - 1]; + ExtendReturn value(ExtendLoop( + model, + between_length, first_right.words, first_right.words + first_right.length, first_right.backoff, + second_left.pointers, second_left.pointers + second_left.length, + pointers_write, + second_left.full ? backoff_buffer : (second_right.backoff + second_right.length))); + if (second_left.full) { + for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i]; + } else { + std::copy(first_right.words, first_right.words + value.next_use, second_right.words + second_right.length); + second_right.length += value.next_use; + value.make_full |= (second_right.length == model.Order() - 1); + } + if (!first_left.full) { + first_left.length = pointers_write - first_left.pointers; + first_left.full = value.make_full || second_left.full || (first_left.length == model.Order() - 1); + } + assert(first_left.length < KENLM_MAX_ORDER); + assert(second_right.length < KENLM_MAX_ORDER); + return value.adjust; +} + +} // namespace ngram +} // namespace lm + +#endif // LM_PARTIAL__ diff --git a/klm/lm/partial_test.cc b/klm/lm/partial_test.cc new file mode 100644 index 00000000..8d309c85 --- /dev/null +++ b/klm/lm/partial_test.cc @@ -0,0 +1,199 @@ +#include "lm/partial.hh" + +#include "lm/left.hh" +#include "lm/model.hh" +#include "util/tokenize_piece.hh" + +#define BOOST_TEST_MODULE PartialTest +#include +#include + +namespace lm { +namespace ngram { +namespace { + +const char *TestLocation() { + if (boost::unit_test::framework::master_test_suite().argc < 2) { + return "test.arpa"; + } + return boost::unit_test::framework::master_test_suite().argv[1]; +} + +Config SilentConfig() { + Config config; + config.arpa_complain = Config::NONE; + config.messages = NULL; + return config; +} + +struct ModelFixture { + ModelFixture() : m(TestLocation(), SilentConfig()) {} + + RestProbingModel m; +}; + +BOOST_FIXTURE_TEST_SUITE(suite, ModelFixture) + +BOOST_AUTO_TEST_CASE(SimpleBefore) { + Left left; + left.full = false; + left.length = 0; + Right right; + right.length = 0; + + Right reveal; + reveal.length = 1; + WordIndex period = m.GetVocabulary().Index("."); + reveal.words[0] = period; + reveal.backoff[0] = -0.845098; + + BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 0, false, left, right), 0.001); + BOOST_CHECK_EQUAL(0, left.length); + BOOST_CHECK(!left.full); + BOOST_CHECK_EQUAL(1, right.length); + BOOST_CHECK_EQUAL(period, right.words[0]); + BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001); + + WordIndex more = m.GetVocabulary().Index("more"); + reveal.words[1] = more; + reveal.backoff[1] = -0.4771212; + reveal.length = 2; + BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 1, false, left, right), 0.001); + BOOST_CHECK_EQUAL(0, left.length); + BOOST_CHECK(!left.full); + BOOST_CHECK_EQUAL(2, right.length); + BOOST_CHECK_EQUAL(period, right.words[0]); + BOOST_CHECK_EQUAL(more, right.words[1]); + BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001); + BOOST_CHECK_CLOSE(-0.4771212, right.backoff[1], 0.001); +} + +BOOST_AUTO_TEST_CASE(AlsoWouldConsider) { + WordIndex would = m.GetVocabulary().Index("would"); + WordIndex consider = m.GetVocabulary().Index("consider"); + + ChartState current; + current.left.length = 1; + current.left.pointers[0] = would; + current.left.full = false; + current.right.length = 1; + current.right.words[0] = would; + current.right.backoff[0] = -0.30103; + + Left after; + after.full = false; + after.length = 1; + after.pointers[0] = consider; + + // adjustment for would consider + BOOST_CHECK_CLOSE(-1.687872 - -0.2922095 - 0.30103, RevealAfter(m, current.left, current.right, after, 0), 0.001); + + BOOST_CHECK_EQUAL(2, current.left.length); + BOOST_CHECK_EQUAL(would, current.left.pointers[0]); + BOOST_CHECK_EQUAL(false, current.left.full); + + WordIndex also = m.GetVocabulary().Index("also"); + Right before; + before.length = 1; + before.words[0] = also; + before.backoff[0] = -0.30103; + // r(would) = -0.2922095 [i would], r(would -> consider) = -1.988902 [b(would) + p(consider)] + // p(also -> would) = -2, p(also would -> consider) = -3 + BOOST_CHECK_CLOSE(-2 + 0.2922095 -3 + 1.988902, RevealBefore(m, before, 0, false, current.left, current.right), 0.001); + BOOST_CHECK_EQUAL(0, current.left.length); + BOOST_CHECK(current.left.full); + BOOST_CHECK_EQUAL(2, current.right.length); + BOOST_CHECK_EQUAL(would, current.right.words[0]); + BOOST_CHECK_EQUAL(also, current.right.words[1]); +} + +BOOST_AUTO_TEST_CASE(EndSentence) { + WordIndex loin = m.GetVocabulary().Index("loin"); + WordIndex period = m.GetVocabulary().Index("."); + WordIndex eos = m.GetVocabulary().EndSentence(); + + ChartState between; + between.left.length = 1; + between.left.pointers[0] = eos; + between.left.full = true; + between.right.length = 0; + + Right before; + before.words[0] = period; + before.words[1] = loin; + before.backoff[0] = -0.845098; + before.backoff[1] = 0.0; + + before.length = 1; + BOOST_CHECK_CLOSE(-0.0410707, RevealBefore(m, before, 0, true, between.left, between.right), 0.001); + BOOST_CHECK_EQUAL(0, between.left.length); +} + +float ScoreFragment(const RestProbingModel &model, unsigned int *begin, unsigned int *end, ChartState &out) { + RuleScore scorer(model, out); + for (unsigned int *i = begin; i < end; ++i) { + scorer.Terminal(*i); + } + return scorer.Finish(); +} + +void CheckAdjustment(const RestProbingModel &model, float expect, const Right &before_in, bool before_full, ChartState between, const Left &after_in) { + Right before(before_in); + Left after(after_in); + after.full = false; + float got = 0.0; + for (unsigned int i = 1; i < 5; ++i) { + if (before_in.length >= i) { + before.length = i; + got += RevealBefore(model, before, i - 1, false, between.left, between.right); + } + if (after_in.length >= i) { + after.length = i; + got += RevealAfter(model, between.left, between.right, after, i - 1); + } + } + if (after_in.full) { + after.full = true; + got += RevealAfter(model, between.left, between.right, after, after.length); + } + if (before_full) { + got += RevealBefore(model, before, before.length, true, between.left, between.right); + } + // Sometimes they're zero and BOOST_CHECK_CLOSE fails for this. + BOOST_CHECK(fabs(expect - got) < 0.001); +} + +void FullDivide(const RestProbingModel &model, StringPiece str) { + std::vector indices; + for (util::TokenIter i(str, ' '); i; ++i) { + indices.push_back(model.GetVocabulary().Index(*i)); + } + ChartState full_state; + float full = ScoreFragment(model, &indices.front(), &indices.back() + 1, full_state); + + ChartState before_state; + before_state.left.full = false; + RuleScore before_scorer(model, before_state); + float before_score = 0.0; + for (unsigned int before = 0; before < indices.size(); ++before) { + for (unsigned int after = before; after <= indices.size(); ++after) { + ChartState after_state, between_state; + float after_score = ScoreFragment(model, &indices.front() + after, &indices.front() + indices.size(), after_state); + float between_score = ScoreFragment(model, &indices.front() + before, &indices.front() + after, between_state); + CheckAdjustment(model, full - before_score - after_score - between_score, before_state.right, before_state.left.full, between_state, after_state.left); + } + before_scorer.Terminal(indices[before]); + before_score = before_scorer.Finish(); + } +} + +BOOST_AUTO_TEST_CASE(Strings) { + FullDivide(m, "also would consider"); + FullDivide(m, "looking on a little more loin . "); + FullDivide(m, "in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown "); +} + +BOOST_AUTO_TEST_SUITE_END() +} // namespace +} // namespace ngram +} // namespace lm diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index abed0112..8ce2378a 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -24,7 +24,7 @@ class DontQuantize { public: static const ModelType kModelTypeAdd = static_cast(0); static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} - static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } + static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } static uint8_t LongestBits(const Config &/*config*/) { return 31; } @@ -138,9 +138,9 @@ class SeparatelyQuantize { static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config); - static std::size_t Size(uint8_t order, const Config &config) { - size_t longest_table = (static_cast(1) << static_cast(config.prob_bits)) * sizeof(float); - size_t middle_table = (static_cast(1) << static_cast(config.backoff_bits)) * sizeof(float) + longest_table; + static uint64_t Size(uint8_t order, const Config &config) { + uint64_t longest_table = (static_cast(1) << static_cast(config.prob_bits)) * sizeof(float); + uint64_t middle_table = (static_cast(1) << static_cast(config.backoff_bits)) * sizeof(float) + longest_table; // unigrams are currently not quantized so no need for a table. return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8; } diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 70727e4c..174bd3a3 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -2,12 +2,13 @@ #include "lm/blank.hh" +#include #include #include +#include #include #include -#include #include #include @@ -31,6 +32,16 @@ bool IsEntirelyWhiteSpace(const StringPiece &line) { const char kBinaryMagic[] = "mmap lm http://kheafield.com/code"; +// strtoull isn't portable enough :-( +uint64_t ReadCount(const std::string &from) { + std::stringstream stream(from); + uint64_t ret; + stream >> ret; + UTIL_THROW_IF(!stream, FormatLoadException, "Bad count " << from); + UTIL_THROW_IF(static_cast(stream.tellg()) != from.size(), FormatLoadException, "Extra content in count: '" << from << "'"); + return ret; +} + } // namespace void ReadARPACounts(util::FilePiece &in, std::vector &number) { @@ -52,15 +63,11 @@ void ReadARPACounts(util::FilePiece &in, std::vector &number) { // So strtol doesn't go off the end of line. std::string remaining(line.data() + 6, line.size() - 6); char *end_ptr; - unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10); + unsigned int length = std::strtol(remaining.c_str(), &end_ptr, 10); if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line); if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line); ++end_ptr; - const char *start = end_ptr; - long int count = std::strtol(start, &end_ptr, 10); - if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count); - if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line); - number.push_back(count); + number.push_back(ReadCount(end_ptr)); } } @@ -103,7 +110,7 @@ void ReadBackoff(util::FilePiece &in, float &backoff) { int float_class = _fpclass(backoff); UTIL_THROW_IF(float_class == _FPCLASS_SNAN || float_class == _FPCLASS_QNAN || float_class == _FPCLASS_NINF || float_class == _FPCLASS_PINF, FormatLoadException, "Bad backoff " << backoff); #else - int float_class = fpclassify(backoff); + int float_class = std::fpclassify(backoff); UTIL_THROW_IF(float_class == FP_NAN || float_class == FP_INFINITE, FormatLoadException, "Bad backoff " << backoff); #endif } diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 7e8c1220..3bcde921 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -74,8 +74,8 @@ template class HashedSearch { // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} - static std::size_t Size(const std::vector &counts, const Config &config) { - std::size_t ret = Unigram::Size(counts[0]); + static uint64_t Size(const std::vector &counts, const Config &config) { + uint64_t ret = Unigram::Size(counts[0]); for (unsigned char n = 1; n < counts.size() - 1; ++n) { ret += Middle::Size(counts[n], config.probing_multiplier); } @@ -160,7 +160,7 @@ template class HashedSearch { #endif {} - static std::size_t Size(uint64_t count) { + static uint64_t Size(uint64_t count) { return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate } diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 10b22ab1..1264baf5 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -44,8 +44,8 @@ template class TrieSearch { Bhiksha::UpdateConfigFromBinary(fd, config); } - static std::size_t Size(const std::vector &counts, const Config &config) { - std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); + static uint64_t Size(const std::vector &counts, const Config &config) { + uint64_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); for (unsigned char i = 1; i < counts.size() - 1; ++i) { ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config); } diff --git a/klm/lm/state.hh b/klm/lm/state.hh index 830e40aa..551510a8 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -47,6 +47,8 @@ class State { unsigned char length; }; +typedef State Right; + inline uint64_t hash_value(const State &state, uint64_t seed = 0) { return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed); } diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 0f1ca574..d9895f89 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -36,7 +36,7 @@ bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_ } } // namespace -std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { +uint64_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits; // Extra entry for next pointer at the end. // +7 then / 8 to round up bits and convert to bytes @@ -57,7 +57,7 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) max_vocab_ = max_vocab; } -template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { +template uint64_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config)); } diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 034a1414..9ea3c546 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -49,7 +49,7 @@ class Unigram { unigram_ = static_cast(start); } - static std::size_t Size(uint64_t count) { + static uint64_t Size(uint64_t count) { // +1 in case unknown doesn't appear. +1 for the final next. return (count + 2) * sizeof(UnigramValue); } @@ -84,7 +84,7 @@ class BitPacked { } protected: - static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); + static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); @@ -99,7 +99,7 @@ class BitPacked { template class BitPackedMiddle : public BitPacked { public: - static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); + static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); // next_source need not be initialized. BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); @@ -128,7 +128,7 @@ template class BitPackedMiddle : public BitPacked { class BitPackedLongest : public BitPacked { public: - static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { + static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { return BaseSize(entries, max_vocab, quant_bits); } diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 5de68f16..398475be 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -87,7 +87,7 @@ void WriteWordsWrapper::Write(int fd) { SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} -std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) { +uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) { // Lead with the number of entries. return sizeof(uint64_t) + sizeof(uint64_t) * entries; } @@ -165,7 +165,7 @@ struct ProbingVocabularyHeader { ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} -std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { +uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) { return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index a25432f9..074cd446 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -62,7 +62,7 @@ class SortedVocabulary : public base::Vocabulary { } // Size for purposes of file writing - static size_t Size(std::size_t entries, const Config &config); + static uint64_t Size(uint64_t entries, const Config &config); // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. WordIndex Bound() const { return bound_; } @@ -129,7 +129,7 @@ class ProbingVocabulary : public base::Vocabulary { return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; } - static size_t Size(std::size_t entries, const Config &config); + static uint64_t Size(uint64_t entries, const Config &config); // Vocab words are [0, Bound()). WordIndex Bound() const { return bound_; } diff --git a/klm/util/Jamfile b/klm/util/Jamfile index 3ee2c2c2..a939265f 100644 --- a/klm/util/Jamfile +++ b/klm/util/Jamfile @@ -1,10 +1,10 @@ -lib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc usage.cc ../..//z : .. : : .. ; +lib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc usage.cc /top//z : .. : : .. ; import testing ; -unit-test bit_packing_test : bit_packing_test.cc kenutil ../..///boost_unit_test_framework ; -run file_piece_test.cc kenutil ../..///boost_unit_test_framework : : file_piece.cc ; -unit-test joint_sort_test : joint_sort_test.cc kenutil ../..///boost_unit_test_framework ; -unit-test probing_hash_table_test : probing_hash_table_test.cc kenutil ../..///boost_unit_test_framework ; -unit-test sorted_uniform_test : sorted_uniform_test.cc kenutil ../..///boost_unit_test_framework ; -unit-test tokenize_piece_test : tokenize_piece_test.cc kenutil ../..///boost_unit_test_framework ; +unit-test bit_packing_test : bit_packing_test.cc kenutil /top//boost_unit_test_framework ; +run file_piece_test.cc kenutil /top//boost_unit_test_framework : : file_piece.cc ; +unit-test joint_sort_test : joint_sort_test.cc kenutil /top//boost_unit_test_framework ; +unit-test probing_hash_table_test : probing_hash_table_test.cc kenutil /top//boost_unit_test_framework ; +unit-test sorted_uniform_test : sorted_uniform_test.cc kenutil /top//boost_unit_test_framework ; +unit-test tokenize_piece_test : tokenize_piece_test.cc kenutil /top//boost_unit_test_framework ; diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc index 07b14e26..eb635ad8 100644 --- a/klm/util/ersatz_progress.cc +++ b/klm/util/ersatz_progress.cc @@ -9,16 +9,16 @@ namespace util { namespace { const unsigned char kWidth = 100; } -ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits::max()), complete_(next_), out_(NULL) {} +ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits::max()), complete_(next_), out_(NULL) {} ErsatzProgress::~ErsatzProgress() { if (out_) Finished(); } -ErsatzProgress::ErsatzProgress(std::size_t complete, std::ostream *to, const std::string &message) +ErsatzProgress::ErsatzProgress(uint64_t complete, std::ostream *to, const std::string &message) : current_(0), next_(complete / kWidth), complete_(complete), stones_written_(0), out_(to) { if (!out_) { - next_ = std::numeric_limits::max(); + next_ = std::numeric_limits::max(); return; } if (!message.empty()) *out_ << message << '\n'; @@ -28,14 +28,14 @@ ErsatzProgress::ErsatzProgress(std::size_t complete, std::ostream *to, const std void ErsatzProgress::Milestone() { if (!out_) { current_ = 0; return; } if (!complete_) return; - unsigned char stone = std::min(static_cast(kWidth), (current_ * kWidth) / complete_); + unsigned char stone = std::min(static_cast(kWidth), (current_ * kWidth) / complete_); for (; stones_written_ < stone; ++stones_written_) { (*out_) << '*'; } if (stone == kWidth) { (*out_) << std::endl; - next_ = std::numeric_limits::max(); + next_ = std::numeric_limits::max(); out_ = NULL; } else { next_ = std::max(next_, (stone * complete_) / kWidth); diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index f709dc51..ff4d590f 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -4,6 +4,8 @@ #include #include +#include + // Ersatz version of boost::progress so core language model doesn't depend on // boost. Also adds option to print nothing. @@ -14,7 +16,7 @@ class ErsatzProgress { ErsatzProgress(); // Null means no output. The null value is useful for passing along the ostream pointer from another caller. - explicit ErsatzProgress(std::size_t complete, std::ostream *to = &std::cerr, const std::string &message = ""); + explicit ErsatzProgress(uint64_t complete, std::ostream *to = &std::cerr, const std::string &message = ""); ~ErsatzProgress(); @@ -23,12 +25,12 @@ class ErsatzProgress { return *this; } - ErsatzProgress &operator+=(std::size_t amount) { + ErsatzProgress &operator+=(uint64_t amount) { if ((current_ += amount) >= next_) Milestone(); return *this; } - void Set(std::size_t to) { + void Set(uint64_t to) { if ((current_ = to) >= next_) Milestone(); Milestone(); } @@ -40,7 +42,7 @@ class ErsatzProgress { private: void Milestone(); - std::size_t current_, next_, complete_; + uint64_t current_, next_, complete_; unsigned char stones_written_; std::ostream *out_; diff --git a/klm/util/exception.cc b/klm/util/exception.cc index c4f8c04c..3806e6de 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -84,4 +84,7 @@ EndOfFileException::EndOfFileException() throw() { } EndOfFileException::~EndOfFileException() throw() {} +OverflowException::OverflowException() throw() {} +OverflowException::~OverflowException() throw() {} + } // namespace util diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 6d6a37cb..83f99cd6 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -2,9 +2,12 @@ #define UTIL_EXCEPTION__ #include +#include #include #include +#include + namespace util { template typename Except::template ExceptionTag::Identity operator<<(Except &e, const Data &data); @@ -111,6 +114,25 @@ class EndOfFileException : public Exception { ~EndOfFileException() throw(); }; +class OverflowException : public Exception { + public: + OverflowException() throw(); + ~OverflowException() throw(); +}; + +template inline std::size_t CheckOverflowInternal(uint64_t value) { + UTIL_THROW_IF(value > static_cast(std::numeric_limits::max()), OverflowException, "Integer overflow detected. This model is too big for 32-bit code."); + return value; +} + +template <> inline std::size_t CheckOverflowInternal<8>(uint64_t value) { + return value; +} + +inline std::size_t CheckOverflow(uint64_t value) { + return CheckOverflowInternal(value); +} + } // namespace util #endif // UTIL_EXCEPTION__ diff --git a/klm/util/file.cc b/klm/util/file.cc index 98f13983..ff5e64c9 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -119,8 +119,13 @@ void FSyncOrThrow(int fd) { } namespace { -void InternalSeek(int fd, off_t off, int whence) { +void InternalSeek(int fd, int64_t off, int whence) { +#if defined(_WIN32) || defined(_WIN64) + UTIL_THROW_IF((__int64)-1 == _lseeki64(fd, off, whence), ErrnoException, "Windows seek failed"); + +#else UTIL_THROW_IF((off_t)-1 == lseek(fd, off, whence), ErrnoException, "Seek failed"); +#endif } } // namespace diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index af341d6d..19a68728 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -11,7 +11,6 @@ #include #include -#include #include #include #include diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 3354b68e..770faa7e 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -8,6 +8,7 @@ #include #include +#include namespace util { @@ -42,8 +43,8 @@ template (multiplier * static_cast(entries))); + static uint64_t Size(uint64_t entries, float multiplier) { + uint64_t buckets = std::max(entries + 1, static_cast(multiplier * static_cast(entries))); return buckets * sizeof(Entry); } -- cgit v1.2.3 From 0ff82d648446645df245decc1e9eafad304eb327 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 22 Oct 2012 14:04:27 +0100 Subject: Update search, make it compile --- Makefile.am | 1 + configure.ac | 6 +- decoder/Makefile.am | 3 +- decoder/decoder.cc | 8 +- decoder/incremental.cc | 184 +++++++++++++++++++++++++++++++++++++++ decoder/incremental.h | 11 +++ decoder/lazy.cc | 178 -------------------------------------- decoder/lazy.h | 11 --- dtrain/Makefile.am | 2 +- klm/alone/Jamfile | 4 - klm/alone/assemble.cc | 76 ---------------- klm/alone/assemble.hh | 21 ----- klm/alone/graph.hh | 87 ------------------- klm/alone/just_vocab.cc | 14 --- klm/alone/labeled_edge.hh | 30 ------- klm/alone/main.cc | 85 ------------------ klm/alone/read.cc | 118 ------------------------- klm/alone/read.hh | 29 ------- klm/alone/threading.cc | 80 ----------------- klm/alone/threading.hh | 129 --------------------------- klm/alone/vocab.cc | 19 ---- klm/alone/vocab.hh | 34 -------- klm/lm/model.cc | 2 +- klm/lm/vocab.cc | 4 +- klm/lm/vocab.hh | 2 +- klm/search/Jamfile | 2 +- klm/search/Makefile.am | 11 +++ klm/search/arity.hh | 8 -- klm/search/context.hh | 10 +-- klm/search/edge.hh | 53 ++++++++---- klm/search/edge_generator.cc | 144 ++++++++++++++----------------- klm/search/edge_generator.hh | 49 ++++++----- klm/search/edge_queue.cc | 25 ------ klm/search/edge_queue.hh | 73 ---------------- klm/search/final.hh | 41 ++++----- klm/search/header.hh | 57 ++++++++++++ klm/search/source.hh | 48 ----------- klm/search/types.hh | 8 +- klm/search/vertex.cc | 10 +-- klm/search/vertex.hh | 55 ++++++------ klm/search/vertex_generator.cc | 97 ++++++++++++--------- klm/search/vertex_generator.hh | 33 +++---- klm/util/Makefile.am | 2 + klm/util/ersatz_progress.hh | 2 +- klm/util/exception.hh | 2 +- klm/util/pool.cc | 35 ++++++++ klm/util/pool.hh | 45 ++++++++++ klm/util/probing_hash_table.hh | 2 +- klm/util/string_piece.cc | 192 +++++++++++++++++++++++++++++++++++++++++ klm/util/tokenize_piece.hh | 12 +++ mira/Makefile.am | 2 +- training/Makefile.am | 38 ++++---- 52 files changed, 838 insertions(+), 1356 deletions(-) create mode 100644 decoder/incremental.cc create mode 100644 decoder/incremental.h delete mode 100644 decoder/lazy.cc delete mode 100644 decoder/lazy.h delete mode 100644 klm/alone/Jamfile delete mode 100644 klm/alone/assemble.cc delete mode 100644 klm/alone/assemble.hh delete mode 100644 klm/alone/graph.hh delete mode 100644 klm/alone/just_vocab.cc delete mode 100644 klm/alone/labeled_edge.hh delete mode 100644 klm/alone/main.cc delete mode 100644 klm/alone/read.cc delete mode 100644 klm/alone/read.hh delete mode 100644 klm/alone/threading.cc delete mode 100644 klm/alone/threading.hh delete mode 100644 klm/alone/vocab.cc delete mode 100644 klm/alone/vocab.hh create mode 100644 klm/search/Makefile.am delete mode 100644 klm/search/arity.hh delete mode 100644 klm/search/edge_queue.cc delete mode 100644 klm/search/edge_queue.hh create mode 100644 klm/search/header.hh delete mode 100644 klm/search/source.hh create mode 100644 klm/util/pool.cc create mode 100644 klm/util/pool.hh create mode 100644 klm/util/string_piece.cc (limited to 'klm/lm/vocab.hh') diff --git a/Makefile.am b/Makefile.am index 3e0103a8..fefc470d 100644 --- a/Makefile.am +++ b/Makefile.am @@ -6,6 +6,7 @@ SUBDIRS = \ mteval \ klm/util \ klm/lm \ + klm/search \ decoder \ training \ training/liblbfgs \ diff --git a/configure.ac b/configure.ac index 03a0ee87..cb132d66 100644 --- a/configure.ac +++ b/configure.ac @@ -12,6 +12,7 @@ AC_PROG_CXX AC_LANG_CPLUSPLUS BOOST_REQUIRE([1.44]) BOOST_PROGRAM_OPTIONS +BOOST_SYSTEM BOOST_TEST AM_PATH_PYTHON AC_CHECK_HEADER(dlfcn.h,AC_DEFINE(HAVE_DLFCN_H)) @@ -73,9 +74,9 @@ fi #BOOST_THREADS CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" -LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS" +LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_SYSTEM_LDFLAGS" # $BOOST_THREAD_LDFLAGS" -LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS" +LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_SYSTEM_LIBS" # $BOOST_THREAD_LIBS" AC_CHECK_HEADER(google/dense_hash_map, @@ -123,6 +124,7 @@ AC_CONFIG_FILES([rampion/Makefile]) AC_CONFIG_FILES([minrisk/Makefile]) AC_CONFIG_FILES([klm/util/Makefile]) AC_CONFIG_FILES([klm/lm/Makefile]) +AC_CONFIG_FILES([klm/search/Makefile]) AC_CONFIG_FILES([mira/Makefile]) AC_CONFIG_FILES([dtrain/Makefile]) AC_CONFIG_FILES([example_extff/Makefile]) diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 5c0a1964..f8f427d3 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -17,7 +17,7 @@ trule_test_SOURCES = trule_test.cc trule_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz cdec_SOURCES = cdec.cc -cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/search/libksearch.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils -I../klm @@ -73,6 +73,7 @@ libcdec_a_SOURCES = \ ff_source_syntax.cc \ ff_bleu.cc \ ff_factory.cc \ + incremental.cc \ lexalign.cc \ lextrans.cc \ tagger.cc \ diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 052823ca..fe812011 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -39,7 +39,7 @@ #include "sampler.h" #include "forest_writer.h" // TODO this section should probably be handled by an Observer -#include "lazy.h" +#include "incremental.h" #include "hg_io.h" #include "aligner.h" @@ -412,7 +412,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") ("show_target_graph", po::value(), "Directory to write the target hypergraphs to") - ("lazy_search", po::value(), "Run lazy search with this language model file") + ("incremental_search", po::value(), "Run lazy search with this language model file") ("coarse_to_fine_beam_prune", po::value(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") ("ctf_beam_widen", po::value()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") ("ctf_num_widenings", po::value()->default_value(2), "Widen coarse beam this many times before backing off to full parse") @@ -828,8 +828,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (conf.count("show_target_graph")) HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); - if (conf.count("lazy_search")) { - PassToLazy(conf["lazy_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); + if (conf.count("incremental_search")) { + PassToIncremental(conf["incremental_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); o->NotifyDecodingComplete(smeta); return true; } diff --git a/decoder/incremental.cc b/decoder/incremental.cc new file mode 100644 index 00000000..768bbd65 --- /dev/null +++ b/decoder/incremental.cc @@ -0,0 +1,184 @@ +#include "incremental.h" + +#include "hg.h" +#include "fdict.h" +#include "tdict.h" + +#include "lm/enumerate_vocab.hh" +#include "lm/model.hh" +#include "search/config.hh" +#include "search/context.hh" +#include "search/edge.hh" +#include "search/edge_generator.hh" +#include "search/rule.hh" +#include "search/vertex.hh" +#include "search/vertex_generator.hh" +#include "util/exception.hh" + +#include +#include + +#include +#include + +namespace { + +struct MapVocab : public lm::EnumerateVocab { + public: + MapVocab() {} + + // Do not call after Lookup. + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_.size()) out_.resize(cdec_id + 1); + out_[cdec_id] = index; + } + + // Assumes Add has been called and will never be called again. + lm::WordIndex FromCDec(WordID id) const { + return out_[out_.size() > id ? id : 0]; + } + + private: + std::vector out_; +}; + +class IncrementalBase { + public: + IncrementalBase(const std::vector &weights) : + cdec_weights_(weights), + weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { + std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; + } + + virtual ~IncrementalBase() {} + + virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; + + static IncrementalBase *Load(const char *model_file, const std::vector &weights); + + protected: + lm::ngram::Config GetConfig() { + lm::ngram::Config ret; + ret.enumerate_vocab = &vocab_; + return ret; + } + + MapVocab vocab_; + + const std::vector &cdec_weights_; + + const search::Weights weights_; +}; + +template class Incremental : public IncrementalBase { + public: + Incremental(const char *model_file, const std::vector &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {} + + void Search(unsigned int pop_limit, const Hypergraph &hg) const; + + private: + void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; + + const Model m_; +}; + +IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector &weights) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + return new Incremental(model_file, weights); + case lm::ngram::REST_PROBING: + return new Incremental(model_file, weights); + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); + } +} + +void PrintFinal(const Hypergraph &hg, const search::Final final) { + const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); + const search::Final *child(final.Children()); + for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { + if (*i > 0) { + std::cout << TD::Convert(*i) << ' '; + } else { + PrintFinal(hg, *child++); + } + } +} + +template void Incremental::Search(unsigned int pop_limit, const Hypergraph &hg) const { + boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); + search::Config config(weights_, pop_limit); + search::Context context(config, m_); + + for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { + search::EdgeGenerator gen; + const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; + for (unsigned int j = 0; j < down_edges.size(); ++j) { + unsigned int edge_index = down_edges[j]; + ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], gen); + } + search::VertexGenerator vertex_gen(context, out_vertices[i]); + gen.Search(context, vertex_gen); + } + const search::Final top = out_vertices[hg.nodes_.size() - 2].BestChild(); + if (top.Valid()) { + std::cout << "NO PATH FOUND" << std::endl; + } else { + PrintFinal(hg, top); + std::cout << "||| " << top.GetScore() << std::endl; + } +} + +template void Incremental::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const { + const std::vector &e = in.rule_->e(); + std::vector words; + words.reserve(e.size()); + std::vector nts; + unsigned int terminals = 0; + float score = 0.0; + for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { + if (*word <= 0) { + nts.push_back(vertices[in.tail_nodes_[-*word]].RootPartial()); + if (nts.back().Empty()) return; + score += nts.back().Bound(); + words.push_back(lm::kMaxWordIndex); + } else { + ++terminals; + words.push_back(vocab_.FromCDec(*word)); + } + } + + if (final) { + words.push_back(m_.GetVocabulary().EndSentence()); + } + + search::PartialEdge out(gen.AllocateEdge(nts.size())); + + memcpy(out.NT(), &nts[0], sizeof(search::PartialVertex) * nts.size()); + + search::Note note; + note.vp = ∈ + out.SetNote(note); + + score += in.rule_->GetFeatureValues().dot(cdec_weights_); + score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; + score += search::ScoreRule(context, words, final, out.Between()); + out.SetScore(score); + + gen.AddEdge(out); +} + +boost::scoped_ptr AwfulGlobalIncremental; + +} // namespace + +void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { + if (!AwfulGlobalIncremental.get()) { + std::cerr << "Pop limit " << pop_limit << std::endl; + AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights)); + } + AwfulGlobalIncremental->Search(pop_limit, hg); +} diff --git a/decoder/incremental.h b/decoder/incremental.h new file mode 100644 index 00000000..180383ce --- /dev/null +++ b/decoder/incremental.h @@ -0,0 +1,11 @@ +#ifndef _INCREMENTAL_H_ +#define _INCREMENTAL_H_ + +#include "weights.h" +#include + +class Hypergraph; + +void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); + +#endif // _INCREMENTAL_H_ diff --git a/decoder/lazy.cc b/decoder/lazy.cc deleted file mode 100644 index 1e6a94fe..00000000 --- a/decoder/lazy.cc +++ /dev/null @@ -1,178 +0,0 @@ -#include "hg.h" -#include "lazy.h" -#include "fdict.h" -#include "tdict.h" - -#include "lm/enumerate_vocab.hh" -#include "lm/model.hh" -#include "search/config.hh" -#include "search/context.hh" -#include "search/edge.hh" -#include "search/edge_queue.hh" -#include "search/vertex.hh" -#include "search/vertex_generator.hh" -#include "util/exception.hh" - -#include -#include - -#include -#include - -namespace { - -struct MapVocab : public lm::EnumerateVocab { - public: - MapVocab() {} - - // Do not call after Lookup. - void Add(lm::WordIndex index, const StringPiece &str) { - const WordID cdec_id = TD::Convert(str.as_string()); - if (cdec_id >= out_.size()) out_.resize(cdec_id + 1); - out_[cdec_id] = index; - } - - // Assumes Add has been called and will never be called again. - lm::WordIndex FromCDec(WordID id) const { - return out_[out_.size() > id ? id : 0]; - } - - private: - std::vector out_; -}; - -class LazyBase { - public: - LazyBase(const std::vector &weights) : - cdec_weights_(weights), - weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { - std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; - } - - virtual ~LazyBase() {} - - virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; - - static LazyBase *Load(const char *model_file, const std::vector &weights); - - protected: - lm::ngram::Config GetConfig() { - lm::ngram::Config ret; - ret.enumerate_vocab = &vocab_; - return ret; - } - - MapVocab vocab_; - - const std::vector &cdec_weights_; - - const search::Weights weights_; -}; - -template class Lazy : public LazyBase { - public: - Lazy(const char *model_file, const std::vector &weights) : LazyBase(weights), m_(model_file, GetConfig()) {} - - void Search(unsigned int pop_limit, const Hypergraph &hg) const; - - private: - unsigned char ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const; - - const Model m_; -}; - -LazyBase *LazyBase::Load(const char *model_file, const std::vector &weights) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - return new Lazy(model_file, weights); - case lm::ngram::REST_PROBING: - return new Lazy(model_file, weights); - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - -void PrintFinal(const Hypergraph &hg, const search::Final &final) { - const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); - boost::array::const_iterator child(final.Children().begin()); - for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { - if (*i > 0) { - std::cout << TD::Convert(*i) << ' '; - } else { - PrintFinal(hg, **child++); - } - } -} - -template void Lazy::Search(unsigned int pop_limit, const Hypergraph &hg) const { - boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); - search::Config config(weights_, pop_limit); - search::Context context(config, m_); - - for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { - search::EdgeQueue queue(context.PopLimit()); - const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; - for (unsigned int j = 0; j < down_edges.size(); ++j) { - unsigned int edge_index = down_edges[j]; - unsigned char arity = ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], queue.InitializeEdge()); - search::Note note; - note.vp = &hg.edges_[edge_index]; - if (arity != 255) queue.AddEdge(arity, note); - } - search::VertexGenerator vertex_gen(context, out_vertices[i]); - queue.Search(context, vertex_gen); - } - const search::Final *top = out_vertices[hg.nodes_.size() - 2].BestChild(); - if (!top) { - std::cout << "NO PATH FOUND" << std::endl; - } else { - PrintFinal(hg, *top); - std::cout << "||| " << top->Bound() << std::endl; - } -} - -template unsigned char Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const { - const std::vector &e = in.rule_->e(); - std::vector words; - unsigned int terminals = 0; - unsigned char nt = 0; - out.score = 0.0; - for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { - if (*word <= 0) { - out.nt[nt] = vertices[in.tail_nodes_[-*word]].RootPartial(); - if (out.nt[nt].Empty()) return 255; - out.score += out.nt[nt].Bound(); - ++nt; - words.push_back(lm::kMaxWordIndex); - } else { - ++terminals; - words.push_back(vocab_.FromCDec(*word)); - } - } - for (unsigned char fill = nt; fill < search::kMaxArity; ++fill) { - out.nt[fill] = search::kBlankPartialVertex; - } - - if (final) { - words.push_back(m_.GetVocabulary().EndSentence()); - } - - out.score += in.rule_->GetFeatureValues().dot(cdec_weights_); - out.score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; - out.score += search::ScoreRule(context, words, final, out.between); - return nt; -} - -boost::scoped_ptr AwfulGlobalLazy; - -} // namespace - -void PassToLazy(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { - if (!AwfulGlobalLazy.get()) { - std::cerr << "Pop limit " << pop_limit << std::endl; - AwfulGlobalLazy.reset(LazyBase::Load(model_file, weights)); - } - AwfulGlobalLazy->Search(pop_limit, hg); -} diff --git a/decoder/lazy.h b/decoder/lazy.h deleted file mode 100644 index 94895b19..00000000 --- a/decoder/lazy.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef _LAZY_H_ -#define _LAZY_H_ - -#include "weights.h" -#include - -class Hypergraph; - -void PassToLazy(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); - -#endif // _LAZY_H_ diff --git a/dtrain/Makefile.am b/dtrain/Makefile.am index 64fef489..ca9581f5 100644 --- a/dtrain/Makefile.am +++ b/dtrain/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = dtrain dtrain_SOURCES = dtrain.cc score.cc -dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/klm/alone/Jamfile b/klm/alone/Jamfile deleted file mode 100644 index 2cc90c05..00000000 --- a/klm/alone/Jamfile +++ /dev/null @@ -1,4 +0,0 @@ -lib standalone : assemble.cc read.cc threading.cc vocab.cc ../lm//kenlm ../util//kenutil ../search//search : .. : : .. ../search//search ../lm//kenlm ; - -exe decode : main.cc standalone main.cc : multi:..//boost_thread ; -exe just_vocab : just_vocab.cc standalone : multi:..//boost_thread ; diff --git a/klm/alone/assemble.cc b/klm/alone/assemble.cc deleted file mode 100644 index 2ae72ce9..00000000 --- a/klm/alone/assemble.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "alone/assemble.hh" - -#include "alone/labeled_edge.hh" -#include "search/final.hh" - -#include - -namespace alone { - -std::ostream &operator<<(std::ostream &o, const search::Final &final) { - const std::vector &words = static_cast(final.From()).Words(); - if (words.empty()) return o; - const search::Final *const *child = final.Children().data(); - std::vector::const_iterator i(words.begin()); - for (; i != words.end() - 1; ++i) { - if (*i) { - o << **i << ' '; - } else { - o << **child << ' '; - ++child; - } - } - - if (*i) { - if (**i != "") { - o << **i; - } - } else { - o << **child; - } - - return o; -} - -namespace { - -void MakeIndent(std::ostream &o, const char *indent_str, unsigned int level) { - for (unsigned int i = 0; i < level; ++i) - o << indent_str; -} - -void DetailedFinalInternal(std::ostream &o, const search::Final &final, const char *indent_str, unsigned int indent) { - o << "(\n"; - MakeIndent(o, indent_str, indent); - const std::vector &words = static_cast(final.From()).Words(); - const search::Final *const *child = final.Children().data(); - for (std::vector::const_iterator i(words.begin()); i != words.end(); ++i) { - if (*i) { - o << **i; - if (i == words.end() - 1) { - o << '\n'; - MakeIndent(o, indent_str, indent); - } else { - o << ' '; - } - } else { - // One extra indent from the line we're currently on. - o << indent_str; - DetailedFinalInternal(o, **child, indent_str, indent + 1); - for (unsigned int i = 0; i < indent; ++i) o << indent_str; - ++child; - } - } - o << ")=" << final.Bound() << '\n'; -} -} // namespace - -void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str) { - DetailedFinalInternal(o, final, indent_str, 0); -} - -void PrintFinal(const search::Final &final) { - std::cout << final << std::endl; -} - -} // namespace alone diff --git a/klm/alone/assemble.hh b/klm/alone/assemble.hh deleted file mode 100644 index e6b0ad5c..00000000 --- a/klm/alone/assemble.hh +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef ALONE_ASSEMBLE__ -#define ALONE_ASSEMBLE__ - -#include - -namespace search { -class Final; -} // namespace search - -namespace alone { - -std::ostream &operator<<(std::ostream &o, const search::Final &final); - -void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str = " "); - -// This isn't called anywhere but makes it easy to print from gdb. -void PrintFinal(const search::Final &final); - -} // namespace alone - -#endif // ALONE_ASSEMBLE__ diff --git a/klm/alone/graph.hh b/klm/alone/graph.hh deleted file mode 100644 index 788352c9..00000000 --- a/klm/alone/graph.hh +++ /dev/null @@ -1,87 +0,0 @@ -#ifndef ALONE_GRAPH__ -#define ALONE_GRAPH__ - -#include "alone/labeled_edge.hh" -#include "search/rule.hh" -#include "search/types.hh" -#include "search/vertex.hh" -#include "util/exception.hh" - -#include -#include -#include - -namespace alone { - -template class FixedAllocator : boost::noncopyable { - public: - FixedAllocator() : current_(NULL), end_(NULL) {} - - void Init(std::size_t count) { - assert(!current_); - array_.reset(new T[count]); - current_ = array_.get(); - end_ = current_ + count; - } - - T &operator[](std::size_t idx) { - return array_.get()[idx]; - } - - T *New() { - T *ret = current_++; - UTIL_THROW_IF(ret >= end_, util::Exception, "Allocating past end"); - return ret; - } - - std::size_t Size() const { - return end_ - array_.get(); - } - - private: - boost::scoped_array array_; - T *current_, *end_; -}; - -class Graph : boost::noncopyable { - public: - typedef LabeledEdge Edge; - typedef search::Vertex Vertex; - - Graph() {} - - void SetCounts(std::size_t vertices, std::size_t edges) { - vertices_.Init(vertices); - edges_.Init(edges); - } - - Vertex *NewVertex() { - return vertices_.New(); - } - - std::size_t VertexSize() const { return vertices_.Size(); } - - Vertex &MutableVertex(std::size_t index) { - return vertices_[index]; - } - - Edge *NewEdge() { - return edges_.New(); - } - - std::size_t EdgeSize() const { return edges_.Size(); } - - void SetRoot(Vertex *root) { root_ = root; } - - Vertex &Root() { return *root_; } - - private: - FixedAllocator vertices_; - FixedAllocator edges_; - - Vertex *root_; -}; - -} // namespace alone - -#endif // ALONE_GRAPH__ diff --git a/klm/alone/just_vocab.cc b/klm/alone/just_vocab.cc deleted file mode 100644 index 35aea5ed..00000000 --- a/klm/alone/just_vocab.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "alone/read.hh" -#include "util/file_piece.hh" - -#include - -int main() { - util::FilePiece f(0, "stdin", &std::cerr); - while (true) { - try { - alone::JustVocab(f, std::cout); - } catch (const util::EndOfFileException &e) { break; } - std::cout << '\n'; - } -} diff --git a/klm/alone/labeled_edge.hh b/klm/alone/labeled_edge.hh deleted file mode 100644 index 94d8cbdf..00000000 --- a/klm/alone/labeled_edge.hh +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef ALONE_LABELED_EDGE__ -#define ALONE_LABELED_EDGE__ - -#include "search/edge.hh" - -#include -#include - -namespace alone { - -class LabeledEdge : public search::Edge { - public: - LabeledEdge() {} - - void AppendWord(const std::string *word) { - words_.push_back(word); - } - - const std::vector &Words() const { - return words_; - } - - private: - // NULL for non-terminals. - std::vector words_; -}; - -} // namespace alone - -#endif // ALONE_LABELED_EDGE__ diff --git a/klm/alone/main.cc b/klm/alone/main.cc deleted file mode 100644 index e09ab01d..00000000 --- a/klm/alone/main.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "alone/threading.hh" -#include "search/config.hh" -#include "search/context.hh" -#include "util/exception.hh" -#include "util/file_piece.hh" -#include "util/usage.hh" - -#include - -#include -#include - -namespace alone { - -template void ReadLoop(const std::string &graph_prefix, Control &control) { - for (unsigned int sentence = 0; ; ++sentence) { - std::stringstream name; - name << graph_prefix << '/' << sentence; - std::auto_ptr file; - try { - file.reset(new util::FilePiece(name.str().c_str())); - } catch (const util::ErrnoException &e) { - if (e.Error() == ENOENT) return; - throw; - } - control.Add(file.release()); - } -} - -template void RunWithModelType(const char *graph_prefix, const char *model_file, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { - Model model(model_file); - search::Weights weights(weight_str); - search::Config config(weights, pop_limit); - - if (threads > 1) { -#ifdef WITH_THREADS - Controller controller(config, model, threads, std::cout); - ReadLoop(graph_prefix, controller); -#else - UTIL_THROW(util::Exception, "Threading support not compiled in."); -#endif - } else { - InThread controller(config, model, std::cout); - ReadLoop(graph_prefix, controller); - } -} - -void Run(const char *graph_prefix, const char *lm_name, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - RunWithModelType(graph_prefix, lm_name, weight_str, pop_limit, threads); - break; - case lm::ngram::REST_PROBING: - RunWithModelType(graph_prefix, lm_name, weight_str, pop_limit, threads); - break; - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - -} // namespace alone - -int main(int argc, char *argv[]) { - if (argc < 5 || argc > 6) { - std::cerr << argv[0] << " graph_prefix lm \"weights\" pop [threads]" << std::endl; - return 1; - } - -#ifdef WITH_THREADS - unsigned thread_count = boost::thread::hardware_concurrency(); -#else - unsigned thread_count = 1; -#endif - if (argc == 6) { - thread_count = boost::lexical_cast(argv[5]); - UTIL_THROW_IF(!thread_count, util::Exception, "Thread count 0"); - } - UTIL_THROW_IF(!thread_count, util::Exception, "Boost doesn't know how many threads there are. Pass it on the command line."); - alone::Run(argv[1], argv[2], argv[3], boost::lexical_cast(argv[4]), thread_count); - - util::PrintUsage(std::cerr); - return 0; -} diff --git a/klm/alone/read.cc b/klm/alone/read.cc deleted file mode 100644 index 0b20be35..00000000 --- a/klm/alone/read.cc +++ /dev/null @@ -1,118 +0,0 @@ -#include "alone/read.hh" - -#include "alone/graph.hh" -#include "alone/vocab.hh" -#include "search/arity.hh" -#include "search/context.hh" -#include "search/weights.hh" -#include "util/file_piece.hh" - -#include -#include - -#include - -namespace alone { - -namespace { - -template Graph::Edge &ReadEdge(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab, bool final) { - Graph::Edge *ret = to.NewEdge(); - - StringPiece got; - - std::vector words; - unsigned long int terminals = 0; - while ("|||" != (got = from.ReadDelimited())) { - if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) { - // non-terminal - char *end_ptr; - unsigned long int child = std::strtoul(got.data() + 1, &end_ptr, 10); - UTIL_THROW_IF(end_ptr != got.data() + got.size() - 1, FormatException, "Bad non-terminal" << got); - UTIL_THROW_IF(child >= to.VertexSize(), FormatException, "Reference to vertex " << child << " but we only have " << to.VertexSize() << " vertices. Is the file in bottom-up format?"); - ret->Add(to.MutableVertex(child)); - words.push_back(lm::kMaxWordIndex); - ret->AppendWord(NULL); - } else { - const std::pair &found = vocab.FindOrAdd(got); - words.push_back(found.second); - ret->AppendWord(&found.first); - ++terminals; - } - } - if (final) { - // This is not counted for the word penalty. - words.push_back(vocab.EndSentence().second); - ret->AppendWord(&vocab.EndSentence().first); - } - // Hard-coded word penalty. - float additive = context.GetWeights().DotNoLM(from.ReadLine()) - context.GetWeights().WordPenalty() * static_cast(terminals) / M_LN10; - ret->InitRule().Init(context, additive, words, final); - unsigned int arity = ret->GetRule().Arity(); - UTIL_THROW_IF(arity > search::kMaxArity, util::Exception, "Edit search/arity.hh and increase " << search::kMaxArity << " to at least " << arity); - return *ret; -} - -} // namespace - -// TODO: refactor -void JustVocab(util::FilePiece &from, std::ostream &out) { - boost::unordered_set seen; - unsigned long int vertices = from.ReadULong(); - from.ReadULong(); // edges - UTIL_THROW_IF(vertices == 0, FormatException, "Vertex count is zero"); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); - std::string temp; - for (unsigned long int i = 0; i < vertices; ++i) { - unsigned long int edge_count = from.ReadULong(); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); - for (unsigned long int e = 0; e < edge_count; ++e) { - StringPiece got; - while ("|||" != (got = from.ReadDelimited())) { - if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) continue; - temp.assign(got.data(), got.size()); - if (seen.insert(temp).second) out << temp << ' '; - } - from.ReadLine(); // weights - } - } - // Eat sentence - from.ReadLine(); -} - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab) { - unsigned long int vertices; - try { - vertices = from.ReadULong(); - } catch (const util::EndOfFileException &e) { return false; } - unsigned long int edges = from.ReadULong(); - UTIL_THROW_IF(vertices < 2, FormatException, "Vertex count is " << vertices); - UTIL_THROW_IF(edges == 0, FormatException, "Edge count is " << edges); - --vertices; - --edges; - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); - to.SetCounts(vertices, edges); - Graph::Vertex *vertex; - for (unsigned long int i = 0; ; ++i) { - vertex = to.NewVertex(); - unsigned long int edge_count = from.ReadULong(); - bool root = (i == vertices - 1); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); - for (unsigned long int e = 0; e < edge_count; ++e) { - vertex->Add(ReadEdge(context, from, to, vocab, root)); - } - vertex->FinishedAdding(); - if (root) break; - } - to.SetRoot(vertex); - StringPiece str = from.ReadLine(); - UTIL_THROW_IF("1" != str, FormatException, "Expected one edge to root"); - // The edge - from.ReadLine(); - return true; -} - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); - -} // namespace alone diff --git a/klm/alone/read.hh b/klm/alone/read.hh deleted file mode 100644 index 10769a86..00000000 --- a/klm/alone/read.hh +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef ALONE_READ__ -#define ALONE_READ__ - -#include "util/exception.hh" - -#include - -namespace util { class FilePiece; } - -namespace search { template class Context; } - -namespace alone { - -class Graph; -class Vocab; - -class FormatException : public util::Exception { - public: - FormatException() {} - ~FormatException() throw() {} -}; - -void JustVocab(util::FilePiece &from, std::ostream &to); - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); - -} // namespace alone - -#endif // ALONE_READ__ diff --git a/klm/alone/threading.cc b/klm/alone/threading.cc deleted file mode 100644 index 475386b6..00000000 --- a/klm/alone/threading.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include "alone/threading.hh" - -#include "alone/assemble.hh" -#include "alone/graph.hh" -#include "alone/read.hh" -#include "alone/vocab.hh" -#include "lm/model.hh" -#include "search/context.hh" -#include "search/vertex_generator.hh" - -#include -#include -#include - -#include - -namespace alone { -template void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out) { - search::Context context(config, model); - Graph graph; - Vocab vocab(model.GetVocabulary()); - { - boost::scoped_ptr in(in_ptr); - ReadCDec(context, *in, graph, vocab); - } - - for (std::size_t i = 0; i < graph.VertexSize(); ++i) { - search::VertexGenerator(context, graph.MutableVertex(i)); - } - search::PartialVertex top = graph.Root().RootPartial(); - if (top.Empty()) { - out << "NO PATH FOUND"; - } else { - search::PartialVertex continuation; - while (!top.Complete()) { - top.Split(continuation); - top = continuation; - } - out << top.End() << " ||| " << top.End().Bound() << std::endl; - } -} - -template void Decode(const search::Config &config, const lm::ngram::ProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); -template void Decode(const search::Config &config, const lm::ngram::RestProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); - -#ifdef WITH_THREADS -template void DecodeHandler::operator()(Input message) { - std::stringstream assemble; - Decode(config_, model_, message.file, assemble); - Produce(message.sentence_id, assemble.str()); -} - -template void DecodeHandler::Produce(unsigned int sentence_id, const std::string &str) { - Output out; - out.sentence_id = sentence_id; - out.str = new std::string(str); - out_.Produce(out); -} - -void PrintHandler::operator()(Output message) { - unsigned int relative = message.sentence_id - done_; - if (waiting_.size() <= relative) waiting_.resize(relative + 1); - waiting_[relative] = message.str; - for (std::string *lead; !waiting_.empty() && (lead = waiting_[0]); waiting_.pop_front(), ++done_) { - out_ << *lead; - delete lead; - } -} - -template Controller::Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to) : - sentence_id_(0), - printer_(decode_workers, 1, boost::ref(to), Output::Poison()), - decoder_(3, decode_workers, boost::in_place(boost::ref(config), boost::ref(model), boost::ref(printer_.In())), Input::Poison()) {} - -template class Controller; -template class Controller; - -#endif - -} // namespace alone diff --git a/klm/alone/threading.hh b/klm/alone/threading.hh deleted file mode 100644 index 0ab0f739..00000000 --- a/klm/alone/threading.hh +++ /dev/null @@ -1,129 +0,0 @@ -#ifndef ALONE_THREADING__ -#define ALONE_THREADING__ - -#ifdef WITH_THREADS -#include "util/pcqueue.hh" -#include "util/pool.hh" -#endif - -#include -#include -#include - -namespace util { -class FilePiece; -} // namespace util - -namespace search { -class Config; -template class Context; -} // namespace search - -namespace alone { - -template void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out); - -class Graph; - -#ifdef WITH_THREADS -struct SentenceID { - unsigned int sentence_id; - bool operator==(const SentenceID &other) const { - return sentence_id == other.sentence_id; - } -}; - -struct Input : public SentenceID { - util::FilePiece *file; - static Input Poison() { - Input ret; - ret.sentence_id = static_cast(-1); - ret.file = NULL; - return ret; - } -}; - -struct Output : public SentenceID { - std::string *str; - static Output Poison() { - Output ret; - ret.sentence_id = static_cast(-1); - ret.str = NULL; - return ret; - } -}; - -template class DecodeHandler { - public: - typedef Input Request; - - DecodeHandler(const search::Config &config, const Model &model, util::PCQueue &out) : config_(config), model_(model), out_(out) {} - - void operator()(Input message); - - private: - void Produce(unsigned int sentence_id, const std::string &str); - - const search::Config &config_; - - const Model &model_; - - util::PCQueue &out_; -}; - -class PrintHandler { - public: - typedef Output Request; - - explicit PrintHandler(std::ostream &o) : out_(o), done_(0) {} - - void operator()(Output message); - - private: - std::ostream &out_; - std::deque waiting_; - unsigned int done_; -}; - -template class Controller { - public: - // This config must remain valid. - explicit Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to); - - // Takes ownership of in. - void Add(util::FilePiece *in) { - Input input; - input.sentence_id = sentence_id_++; - input.file = in; - decoder_.Produce(input); - } - - private: - unsigned int sentence_id_; - - util::Pool printer_; - - util::Pool > decoder_; -}; -#endif - -// Same API as controller. -template class InThread { - public: - InThread(const search::Config &config, const Model &model, std::ostream &to) : config_(config), model_(model), to_(to) {} - - // Takes ownership of in. - void Add(util::FilePiece *in) { - Decode(config_, model_, in, to_); - } - - private: - const search::Config &config_; - - const Model &model_; - - std::ostream &to_; -}; - -} // namespace alone -#endif // ALONE_THREADING__ diff --git a/klm/alone/vocab.cc b/klm/alone/vocab.cc deleted file mode 100644 index ffe55301..00000000 --- a/klm/alone/vocab.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "alone/vocab.hh" - -#include "lm/virtual_interface.hh" -#include "util/string_piece.hh" - -namespace alone { - -Vocab::Vocab(const lm::base::Vocabulary &backing) : backing_(backing), end_sentence_(FindOrAdd("")) {} - -const std::pair &Vocab::FindOrAdd(const StringPiece &str) { - Map::const_iterator i(FindStringPiece(map_, str)); - if (i != map_.end()) return *i; - std::pair to_ins; - to_ins.first.assign(str.data(), str.size()); - to_ins.second = backing_.Index(str); - return *map_.insert(to_ins).first; -} - -} // namespace alone diff --git a/klm/alone/vocab.hh b/klm/alone/vocab.hh deleted file mode 100644 index 3ac0f542..00000000 --- a/klm/alone/vocab.hh +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef ALONE_VOCAB__ -#define ALONE_VOCAB__ - -#include "lm/word_index.hh" -#include "util/string_piece.hh" - -#include -#include - -#include - -namespace lm { namespace base { class Vocabulary; } } - -namespace alone { - -class Vocab { - public: - explicit Vocab(const lm::base::Vocabulary &backing); - - const std::pair &FindOrAdd(const StringPiece &str); - - const std::pair &EndSentence() const { return end_sentence_; } - - private: - typedef boost::unordered_map Map; - Map map_; - - const lm::base::Vocabulary &backing_; - - const std::pair &end_sentence_; -}; - -} // namespace alone -#endif // ALONE_VCOAB__ diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 40af8a63..2fd20481 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -87,7 +87,7 @@ template void GenericModel.. ; +lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : .. ; import testing ; diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am new file mode 100644 index 00000000..ccc5b7f6 --- /dev/null +++ b/klm/search/Makefile.am @@ -0,0 +1,11 @@ +noinst_LIBRARIES = libksearch.a + +libksearch_a_SOURCES = \ + edge_generator.cc \ + rule.cc \ + vertex.cc \ + vertex_generator.cc \ + weights.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. + diff --git a/klm/search/arity.hh b/klm/search/arity.hh deleted file mode 100644 index 09c2c671..00000000 --- a/klm/search/arity.hh +++ /dev/null @@ -1,8 +0,0 @@ -#ifndef SEARCH_ARITY__ -#define SEARCH_ARITY__ -namespace search { - -const unsigned int kMaxArity = 2; - -} // namespace search -#endif // SEARCH_ARITY__ diff --git a/klm/search/context.hh b/klm/search/context.hh index 27940053..62163144 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -7,6 +7,7 @@ #include "search/types.hh" #include "search/vertex.hh" #include "util/exception.hh" +#include "util/pool.hh" #include #include @@ -21,10 +22,8 @@ class ContextBase { public: explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - Final *NewFinal() { - Final *ret = final_pool_.construct(); - assert(ret); - return ret; + util::Pool &FinalPool() { + return final_pool_; } VertexNode *NewVertexNode() { @@ -42,7 +41,8 @@ class ContextBase { const Weights &GetWeights() const { return weights_; } private: - boost::object_pool final_pool_; + util::Pool final_pool_; + boost::object_pool vertex_node_pool_; unsigned int pop_limit_; diff --git a/klm/search/edge.hh b/klm/search/edge.hh index 77ab0ade..187904bf 100644 --- a/klm/search/edge.hh +++ b/klm/search/edge.hh @@ -2,30 +2,53 @@ #define SEARCH_EDGE__ #include "lm/state.hh" -#include "search/arity.hh" -#include "search/rule.hh" +#include "search/header.hh" #include "search/types.hh" #include "search/vertex.hh" +#include "util/pool.hh" -#include +#include + +#include namespace search { -struct PartialEdge { - Score score; - // Terminals - lm::ngram::ChartState between[kMaxArity + 1]; - // Non-terminals - PartialVertex nt[kMaxArity]; +// Copyable, but the copy will be shallow. +class PartialEdge : public Header { + public: + // Allow default construction for STL. + PartialEdge() {} + + PartialEdge(util::Pool &pool, Arity arity) + : Header(pool.Allocate(Size(arity, arity + 1)), arity) {} + + PartialEdge(util::Pool &pool, Arity arity, Arity chart_states) + : Header(pool.Allocate(Size(arity, chart_states)), arity) {} - const lm::ngram::ChartState &CompletedState() const { - return between[0]; - } + // Non-terminals + const PartialVertex *NT() const { + return reinterpret_cast(After()); + } + PartialVertex *NT() { + return reinterpret_cast(After()); + } - bool operator<(const PartialEdge &other) const { - return score < other.score; - } + const lm::ngram::ChartState &CompletedState() const { + return *Between(); + } + const lm::ngram::ChartState *Between() const { + return reinterpret_cast(After() + GetArity() * sizeof(PartialVertex)); + } + lm::ngram::ChartState *Between() { + return reinterpret_cast(After() + GetArity() * sizeof(PartialVertex)); + } + + private: + static std::size_t Size(Arity arity, Arity chart_states) { + return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState); + } }; + } // namespace search #endif // SEARCH_EDGE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index 56239dfb..260159b1 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -4,117 +4,107 @@ #include "lm/partial.hh" #include "search/context.hh" #include "search/vertex.hh" -#include "search/vertex_generator.hh" #include namespace search { -EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { -/* for (unsigned char i = 0; i < edge.Arity(); ++i) { - root.nt[i] = edge.GetVertex(i).RootPartial(); - } - for (unsigned char i = edge.Arity(); i < 2; ++i) { - root.nt[i] = kBlankPartialVertex; - }*/ - generate_.push(&root); - top_score_ = root.score; -} - namespace { -template float FastScore(const Context &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) { - memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1)); - - float ret = 0.0; - lm::ngram::ChartState *before, *after; - if (victim == 0) { - before = &update.between[0]; - after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1]; - } else { - assert(victim == 1); - assert(arity == 2); - before = &update.between[previous.nt[0].Complete() ? 0 : 1]; - after = &update.between[2]; - } - const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State(); - const PartialVertex &update_nt = update.nt[victim]; +template void FastScore(const Context &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) { + lm::ngram::ChartState *between = update.Between(); + lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1]; + + float adjustment = 0.0; + const lm::ngram::ChartState &previous_reveal = previous_vertex.State(); + const PartialVertex &update_nt = update.NT()[victim]; const lm::ngram::ChartState &update_reveal = update_nt.State(); - float just_after = 0.0; if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { - just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); + adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); } - if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) { - ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); + if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) { + adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); } if (update_nt.Complete()) { if (update_reveal.left.full) { before->left.full = true; } else { assert(update_reveal.left.length == update_reveal.right.length); - ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); + adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); } - if (victim == 0) { - update.between[0].right = after->right; - } else { - update.between[2].left = before->left; + before->right = after->right; + // Shift the others shifted one down, covering after. + for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) { + *cover = *(cover + 1); } } - return previous.score + (ret + just_after) * context.GetWeights().LM(); + update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); } } // namespace -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool) { +template PartialEdge EdgeGenerator::Pop(Context &context) { assert(!generate_.empty()); - PartialEdge &top = *generate_.top(); + PartialEdge top = generate_.top(); generate_.pop(); - unsigned int victim = 0; - unsigned char lowest_length = 255; - for (unsigned char i = 0; i != arity_; ++i) { - if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { - lowest_length = top.nt[i].Length(); - victim = i; + PartialVertex *const top_nt = top.NT(); + const Arity arity = top.GetArity(); + + Arity victim = 0; + Arity victim_completed; + Arity incomplete; + // Select victim or return if complete. + { + Arity completed = 0; + unsigned char lowest_length = 255; + for (Arity i = 0; i != arity; ++i) { + if (top_nt[i].Complete()) { + ++completed; + } else if (top_nt[i].Length() < lowest_length) { + lowest_length = top_nt[i].Length(); + victim = i; + victim_completed = completed; + } } - } - if (lowest_length == 255) { - // All states report complete. - top.between[0].right = top.between[arity_].right; - // Now top.between[0] is the full edge state. - top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; - return ⊤ + if (lowest_length == 255) { + return top; + } + incomplete = arity - completed; } - unsigned int stay = !victim; - PartialEdge &continuation = *static_cast(partial_edge_pool.malloc()); - float old_bound = top.nt[victim].Bound(); - // The alternate's score will change because alternate.nt[victim] changes. - bool split = top.nt[victim].Split(continuation.nt[victim]); - // top is now the alternate. + PartialVertex old_value(top_nt[victim]); + PartialVertex alternate_changed; + if (top_nt[victim].Split(alternate_changed)) { + PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1); + alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound()); - continuation.nt[stay] = top.nt[stay]; - continuation.score = FastScore(context, victim, arity_, top, continuation); - // TODO: dedupe? - generate_.push(&continuation); + alternate.SetNote(top.GetNote()); + + PartialVertex *alternate_nt = alternate.NT(); + for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i]; + alternate_nt[victim] = alternate_changed; + for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i]; + + memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1)); - if (split) { - // We have an alternate. - top.score += top.nt[victim].Bound() - old_bound; // TODO: dedupe? - generate_.push(&top); - } else { - partial_edge_pool.free(&top); + generate_.push(alternate); } - top_score_ = generate_.top()->score; - return NULL; + // top is now the continuation. + FastScore(context, victim, victim - victim_completed, incomplete, old_value, top); + // TODO: dedupe? + generate_.push(top); + + // Invalid indicates no new hypothesis generated. + return PartialEdge(); } -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); } // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index 875ccc5e..582c78b7 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -3,11 +3,8 @@ #include "search/edge.hh" #include "search/note.hh" +#include "search/types.hh" -#include -#include - -#include #include namespace lm { @@ -20,38 +17,40 @@ namespace search { template class Context; -class VertexGenerator; - -struct PartialEdgePointerLess : std::binary_function { - bool operator()(const PartialEdge *first, const PartialEdge *second) const { - return *first < *second; - } -}; - class EdgeGenerator { public: - EdgeGenerator(PartialEdge &root, unsigned char arity, Note note); + EdgeGenerator() {} - Score TopScore() const { - return top_score_; + PartialEdge AllocateEdge(Arity arity) { + return PartialEdge(partial_edge_pool_, arity); } - Note GetNote() const { - return note_; + void AddEdge(PartialEdge edge) { + generate_.push(edge); } - // Pop. If there's a complete hypothesis, return it. Otherwise return NULL. - template PartialEdge *Pop(Context &context, boost::pool<> &partial_edge_pool); + bool Empty() const { return generate_.empty(); } + + // Pop. If there's a complete hypothesis, return it. Otherwise return an invalid PartialEdge. + template PartialEdge Pop(Context &context); + + template void Search(Context &context, Output &output) { + unsigned to_pop = context.PopLimit(); + while (to_pop > 0 && !generate_.empty()) { + PartialEdge got(Pop(context)); + if (got.Valid()) { + output.NewHypothesis(got); + --to_pop; + } + } + output.FinishedSearch(); + } private: - Score top_score_; - - unsigned char arity_; + util::Pool partial_edge_pool_; - typedef std::priority_queue, PartialEdgePointerLess> Generate; + typedef std::priority_queue Generate; Generate generate_; - - Note note_; }; } // namespace search diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc deleted file mode 100644 index e3ae6ebf..00000000 --- a/klm/search/edge_queue.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "search/edge_queue.hh" - -#include "lm/left.hh" -#include "search/context.hh" - -#include - -namespace search { - -EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) { - take_ = static_cast(partial_edge_pool_.malloc()); -} - -/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) { - // Ignore empty edges. - for (unsigned char i = 0; i < edge.Arity(); ++i) { - PartialVertex root(edge.GetVertex(i).RootPartial()); - if (root.Empty()) return; - total_score += root.Bound(); - } - PartialEdge &allocated = *static_cast(partial_edge_pool_.malloc()); - allocated.score = total_score; -}*/ - -} // namespace search diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh deleted file mode 100644 index 187eaed7..00000000 --- a/klm/search/edge_queue.hh +++ /dev/null @@ -1,73 +0,0 @@ -#ifndef SEARCH_EDGE_QUEUE__ -#define SEARCH_EDGE_QUEUE__ - -#include "search/edge.hh" -#include "search/edge_generator.hh" -#include "search/note.hh" - -#include -#include - -#include - -namespace search { - -template class Context; - -class EdgeQueue { - public: - explicit EdgeQueue(unsigned int pop_limit_hint); - - PartialEdge &InitializeEdge() { - return *take_; - } - - void AddEdge(unsigned char arity, Note note) { - generate_.push(edge_pool_.construct(*take_, arity, note)); - take_ = static_cast(partial_edge_pool_.malloc()); - } - - bool Empty() const { return generate_.empty(); } - - /* Generate hypotheses and send them to output. Normally, output is a - * VertexGenerator, but the decoder may want to route edges to different - * vertices i.e. if they have different LHS non-terminal labels. - */ - template void Search(Context &context, Output &output) { - int to_pop = context.PopLimit(); - while (to_pop > 0 && !generate_.empty()) { - EdgeGenerator *top = generate_.top(); - generate_.pop(); - PartialEdge *ret = top->Pop(context, partial_edge_pool_); - if (ret) { - output.NewHypothesis(*ret, top->GetNote()); - --to_pop; - if (top->TopScore() != -kScoreInf) { - generate_.push(top); - } - } else { - generate_.push(top); - } - } - output.FinishedSearch(); - } - - private: - boost::object_pool edge_pool_; - - struct LessByTopScore : public std::binary_function { - bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { - return first->TopScore() < second->TopScore(); - } - }; - - typedef std::priority_queue, LessByTopScore> Generate; - Generate generate_; - - boost::pool<> partial_edge_pool_; - - PartialEdge *take_; -}; - -} // namespace search -#endif // SEARCH_EDGE_QUEUE__ diff --git a/klm/search/final.hh b/klm/search/final.hh index 1b3092ac..50e62cf2 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -1,37 +1,34 @@ #ifndef SEARCH_FINAL__ #define SEARCH_FINAL__ -#include "search/arity.hh" -#include "search/note.hh" -#include "search/types.hh" - -#include +#include "search/header.hh" +#include "util/pool.hh" namespace search { -class Final { +// A full hypothesis with pointers to children. +class Final : public Header { public: - typedef boost::array ChildArray; + Final() {} - void Reset(Score bound, Note note, const Final &left, const Final &right) { - bound_ = bound; - note_ = note; - children_[0] = &left; - children_[1] = &right; + Final(util::Pool &pool, Score score, Arity arity, Note note) + : Header(pool.Allocate(Size(arity)), arity) { + SetScore(score); + SetNote(note); } - const ChildArray &Children() const { return children_; } - - Note GetNote() const { return note_; } - - Score Bound() const { return bound_; } + // These are arrays of length GetArity(). + Final *Children() { + return reinterpret_cast(After()); + } + const Final *Children() const { + return reinterpret_cast(After()); + } private: - Score bound_; - - Note note_; - - ChildArray children_; + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Final); + } }; } // namespace search diff --git a/klm/search/header.hh b/klm/search/header.hh new file mode 100644 index 00000000..25550dbe --- /dev/null +++ b/klm/search/header.hh @@ -0,0 +1,57 @@ +#ifndef SEARCH_HEADER__ +#define SEARCH_HEADER__ + +// Header consisting of Score, Arity, and Note + +#include "search/note.hh" +#include "search/types.hh" + +#include + +namespace search { + +// Copying is shallow. +class Header { + public: + bool Valid() const { return base_; } + + Score GetScore() const { + return *reinterpret_cast(base_); + } + void SetScore(Score to) { + *reinterpret_cast(base_) = to; + } + bool operator<(const Header &other) const { + return GetScore() < other.GetScore(); + } + + Arity GetArity() const { + return *reinterpret_cast(base_ + sizeof(Score)); + } + + Note GetNote() const { + return *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)); + } + void SetNote(Note to) { + *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)) = to; + } + + protected: + Header() : base_(NULL) {} + + Header(void *base, Arity arity) : base_(static_cast(base)) { + *reinterpret_cast(base_ + sizeof(Score)) = arity; + } + + static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note); + + uint8_t *After() { return base_ + kHeaderSize; } + const uint8_t *After() const { return base_ + kHeaderSize; } + + private: + uint8_t *base_; +}; + +} // namespace search + +#endif // SEARCH_HEADER__ diff --git a/klm/search/source.hh b/klm/search/source.hh deleted file mode 100644 index 11839f7b..00000000 --- a/klm/search/source.hh +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef SEARCH_SOURCE__ -#define SEARCH_SOURCE__ - -#include "search/types.hh" - -#include -#include - -namespace search { - -template class Source { - public: - Source() : bound_(kScoreInf) {} - - Index Size() const { - return final_.size(); - } - - Score Bound() const { - return bound_; - } - - const Final &operator[](Index index) const { - return *final_[index]; - } - - Score ScoreOrBound(Index index) const { - return Size() > index ? final_[index]->Total() : Bound(); - } - - protected: - void AddFinal(const Final &store) { - final_.push_back(&store); - } - - void SetBound(Score to) { - assert(to <= bound_ + 0.001); - bound_ = to; - } - - private: - std::vector final_; - - Score bound_; -}; - -} // namespace search -#endif // SEARCH_SOURCE__ diff --git a/klm/search/types.hh b/klm/search/types.hh index 9726379f..06eb5bfa 100644 --- a/klm/search/types.hh +++ b/klm/search/types.hh @@ -1,17 +1,13 @@ #ifndef SEARCH_TYPES__ #define SEARCH_TYPES__ -#include +#include namespace search { typedef float Score; -const Score kScoreInf = INFINITY; -// This could have been an enum but gcc wants 4 bytes. -typedef bool ExtendDirection; -const ExtendDirection kExtendLeft = 0; -const ExtendDirection kExtendRight = 1; +typedef uint32_t Arity; } // namespace search diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index cc53c0dd..11f4631f 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -21,9 +21,9 @@ struct GreaterByBound : public std::binary_functionBound(); + bound_ = end_.GetScore(); return; } if (extend_.size() == 1 && parent_ptr) { @@ -39,10 +39,4 @@ void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { bound_ = extend_.front()->Bound(); } -namespace { -VertexNode kBlankVertexNode; -} // namespace - -PartialVertex kBlankPartialVertex(kBlankVertexNode); - } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index e1a9ad11..52bc1dfe 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -18,7 +18,7 @@ class ContextBase; class VertexNode { public: - VertexNode() : end_(NULL) {} + VertexNode() {} void InitRoot() { extend_.clear(); @@ -26,8 +26,7 @@ class VertexNode { state_.left.length = 0; state_.right.length = 0; right_full_ = false; - bound_ = -kScoreInf; - end_ = NULL; + end_ = Final(); } lm::ngram::ChartState &MutableState() { return state_; } @@ -37,19 +36,20 @@ class VertexNode { extend_.push_back(next); } - void SetEnd(Final *end) { end_ = end; } + void SetEnd(Final end) { + assert(!end_.Valid()); + end_ = end; + } - Final &MutableEnd() { return *end_; } - void SortAndSet(ContextBase &context, VertexNode **parent_pointer); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_ && extend_.empty(); + return !end_.Valid() && extend_.empty(); } bool Complete() const { - return end_; + return end_.Valid(); } const lm::ngram::ChartState &State() const { return state_; } @@ -63,8 +63,8 @@ class VertexNode { return state_.left.length + state_.right.length; } - // May be NULL. - const Final *End() const { return end_; } + // Will be invalid unless this is a leaf. + const Final End() const { return end_; } const VertexNode &operator[](size_t index) const { return *extend_[index]; @@ -81,7 +81,7 @@ class VertexNode { bool right_full_; Score bound_; - Final *end_; + Final end_; }; class PartialVertex { @@ -97,7 +97,7 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); } + Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } unsigned char Length() const { return back_->Length(); } @@ -105,20 +105,24 @@ class PartialVertex { return index_ + 1 < back_->Size(); } - // Split into continuation and alternative, rendering this the alternative. - bool Split(PartialVertex &continuation) { + // Split into continuation and alternative, rendering this the continuation. + bool Split(PartialVertex &alternative) { assert(!Complete()); - continuation.back_ = &((*back_)[index_]); - continuation.index_ = 0; + bool ret; if (index_ + 1 < back_->Size()) { - ++index_; - return true; + alternative.index_ = index_ + 1; + alternative.back_ = back_; + ret = true; + } else { + ret = false; } - return false; + back_ = &((*back_)[index_]); + index_ = 0; + return ret; } - const Final &End() const { - return *back_->End(); + const Final End() const { + return back_->End(); } private: @@ -126,25 +130,22 @@ class PartialVertex { unsigned int index_; }; -extern PartialVertex kBlankPartialVertex; - class Vertex { public: Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } - const Final *BestChild() const { + const Final BestChild() const { PartialVertex top(RootPartial()); if (top.Empty()) { - return NULL; + return Final(); } else { PartialVertex continuation; while (!top.Complete()) { top.Split(continuation); - top = continuation; } - return &top.End(); + return top.End(); } } diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index d94e6e06..0945fe55 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -10,74 +10,85 @@ namespace search { VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { gen.root_.InitRoot(); - root_.under = &gen.root_; } namespace { + const uint64_t kCompleteAdd = static_cast(-1); -} // namespace -void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) { - const lm::ngram::ChartState &state = partial.CompletedState(); - std::pair got(existing_.insert(std::pair(hash_value(state), NULL))); - if (!got.second) { - // Found it already. - Final &exists = *got.first->second; - if (exists.Bound() < partial.score) { - exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); - } - return; +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map extend; +}; + +Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { + Trie &next = node.extend[added]; + if (!next.under) { + next.under = context.NewVertexNode(); + lm::ngram::ChartState &writing = next.under->MutableState(); + writing = state; + writing.left.full &= left_full && state.left.full; + next.under->MutableRightFull() = right_full && state.left.full; + writing.left.length = left; + writing.right.length = right; + node.under->AddExtend(next.under); } + return next; +} + +void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { + Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); + Final *child_out = final.Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = part->End(); + + starter.under->SetEnd(final); +} + +void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + unsigned char left = 0, right = 0; - Trie *node = &root_; + Trie *node = &root; while (true) { if (left == state.left.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false); for (; right < state.right.length; ++right) { - node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false); } break; } - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false); left++; if (right == state.right.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true); for (; left < state.left.length; ++left) { - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true); } break; } - node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false); right++; } - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - got.first->second = CompleteTransition(*node, state, note, partial); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); + CompleteTransition(context, *node, partial); } -VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { - VertexGenerator::Trie &next = node.extend[added]; - if (!next.under) { - next.under = context_.NewVertexNode(); - lm::ngram::ChartState &writing = next.under->MutableState(); - writing = state; - writing.left.full &= left_full && state.left.full; - next.under->MutableRightFull() = right_full && state.left.full; - writing.left.length = left; - writing.right.length = right; - node.under->AddExtend(next.under); - } - return next; -} +} // namespace -Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) { - VertexNode &node = *starter.under; - assert(node.State().left.full == state.left.full); - assert(!node.End()); - Final *final = context_.NewFinal(); - final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); - node.SetEnd(final); - return final; +void VertexGenerator::FinishedSearch() { + Trie root; + root.under = &gen_.root_; + for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, i->second); + } + root.under->SortAndSet(context_, NULL); } } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 6b98da3e..60e86112 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -1,13 +1,11 @@ #ifndef SEARCH_VERTEX_GENERATOR__ #define SEARCH_VERTEX_GENERATOR__ -#include "search/note.hh" +#include "search/edge.hh" #include "search/vertex.hh" #include -#include - namespace lm { namespace ngram { class ChartState; @@ -18,40 +16,29 @@ namespace search { class ContextBase; class Final; -struct PartialEdge; class VertexGenerator { public: VertexGenerator(ContextBase &context, Vertex &gen); - void NewHypothesis(const PartialEdge &partial, Note note); - - void FinishedSearch() { - root_.under->SortAndSet(context_, NULL); + void NewHypothesis(PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + std::pair ret(existing_.insert(std::make_pair(hash_value(state), partial))); + if (!ret.second && ret.first->second < partial) { + ret.first->second = partial; + } } + void FinishedSearch(); + const Vertex &Generating() const { return gen_; } private: - // Parallel structure to VertexNode. - struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map extend; - }; - - Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); - - Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial); - ContextBase &context_; Vertex &gen_; - Trie root_; - - typedef boost::unordered_map Existing; + typedef boost::unordered_map Existing; Existing existing_; }; diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 5ceccf2c..5306850f 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -26,6 +26,8 @@ libklm_util_a_SOURCES = \ file_piece.cc \ mmap.cc \ murmur_hash.cc \ + pool.cc \ + string_piece.cc \ usage.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index ff4d590f..9909736d 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -4,7 +4,7 @@ #include #include -#include +#include // Ersatz version of boost::progress so core language model doesn't depend on // boost. Also adds option to print nothing. diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 83f99cd6..053a850b 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -6,7 +6,7 @@ #include #include -#include +#include namespace util { diff --git a/klm/util/pool.cc b/klm/util/pool.cc new file mode 100644 index 00000000..2dffd06f --- /dev/null +++ b/klm/util/pool.cc @@ -0,0 +1,35 @@ +#include "util/pool.hh" + +#include + +namespace util { + +Pool::Pool() { + current_ = NULL; + current_end_ = NULL; +} + +Pool::~Pool() { + FreeAll(); +} + +void Pool::FreeAll() { + for (std::vector::const_iterator i(free_list_.begin()); i != free_list_.end(); ++i) { + free(*i); + } + free_list_.clear(); + current_ = NULL; + current_end_ = NULL; +} + +void *Pool::More(std::size_t size) { + std::size_t amount = std::max(static_cast(32) << free_list_.size(), size); + uint8_t *ret = static_cast(malloc(amount)); + if (!ret) throw std::bad_alloc(); + free_list_.push_back(ret); + current_ = ret + size; + current_end_ = ret + amount; + return ret; +} + +} // namespace util diff --git a/klm/util/pool.hh b/klm/util/pool.hh new file mode 100644 index 00000000..72f8a0c8 --- /dev/null +++ b/klm/util/pool.hh @@ -0,0 +1,45 @@ +// Very simple pool. It can only allocate memory. And all of the memory it +// allocates must be freed at the same time. + +#ifndef UTIL_POOL__ +#define UTIL_POOL__ + +#include + +#include + +namespace util { + +class Pool { + public: + Pool(); + + ~Pool(); + + void *Allocate(std::size_t size) { + void *ret = current_; + current_ += size; + if (current_ < current_end_) { + return ret; + } else { + return More(size); + } + } + + void FreeAll(); + + private: + void *More(std::size_t size); + + std::vector free_list_; + + uint8_t *current_, *current_end_; + + // no copying + Pool(const Pool &); + Pool &operator=(const Pool &); +}; + +} // namespace util + +#endif // UTIL_POOL__ diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 770faa7e..4a8aff35 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -8,7 +8,7 @@ #include #include -#include +#include namespace util { diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc new file mode 100644 index 00000000..b422cefc --- /dev/null +++ b/klm/util/string_piece.cc @@ -0,0 +1,192 @@ +// Copyright 2004 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in string_piece.hh. + +#include "util/string_piece.hh" + +#include + +#include + +#ifndef HAVE_ICU + +typedef StringPiece::size_type size_type; + +void StringPiece::CopyToString(std::string* target) const { + target->assign(ptr_, length_); +} + +size_type StringPiece::find(const StringPiece& s, size_type pos) const { + if (length_ < 0 || pos > static_cast(length_)) + return npos; + + const char* result = std::search(ptr_ + pos, ptr_ + length_, + s.ptr_, s.ptr_ + s.length_); + const size_type xpos = result - ptr_; + return xpos + s.length_ <= length_ ? xpos : npos; +} + +size_type StringPiece::find(char c, size_type pos) const { + if (length_ <= 0 || pos >= static_cast(length_)) { + return npos; + } + const char* result = std::find(ptr_ + pos, ptr_ + length_, c); + return result != ptr_ + length_ ? result - ptr_ : npos; +} + +size_type StringPiece::rfind(const StringPiece& s, size_type pos) const { + if (length_ < s.length_) return npos; + const size_t ulen = length_; + if (s.length_ == 0) return std::min(ulen, pos); + + const char* last = ptr_ + std::min(ulen - s.length_, pos) + s.length_; + const char* result = std::find_end(ptr_, last, s.ptr_, s.ptr_ + s.length_); + return result != last ? result - ptr_ : npos; +} + +size_type StringPiece::rfind(char c, size_type pos) const { + if (length_ <= 0) return npos; + for (int i = std::min(pos, static_cast(length_ - 1)); + i >= 0; --i) { + if (ptr_[i] == c) { + return i; + } + } + return npos; +} + +// For each character in characters_wanted, sets the index corresponding +// to the ASCII code of that character to 1 in table. This is used by +// the find_.*_of methods below to tell whether or not a character is in +// the lookup table in constant time. +// The argument `table' must be an array that is large enough to hold all +// the possible values of an unsigned char. Thus it should be be declared +// as follows: +// bool table[UCHAR_MAX + 1] +static inline void BuildLookupTable(const StringPiece& characters_wanted, + bool* table) { + const size_type length = characters_wanted.length(); + const char* const data = characters_wanted.data(); + for (size_type i = 0; i < length; ++i) { + table[static_cast(data[i])] = true; + } +} + +size_type StringPiece::find_first_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0 || s.length_ == 0) + return npos; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_first_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = pos; i < length_; ++i) { + if (lookup[static_cast(ptr_[i])]) { + return i; + } + } + return npos; +} + +size_type StringPiece::find_first_not_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0) + return npos; + + if (s.length_ == 0) + return 0; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_first_not_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = pos; i < length_; ++i) { + if (!lookup[static_cast(ptr_[i])]) { + return i; + } + } + return npos; +} + +size_type StringPiece::find_first_not_of(char c, size_type pos) const { + if (length_ == 0) + return npos; + + for (; pos < length_; ++pos) { + if (ptr_[pos] != c) { + return pos; + } + } + return npos; +} + +size_type StringPiece::find_last_of(const StringPiece& s, size_type pos) const { + if (length_ == 0 || s.length_ == 0) + return npos; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_last_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = std::min(pos, length_ - 1); ; --i) { + if (lookup[static_cast(ptr_[i])]) + return i; + if (i == 0) + break; + } + return npos; +} + +size_type StringPiece::find_last_not_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0) + return npos; + + size_type i = std::min(pos, length_ - 1); + if (s.length_ == 0) + return i; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_last_not_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (; ; --i) { + if (!lookup[static_cast(ptr_[i])]) + return i; + if (i == 0) + break; + } + return npos; +} + +size_type StringPiece::find_last_not_of(char c, size_type pos) const { + if (length_ == 0) + return npos; + + for (size_type i = std::min(pos, length_ - 1); ; --i) { + if (ptr_[i] != c) + return i; + if (i == 0) + break; + } + return npos; +} + +StringPiece StringPiece::substr(size_type pos, size_type n) const { + if (pos > length_) pos = length_; + if (n > length_ - pos) n = length_ - pos; + return StringPiece(ptr_ + pos, n); +} + +const size_type StringPiece::npos = size_type(-1); + +#endif // !HAVE_ICU diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index c7e1c863..4a7f5460 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -54,6 +54,18 @@ class AnyCharacter { StringPiece chars_; }; +class AnyCharacterLast { + public: + explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {} + + StringPiece Find(const StringPiece &in) const { + return StringPiece(std::find_end(in.data(), in.data() + in.size(), chars_.data(), chars_.data() + chars_.size()), 1); + } + + private: + StringPiece chars_; +}; + template class TokenIter : public boost::iterator_facade, const StringPiece, boost::forward_traversal_tag> { public: TokenIter() {} diff --git a/mira/Makefile.am b/mira/Makefile.am index 7b4a4e12..3f8f17cd 100644 --- a/mira/Makefile.am +++ b/mira/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = kbest_mira kbest_mira_SOURCES = kbest_mira.cc -kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/Makefile.am b/training/Makefile.am index 5254333a..f9c25391 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -32,60 +32,60 @@ libtraining_a_SOURCES = \ risk.cc mpi_online_optimize_SOURCES = mpi_online_optimize.cc -mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_flex_optimize_SOURCES = mpi_flex_optimize.cc -mpi_flex_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_flex_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc -mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_extract_features_SOURCES = mpi_extract_features.cc -mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc cllh_observer.cc -mpi_batch_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_batch_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc cllh_observer.cc -mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz augment_grammar_SOURCES = augment_grammar.cc -augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz test_ngram_SOURCES = test_ngram.cc -test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz fast_align_SOURCES = fast_align.cc ttables.cc -fast_align_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +fast_align_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz lbl_model_SOURCES = lbl_model.cc -lbl_model_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +lbl_model_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz grammar_convert_SOURCES = grammar_convert.cc -grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz optimize_test_SOURCES = optimize_test.cc -optimize_test_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +optimize_test_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz collapse_weights_SOURCES = collapse_weights.cc -collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz lbfgs_test_SOURCES = lbfgs_test.cc -lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_optimize_reduce_SOURCES = mr_optimize_reduce.cc -mr_optimize_reduce_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_optimize_reduce_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_em_map_adapter_SOURCES = mr_em_map_adapter.cc -mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_reduce_to_weights_SOURCES = mr_reduce_to_weights.cc -mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_em_adapted_reduce_SOURCES = mr_em_adapted_reduce.cc -mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz plftools_SOURCES = plftools.cc -plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm -- cgit v1.2.3