From b9d7da0413403805f035479a0a426c27102032f6 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 3 Jun 2011 20:49:52 -0400 Subject: Add exception catcher around constructor --- decoder/ff_klm.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'decoder') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 35b35d36..71ba9f30 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -385,7 +385,12 @@ KLanguageModel::KLanguageModel(const string& param) { if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { abort(); } - pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); + try { + pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); + } fid_ = FD::Convert(featname); oov_fid_ = FD::Convert(featname+"_OOV"); cerr << "FID: " << oov_fid_ << endl; -- cgit v1.2.3 From dcf4447590277887d65b0bdec7e6818081869a9a Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 3 Jun 2011 20:56:37 -0400 Subject: Code cleanup for vocabulary mapping --- decoder/ff_klm.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 71ba9f30..a3bd0c5f 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -282,11 +282,10 @@ class KLanguageModelImpl { KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : kCDEC_UNK(TD::Convert("")) , add_sos_eos_(!explicit_markers) { - if (true) { - boost::scoped_ptr vm; - vm.reset(new VMapper(&cdec2klm_map_)); + { + VMapper vm(&cdec2klm_map_); lm::ngram::Config conf; - conf.enumerate_vocab = vm.get(); + conf.enumerate_vocab = &vm; ngram_ = new Model(filename.c_str(), conf); } order_ = ngram_->Order(); -- cgit v1.2.3 From 205893513c8343fdc55789e427fab4c8b536dc12 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sun, 26 Jun 2011 18:40:15 -0400 Subject: Quantization --- BUILDING | 5 +- decoder/cdec_ff.cc | 1 + decoder/ff_klm.cc | 1 + klm/compile.sh | 4 +- klm/lm/Makefile.am | 1 + klm/lm/binary_format.cc | 17 +++- klm/lm/binary_format.hh | 10 ++- klm/lm/blank.hh | 4 + klm/lm/build_binary.cc | 83 +++++++++++++---- klm/lm/config.cc | 2 + klm/lm/config.hh | 5 ++ klm/lm/model.cc | 25 +++--- klm/lm/model.hh | 21 +++-- klm/lm/model_test.cc | 13 +-- klm/lm/quantize.cc | 84 ++++++++++++++++++ klm/lm/quantize.hh | 207 +++++++++++++++++++++++++++++++++++++++++++ klm/lm/search_hashed.cc | 38 ++++++-- klm/lm/search_hashed.hh | 122 ++++++++++++------------- klm/lm/search_trie.cc | 121 ++++++++++++++++++++----- klm/lm/search_trie.hh | 132 +++++++++++++++------------ klm/lm/trie.cc | 123 ++++++++++--------------- klm/lm/trie.hh | 33 +++---- klm/lm/vocab.cc | 18 ++-- klm/lm/vocab.hh | 35 +++++--- klm/util/bit_packing.cc | 4 +- klm/util/bit_packing.hh | 39 +++++--- klm/util/bit_packing_test.cc | 25 ++++-- klm/util/sorted_uniform.hh | 120 +++++++++++++++++-------- 28 files changed, 921 insertions(+), 372 deletions(-) create mode 100644 klm/lm/quantize.cc create mode 100644 klm/lm/quantize.hh (limited to 'decoder') diff --git a/BUILDING b/BUILDING index dcb3d45b..b7535d70 100644 --- a/BUILDING +++ b/BUILDING @@ -1,6 +1,5 @@ To build cdec, you'll need: - * SRILM (register and download from http://www.speech.sri.com/projects/srilm/) * Google c++ testing framework (http://code.google.com/p/googletest/) * boost headers & boost program_options (you may need to install a package like boost-devel) @@ -9,7 +8,7 @@ To build cdec, you'll need: Instructions for building ----------------------------------- - 1) Download and build SRILM + 1) Optional: Download and build SRILM 2) Download, build, and install Google Test (optional, this is necessary to build unit tests that may be useful in development; system tests @@ -22,7 +21,7 @@ Instructions for building 4) Configure and build. Your command will look something like this. - ./configure --with-srilm=/home/me/software/srilm-1.5.9 --disable-gtest + ./configure --disable-gtest make If you get errors during configure about missing BOOST macros, then step 3 diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 37aa655b..31f88a4f 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -55,6 +55,7 @@ void register_feature_functions() { ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); ff_registry.Register("KLanguageModel", new FFFactory >()); ff_registry.Register("KLanguageModel_Trie", new FFFactory >()); + ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory >()); ff_registry.Register("KLanguageModel_Probing", new FFFactory >()); ff_registry.Register("NonLatinCount", new FFFactory); ff_registry.Register("RuleShape", new FFFactory); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index a3bd0c5f..9b7fe2d3 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -437,4 +437,5 @@ void KLanguageModel::FinalTraversalFeatures(const void* ant_state, // instantiate templates template class KLanguageModel; template class KLanguageModel; +template class KLanguageModel; diff --git a/klm/compile.sh b/klm/compile.sh index 49e04db8..6ca85e1f 100755 --- a/klm/compile.sh +++ b/klm/compile.sh @@ -3,11 +3,9 @@ #If your code uses ICU, edit util/string_piece.hh and uncomment #define USE_ICU #I use zlib by default. If you don't want to depend on zlib, remove #define USE_ZLIB from util/file_piece.hh -#don't need to use if compiling with moses Makefiles already - set -e -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do +for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o done g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 61d98d97..395494bc 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -15,6 +15,7 @@ libklm_a_SOURCES = \ binary_format.cc \ config.cc \ lm_exception.cc \ + quantize.cc \ model.cc \ ngram_query.cc \ read_arpa.cc \ diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 34d9ffca..92b1008b 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -80,6 +80,14 @@ void WriteHeader(void *to, const Parameters ¶ms) { } // namespace +void SeekOrThrow(int fd, off_t off) { + if ((off_t)-1 == lseek(fd, off, SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed"); +} + +void AdvanceOrThrow(int fd, off_t off) { + if ((off_t)-1 == lseek(fd, off, SEEK_CUR)) UTIL_THROW(util::ErrnoException, "Seek failed"); +} + uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { if (config.write_mmap) { std::size_t total = TotalHeaderSize(order) + memory_size; @@ -156,7 +164,7 @@ bool IsBinaryFormat(int fd) { } void ReadHeader(int fd, Parameters &out) { - if ((off_t)-1 == lseek(fd, sizeof(Sanity), SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed in binary file"); + SeekOrThrow(fd, sizeof(Sanity)); ReadLoop(fd, &out.fixed, sizeof(out.fixed)); if (out.fixed.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0."); @@ -173,6 +181,10 @@ void MatchCheck(ModelType model_type, const Parameters ¶ms) { } } +void SeekPastHeader(int fd, const Parameters ¶ms) { + SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); +} + uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { const off_t file_size = util::SizeFile(backing.file.get()); // The header is smaller than a page, so we have to map the whole header as well. @@ -186,8 +198,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); if (config.enumerate_vocab) { - if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) - UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words"); + SeekOrThrow(backing.file.get(), total_map); } return reinterpret_cast(backing.search.get()) + TotalHeaderSize(params.counts.size()); } diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 1fc71be4..2b32b450 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -16,7 +16,7 @@ namespace lm { namespace ngram { -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2} ModelType; +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType; /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in @@ -48,6 +48,10 @@ struct Backing { util::scoped_memory search; }; +void SeekOrThrow(int fd, off_t off); +// Seek forward +void AdvanceOrThrow(int fd, off_t off); + // Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. @@ -65,6 +69,8 @@ void ReadHeader(int fd, Parameters ¶ms); void MatchCheck(ModelType model_type, const Parameters ¶ms); +void SeekPastHeader(int fd, const Parameters ¶ms); + uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing); void ComplainAboutARPA(const Config &config, ModelType model_type); @@ -83,6 +89,8 @@ template void LoadLM(const char *file, const Config &config, To &to) // Replace the run-time configured probing_multiplier with the one in the file. Config new_config(config); new_config.probing_multiplier = params.fixed.probing_multiplier; + detail::SeekPastHeader(backing.file.get(), params); + To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config); std::size_t memory_size = To::Size(params.counts, new_config); uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); to.InitializeFromBinary(start, params, new_config, backing.file.get()); diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 4615a09e..162411a9 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -22,6 +22,8 @@ namespace ngram { */ const float kNoExtensionBackoff = -0.0; const float kExtensionBackoff = 0.0; +const uint64_t kNoExtensionQuant = 0; +const uint64_t kExtensionQuant = 1; inline void SetExtension(float &backoff) { if (backoff == kNoExtensionBackoff) backoff = kExtensionBackoff; @@ -47,6 +49,8 @@ inline bool HasExtension(const float &backoff) { */ const float kBlankProb = -std::numeric_limits::infinity(); const float kBlankBackoff = kNoExtensionBackoff; +const uint32_t kBlankProbQuant = 0; +const uint32_t kBlankBackoffQuant = 0; } // namespace ngram } // namespace lm diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 91ad2fb9..4552c419 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,22 +15,21 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n" "-u sets the default log10 probability for if the ARPA file does not have\n" "one.\n" "-s allows models to be built even if they do not have and .\n" "-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" -"type is one of probing, trie, or sorted:\n\n" +"type is either probing or trie:\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" "trie is a straightforward trie with bit-level packing. It uses the least\n" "memory and is still faster than SRI or IRST. Building the trie format uses an\n" "on-disk sort to save memory.\n" "-t is the temporary directory prefix. Default is the output file name.\n" -"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n\n" -/*"sorted is like probing but uses a sorted uniform map instead of a hash table.\n" -"It uses more memory than trie and is also slower, so there's no real reason to\n" -"use it.\n\n"*/ +"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n" +"-q turns quantization on and sets the number of bits (e.g. -q 8).\n" +"-b sets backoff quantization bits. Requires -q and defaults to that value.\n\n" "See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" "Passing only an input file will print memory usage of each data structure.\n" "If the ARPA file does not have , -u sets 's probability; default 0.0.\n"; @@ -51,19 +50,53 @@ unsigned long int ParseUInt(const char *from) { return ret; } +uint8_t ParseBitCount(const char *from) { + unsigned long val = ParseUInt(from); + if (val > 25) { + util::ParseNumberException e(from); + e << " bit counts are limited to 256."; + } + return val; +} + void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t probing_size = ProbingModel::Size(counts, config); - // probing is always largest so use it to determine number of columns. - long int length = std::max(5, lrint(ceil(log10(probing_size)))); + std::size_t sizes[3]; + sizes[0] = ProbingModel::Size(counts, config); + sizes[1] = TrieModel::Size(counts, config); + sizes[2] = QuantTrieModel::Size(counts, config); + std::size_t max_length = *std::max_element(sizes, sizes + 3); + std::size_t min_length = *std::max_element(sizes, sizes + 3); + std::size_t divide; + char prefix; + if (min_length < (1 << 10) * 10) { + prefix = ' '; + divide = 1; + } else if (min_length < (1 << 20) * 10) { + prefix = 'k'; + divide = 1 << 10; + } else if (min_length < (1ULL << 30) * 10) { + prefix = 'M'; + divide = 1 << 20; + } else { + prefix = 'G'; + divide = 1 << 30; + } + long int length = std::max(2, lrint(ceil(log10(max_length / divide)))); std::cout << "Memory estimate:\ntype "; // right align bytes. - for (long int i = 0; i < length - 5; ++i) std::cout << ' '; - std::cout << "bytes\n" - "probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n" - "trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n"; + for (long int i = 0; i < length - 2; ++i) std::cout << ' '; + std::cout << prefix << "B\n" + "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" + "trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n" + "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"; +} + +void ProbingQuantizationUnsupported() { + std::cerr << "Quantization is only implemented in the trie data structure." << std::endl; + exit(1); } } // namespace ngram @@ -73,11 +106,21 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { int main(int argc, char *argv[]) { using namespace lm::ngram; + bool quantize = false, set_backoff_bits = false; try { lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "siu:p:t:m:")) != -1) { + while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) { switch(opt) { + case 'q': + config.prob_bits = ParseBitCount(optarg); + if (!set_backoff_bits) config.backoff_bits = config.prob_bits; + quantize = true; + break; + case 'b': + config.backoff_bits = ParseBitCount(optarg); + set_backoff_bits = true; + break; case 'u': config.unknown_missing_logprob = ParseFloat(optarg); break; @@ -100,19 +143,29 @@ int main(int argc, char *argv[]) { Usage(argv[0]); } } + if (!quantize && set_backoff_bits) { + std::cerr << "You specified backoff quantization (-b) but not probability quantization (-q)" << std::endl; + abort(); + } if (optind + 1 == argc) { ShowSizes(argv[optind], config); } else if (optind + 2 == argc) { config.write_mmap = argv[optind + 1]; + if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); ProbingModel(argv[optind], config); } else if (optind + 3 == argc) { const char *model_type = argv[optind]; const char *from_file = argv[optind + 1]; config.write_mmap = argv[optind + 2]; if (!strcmp(model_type, "probing")) { + if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); ProbingModel(from_file, config); } else if (!strcmp(model_type, "trie")) { - TrieModel(from_file, config); + if (quantize) { + QuantTrieModel(from_file, config); + } else { + TrieModel(from_file, config); + } } else { Usage(argv[0]); } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index cee8fce2..08e1af5c 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -18,6 +18,8 @@ Config::Config() : arpa_complain(ALL), write_mmap(NULL), include_vocab(true), + prob_bits(8), + backoff_bits(8), load_method(util::POPULATE_OR_READ) {} } // namespace ngram diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 6c7fe39b..dcc7cf35 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -71,6 +71,11 @@ struct Config { // Include the vocab in the binary file? Only effective if write_mmap != NULL. bool include_vocab; + // Quantization options. Only effective for QuantTrieModel. One value is + // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used + // to quantize. + uint8_t prob_bits, backoff_bits; + // ONLY EFFECTIVE WHEN READING BINARY diff --git a/klm/lm/model.cc b/klm/lm/model.cc index f0579c0c..a1d10b3d 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -44,17 +44,13 @@ template GenericModel::Ge begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; State null_context = State(); null_context.valid_length_ = 0; - P::Init(begin_sentence, null_context, vocab_, search_.middle.size() + 2); + P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); } template void GenericModel::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { SetupMemory(start, params.counts, config); vocab_.LoadedBinary(fd, config.enumerate_vocab); - search_.unigram.LoadedBinary(); - for (typename std::vector::iterator i = search_.middle.begin(); i != search_.middle.end(); ++i) { - i->LoadedBinary(); - } - search_.longest.LoadedBinary(); + search_.LoadedBinary(); } template void GenericModel::InitializeFromARPA(const char *file, const Config &config) { @@ -116,8 +112,9 @@ template FullScoreReturn GenericModel void GenericModel FullScoreReturn GenericModel::const_iterator mid_iter = search_.middle.begin(); + const typename Search::Middle *mid_iter = search_.MiddleBegin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { // Ran out of history. Typically no backoff, but this could be a blank. @@ -192,7 +189,7 @@ template FullScoreReturn GenericModel FullScoreReturn GenericModel; -template class GenericModel; -template class GenericModel; +template class GenericModel; // HASH_PROBING +template class GenericModel, SortedVocabulary>; // TRIE_SORTED +template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT } // namespace detail } // namespace ngram diff --git a/klm/lm/model.hh b/klm/lm/model.hh index b85ccdcc..1f49a382 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -5,6 +5,7 @@ #include "lm/config.hh" #include "lm/facade.hh" #include "lm/max_order.hh" +#include "lm/quantize.hh" #include "lm/search_hashed.hh" #include "lm/search_trie.hh" #include "lm/vocab.hh" @@ -70,9 +71,10 @@ template class GenericModel : public base::Mod private: typedef base::ModelFacade, State, VocabularyT> P; public: - // Get the size of memory that will be mapped given ngram counts. This - // does not include small non-mapped control structures, such as this class - // itself. + /* Get the size of memory that will be mapped given ngram counts. This + * does not include small non-mapped control structures, such as this class + * itself. + */ static size_t Size(const std::vector &counts, const Config &config = Config()); /* Load the model from a file. It may be an ARPA or binary file. Binary @@ -111,6 +113,11 @@ template class GenericModel : public base::Mod private: friend void LoadLM<>(const char *file, const Config &config, GenericModel &to); + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { + AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); + Search::UpdateConfigFromBinary(fd, counts, config); + } + float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const; FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; @@ -130,9 +137,7 @@ template class GenericModel : public base::Mod VocabularyT vocab_; - typedef typename Search::Unigram Unigram; typedef typename Search::Middle Middle; - typedef typename Search::Longest Longest; Search search_; }; @@ -141,13 +146,15 @@ template class GenericModel : public base::Mod // These must also be instantiated in the cc file. typedef ::lm::ngram::ProbingVocabulary Vocabulary; -typedef detail::GenericModel ProbingModel; +typedef detail::GenericModel ProbingModel; // HASH_PROBING // Default implementation. No real reason for it to be the default. typedef ProbingModel Model; // Smaller implementation. typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel TrieModel; +typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED + +typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED } // namespace ngram } // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 548c098d..8bf040ff 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -243,13 +243,14 @@ BOOST_AUTO_TEST_CASE(probing) { LoadingTest(); } -/*BOOST_AUTO_TEST_CASE(sorted) { - LoadingTest(); -}*/ BOOST_AUTO_TEST_CASE(trie) { LoadingTest(); } +BOOST_AUTO_TEST_CASE(quant) { + LoadingTest(); +} + template void BinaryTest() { Config config; config.write_mmap = "test.binary"; @@ -275,12 +276,12 @@ template void BinaryTest() { BOOST_AUTO_TEST_CASE(write_and_read_probing) { BinaryTest(); } -/*BOOST_AUTO_TEST_CASE(write_and_read_sorted) { - BinaryTest(); -}*/ BOOST_AUTO_TEST_CASE(write_and_read_trie) { BinaryTest(); } +BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) { + BinaryTest(); +} } // namespace } // namespace ngram diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc new file mode 100644 index 00000000..b4d76893 --- /dev/null +++ b/klm/lm/quantize.cc @@ -0,0 +1,84 @@ +#include "lm/quantize.hh" + +#include "lm/lm_exception.hh" + +#include +#include + +#include + +namespace lm { +namespace ngram { + +/* Quantize into bins of equal size as described in + * M. Federico and N. Bertoldi. 2006. How many bits are needed + * to store probabilities for phrase-based translation? In Proc. + * of the Workshop on Statistical Machine Translation, pages + * 94–101, New York City, June. Association for Computa- + * tional Linguistics. + */ + +namespace { + +void MakeBins(float *values, float *values_end, float *centers, uint32_t bins) { + std::sort(values, values_end); + const float *start = values, *finish; + for (uint32_t i = 0; i < bins; ++i, ++centers, start = finish) { + finish = values + (((values_end - values) * static_cast(i + 1)) / bins); + if (finish == start) { + // zero length bucket. + *centers = i ? *(centers - 1) : -std::numeric_limits::infinity(); + } else { + *centers = std::accumulate(start, finish, 0.0) / static_cast(finish - start); + } + } +} + +const char kSeparatelyQuantizeVersion = 1; + +} // namespace + +void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector &/*counts*/, Config &config) { + char version; + if (read(fd, &version, 1) != 1 || read(fd, &config.prob_bits, 1) != 1 || read(fd, &config.backoff_bits, 1) != 1) + UTIL_THROW(util::ErrnoException, "Failed to read header for quantization."); + if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion); +} + +void SeparatelyQuantize::SetupMemory(void *start, const Config &config) { + // Reserve 8 byte header for bit counts. + start_ = reinterpret_cast(static_cast(start) + 8); + prob_bits_ = config.prob_bits; + backoff_bits_ = config.backoff_bits; + // We need the reserved values. + if (config.prob_bits == 0) UTIL_THROW(ConfigException, "You can't quantize probability to zero"); + if (config.backoff_bits == 0) UTIL_THROW(ConfigException, "You can't quantize backoff to zero"); + if (config.prob_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing probability supports at most 25 bits. Currently you have requested " << static_cast(config.prob_bits) << " bits."); + if (config.backoff_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing backoff supports at most 25 bits. Currently you have requested " << static_cast(config.backoff_bits) << " bits."); +} + +void SeparatelyQuantize::Train(uint8_t order, std::vector &prob, std::vector &backoff) { + TrainProb(order, prob); + + // Backoff + float *centers = start_ + TableStart(order) + ProbTableLength(); + *(centers++) = kNoExtensionBackoff; + *(centers++) = kExtensionBackoff; + MakeBins(&*backoff.begin(), &*backoff.end(), centers, (1ULL << backoff_bits_) - 2); +} + +void SeparatelyQuantize::TrainProb(uint8_t order, std::vector &prob) { + float *centers = start_ + TableStart(order); + *(centers++) = kBlankProb; + MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_) - 1); +} + +void SeparatelyQuantize::FinishedLoading(const Config &config) { + uint8_t *actual_base = reinterpret_cast(start_) - 8; + *(actual_base++) = kSeparatelyQuantizeVersion; // version + *(actual_base++) = config.prob_bits; + *(actual_base++) = config.backoff_bits; +} + +} // namespace ngram +} // namespace lm diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh new file mode 100644 index 00000000..aae72b34 --- /dev/null +++ b/klm/lm/quantize.hh @@ -0,0 +1,207 @@ +#ifndef LM_QUANTIZE_H__ +#define LM_QUANTIZE_H__ + +#include "lm/binary_format.hh" // for ModelType +#include "lm/blank.hh" +#include "lm/config.hh" +#include "util/bit_packing.hh" + +#include +#include + +#include + +#include + +namespace lm { +namespace ngram { + +class Config; + +/* Store values directly and don't quantize. */ +class DontQuantize { + public: + static const ModelType kModelType = TRIE_SORTED; + static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} + static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } + static uint8_t MiddleBits(const Config &/*config*/) { return 63; } + static uint8_t LongestBits(const Config &/*config*/) { return 31; } + + struct Middle { + void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { + util::WriteNonPositiveFloat31(base, bit_offset, prob); + util::WriteFloat32(base, bit_offset + 31, backoff); + } + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + backoff = util::ReadFloat32(base, bit_offset + 31); + } + void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { + backoff = util::ReadFloat32(base, bit_offset + 31); + } + uint8_t TotalBits() const { return 63; } + }; + + struct Longest { + void Write(void *base, uint64_t bit_offset, float prob) const { + util::WriteNonPositiveFloat31(base, bit_offset, prob); + } + void Read(const void *base, uint64_t bit_offset, float &prob) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + } + uint8_t TotalBits() const { return 31; } + }; + + DontQuantize() {} + + void SetupMemory(void * /*start*/, const Config & /*config*/) {} + + static const bool kTrain = false; + // These should never be called because kTrain is false. + void Train(uint8_t /*order*/, std::vector &/*prob*/, std::vector &/*backoff*/) {} + void TrainProb(uint8_t, std::vector &/*prob*/) {} + + void FinishedLoading(const Config &) {} + + Middle Mid(uint8_t /*order*/) const { return Middle(); } + Longest Long(uint8_t /*order*/) const { return Longest(); } +}; + +class SeparatelyQuantize { + private: + class Bins { + public: + // Sigh C++ default constructor + Bins() {} + + Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} + + uint64_t EncodeProb(float value) const { + return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1)); + } + + uint64_t EncodeBackoff(float value) const { + if (value == 0.0) { + return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant; + } + return Encode(value, 2); + } + + float Decode(std::size_t off) const { return begin_[off]; } + + uint8_t Bits() const { return bits_; } + + uint64_t Mask() const { return mask_; } + + private: + uint64_t Encode(float value, size_t reserved) const { + const float *above = std::lower_bound(begin_ + reserved, end_, value); + if (above == begin_ + reserved) return reserved; + if (above == end_) return end_ - begin_ - 1; + return above - begin_ - (value - *(above - 1) < *above - value); + } + + const float *begin_; + const float *end_; + uint8_t bits_; + uint64_t mask_; + }; + + public: + static const ModelType kModelType = QUANT_TRIE_SORTED; + + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config); + + static std::size_t Size(uint8_t order, const Config &config) { + size_t longest_table = (static_cast(1) << static_cast(config.prob_bits)) * sizeof(float); + size_t middle_table = (static_cast(1) << static_cast(config.backoff_bits)) * sizeof(float) + longest_table; + // unigrams are currently not quantized so no need for a table. + return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8; + } + + static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; } + static uint8_t LongestBits(const Config &config) { return config.prob_bits; } + + class Middle { + public: + Middle(uint8_t prob_bits, const float *prob_begin, uint8_t backoff_bits, const float *backoff_begin) : + total_bits_(prob_bits + backoff_bits), total_mask_((1ULL << total_bits_) - 1), prob_(prob_bits, prob_begin), backoff_(backoff_bits, backoff_begin) {} + + void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { + util::WriteInt57(base, bit_offset, total_bits_, + (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); + } + + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { + uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); + prob = prob_.Decode(both >> backoff_.Bits()); + backoff = backoff_.Decode(both & backoff_.Mask()); + } + + void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { + backoff = backoff_.Decode(util::ReadInt25(base, bit_offset, backoff_.Bits(), backoff_.Mask())); + } + + uint8_t TotalBits() const { + return total_bits_; + } + + private: + const uint8_t total_bits_; + const uint64_t total_mask_; + const Bins prob_; + const Bins backoff_; + }; + + class Longest { + public: + // Sigh C++ default constructor + Longest() {} + + Longest(uint8_t prob_bits, const float *prob_begin) : prob_(prob_bits, prob_begin) {} + + void Write(void *base, uint64_t bit_offset, float prob) const { + util::WriteInt25(base, bit_offset, prob_.Bits(), prob_.EncodeProb(prob)); + } + + void Read(const void *base, uint64_t bit_offset, float &prob) const { + prob = prob_.Decode(util::ReadInt25(base, bit_offset, prob_.Bits(), prob_.Mask())); + } + + uint8_t TotalBits() const { return prob_.Bits(); } + + private: + Bins prob_; + }; + + SeparatelyQuantize() {} + + void SetupMemory(void *start, const Config &config); + + static const bool kTrain = true; + // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff. + void Train(uint8_t order, std::vector &prob, std::vector &backoff); + // Train just probabilities (for longest order). + void TrainProb(uint8_t order, std::vector &prob); + + void FinishedLoading(const Config &config); + + Middle Mid(uint8_t order) const { + const float *table = start_ + TableStart(order); + return Middle(prob_bits_, table, backoff_bits_, table + ProbTableLength()); + } + + Longest Long(uint8_t order) const { return Longest(prob_bits_, start_ + TableStart(order)); } + + private: + size_t TableStart(uint8_t order) const { return ((1ULL << prob_bits_) + (1ULL << backoff_bits_)) * static_cast(order - 2); } + size_t ProbTableLength() const { return (1ULL << prob_bits_); } + + float *start_; + uint8_t prob_bits_, backoff_bits_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_QUANTIZE_H__ diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index eaad59ab..c56ba7b8 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -80,6 +80,21 @@ template void ReadNGrams( } // namespace namespace detail { + +template uint8_t *TemplateHashedSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { + std::size_t allocated = Unigram::Size(counts[0]); + unigram = Unigram(start, allocated); + start += allocated; + for (unsigned int n = 2; n < counts.size(); ++n) { + allocated = Middle::Size(counts[n - 1], config.probing_multiplier); + middle_.push_back(Middle(start, allocated)); + start += allocated; + } + allocated = Longest::Size(counts.back(), config.probing_multiplier); + longest = Longest(start, allocated); + start += allocated; + return start; +} template template void TemplateHashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. @@ -92,15 +107,15 @@ template template void TemplateHashe try { if (counts.size() > 2) { - ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0], warn); + ReadNGrams(f, 2, counts[1], vocab, middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); } for (unsigned int n = 3; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle(middle[n-3]), middle[n-2], warn); + ReadNGrams(f, n, counts[n-1], vocab, middle_, ActivateLowerMiddle(middle_[n-3]), middle_[n-2], warn); } if (counts.size() > 2) { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle(middle.back()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateLowerMiddle(middle_.back()), longest, warn); } else { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateUnigram(unigram.Raw()), longest, warn); } } catch (util::ProbingSizeException &e) { UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); @@ -108,13 +123,18 @@ template template void TemplateHashe ReadEnd(f); } -template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); -template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, SortedVocabulary &vocab, Backing &backing); - -SortedHashedSearch::SortedHashedSearch() { - UTIL_THROW(util::Exception, "Sorted is broken at the moment, sorry"); +template void TemplateHashedSearch::LoadedBinary() { + unigram.LoadedBinary(); + for (typename std::vector::iterator i = middle_.begin(); i != middle_.end(); ++i) { + i->LoadedBinary(); + } + longest.LoadedBinary(); } +template class TemplateHashedSearch; + +template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); + } // namespace detail } // namespace ngram } // namespace lm diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 6dc11fb3..f3acdefc 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -8,7 +8,6 @@ #include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" -#include "util/sorted_uniform.hh" #include #include @@ -62,73 +61,71 @@ struct HashedSearch { } }; -template struct TemplateHashedSearch : public HashedSearch { - typedef MiddleT Middle; - std::vector middle; +template class TemplateHashedSearch : public HashedSearch { + public: + typedef MiddleT Middle; - typedef LongestT Longest; - Longest longest; + typedef LongestT Longest; + Longest longest; - static std::size_t Size(const std::vector &counts, const Config &config) { - std::size_t ret = Unigram::Size(counts[0]); - for (unsigned char n = 1; n < counts.size() - 1; ++n) { - ret += Middle::Size(counts[n], config.probing_multiplier); - } - return ret + Longest::Size(counts.back(), config.probing_multiplier); - } + // TODO: move probing_multiplier here with next binary file format update. + static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} - uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { - std::size_t allocated = Unigram::Size(counts[0]); - unigram = Unigram(start, allocated); - start += allocated; - for (unsigned int n = 2; n < counts.size(); ++n) { - allocated = Middle::Size(counts[n - 1], config.probing_multiplier); - middle.push_back(Middle(start, allocated)); - start += allocated; + static std::size_t Size(const std::vector &counts, const Config &config) { + std::size_t ret = Unigram::Size(counts[0]); + for (unsigned char n = 1; n < counts.size() - 1; ++n) { + ret += Middle::Size(counts[n], config.probing_multiplier); + } + return ret + Longest::Size(counts.back(), config.probing_multiplier); } - allocated = Longest::Size(counts.back(), config.probing_multiplier); - longest = Longest(start, allocated); - start += allocated; - return start; - } - template void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing); + uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &config); - bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { - node = CombineWordHash(node, word); - typename Middle::ConstIterator found; - if (!middle.Find(node, found)) return false; - prob = found->GetValue().prob; - backoff = found->GetValue().backoff; - return true; - } + template void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing); - bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { - node = CombineWordHash(node, word); - typename Middle::ConstIterator found; - if (!middle.Find(node, found)) return false; - backoff = found->GetValue().backoff; - return true; - } + const Middle *MiddleBegin() const { return &*middle_.begin(); } + const Middle *MiddleEnd() const { return &*middle_.end(); } - bool LookupLongest(WordIndex word, float &prob, Node &node) const { - node = CombineWordHash(node, word); - typename Longest::ConstIterator found; - if (!longest.Find(node, found)) return false; - prob = found->GetValue().prob; - return true; - } + bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + prob = found->GetValue().prob; + backoff = found->GetValue().backoff; + return true; + } + + void LoadedBinary(); - // Geenrate a node without necessarily checking that it actually exists. - // Optionally return false if it's know to not exist. - bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - assert(begin != end); - node = static_cast(*begin); - for (const WordIndex *i = begin + 1; i < end; ++i) { - node = CombineWordHash(node, *i); + bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + backoff = found->GetValue().backoff; + return true; } - return true; - } + + bool LookupLongest(WordIndex word, float &prob, Node &node) const { + node = CombineWordHash(node, word); + typename Longest::ConstIterator found; + if (!longest.Find(node, found)) return false; + prob = found->GetValue().prob; + return true; + } + + // Geenrate a node without necessarily checking that it actually exists. + // Optionally return false if it's know to not exist. + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + assert(begin != end); + node = static_cast(*begin); + for (const WordIndex *i = begin + 1; i < end; ++i) { + node = CombineWordHash(node, *i); + } + return true; + } + + private: + std::vector middle_; }; // std::identity is an SGI extension :-( @@ -143,15 +140,6 @@ struct ProbingHashedSearch : public TemplateHashedSearch< static const ModelType kModelType = HASH_PROBING; }; -struct SortedHashedSearch : public TemplateHashedSearch< - util::SortedUniformMap >, - util::SortedUniformMap > > { - - SortedHashedSearch(); - - static const ModelType kModelType = HASH_SORTED; -}; - } // namespace detail } // namespace ngram } // namespace lm diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 7c57072b..1ce4d278 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -4,6 +4,7 @@ #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" +#include "lm/quantize.hh" #include "lm/read_arpa.hh" #include "lm/trie.hh" #include "lm/vocab.hh" @@ -21,6 +22,7 @@ #include #include #include +#include #include #include @@ -579,7 +581,7 @@ bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const W // Phase to count n-grams, including blanks inserted because they were pruned but have extensions class JustCount { public: - JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) + template JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order) : counts_(counts), longest_counts_(counts + order - 1) {} void Unigrams(WordIndex begin, WordIndex end) { @@ -608,9 +610,9 @@ class JustCount { }; // Phase to actually write n-grams to the trie. -class WriteEntries { +template class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -647,14 +649,14 @@ class WriteEntries { private: ContextReader *contexts_; UnigramValue *const unigrams_; - BitPackedMiddle *const middle_; - BitPackedLongest &longest_; + BitPackedMiddle *const middle_; + BitPackedLongest &longest_; BitPacked &bigram_pack_; }; template class RecursiveInsert { public: - RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : + template RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) : doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { } @@ -775,7 +777,51 @@ void SanityCheckCounts(const std::vector &initial, const std::vector &counts, const Config &config, TrieSearch &out, Backing &backing) { +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} + +template void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + ProbBackoff weights; + std::vector probs, backoffs; + probs.reserve(count); + backoffs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); + ++progress; + } + } + quant.Train(order, probs, backoffs); +} + +template void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + Prob weights; + std::vector probs, backoffs; + probs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + ++progress; + } + } + quant.TrainProb(order, probs); +} + +} // namespace + +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing) { std::vector inputs(counts.size() - 1); std::vector contexts(counts.size() - 1); @@ -791,7 +837,7 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co std::vector fixed_counts(counts.size()); { - RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } for (std::vector::const_iterator i = inputs.begin(); i != inputs.end(); ++i) { @@ -800,7 +846,16 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + + if (Quant::kTrain) { + util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); + for (unsigned char i = 2; i < counts.size(); ++i) { + TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant); + } + TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant); + quant.FinishedLoading(config); + } for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -808,7 +863,7 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert inserter(&*inputs.begin(), &*contexts.begin(), unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -845,23 +900,44 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. - if (!out.middle.empty()) { - for (size_t i = 0; i < out.middle.size() - 1; ++i) { - out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); + if (out.middle_begin_ != out.middle_end_) { + for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { + i->FinishedLoading((i+1)->InsertIndex()); } - out.middle.back().FinishedLoading(out.longest.InsertIndex()); + (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex()); } } -bool IsDirectory(const char *path) { - struct stat info; - if (0 != stat(path, &info)) return false; - return S_ISDIR(info.st_mode); +template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { + quant_.SetupMemory(start, config); + start += Quant::Size(counts.size(), config); + unigram.Init(start); + start += Unigram::Size(counts[0]); + FreeMiddles(); + middle_begin_ = static_cast(malloc(sizeof(Middle) * (counts.size() - 2))); + middle_end_ = middle_begin_ + (counts.size() - 2); + for (unsigned char i = counts.size() - 1; i >= 2; --i) { + new (middle_begin_ + i - 2) Middle( + start, + quant_.Mid(i), + counts[0], + counts[i], + (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1])); + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + } + longest.Init(start, quant_.Long(counts.size()), counts[0]); + return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -} // namespace +template void TrieSearch::LoadedBinary() { + unigram.LoadedBinary(); + for (Middle *i = middle_begin_; i != middle_end_; ++i) { + i->LoadedBinary(); + } + longest.LoadedBinary(); +} -void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; @@ -885,12 +961,15 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v // At least 1MB sorting memory. ARPAToSortedFiles(config, f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory, counts, config, *this, backing); + BuildTrie(temporary_directory, counts, config, *this, quant_, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } } +template class TrieSearch; +template class TrieSearch; + } // namespace trie } // namespace ngram } // namespace lm diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0f720217..0a52acb5 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,72 +13,88 @@ struct Backing; class SortedVocabulary; namespace trie { -struct TrieSearch { - typedef NodeRange Node; +template class TrieSearch; +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); - typedef ::lm::ngram::trie::Unigram Unigram; - Unigram unigram; +template class TrieSearch { + public: + typedef NodeRange Node; - typedef trie::BitPackedMiddle Middle; - std::vector middle; + typedef ::lm::ngram::trie::Unigram Unigram; + Unigram unigram; - typedef trie::BitPackedLongest Longest; - Longest longest; + typedef trie::BitPackedMiddle Middle; - static const ModelType kModelType = TRIE_SORTED; + typedef trie::BitPackedLongest Longest; + Longest longest; - static std::size_t Size(const std::vector &counts, const Config &/*config*/) { - std::size_t ret = Unigram::Size(counts[0]); - for (unsigned char i = 1; i < counts.size() - 1; ++i) { - ret += Middle::Size(counts[i], counts[0], counts[i+1]); + static const ModelType kModelType = Quant::kModelType; + + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { + Quant::UpdateConfigFromBinary(fd, counts, config); } - return ret + Longest::Size(counts.back(), counts[0]); - } - - uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &/*config*/) { - unigram.Init(start); - start += Unigram::Size(counts[0]); - middle.resize(counts.size() - 2); - for (unsigned char i = 1; i < counts.size() - 1; ++i) { - middle[i-1].Init( - start, - counts[0], - counts[i+1], - (i == counts.size() - 2) ? static_cast(longest) : static_cast(middle[i])); - start += Middle::Size(counts[i], counts[0], counts[i+1]); + + static std::size_t Size(const std::vector &counts, const Config &config) { + std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); + for (unsigned char i = 1; i < counts.size() - 1; ++i) { + ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); + } + return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } - longest.Init(start, counts[0]); - return start + Longest::Size(counts.back(), counts[0]); - } - - void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - return unigram.Find(word, prob, backoff, node); - } - - bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { - return mid.Find(word, prob, backoff, node); - } - - bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { - return mid.FindNoProb(word, backoff, node); - } - - bool LookupLongest(WordIndex word, float &prob, const Node &node) const { - return longest.Find(word, prob, node); - } - - bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - // TODO: don't decode backoff. - assert(begin != end); - float ignored_prob, ignored_backoff; - LookupUnigram(*begin, ignored_prob, ignored_backoff, node); - for (const WordIndex *i = begin + 1; i < end; ++i) { - if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false; + + TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {} + + ~TrieSearch() { FreeMiddles(); } + + uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &config); + + void LoadedBinary(); + + const Middle *MiddleBegin() const { return middle_begin_; } + const Middle *MiddleEnd() const { return middle_end_; } + + void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); + + bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + return unigram.Find(word, prob, backoff, node); + } + + bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { + return mid.Find(word, prob, backoff, node); } - return true; - } + + bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { + return mid.FindNoProb(word, backoff, node); + } + + bool LookupLongest(WordIndex word, float &prob, const Node &node) const { + return longest.Find(word, prob, node); + } + + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + // TODO: don't decode backoff. + assert(begin != end); + float ignored_prob, ignored_backoff; + LookupUnigram(*begin, ignored_prob, ignored_backoff, node); + for (const WordIndex *i = begin + 1; i < end; ++i) { + if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; + } + return true; + } + + private: + friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); + + // Middles are managed manually so we can delay construction and they don't have to be copyable. + void FreeMiddles() { + for (const Middle *i = middle_begin_; i != middle_end_; ++i) { + i->~Middle(); + } + free(middle_begin_); + } + + Middle *middle_begin_, *middle_end_; + Quant quant_; }; } // namespace trie diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 2c633613..63c2a612 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,8 +1,8 @@ #include "lm/trie.hh" +#include "lm/quantize.hh" #include "util/bit_packing.hh" #include "util/exception.hh" -#include "util/proxy_iterator.hh" #include "util/sorted_uniform.hh" #include @@ -12,53 +12,32 @@ namespace ngram { namespace trie { namespace { -// Assumes key is first. -class JustKeyProxy { +class KeyAccessor { public: - JustKeyProxy() : inner_(), base_(), key_mask_(), key_bits_(), total_bits_() {} + KeyAccessor(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) + : base_(reinterpret_cast(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {} - operator uint64_t() const { return GetKey(); } + typedef uint64_t Key; - uint64_t GetKey() const { - uint64_t bit_off = inner_ * static_cast(total_bits_); - return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_bits_, key_mask_); + Key operator()(uint64_t index) const { + return util::ReadInt57(base_, index * static_cast(total_bits_), key_bits_, key_mask_); } private: - friend class util::ProxyIterator; - friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); - - JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) - : inner_(index), base_(static_cast(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {} - - // This is a read-only iterator. - JustKeyProxy &operator=(const JustKeyProxy &other); - - typedef uint64_t value_type; - - typedef uint64_t InnerIterator; - uint64_t &Inner() { return inner_; } - const uint64_t &Inner() const { return inner_; } - - // The address in bits is base_ * 8 + inner_ * total_bits_. - uint64_t inner_; const uint8_t *const base_; - const uint64_t key_mask_; + const WordIndex key_mask_; const uint8_t key_bits_, total_bits_; }; -bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { - util::ProxyIterator begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits)); - util::ProxyIterator end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits)); - util::ProxyIterator out; - if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false; - at_index = out.Inner(); +bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, const uint64_t max_vocab, const uint64_t key, uint64_t &at_index) { + KeyAccessor accessor(base, key_mask, key_bits, total_bits); + if (!util::BoundedSortedUniformFind::T>(accessor, begin_index - 1, (uint64_t)0, end_index, max_vocab, key, at_index)) return false; return true; } } // namespace std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { - uint8_t total_bits = util::RequiredBits(max_vocab) + 31 + remaining_bits; + uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits; // Extra entry for next pointer at the end. // +7 then / 8 to round up bits and convert to bytes // +sizeof(uint64_t) so that ReadInt57 etc don't go segfault. @@ -71,100 +50,96 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) word_bits_ = util::RequiredBits(max_vocab); word_mask_ = (1ULL << word_bits_) - 1ULL; if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions."); - prob_bits_ = 31; - total_bits_ = word_bits_ + prob_bits_ + remaining_bits; + total_bits_ = word_bits_ + remaining_bits; base_ = static_cast(base); insert_index_ = 0; + max_vocab_ = max_vocab; } -std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { - return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr)); +template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { + return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr)); } -void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) { - next_source_ = &next_source; - backoff_bits_ = 32; - next_bits_ = util::RequiredBits(max_next); +template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) { if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); - next_mask_ = (1ULL << next_bits_) - 1; - - BaseInit(base, max_vocab, backoff_bits_ + next_bits_); + BaseInit(base, max_vocab, quant.TotalBits() + next_bits_); } -void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { +template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { assert(word <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word); + util::WriteInt57(base_, at_pointer, word_bits_, word); at_pointer += word_bits_; - util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); - at_pointer += prob_bits_; - util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff); - at_pointer += backoff_bits_; + quant_.Write(base_, at_pointer, prob, backoff); + at_pointer += quant_.TotalBits(); uint64_t next = next_source_->InsertIndex(); assert(next <= next_mask_); - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next); + util::WriteInt57(base_, at_pointer, next_bits_, next); ++insert_index_; } -bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) { + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } at_pointer *= total_bits_; at_pointer += word_bits_; - prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); - at_pointer += prob_bits_; - backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); - at_pointer += backoff_bits_; - range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + quant_.Read(base_, at_pointer, prob, backoff); + at_pointer += quant_.TotalBits(); + + range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); // Read the next entry's pointer. at_pointer += total_bits_; - range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); return true; } -bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; at_pointer *= total_bits_; at_pointer += word_bits_; - at_pointer += prob_bits_; - backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); - at_pointer += backoff_bits_; - range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + quant_.ReadBackoff(base_, at_pointer, backoff); + at_pointer += quant_.TotalBits(); + range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); // Read the next entry's pointer. at_pointer += total_bits_; - range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); return true; } -void BitPackedMiddle::FinishedLoading(uint64_t next_end) { +template void BitPackedMiddle::FinishedLoading(uint64_t next_end) { assert(next_end <= next_mask_); uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; - util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end); + util::WriteInt57(base_, last_next_write, next_bits_, next_end); } -void BitPackedLongest::Insert(WordIndex index, float prob) { +template void BitPackedLongest::Insert(WordIndex index, float prob) { assert(index <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, index); + util::WriteInt57(base_, at_pointer, word_bits_, index); at_pointer += word_bits_; - util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); + quant_.Write(base_, at_pointer, prob); ++insert_index_; } -bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const { +template bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; at_pointer = at_pointer * total_bits_ + word_bits_; - prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); + quant_.Read(base_, at_pointer, prob); return true; } +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedLongest; +template class BitPackedLongest; + } // namespace trie } // namespace ngram } // namespace lm diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 6aef050c..8fa21aaf 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -74,23 +74,21 @@ class BitPacked { void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); - uint8_t word_bits_, prob_bits_; + uint8_t word_bits_; uint8_t total_bits_; uint64_t word_mask_; uint8_t *base_; - uint64_t insert_index_; + uint64_t insert_index_, max_vocab_; }; -class BitPackedMiddle : public BitPacked { +template class BitPackedMiddle : public BitPacked { public: - BitPackedMiddle() {} - - static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); // next_source need not be initialized. - void Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); + BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); void Insert(WordIndex word, float prob, float backoff); @@ -101,28 +99,33 @@ class BitPackedMiddle : public BitPacked { void FinishedLoading(uint64_t next_end); private: - uint8_t backoff_bits_, next_bits_; + Quant quant_; + uint8_t next_bits_; uint64_t next_mask_; const BitPacked *next_source_; }; -class BitPackedLongest : public BitPacked { +template class BitPackedLongest : public BitPacked { public: - BitPackedLongest() {} - - static std::size_t Size(uint64_t entries, uint64_t max_vocab) { - return BaseSize(entries, max_vocab, 0); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { + return BaseSize(entries, max_vocab, quant_bits); } - void Init(void *base, uint64_t max_vocab) { - return BaseInit(base, max_vocab, 0); + BitPackedLongest() {} + + void Init(void *base, const Quant &quant, uint64_t max_vocab) { + quant_ = quant; + BaseInit(base, max_vocab, quant_.TotalBits()); } void Insert(WordIndex word, float prob); bool Find(WordIndex word, float &prob, const NodeRange &node) const; + + private: + Quant quant_; }; } // namespace trie diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 515af5db..7defd5c1 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -28,8 +28,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("", 5); // Sadly some LMs have . const uint64_t kUnknownCapHash = detail::HashForVocab("", 5); -void ReadWords(int fd, EnumerateVocab *enumerate) { - if (!enumerate) return; +WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { + if (!enumerate) return std::numeric_limits::max(); const std::size_t kInitialRead = 16384; std::string buf; buf.reserve(kInitialRead + 100); @@ -38,7 +38,7 @@ void ReadWords(int fd, EnumerateVocab *enumerate) { while (true) { ssize_t got = read(fd, &buf[0], kInitialRead); if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); - if (got == 0) return; + if (got == 0) return index; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; @@ -87,13 +87,13 @@ SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) { // Lead with the number of entries. - return sizeof(uint64_t) + sizeof(Entry) * entries; + return sizeof(uint64_t) + sizeof(uint64_t) * entries; } void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) { assert(allocated >= Size(entries, config)); // Leave space for number of entries. - begin_ = reinterpret_cast(reinterpret_cast(start) + 1); + begin_ = reinterpret_cast(start) + 1; end_ = begin_; saw_unk_ = false; } @@ -112,7 +112,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { saw_unk_ = true; return 0; } - end_->key = hashed; + *end_ = hashed; if (enumerate_) { strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); } @@ -134,8 +134,10 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { util::JointSort(begin_, end_, reorder_vocab + 1); } SetSpecial(Index(""), Index(""), 0); - // Save size. + // Save size. Excludes UNK. *(reinterpret_cast(begin_) - 1) = end_ - begin_; + // Includes UNK. + bound_ = end_ - begin_ + 1; } void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { @@ -183,7 +185,7 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { lookup_.LoadedBinary(); - ReadWords(fd, to); + available_ = ReadWords(fd, to); SetSpecial(Index(""), Index(""), 0); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 546c1649..c92518e4 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -9,6 +9,7 @@ #include "util/sorted_uniform.hh" #include "util/string_piece.hh" +#include #include #include @@ -44,22 +45,16 @@ class WriteWordsWrapper : public EnumerateVocab { // Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. class SortedVocabulary : public base::Vocabulary { - private: - // Sorted uniform requires a GetKey function. - struct Entry { - uint64_t GetKey() const { return key; } - uint64_t key; - bool operator<(const Entry &other) const { - return key < other.key; - } - }; - public: SortedVocabulary(); WordIndex Index(const StringPiece &str) const { - const Entry *found; - if (util::SortedUniformFind(begin_, end_, detail::HashForVocab(str), found)) { + const uint64_t *found; + if (util::BoundedSortedUniformFind, util::Pivot64>( + util::IdentityAccessor(), + begin_ - 1, 0, + end_, std::numeric_limits::max(), + detail::HashForVocab(str), found)) { return found - begin_ + 1; // +1 because is 0 and does not appear in the lookup table. } else { return 0; @@ -68,6 +63,10 @@ class SortedVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); + // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. + // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases. + WordIndex Bound() const { return bound_; } + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); @@ -83,7 +82,11 @@ class SortedVocabulary : public base::Vocabulary { void LoadedBinary(int fd, EnumerateVocab *to); private: - Entry *begin_, *end_; + uint64_t *begin_, *end_; + + WordIndex bound_; + + WordIndex highest_value_; bool saw_unk_; @@ -105,6 +108,12 @@ class ProbingVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); + // Vocab words are [0, Bound()). + // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary. + // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update. + // Specifically, the binary file format does not currently indicate whether is in count or not. + WordIndex Bound() const { return available_; } + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc index 681da5f2..41999b72 100644 --- a/klm/util/bit_packing.cc +++ b/klm/util/bit_packing.cc @@ -28,10 +28,10 @@ void BitPackingSanity() { memset(mem, 0, sizeof(mem)); const uint64_t test57 = 0x123456789abcdefULL; for (uint64_t b = 0; b < 57 * 8; b += 57) { - WriteInt57(mem + b / 8, b % 8, 57, test57); + WriteInt57(mem, b, 57, test57); } for (uint64_t b = 0; b < 57 * 8; b += 57) { - if (test57 != ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)) + if (test57 != ReadInt57(mem, b, 57, (1ULL << 57) - 1)) UTIL_THROW(Exception, "The bit packing routines are failing for your architecture. Please send a bug report with your architecture, operating system, and compiler."); } // TODO: more checks. diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 5c71c792..b35d80c8 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -42,47 +42,62 @@ inline uint8_t BitPackShift(uint8_t bit, uint8_t length) { #error "Bit packing code isn't written for your byte order." #endif +inline uint64_t ReadOff(const void *base, uint64_t bit_off) { + return *reinterpret_cast(reinterpret_cast(base) + (bit_off >> 3)); +} + /* Pack integers up to 57 bits using their least significant digits. * The length is specified using mask: * Assumes mask == (1 << length) - 1 where length <= 57. */ -inline uint64_t ReadInt57(const void *base, uint8_t bit, uint8_t length, uint64_t mask) { - return (*reinterpret_cast(base) >> BitPackShift(bit, length)) & mask; +inline uint64_t ReadInt57(const void *base, uint64_t bit_off, uint8_t length, uint64_t mask) { + return (ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, length)) & mask; } /* Assumes value < (1 << length) and length <= 57. * Assumes the memory is zero initially. */ -inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) { - *reinterpret_cast(base) |= (value << BitPackShift(bit, length)); +inline void WriteInt57(void *base, uint64_t bit_off, uint8_t length, uint64_t value) { + *reinterpret_cast(reinterpret_cast(base) + (bit_off >> 3)) |= + (value << BitPackShift(bit_off & 7, length)); +} + +/* Same caveats as above, but for a 25 bit limit. */ +inline uint32_t ReadInt25(const void *base, uint64_t bit_off, uint8_t length, uint32_t mask) { + return (*reinterpret_cast(reinterpret_cast(base) + (bit_off >> 3)) >> BitPackShift(bit_off & 7, length)) & mask; +} + +inline void WriteInt25(void *base, uint64_t bit_off, uint8_t length, uint32_t value) { + *reinterpret_cast(reinterpret_cast(base) + (bit_off >> 3)) |= + (value << BitPackShift(bit_off & 7, length)); } typedef union { float f; uint32_t i; } FloatEnc; -inline float ReadFloat32(const void *base, uint8_t bit) { +inline float ReadFloat32(const void *base, uint64_t bit_off) { FloatEnc encoded; - encoded.i = *reinterpret_cast(base) >> BitPackShift(bit, 32); + encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32); return encoded.f; } -inline void WriteFloat32(void *base, uint8_t bit, float value) { +inline void WriteFloat32(void *base, uint64_t bit_off, float value) { FloatEnc encoded; encoded.f = value; - WriteInt57(base, bit, 32, encoded.i); + WriteInt57(base, bit_off, 32, encoded.i); } const uint32_t kSignBit = 0x80000000; -inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) { +inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) { FloatEnc encoded; - encoded.i = *reinterpret_cast(base) >> BitPackShift(bit, 31); + encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31); // Sign bit set means negative. encoded.i |= kSignBit; return encoded.f; } -inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) { +inline void WriteNonPositiveFloat31(void *base, uint64_t bit_off, float value) { FloatEnc encoded; encoded.f = value; encoded.i &= ~kSignBit; - WriteInt57(base, bit, 31, encoded.i); + WriteInt57(base, bit_off, 31, encoded.i); } void BitPackingSanity(); diff --git a/klm/util/bit_packing_test.cc b/klm/util/bit_packing_test.cc index c578ddd1..4edc2004 100644 --- a/klm/util/bit_packing_test.cc +++ b/klm/util/bit_packing_test.cc @@ -9,15 +9,16 @@ namespace util { namespace { const uint64_t test57 = 0x123456789abcdefULL; +const uint32_t test25 = 0x1234567; -BOOST_AUTO_TEST_CASE(ZeroBit) { +BOOST_AUTO_TEST_CASE(ZeroBit57) { char mem[16]; memset(mem, 0, sizeof(mem)); WriteInt57(mem, 0, 57, test57); BOOST_CHECK_EQUAL(test57, ReadInt57(mem, 0, 57, (1ULL << 57) - 1)); } -BOOST_AUTO_TEST_CASE(EachBit) { +BOOST_AUTO_TEST_CASE(EachBit57) { char mem[16]; for (uint8_t b = 0; b < 8; ++b) { memset(mem, 0, sizeof(mem)); @@ -26,15 +27,27 @@ BOOST_AUTO_TEST_CASE(EachBit) { } } -BOOST_AUTO_TEST_CASE(Consecutive) { +BOOST_AUTO_TEST_CASE(Consecutive57) { char mem[57+8]; memset(mem, 0, sizeof(mem)); for (uint64_t b = 0; b < 57 * 8; b += 57) { - WriteInt57(mem + (b / 8), b % 8, 57, test57); - BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); + WriteInt57(mem, b, 57, test57); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); } for (uint64_t b = 0; b < 57 * 8; b += 57) { - BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); + } +} + +BOOST_AUTO_TEST_CASE(Consecutive25) { + char mem[25+8]; + memset(mem, 0, sizeof(mem)); + for (uint64_t b = 0; b < 25 * 8; b += 25) { + WriteInt25(mem, b, 25, test25); + BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); + } + for (uint64_t b = 0; b < 25 * 8; b += 25) { + BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); } } diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index 05826b51..84d7aa02 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -9,52 +9,96 @@ namespace util { -inline std::size_t Pivot(uint64_t off, uint64_t range, std::size_t width) { - std::size_t ret = static_cast(static_cast(off) / static_cast(range) * static_cast(width)); - // Cap for floating point rounding - return (ret < width) ? ret : width - 1; -} -/*inline std::size_t Pivot(uint32_t off, uint32_t range, std::size_t width) { - return static_cast(static_cast(off) * static_cast(width) / static_cast(range)); +template class IdentityAccessor { + public: + typedef T Key; + T operator()(const uint64_t *in) const { return *in; } +}; + +struct Pivot64 { + static inline std::size_t Calc(uint64_t off, uint64_t range, std::size_t width) { + std::size_t ret = static_cast(static_cast(off) / static_cast(range) * static_cast(width)); + // Cap for floating point rounding + return (ret < width) ? ret : width - 1; + } +}; + +// Use when off * width is <2^64. This is guaranteed when each of them is actually a 32-bit value. +struct Pivot32 { + static inline std::size_t Calc(uint64_t off, uint64_t range, uint64_t width) { + return static_cast((off * width) / (range + 1)); + } +}; + +// Usage: PivotSelect::T +template struct PivotSelect; +template <> struct PivotSelect<8> { typedef Pivot64 T; }; +template <> struct PivotSelect<4> { typedef Pivot32 T; }; +template <> struct PivotSelect<2> { typedef Pivot32 T; }; + +/* Binary search. */ +template bool BinaryFind( + const Accessor &accessor, + Iterator begin, + Iterator end, + const typename Accessor::Key key, Iterator &out) { + while (end > begin) { + Iterator pivot(begin + (end - begin) / 2); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + begin = pivot + 1; + } else if (mid > key) { + end = pivot; + } else { + out = pivot; + return true; + } + } + return false; } -inline std::size_t Pivot(uint16_t off, uint16_t range, std::size_t width) { - return static_cast(static_cast(off) * width / static_cast(range)); + +// Search the range [before_it + 1, after_it - 1] for key. +// Preconditions: +// before_v <= key <= after_v +// before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v +// range is sorted. +template bool BoundedSortedUniformFind( + const Accessor &accessor, + Iterator before_it, typename Accessor::Key before_v, + Iterator after_it, typename Accessor::Key after_v, + const typename Accessor::Key key, Iterator &out) { + while (after_it - before_it > 1) { + Iterator pivot(before_it + (1 + Pivot::Calc(key - before_v, after_v - before_v, after_it - before_it - 1))); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + before_it = pivot; + before_v = mid; + } else if (mid > key) { + after_it = pivot; + after_v = mid; + } else { + out = pivot; + return true; + } + } + return false; } -inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t width) { - return static_cast(static_cast(off) * width / static_cast(range)); -}*/ -template bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) { +template bool SortedUniformFind(const Accessor &accessor, Iterator begin, Iterator end, const typename Accessor::Key key, Iterator &out) { if (begin == end) return false; - Key below(begin->GetKey()); + typename Accessor::Key below(accessor(begin)); if (key <= below) { if (key == below) { out = begin; return true; } return false; } // Make the range [begin, end]. --end; - Key above(end->GetKey()); + typename Accessor::Key above(accessor(end)); if (key >= above) { if (key == above) { out = end; return true; } return false; } - - // Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above. - while (end - begin > 1) { - Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast(end - begin - 1)))); - Key mid(pivot->GetKey()); - if (mid < key) { - begin = pivot; - below = mid; - } else if (mid > key) { - end = pivot; - above = mid; - } else { - out = pivot; - return true; - } - } - return false; + return BoundedSortedUniformFind(accessor, begin, below, end, above, key, out); } // To use this template, you need to define a Pivot function to match Key. @@ -64,7 +108,13 @@ template class SortedUniformMap { typedef typename Packing::ConstIterator ConstIterator; typedef typename Packing::MutableIterator MutableIterator; - public: + struct Accessor { + public: + typedef typename Packing::Key Key; + const Key &operator()(const ConstIterator &i) const { return i->GetKey(); } + Key &operator()(const MutableIterator &i) const { return i->GetKey(); } + }; + // Offer consistent API with probing hash. static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) { return sizeof(uint64_t) + entries * Packing::kBytes; @@ -120,7 +170,7 @@ template class SortedUniformMap { assert(initialized_); assert(loaded_); #endif - return SortedUniformFind(begin_, end_, key, out); + return SortedUniformFind(begin_, end_, key, out); } // Do not call before FinishedInserting. @@ -129,7 +179,7 @@ template class SortedUniformMap { assert(initialized_); assert(loaded_); #endif - return SortedUniformFind(ConstIterator(begin_), ConstIterator(end_), key, out); + return SortedUniformFind(Accessor(), ConstIterator(begin_), ConstIterator(end_), key, out); } ConstIterator begin() const { return begin_; } -- cgit v1.2.3 From fe4b60f8669f0bdfcc67832e5487b33bd4b28938 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 6 Jul 2011 19:54:58 -0400 Subject: ngram count features --- decoder/Makefile.am | 1 + decoder/cdec_ff.cc | 2 + decoder/ff_ngrams.cc | 319 +++++++++++++++++++++++++++++++++++++++++++++++++++ decoder/ff_ngrams.h | 29 +++++ 4 files changed, 351 insertions(+) create mode 100644 decoder/ff_ngrams.cc create mode 100644 decoder/ff_ngrams.h (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 244da2de..d884c431 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -65,6 +65,7 @@ libcdec_a_SOURCES = \ ff_charset.cc \ ff_lm.cc \ ff_klm.cc \ + ff_ngrams.cc \ ff_spans.cc \ ff_ruleshape.cc \ ff_wordalign.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 31f88a4f..3451c9fb 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -4,6 +4,7 @@ #include "ff_spans.h" #include "ff_lm.h" #include "ff_klm.h" +#include "ff_ngrams.h" #include "ff_csplit.h" #include "ff_wordalign.h" #include "ff_tagger.h" @@ -51,6 +52,7 @@ void register_feature_functions() { ff_registry.Register("RandLM", new FFFactory); #endif ff_registry.Register("SpanFeatures", new FFFactory()); + ff_registry.Register("NgramFeatures", new FFFactory()); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); ff_registry.Register("KLanguageModel", new FFFactory >()); diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc new file mode 100644 index 00000000..54b394ae --- /dev/null +++ b/decoder/ff_ngrams.cc @@ -0,0 +1,319 @@ +#include "ff_ngrams.h" + +#include +#include + +#include + +#include "filelib.h" +#include "stringlib.h" +#include "hg.h" +#include "tdict.h" + +using namespace std; + +static const unsigned char HAS_FULL_CONTEXT = 1; +static const unsigned char HAS_EOS_ON_RIGHT = 2; +static const unsigned char MASK = 7; + +namespace { +template +struct State { + explicit State() { + memset(state, 0, sizeof(state)); + } + explicit State(int order) { + memset(state, 0, (order - 1) * sizeof(WordID)); + } + State(char order, const WordID* mem) { + memcpy(state, mem, (order - 1) * sizeof(WordID)); + } + State(const State& other) { + memcpy(state, other.state, sizeof(state)); + } + const State& operator=(const State& other) { + memcpy(state, other.state, sizeof(state)); + } + explicit State(const State& other, unsigned order, WordID extend) { + char om1 = order - 1; + assert(om1 > 0); + for (char i = 1; i < om1; ++i) state[i - 1]= other.state[i]; + state[om1 - 1] = extend; + } + const WordID& operator[](size_t i) const { return state[i]; } + WordID& operator[](size_t i) { return state[i]; } + WordID state[MAX_ORDER]; +}; +} + +class NgramDetectorImpl { + + // returns the number of unscored words at the left edge of a span + inline int UnscoredSize(const void* state) const { + return *(static_cast(state) + unscored_size_offset_); + } + + inline void SetUnscoredSize(int size, void* state) const { + *(static_cast(state) + unscored_size_offset_) = size; + } + + inline State<5> RemnantLMState(const void* cstate) const { + return State<5>(order_, static_cast(cstate)); + } + + inline const State<5> BeginSentenceState() const { + State<5> state(order_); + state.state[0] = kSOS_; + return state; + } + + inline void SetRemnantLMState(const State<5>& lmstate, void* state) const { + // if we were clever, we could use the memory pointed to by state to do all + // the work, avoiding this copy + memcpy(state, lmstate.state, (order_-1) * sizeof(WordID)); + } + + WordID IthUnscoredWord(int i, const void* state) const { + const WordID* const mem = reinterpret_cast(static_cast(state) + unscored_words_offset_); + return mem[i]; + } + + void SetIthUnscoredWord(int i, const WordID index, void *state) const { + WordID* mem = reinterpret_cast(static_cast(state) + unscored_words_offset_); + mem[i] = index; + } + + inline bool GetFlag(const void *state, unsigned char flag) const { + return (*(static_cast(state) + is_complete_offset_) & flag); + } + + inline void SetFlag(bool on, unsigned char flag, void *state) const { + if (on) { + *(static_cast(state) + is_complete_offset_) |= flag; + } else { + *(static_cast(state) + is_complete_offset_) &= (MASK ^ flag); + } + } + + inline bool HasFullContext(const void *state) const { + return GetFlag(state, HAS_FULL_CONTEXT); + } + + inline void SetHasFullContext(bool flag, void *state) const { + SetFlag(flag, HAS_FULL_CONTEXT, state); + } + + void FireFeatures(const State<5>& state, const WordID cur, SparseVector* feats) { + assert(order_ == 2); + if (cur >= unimap_.size()) + unimap_.resize(cur + 10, 0); + int& uf = unimap_[cur]; + if (!uf) { + ostringstream os; + os << "U:" << TD::Convert(cur); + uf = FD::Convert(os.str()); + } + feats->set_value(uf, 1.0); + if (state.state[0]) { + if (state.state[0] >= bimap_.size()) + bimap_.resize(state.state[0] + 10); + int& bf = bimap_[state.state[0]][cur]; + if (!bf) { + ostringstream os; + os << "B:" << TD::Convert(state[0]) << '_' << TD::Convert(cur); + bf = FD::Convert(os.str()); + } + feats->set_value(bf, 1.0); + } + } + + public: + void LookupWords(const TRule& rule, const vector& ant_states, SparseVector* feats, SparseVector* est_feats, void* remnant) { + double sum = 0.0; + double est_sum = 0.0; + int num_scored = 0; + int num_estimated = 0; + bool saw_eos = false; + bool has_some_history = false; + State<5> state; + const vector& e = rule.e(); + bool context_complete = false; + for (int j = 0; j < e.size(); ++j) { + if (e[j] < 1) { // handle non-terminal substitution + const void* astate = (ant_states[-e[j]]); + int unscored_ant_len = UnscoredSize(astate); + for (int k = 0; k < unscored_ant_len; ++k) { + const WordID cur_word = IthUnscoredWord(k, astate); + const bool is_oov = (cur_word == 0); + SparseVector p; + if (cur_word == kSOS_) { + state = BeginSentenceState(); + if (has_some_history) { // this is immediately fully scored, and bad + p.set_value(FD::Convert("Malformed"), 1.0); + context_complete = true; + } else { // this might be a real + num_scored = max(0, order_ - 2); + } + } else { + FireFeatures(state, cur_word, &p); + const State<5> scopy = State<5>(state, order_, cur_word); + state = scopy; + if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); } + saw_eos = (cur_word == kEOS_); + } + has_some_history = true; + ++num_scored; + if (!context_complete) { + if (num_scored >= order_) context_complete = true; + } + if (context_complete) { + (*feats) += p; + } else { + if (remnant) + SetIthUnscoredWord(num_estimated, cur_word, remnant); + ++num_estimated; + (*est_feats) += p; + } + } + saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); + if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 + state = RemnantLMState(astate); + context_complete = true; + } + } else { // handle terminal + const WordID cur_word = e[j]; + SparseVector p; + if (cur_word == kSOS_) { + state = BeginSentenceState(); + if (has_some_history) { // this is immediately fully scored, and bad + p.set_value(FD::Convert("Malformed"), -100); + context_complete = true; + } else { // this might be a real + num_scored = max(0, order_ - 2); + } + } else { + FireFeatures(state, cur_word, &p); + const State<5> scopy = State<5>(state, order_, cur_word); + state = scopy; + if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); } + saw_eos = (cur_word == kEOS_); + } + has_some_history = true; + ++num_scored; + if (!context_complete) { + if (num_scored >= order_) context_complete = true; + } + if (context_complete) { + (*feats) += p; + } else { + if (remnant) + SetIthUnscoredWord(num_estimated, cur_word, remnant); + ++num_estimated; + (*est_feats) += p; + } + } + } + if (remnant) { + SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); + SetRemnantLMState(state, remnant); + SetUnscoredSize(num_estimated, remnant); + SetHasFullContext(context_complete || (num_scored >= order_), remnant); + } + } + + // this assumes no target words on final unary -> goal rule. is that ok? + // for (n-1 left words) and (n-1 right words) + void FinalTraversal(const void* state, SparseVector* feats) { + if (add_sos_eos_) { // rules do not produce , so do it here + SetRemnantLMState(BeginSentenceState(), dummy_state_); + SetHasFullContext(1, dummy_state_); + SetUnscoredSize(0, dummy_state_); + dummy_ants_[1] = state; + LookupWords(*dummy_rule_, dummy_ants_, feats, NULL, NULL); + } else { // rules DO produce ... +#if 0 + double p = 0; + if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } + if (UnscoredSize(state) > 0) { // are there unscored words + if (kSOS_ != IthUnscoredWord(0, state)) { + p -= 100 * UnscoredSize(state); + } + } + return p; +#endif + } + } + + public: + explicit NgramDetectorImpl(bool explicit_markers) : + kCDEC_UNK(TD::Convert("")) , + add_sos_eos_(!explicit_markers) { + order_ = 2; + state_size_ = (order_ - 1) * sizeof(WordID) + 2 + (order_ - 1) * sizeof(WordID); + unscored_size_offset_ = (order_ - 1) * sizeof(WordID); + is_complete_offset_ = unscored_size_offset_ + 1; + unscored_words_offset_ = is_complete_offset_ + 1; + + // special handling of beginning / ending sentence markers + dummy_state_ = new char[state_size_]; + memset(dummy_state_, 0, state_size_); + dummy_ants_.push_back(dummy_state_); + dummy_ants_.push_back(NULL); + dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); + kSOS_ = TD::Convert(""); + kEOS_ = TD::Convert(""); + } + + ~NgramDetectorImpl() { + delete[] dummy_state_; + } + + int ReserveStateSize() const { return state_size_; } + + private: + const WordID kCDEC_UNK; + WordID kSOS_; // - requires special handling. + WordID kEOS_; // + const bool add_sos_eos_; // flag indicating whether the hypergraph produces and + // if this is true, FinalTransitionFeatures will "add" and + // if false, FinalTransitionFeatures will score anything with the + // markers in the right place (i.e., the beginning and end of + // the sentence) with 0, and anything else with -100 + + int order_; + int state_size_; + int unscored_size_offset_; + int is_complete_offset_; + int unscored_words_offset_; + char* dummy_state_; + vector dummy_ants_; + TRulePtr dummy_rule_; + mutable std::vector unimap_; // [left][right] + mutable std::vector > bimap_; // [left][right] +}; + +NgramDetector::NgramDetector(const string& param) { + string filename, mapfile, featname; + bool explicit_markers = (param == "-x"); + pimpl_ = new NgramDetectorImpl(explicit_markers); + SetStateSize(pimpl_->ReserveStateSize()); +} + +NgramDetector::~NgramDetector() { + delete pimpl_; +} + +void NgramDetector::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + pimpl_->LookupWords(*edge.rule_, ant_states, features, estimated_features, state); +} + +void NgramDetector::FinalTraversalFeatures(const void* ant_state, + SparseVector* features) const { + pimpl_->FinalTraversal(ant_state, features); +} + diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h new file mode 100644 index 00000000..82f61b33 --- /dev/null +++ b/decoder/ff_ngrams.h @@ -0,0 +1,29 @@ +#ifndef _NGRAMS_FF_H_ +#define _NGRAMS_FF_H_ + +#include +#include +#include + +#include "ff.h" + +struct NgramDetectorImpl; +class NgramDetector : public FeatureFunction { + public: + // param = "filename.lm [-o n]" + NgramDetector(const std::string& param); + ~NgramDetector(); + virtual void FinalTraversalFeatures(const void* context, + SparseVector* features) const; + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + NgramDetectorImpl* pimpl_; +}; + +#endif -- cgit v1.2.3 From 75b814cb246052746134f32c723cf6d278b148df Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 6 Jul 2011 23:32:53 -0400 Subject: better handling of ngram features --- decoder/ff_ngrams.cc | 49 +++++++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 22 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc index 54b394ae..d52667cd 100644 --- a/decoder/ff_ngrams.cc +++ b/decoder/ff_ngrams.cc @@ -103,27 +103,29 @@ class NgramDetectorImpl { SetFlag(flag, HAS_FULL_CONTEXT, state); } - void FireFeatures(const State<5>& state, const WordID cur, SparseVector* feats) { - assert(order_ == 2); - if (cur >= unimap_.size()) - unimap_.resize(cur + 10, 0); - int& uf = unimap_[cur]; - if (!uf) { - ostringstream os; - os << "U:" << TD::Convert(cur); - uf = FD::Convert(os.str()); - } - feats->set_value(uf, 1.0); - if (state.state[0]) { - if (state.state[0] >= bimap_.size()) - bimap_.resize(state.state[0] + 10); - int& bf = bimap_[state.state[0]][cur]; - if (!bf) { + void FireFeatures(const State<5>& state, WordID cur, SparseVector* feats) { + FidTree* ft = &fidroot_; + int n = 0; + WordID buf[10]; + int ci = order_ - 1; + WordID curword = cur; + while(curword) { + buf[n] = curword; + int& fid = ft->fids[curword]; + ++n; + if (!fid) { + const char* code="_UBT456789"; ostringstream os; - os << "B:" << TD::Convert(state[0]) << '_' << TD::Convert(cur); - bf = FD::Convert(os.str()); + os << code[n] << ':'; + for (int i = n-1; i >= 0; --i) + os << (i != n-1 ? "_" : "") << TD::Convert(buf[i]); + fid = FD::Convert(os.str()); } - feats->set_value(bf, 1.0); + feats->set_value(fid, 1); + ft = &ft->levels[curword]; + --ci; + if (ci < 0) break; + curword = state[ci]; } } @@ -248,7 +250,7 @@ class NgramDetectorImpl { explicit NgramDetectorImpl(bool explicit_markers) : kCDEC_UNK(TD::Convert("")) , add_sos_eos_(!explicit_markers) { - order_ = 2; + order_ = 3; state_size_ = (order_ - 1) * sizeof(WordID) + 2 + (order_ - 1) * sizeof(WordID); unscored_size_offset_ = (order_ - 1) * sizeof(WordID); is_complete_offset_ = unscored_size_offset_ + 1; @@ -288,8 +290,11 @@ class NgramDetectorImpl { char* dummy_state_; vector dummy_ants_; TRulePtr dummy_rule_; - mutable std::vector unimap_; // [left][right] - mutable std::vector > bimap_; // [left][right] + struct FidTree { + map fids; + map levels; + }; + mutable FidTree fidroot_; }; NgramDetector::NgramDetector(const string& param) { -- cgit v1.2.3 From 71daf4bf0b91a247d0d1663ae7850a3db85a378d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 7 Jul 2011 18:39:38 -0400 Subject: support for extracting k-best derivation trees --- decoder/decoder.cc | 12 +++++++++--- decoder/oracle_bleu.h | 22 +++++++++++++++------- 2 files changed, 24 insertions(+), 10 deletions(-) (limited to 'decoder') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index ff068be9..2c3a06de 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -416,6 +416,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") ("extract_rules", po::value(), "Extract the rules used in translation (de-duped) to this file") + ("show_derivations", po::value(), "Directory to print the derivation structures to") ("graphviz","Show (constrained) translation forest in GraphViz format") ("max_translation_beam,x", po::value(), "Beam approximation to get max translation from the chart") ("max_translation_sample,X", po::value(), "Sample the max translation from the chart") @@ -426,6 +427,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("vector_format",po::value()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") ("forest_output,O",po::value(),"Directory to write forests to"); + // ob.AddOptions(&opts); #ifdef FSA_RESCORING po::options_description cfgo(cfg_options.description()); @@ -677,6 +679,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream kbest = conf.count("k_best"); unique_kbest = conf.count("unique_k_best"); get_oracle_forest = conf.count("get_oracle_forest"); + oracle.show_derivation=conf.count("show_derivations"); #ifdef FSA_RESCORING cfg_options.Validate(); @@ -938,7 +941,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } else { if (kbest && !has_ref) { //TODO: does this work properly? - oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"-"); + const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; + oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"-", deriv_fname); } else if (csplit_output_plf) { cout << HypergraphIO::AsPLF(forest, false) << endl; } else { @@ -1055,8 +1059,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } } if (conf.count("graphviz")) forest.PrintGraphviz(); - if (kbest) - oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"-"); + if (kbest) { + const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; + oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"-", deriv_fname); + } if (conf.count("show_conditional_prob")) { const prob_t ref_z = Inside(forest); cout << (log(ref_z) - log(first_z)) << endl << flush; diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 15d48588..b603e27a 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -272,23 +272,31 @@ struct OracleBleu { } kbest_out<score)<<"\n"; deriv_out< > >(sent_id,forest,k,ko.get(),std::cerr); + kbest > >(sent_id,forest,k,ko.get(),oderiv.get()); else { - kbest(sent_id,forest,k,ko.get(),std::cerr); + kbest(sent_id,forest,k,ko.get(),oderiv.get()); } } @@ -296,7 +304,7 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo { std::ostringstream kbest_string_stream; kbest_string_stream << forest_output << "/kbest_"< Date: Fri, 8 Jul 2011 13:56:42 +0200 Subject: add Fast Cube Pruning --- decoder/apply_models.cc | 196 ++++++++++++++++++++++++++++++++++++++++++++++-- decoder/apply_models.h | 6 +- decoder/decoder.cc | 10 ++- 3 files changed, 204 insertions(+), 8 deletions(-) (limited to 'decoder') diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9390c809..62eff262 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -17,6 +17,10 @@ #include "hg.h" #include "ff.h" +#define NORMAL_CP 1 +#define FAST_CP 2 +#define FAST_CP_2 3 + using namespace std; using namespace std::tr1; @@ -164,13 +168,15 @@ public: const SentenceMetadata& sm, const Hypergraph& i, int pop_limit, - Hypergraph* o) : + Hypergraph* o, + int s = NORMAL_CP ) : models(m), smeta(sm), in(i), out(*o), D(in.nodes_.size()), - pop_limit_(pop_limit) { + pop_limit_(pop_limit), + strategy_(s){ if (!SILENT) cerr << " Applying feature functions (cube pruning, pop_limit = " << pop_limit_ << ')' << endl; node_states_.reserve(kRESERVE_NUM_NODES); } @@ -186,7 +192,15 @@ public: if (!SILENT) cerr << " "; for (int i = 0; i < in.nodes_.size(); ++i) { if (!SILENT && i % every == 0) cerr << '.'; - KBest(i, i == goal_id); + if (strategy_==NORMAL_CP){ + KBest(i, i == goal_id); + } + if (strategy_==FAST_CP){ + KBestFast(i, i == goal_id); + } + if (strategy_==FAST_CP_2){ + KBestFast2(i, i == goal_id); + } } if (!SILENT) { cerr << endl; @@ -283,6 +297,114 @@ public: delete freelist[i]; } + void KBestFast(const int vert_index, const bool is_goal) { + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + + PushSuccFast(*item, is_goal, &cand); + IncorporateIntoPlusLMForest(item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ + D_v[c++] = i->second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_accepted; + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + assert(unique_accepted.insert(item).second); // these should all be unique! + // cerr << "POPPED: " << *item << endl; + + PushSuccFast2(*item, is_goal, &cand, &unique_accepted); + IncorporateIntoPlusLMForest(item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ + D_v[c++] = i->second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<tail_nodes_[i]].size()) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + if(item.j_[i]!=0){ + return; + } + } + } + + //PushSucc only if all ancest Cand are added + void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){ + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (HasAllAncestors(&query_unique,ps)) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + } + } + } + + bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){ + for (int i = 0; i < item->j_.size(); ++i) { + JVector j = item->j_; + --j[i]; + if (j[i] >=0) { + Candidate query_unique(*item->in_edge_, j); + if (cs->count(&query_unique) == 0) { + return false; + } + } + } + return true; + } + const ModelSet& models; const SentenceMetadata& smeta; const Hypergraph& in; @@ -311,6 +481,7 @@ public: FFStates node_states_; // for each node in the out-HG what is // its q function value? const int pop_limit_; + const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) }; struct NoPruningRescorer { @@ -412,15 +583,28 @@ void ApplyModelSet(const Hypergraph& in, if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) { NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state ma.Apply(); - } else if (config.algorithm == IntersectionConfiguration::CUBE) { + } else if (config.algorithm == IntersectionConfiguration::CUBE + || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING + || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) { int pl = config.pop_limit; const int max_pl_for_large=50; if (pl > max_pl_for_large && in.nodes_.size() > 80000) { pl = max_pl_for_large; cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; } - CubePruningRescorer ma(models, smeta, in, pl, out); - ma.Apply(); + if (config.algorithm == IntersectionConfiguration::CUBE) { + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); + } + else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){ + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); + ma.Apply(); + } + else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){ + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); + ma.Apply(); + } + } else { cerr << "Don't understand intersection algorithm " << config.algorithm << endl; exit(1); diff --git a/decoder/apply_models.h b/decoder/apply_models.h index a85694aa..19a4c7be 100644 --- a/decoder/apply_models.h +++ b/decoder/apply_models.h @@ -13,6 +13,8 @@ struct IntersectionConfiguration { enum { FULL, CUBE, + FAST_CUBE_PRUNING, + FAST_CUBE_PRUNING_2, N_ALGORITHMS }; @@ -25,7 +27,9 @@ enum { inline std::ostream& operator<<(std::ostream& os, const IntersectionConfiguration& c) { if (c.algorithm == 0) { os << "FULL"; } else if (c.algorithm == 1) { os << "CUBE:k=" << c.pop_limit; } - else if (c.algorithm == 2) { os << "N_ALGORITHMS"; } + else if (c.algorithm == 2) { os << "FAST_CUBE_PRUNING"; } + else if (c.algorithm == 3) { os << "FAST_CUBE_PRUNING_2"; } + else if (c.algorithm == 4) { os << "N_ALGORITHMS"; } else os << "OTHER"; return os; } diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 2c3a06de..8a4a1485 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -357,7 +357,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("weights,w",po::value(),"Feature weights file (initial forest / pass 1)") ("feature_function,F",po::value >()->composing(), "Pass 1 additional feature function(s) (-L for list)") - ("intersection_strategy,I",po::value()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full") + ("intersection_strategy,I",po::value()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full, Fast_cube_pruning, Fast_cube_pruning_2") ("summary_feature", po::value(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z") ("summary_feature_type", po::value()->default_value("node_risk"), "Summary feature types: node_risk, edge_risk, edge_prob") ("density_prune", po::value(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") @@ -597,6 +597,14 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream if (LowercaseString(str(isn.c_str(),conf)) == "full") { palg = 0; } + if (LowercaseString(conf["intersection_strategy"].as()) == "fast_cube_pruning") { + palg = 2; + cerr << "Using Fast Cube Pruning intersection (see Algorithm 2 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n"; + } + if (LowercaseString(conf["intersection_strategy"].as()) == "fast_cube_pruning_2") { + palg = 3; + cerr << "Using Fast Cube Pruning 2 intersection (see Algorithm 3 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n"; + } rp.inter_conf.reset(new IntersectionConfiguration(palg, pop_limit)); } else { break; // TODO alert user if there are any future configurations -- cgit v1.2.3 From ed8a6e81d87f6e917ecffc290cde0a340b6aa03b Mon Sep 17 00:00:00 2001 From: andrea gesmundo Date: Fri, 8 Jul 2011 15:33:47 +0200 Subject: add cp time measure (def macro) --- decoder/cdec.cc | 8 ++++++++ decoder/decoder.cc | 13 +++++++++++++ decoder/decoder.h | 14 ++++++++++++++ 3 files changed, 35 insertions(+) (limited to 'decoder') diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 5c40f56e..c671af57 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -19,11 +19,19 @@ int main(int argc, char** argv) { assert(*in); string buf; +#ifdef CP_TIME + clock_t time_cp(0);//, end_cp; +#endif while(*in) { getline(*in, buf); if (buf.empty()) continue; decoder.Decode(buf); } +#ifdef CP_TIME + cerr << "Time required for Cube Pruning execution: " + << CpTime::Get() + << " seconds." << "\n\n"; +#endif if (show_feature_dictionary) { int num = FD::NumFeats(); for (int i = 1; i < num; ++i) { diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 8a4a1485..76f31352 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -46,6 +46,13 @@ #include "cfg_options.h" #endif +#ifdef CP_TIME + clock_t CpTime::time_; + void CpTime::Add(clock_t x){time_+=x;} + void CpTime::Sub(clock_t x){time_-=x;} + double CpTime::Get(){return (double)(time_)/CLOCKS_PER_SEC;} +#endif + static const double kMINUS_EPSILON = -1e-6; // don't be too strict using namespace std; @@ -806,11 +813,17 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { Timer t("Forest rescoring:"); rp.models->PrepareForInput(smeta); Hypergraph rescored_forest; +#ifdef CP_TIME + CpTime::Sub(clock()); +#endif ApplyModelSet(forest, smeta, *rp.models, *rp.inter_conf, &rescored_forest); +#ifdef CP_TIME + CpTime::Add(clock()); +#endif forest.swap(rescored_forest); forest.Reweight(cur_weights); if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation); diff --git a/decoder/decoder.h b/decoder/decoder.h index 813400e3..5491369f 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -7,6 +7,20 @@ #include #include +#undef CP_TIME +//#define CP_TIME +#ifdef CP_TIME +#include +struct CpTime{ +public: + static void Add(clock_t x); + static void Sub(clock_t x); + static double Get(); +private: + static clock_t time_; +}; +#endif + class SentenceMetadata; struct Hypergraph; struct DecoderImpl; -- cgit v1.2.3 From a037f52a87f7d5711b5521047e7fb3fcd756c647 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 13 Jul 2011 00:14:34 -0400 Subject: escape feature names --- decoder/ff_spans.cc | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc index e1da088d..bc23974d 100644 --- a/decoder/ff_spans.cc +++ b/decoder/ff_spans.cc @@ -13,6 +13,17 @@ using namespace std; +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + // log transform to make long spans cluster together // but preserve differences int SpanSizeTransform(unsigned span_size) { @@ -140,19 +151,19 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { word = MapIfNecessary(word); ostringstream sfid; sfid << "ES:" << TD::Convert(word); - end_span_ids_[i] = FD::Convert(sfid.str()); + end_span_ids_[i] = FD::Convert(Escape(sfid.str())); ostringstream esbiid; esbiid << "EBI:" << TD::Convert(bword) << "_" << TD::Convert(word); - end_bigram_ids_[i] = FD::Convert(esbiid.str()); + end_bigram_ids_[i] = FD::Convert(Escape(esbiid.str())); ostringstream bsbiid; bsbiid << "BBI:" << TD::Convert(bword) << "_" << TD::Convert(word); - beg_bigram_ids_[i] = FD::Convert(bsbiid.str()); + beg_bigram_ids_[i] = FD::Convert(Escape(bsbiid.str())); ostringstream bfid; bfid << "BS:" << TD::Convert(bword); - beg_span_ids_[i] = FD::Convert(bfid.str()); + beg_span_ids_[i] = FD::Convert(Escape(bfid.str())); if (use_collapsed_features_) { - end_span_vals_[i] = feat2val_[sfid.str()] + feat2val_[esbiid.str()]; - beg_span_vals_[i] = feat2val_[bfid.str()] + feat2val_[bsbiid.str()]; + end_span_vals_[i] = feat2val_[Escape(sfid.str())] + feat2val_[Escape(esbiid.str())]; + beg_span_vals_[i] = feat2val_[Escape(bfid.str())] + feat2val_[Escape(bsbiid.str())]; } } for (int i = 0; i <= lattice.size(); ++i) { @@ -167,16 +178,16 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { word = MapIfNecessary(word); ostringstream pf; pf << "S:" << TD::Convert(bword) << "_" << TD::Convert(word); - span_feats_(i,j).first = FD::Convert(pf.str()); - span_feats_(i,j).second = FD::Convert("S_" + pf.str()); + span_feats_(i,j).first = FD::Convert(Escape(pf.str())); + span_feats_(i,j).second = FD::Convert(Escape("S_" + pf.str())); ostringstream lf; const unsigned span_size = (i < j ? j - i : i - j); lf << "LS:" << SpanSizeTransform(span_size) << "_" << TD::Convert(bword) << "_" << TD::Convert(word); - len_span_feats_(i,j).first = FD::Convert(lf.str()); - len_span_feats_(i,j).second = FD::Convert("S_" + lf.str()); + len_span_feats_(i,j).first = FD::Convert(Escape(lf.str())); + len_span_feats_(i,j).second = FD::Convert(Escape("S_" + lf.str())); if (use_collapsed_features_) { - span_vals_(i,j).first = feat2val_[pf.str()] + feat2val_[lf.str()]; - span_vals_(i,j).second = feat2val_["S_" + pf.str()] + feat2val_["S_" + lf.str()]; + span_vals_(i,j).first = feat2val_[Escape(pf.str())] + feat2val_[Escape(lf.str())]; + span_vals_(i,j).second = feat2val_[Escape("S_" + pf.str())] + feat2val_[Escape("S_" + lf.str())]; } } } @@ -209,14 +220,14 @@ void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, const string& cur = TD::Convert(w); ostringstream os; os << "RB:" << prev << '_' << cur; - const int fid = FD::Convert(os.str()); + const int fid = FD::Convert(Escape(os.str())); if (fid <= 0) return; f.add_value(fid, 1.0); prev = cur; } ostringstream os; os << "RB:" << prev << '_' << ""; - f.set_value(FD::Convert(os.str()), 1.0); + f.set_value(FD::Convert(Escape(os.str())), 1.0); } (*features) += it->second; } -- cgit v1.2.3 From 816bee82abc909335d4f3a300cff99afa4dd1da5 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 13 Jul 2011 18:00:22 -0400 Subject: escape bad feature names --- decoder/ff_ngrams.cc | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc index d52667cd..04dd1906 100644 --- a/decoder/ff_ngrams.cc +++ b/decoder/ff_ngrams.cc @@ -46,6 +46,17 @@ struct State { }; } +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + class NgramDetectorImpl { // returns the number of unscored words at the left edge of a span @@ -114,11 +125,17 @@ class NgramDetectorImpl { int& fid = ft->fids[curword]; ++n; if (!fid) { - const char* code="_UBT456789"; + const char* code="_UBT456789"; // prefix code (unigram, bigram, etc.) ostringstream os; os << code[n] << ':'; - for (int i = n-1; i >= 0; --i) - os << (i != n-1 ? "_" : "") << TD::Convert(buf[i]); + for (int i = n-1; i >= 0; --i) { + os << (i != n-1 ? "_" : ""); + const string& tok = TD::Convert(buf[i]); + if (tok.find('=') == string::npos) + os << tok; + else + os << Escape(tok); + } fid = FD::Convert(os.str()); } feats->set_value(fid, 1); -- cgit v1.2.3 From d73b5d25bd0af14a4a83490d67ba2553b6af9884 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 28 Jul 2011 17:08:59 +0100 Subject: stuff --- decoder/apply_models.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'decoder') diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 62eff262..26cdb881 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -190,8 +190,12 @@ public: if (num_nodes > 100) every = 10; assert(in.nodes_[pregoal].out_edges_.size() == 1); if (!SILENT) cerr << " "; + int has = 0; for (int i = 0; i < in.nodes_.size(); ++i) { - if (!SILENT && i % every == 0) cerr << '.'; + if (!SILENT) { + int needs = (50 * i / in.nodes_.size()); + while (has < needs) { cerr << '.'; ++has; } + } if (strategy_==NORMAL_CP){ KBest(i, i == goal_id); } -- cgit v1.2.3 From 2c14cf2218031c29a9884bccf17e9273c71a33b2 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 18 Aug 2011 12:14:01 +0100 Subject: KenLM update: Bhiksha's trick, simple test for lms without unk, auto-detect binary files instead of requiring them to be specified at runtime. --- decoder/cdec_ff.cc | 5 +- decoder/ff_klm.cc | 38 +++++- decoder/ff_klm.h | 7 +- klm/compile.sh | 2 +- klm/lm/Makefile.am | 1 + klm/lm/bhiksha.cc | 93 +++++++++++++++ klm/lm/bhiksha.hh | 108 +++++++++++++++++ klm/lm/binary_format.cc | 13 ++- klm/lm/binary_format.hh | 9 +- klm/lm/build_binary.cc | 54 ++++++--- klm/lm/config.cc | 1 + klm/lm/config.hh | 5 +- klm/lm/model.cc | 67 ++++++----- klm/lm/model.hh | 12 +- klm/lm/model_test.cc | 73 ++++++++++-- klm/lm/ngram_query.cc | 9 ++ klm/lm/quantize.cc | 1 + klm/lm/quantize.hh | 4 +- klm/lm/read_arpa.cc | 6 +- klm/lm/search_hashed.cc | 2 +- klm/lm/search_hashed.hh | 3 +- klm/lm/search_trie.cc | 45 +++---- klm/lm/search_trie.hh | 20 ++-- klm/lm/test_nounk.arpa | 120 +++++++++++++++++++ klm/lm/trie.cc | 57 ++++----- klm/lm/trie.hh | 24 ++-- klm/lm/vocab.cc | 6 +- klm/lm/vocab.hh | 4 + klm/util/bit_packing.hh | 13 ++- klm/util/murmur_hash.cc | 258 ++++++++++++++++++++--------------------- klm/util/probing_hash_table.hh | 2 +- klm/util/sorted_uniform.hh | 23 +++- 32 files changed, 792 insertions(+), 293 deletions(-) create mode 100644 klm/lm/bhiksha.cc create mode 100644 klm/lm/bhiksha.hh create mode 100644 klm/lm/test_nounk.arpa (limited to 'decoder') diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 3451c9fb..1ef76a05 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -55,10 +55,7 @@ void register_feature_functions() { ff_registry.Register("NgramFeatures", new FFFactory()); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); - ff_registry.Register("KLanguageModel", new FFFactory >()); - ff_registry.Register("KLanguageModel_Trie", new FFFactory >()); - ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory >()); - ff_registry.Register("KLanguageModel_Probing", new FFFactory >()); + ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); ff_registry.Register("NonLatinCount", new FFFactory); ff_registry.Register("RuleShape", new FFFactory); ff_registry.Register("RelativeSentencePosition", new FFFactory); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 9b7fe2d3..24dcb9c3 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -9,6 +9,7 @@ #include "stringlib.h" #include "hg.h" #include "tdict.h" +#include "lm/model.hh" #include "lm/enumerate_vocab.hh" using namespace std; @@ -434,8 +435,37 @@ void KLanguageModel::FinalTraversalFeatures(const void* ant_state, features->set_value(oov_fid_, oovs); } -// instantiate templates -template class KLanguageModel; -template class KLanguageModel; -template class KLanguageModel; +template boost::shared_ptr CreateModel(const std::string ¶m) { + KLanguageModel *ret = new KLanguageModel(param); + ret->Init(); + return boost::shared_ptr(ret); +} +boost::shared_ptr KLanguageModelFactory::Create(std::string param) const { + using namespace lm::ngram; + std::string filename, ignored_map; + bool ignored_markers; + std::string ignored_featname; + ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); + ModelType m; + if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + + switch (m) { + case HASH_PROBING: + return CreateModel(param); + case TRIE_SORTED: + return CreateModel(param); + case ARRAY_TRIE_SORTED: + return CreateModel(param); + case QUANT_TRIE_SORTED: + return CreateModel(param); + case QUANT_ARRAY_TRIE_SORTED: + return CreateModel(param); + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); + } +} + +std::string KLanguageModelFactory::usage(bool params,bool verbose) const { + return KLanguageModel::usage(params, verbose); +} diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index 5eafe8be..6efe50f6 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -4,8 +4,8 @@ #include #include +#include "ff_factory.h" #include "ff.h" -#include "lm/model.hh" template struct KLanguageModelImpl; @@ -34,4 +34,9 @@ class KLanguageModel : public FeatureFunction { KLanguageModelImpl* pimpl_; }; +struct KLanguageModelFactory : public FactoryBase { + FP Create(std::string param) const; + std::string usage(bool params,bool verbose) const; +}; + #endif diff --git a/klm/compile.sh b/klm/compile.sh index 6ca85e1f..abe3473a 100755 --- a/klm/compile.sh +++ b/klm/compile.sh @@ -5,7 +5,7 @@ set -e -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do +for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{bhiksha,binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o done g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 395494bc..fae6b41a 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -12,6 +12,7 @@ build_binary_LDADD = libklm.a ../util/libklm_util.a -lz noinst_LIBRARIES = libklm.a libklm_a_SOURCES = \ + bhiksha.cc \ binary_format.cc \ config.cc \ lm_exception.cc \ diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc new file mode 100644 index 00000000..bf86fd4b --- /dev/null +++ b/klm/lm/bhiksha.cc @@ -0,0 +1,93 @@ +#include "lm/bhiksha.hh" +#include "lm/config.hh" + +#include + +namespace lm { +namespace ngram { +namespace trie { + +DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) : + next_(util::BitsMask::ByMax(max_next)) {} + +const uint8_t kArrayBhikshaVersion = 0; + +void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { + uint8_t version; + uint8_t configured_bits; + if (read(fd, &version, 1) != 1 || read(fd, &configured_bits, 1) != 1) { + UTIL_THROW(util::ErrnoException, "Could not read from binary file"); + } + if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); + config.pointer_bhiksha_bits = configured_bits; +} + +namespace { + +// Find argmin_{chopped \in [0, RequiredBits(max_next)]} ChoppedDelta(max_offset) +uint8_t ChopBits(uint64_t max_offset, uint64_t max_next, const Config &config) { + uint8_t required = util::RequiredBits(max_next); + uint8_t best_chop = 0; + int64_t lowest_change = std::numeric_limits::max(); + // There are probably faster ways but I don't care because this is only done once per order at construction time. + for (uint8_t chop = 0; chop <= std::min(required, config.pointer_bhiksha_bits); ++chop) { + int64_t change = (max_next >> (required - chop)) * 64 /* table cost in bits */ + - max_offset * static_cast(chop); /* savings in bits*/ + if (change < lowest_change) { + lowest_change = change; + best_chop = chop; + } + } + return best_chop; +} + +std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &config) { + uint8_t required = util::RequiredBits(max_next); + uint8_t chopping = ChopBits(max_offset, max_next, config); + return (max_next >> (required - chopping)) + 1 /* we store 0 too */; +} +} // namespace + +std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { + return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */; +} + +uint8_t ArrayBhiksha::InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config) { + return util::RequiredBits(max_next) - ChopBits(max_offset, max_next, config); +} + +namespace { + +void *AlignTo8(void *from) { + uint8_t *val = reinterpret_cast(from); + std::size_t remainder = reinterpret_cast(val) & 7; + if (!remainder) return val; + return val + 8 - remainder; +} + +} // namespace + +ArrayBhiksha::ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_next, const Config &config) + : next_inline_(util::BitsMask::ByBits(InlineBits(max_offset, max_next, config))), + offset_begin_(reinterpret_cast(AlignTo8(base)) + 1 /* 8-byte header */), + offset_end_(offset_begin_ + ArrayCount(max_offset, max_next, config)), + write_to_(reinterpret_cast(AlignTo8(base)) + 1 /* 8-byte header */ + 1 /* first entry is 0 */), + original_base_(base) {} + +void ArrayBhiksha::FinishedLoading(const Config &config) { + // *offset_begin_ = 0 but without a const_cast. + *(write_to_ - (write_to_ - offset_begin_)) = 0; + + if (write_to_ != offset_end_) UTIL_THROW(util::Exception, "Did not get all the array entries that were expected."); + + uint8_t *head_write = reinterpret_cast(original_base_); + *(head_write++) = kArrayBhikshaVersion; + *(head_write++) = config.pointer_bhiksha_bits; +} + +void ArrayBhiksha::LoadedBinary() { +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh new file mode 100644 index 00000000..cfb2b053 --- /dev/null +++ b/klm/lm/bhiksha.hh @@ -0,0 +1,108 @@ +/* Simple implementation of + * @inproceedings{bhikshacompression, + * author={Bhiksha Raj and Ed Whittaker}, + * year={2003}, + * title={Lossless Compression of Language Model Structure and Word Identifiers}, + * booktitle={Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing}, + * pages={388--391}, + * } + * + * Currently only used for next pointers. + */ + +#include + +#include "lm/binary_format.hh" +#include "lm/trie.hh" +#include "util/bit_packing.hh" +#include "util/sorted_uniform.hh" + +namespace lm { +namespace ngram { +class Config; + +namespace trie { + +class DontBhiksha { + public: + static const ModelType kModelTypeAdd = static_cast(0); + + static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} + + static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } + + static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) { + return util::RequiredBits(max_next); + } + + DontBhiksha(const void *base, uint64_t max_offset, uint64_t max_next, const Config &config); + + void ReadNext(const void *base, uint64_t bit_offset, uint64_t /*index*/, uint8_t total_bits, NodeRange &out) const { + out.begin = util::ReadInt57(base, bit_offset, next_.bits, next_.mask); + out.end = util::ReadInt57(base, bit_offset + total_bits, next_.bits, next_.mask); + //assert(out.end >= out.begin); + } + + void WriteNext(void *base, uint64_t bit_offset, uint64_t /*index*/, uint64_t value) { + util::WriteInt57(base, bit_offset, next_.bits, value); + } + + void FinishedLoading(const Config &/*config*/) {} + + void LoadedBinary() {} + + uint8_t InlineBits() const { return next_.bits; } + + private: + util::BitsMask next_; +}; + +class ArrayBhiksha { + public: + static const ModelType kModelTypeAdd = kArrayAdd; + + static void UpdateConfigFromBinary(int fd, Config &config); + + static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); + + static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config); + + ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config); + + void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const { + const uint64_t *begin_it = util::BinaryBelow(util::IdentityAccessor(), offset_begin_, offset_end_, index); + const uint64_t *end_it; + for (end_it = begin_it; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {} + --end_it; + out.begin = ((begin_it - offset_begin_) << next_inline_.bits) | + util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask); + out.end = ((end_it - offset_begin_) << next_inline_.bits) | + util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask); + } + + void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) { + uint64_t encode = value >> next_inline_.bits; + for (; write_to_ <= offset_begin_ + encode; ++write_to_) *write_to_ = index; + util::WriteInt57(base, bit_offset, next_inline_.bits, value & next_inline_.mask); + } + + void FinishedLoading(const Config &config); + + void LoadedBinary(); + + uint8_t InlineBits() const { return next_inline_.bits; } + + private: + const util::BitsMask next_inline_; + + const uint64_t *const offset_begin_; + const uint64_t *const offset_end_; + + uint64_t *write_to_; + + void *original_base_; +}; + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 92b1008b..e02e621a 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -40,7 +40,7 @@ struct Sanity { } }; -const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"}; +const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; std::size_t Align8(std::size_t in) { std::size_t off = in % 8; @@ -100,16 +100,17 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ } } -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) { +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { + std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; if (config.write_mmap) { // Grow the file to accomodate the search, using zeros. - if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size)) - UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed"); + if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size)) + UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed"); // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. off_t page_size = sysconf(_SC_PAGE_SIZE); - off_t alignment_cruft = backing.vocab.size() % page_size; - backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); + off_t alignment_cruft = adjusted_vocab % page_size; + backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); return reinterpret_cast(backing.search.get()) + alignment_cruft; } else { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 2b32b450..d28cb6c5 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -16,7 +16,12 @@ namespace lm { namespace ngram { -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType; +/* Not the best numbering system, but it grew this way for historical reasons + * and I want to preserve existing binary files. */ +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; + +const static ModelType kQuantAdd = static_cast(QUANT_TRIE_SORTED - TRIE_SORTED); +const static ModelType kArrayAdd = static_cast(ARRAY_TRIE_SORTED - TRIE_SORTED); /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in @@ -55,7 +60,7 @@ void AdvanceOrThrow(int fd, off_t off); // Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing); +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing); // Write header to binary file. This is done last to prevent incomplete files // from loading. diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 4552c419..b7aee4de 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,12 +15,12 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n" -"-u sets the default log10 probability for if the ARPA file does not have\n" -"one.\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n" +"-u sets the log10 probability for if the ARPA file does not have one.\n" +" Default is -100. The ARPA file will always take precedence.\n" "-s allows models to be built even if they do not have and .\n" -"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" -"type is either probing or trie:\n\n" +"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n" +"type is either probing or trie. Default is probing.\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" "trie is a straightforward trie with bit-level packing. It uses the least\n" @@ -29,10 +29,11 @@ void Usage(const char *name) { "-t is the temporary directory prefix. Default is the output file name.\n" "-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n" "-q turns quantization on and sets the number of bits (e.g. -q 8).\n" -"-b sets backoff quantization bits. Requires -q and defaults to that value.\n\n" -"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" -"Passing only an input file will print memory usage of each data structure.\n" -"If the ARPA file does not have , -u sets 's probability; default 0.0.\n"; +"-b sets backoff quantization bits. Requires -q and defaults to that value.\n" +"-a compresses pointers using an array of offsets. The parameter is the\n" +" maximum number of bits encoded by the array. Memory is minimized subject\n" +" to the maximum, so pick 255 to minimize memory.\n\n" +"Get a memory estimate by passing an ARPA file without an output file name.\n"; exit(1); } @@ -63,12 +64,14 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t sizes[3]; + std::size_t sizes[5]; sizes[0] = ProbingModel::Size(counts, config); sizes[1] = TrieModel::Size(counts, config); sizes[2] = QuantTrieModel::Size(counts, config); - std::size_t max_length = *std::max_element(sizes, sizes + 3); - std::size_t min_length = *std::max_element(sizes, sizes + 3); + sizes[3] = ArrayTrieModel::Size(counts, config); + sizes[4] = QuantArrayTrieModel::Size(counts, config); + std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); + std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); std::size_t divide; char prefix; if (min_length < (1 << 10) * 10) { @@ -91,7 +94,9 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::cout << prefix << "B\n" "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" "trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n" - "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"; + "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" + "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" + "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; } void ProbingQuantizationUnsupported() { @@ -106,11 +111,11 @@ void ProbingQuantizationUnsupported() { int main(int argc, char *argv[]) { using namespace lm::ngram; - bool quantize = false, set_backoff_bits = false; try { + bool quantize = false, set_backoff_bits = false, bhiksha = false; lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) { + while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) { switch(opt) { case 'q': config.prob_bits = ParseBitCount(optarg); @@ -121,6 +126,9 @@ int main(int argc, char *argv[]) { config.backoff_bits = ParseBitCount(optarg); set_backoff_bits = true; break; + case 'a': + config.pointer_bhiksha_bits = ParseBitCount(optarg); + bhiksha = true; case 'u': config.unknown_missing_logprob = ParseFloat(optarg); break; @@ -162,9 +170,17 @@ int main(int argc, char *argv[]) { ProbingModel(from_file, config); } else if (!strcmp(model_type, "trie")) { if (quantize) { - QuantTrieModel(from_file, config); + if (bhiksha) { + QuantArrayTrieModel(from_file, config); + } else { + QuantTrieModel(from_file, config); + } } else { - TrieModel(from_file, config); + if (bhiksha) { + ArrayTrieModel(from_file, config); + } else { + TrieModel(from_file, config); + } } } else { Usage(argv[0]); @@ -173,9 +189,9 @@ int main(int argc, char *argv[]) { Usage(argv[0]); } } - catch (std::exception &e) { + catch (const std::exception &e) { std::cerr << e.what() << std::endl; - abort(); + return 1; } return 0; } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index 08e1af5c..297589a4 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -20,6 +20,7 @@ Config::Config() : include_vocab(true), prob_bits(8), backoff_bits(8), + pointer_bhiksha_bits(22), load_method(util::POPULATE_OR_READ) {} } // namespace ngram diff --git a/klm/lm/config.hh b/klm/lm/config.hh index dcc7cf35..227b8512 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -73,9 +73,12 @@ struct Config { // Quantization options. Only effective for QuantTrieModel. One value is // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used - // to quantize. + // to quantize (and one of the remaining backoffs will be 0). uint8_t prob_bits, backoff_bits; + // Bhiksha compression (simple form). Only works with trie. + uint8_t pointer_bhiksha_bits; + // ONLY EFFECTIVE WHEN READING BINARY diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a1d10b3d..27e24b1c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -21,6 +21,8 @@ size_t hash_value(const State &state) { namespace detail { +template const ModelType GenericModel::kModelType = Search::kModelType; + template size_t GenericModel::Size(const std::vector &counts, const Config &config) { return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); } @@ -56,35 +58,40 @@ template void GenericModel void GenericModel::InitializeFromARPA(const char *file, const Config &config) { // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. util::FilePiece f(backing_.file.release(), file, config.messages); - std::vector counts; - // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. - ReadARPACounts(f, counts); - - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); - if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); - if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); - - std::size_t vocab_size = VocabularyT::Size(counts[0], config); - // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. - vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); - - if (config.write_mmap) { - WriteWordsWrapper wrap(config.enumerate_vocab); - vocab_.ConfigureEnumerate(&wrap, counts[0]); - search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get()); - } else { - vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); - search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - } + try { + std::vector counts; + // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. + ReadARPACounts(f, counts); + + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); + if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); + + std::size_t vocab_size = VocabularyT::Size(counts[0], config); + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); + + if (config.write_mmap) { + WriteWordsWrapper wrap(config.enumerate_vocab); + vocab_.ConfigureEnumerate(&wrap, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + wrap.Write(backing_.file.get()); + } else { + vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + } - if (!vocab_.SawUnk()) { - assert(config.unknown_missing != THROW_UP); - // Default probabilities for unknown. - search_.unigram.Unknown().backoff = 0.0; - search_.unigram.Unknown().prob = config.unknown_missing_logprob; + if (!vocab_.SawUnk()) { + assert(config.unknown_missing != THROW_UP); + // Default probabilities for unknown. + search_.unigram.Unknown().backoff = 0.0; + search_.unigram.Unknown().prob = config.unknown_missing_logprob; + } + FinishFile(config, kModelType, counts, backing_); + } catch (util::Exception &e) { + e << " Byte: " << f.Offset(); + throw; } - FinishFile(config, kModelType, counts, backing_); } template FullScoreReturn GenericModel::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { @@ -225,8 +232,10 @@ template FullScoreReturn GenericModel; // HASH_PROBING -template class GenericModel, SortedVocabulary>; // TRIE_SORTED -template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel, SortedVocabulary>; // TRIE_SORTED +template class GenericModel, SortedVocabulary>; +template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel, SortedVocabulary>; } // namespace detail } // namespace ngram diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 1f49a382..21595321 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -1,6 +1,7 @@ #ifndef LM_MODEL__ #define LM_MODEL__ +#include "lm/bhiksha.hh" #include "lm/binary_format.hh" #include "lm/config.hh" #include "lm/facade.hh" @@ -71,6 +72,9 @@ template class GenericModel : public base::Mod private: typedef base::ModelFacade, State, VocabularyT> P; public: + // This is the model type returned by RecognizeBinary. + static const ModelType kModelType; + /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class * itself. @@ -131,8 +135,6 @@ template class GenericModel : public base::Mod Backing &MutableBacking() { return backing_; } - static const ModelType kModelType = Search::kModelType; - Backing backing_; VocabularyT vocab_; @@ -152,9 +154,11 @@ typedef ProbingModel Model; // Smaller implementation. typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> ArrayTrieModel; -typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> QuantArrayTrieModel; } // namespace ngram } // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 8bf040ff..57c7291c 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -193,6 +193,14 @@ template void Stateless(const M &model) { BOOST_CHECK_EQUAL(static_cast(0), state.history_[0]); } +template void NoUnkCheck(const M &model) { + WordIndex unk_index = 0; + State state; + + FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state); + BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001); +} + template void Everything(const M &m) { Starters(m); Continuation(m); @@ -231,25 +239,38 @@ template void LoadingTest() { Config config; config.arpa_complain = Config::NONE; config.messages = NULL; - ExpectEnumerateVocab enumerate; - config.enumerate_vocab = &enumerate; config.probing_multiplier = 2.0; - ModelT m("test.arpa", config); - enumerate.Check(m.GetVocabulary()); - Everything(m); + { + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test.arpa", config); + enumerate.Check(m.GetVocabulary()); + Everything(m); + } + { + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test_nounk.arpa", config); + enumerate.Check(m.GetVocabulary()); + NoUnkCheck(m); + } } BOOST_AUTO_TEST_CASE(probing) { LoadingTest(); } - BOOST_AUTO_TEST_CASE(trie) { LoadingTest(); } - -BOOST_AUTO_TEST_CASE(quant) { +BOOST_AUTO_TEST_CASE(quant_trie) { LoadingTest(); } +BOOST_AUTO_TEST_CASE(bhiksha_trie) { + LoadingTest(); +} +BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) { + LoadingTest(); +} template void BinaryTest() { Config config; @@ -267,10 +288,34 @@ template void BinaryTest() { config.write_mmap = NULL; - ModelT binary("test.binary", config); - enumerate.Check(binary.GetVocabulary()); - Everything(binary); + ModelType type; + BOOST_REQUIRE(RecognizeBinary("test.binary", type)); + BOOST_CHECK_EQUAL(ModelT::kModelType, type); + + { + ModelT binary("test.binary", config); + enumerate.Check(binary.GetVocabulary()); + Everything(binary); + } unlink("test.binary"); + + // Now test without . + config.write_mmap = "test_nounk.binary"; + config.messages = NULL; + enumerate.Clear(); + { + ModelT copy_model("test_nounk.arpa", config); + enumerate.Check(copy_model.GetVocabulary()); + enumerate.Clear(); + NoUnkCheck(copy_model); + } + config.write_mmap = NULL; + { + ModelT binary("test_nounk.binary", config); + enumerate.Check(binary.GetVocabulary()); + NoUnkCheck(binary); + } + unlink("test_nounk.binary"); } BOOST_AUTO_TEST_CASE(write_and_read_probing) { @@ -282,6 +327,12 @@ BOOST_AUTO_TEST_CASE(write_and_read_trie) { BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) { BinaryTest(); } +BOOST_AUTO_TEST_CASE(write_and_read_array_trie) { + BinaryTest(); +} +BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) { + BinaryTest(); +} } // namespace } // namespace ngram diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 9454a6d1..d9db4aa2 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -99,6 +99,15 @@ int main(int argc, char *argv[]) { case lm::ngram::TRIE_SORTED: Query(argv[1], sentence_context); break; + case lm::ngram::QUANT_TRIE_SORTED: + Query(argv[1], sentence_context); + break; + case lm::ngram::ARRAY_TRIE_SORTED: + Query(argv[1], sentence_context); + break; + case lm::ngram::QUANT_ARRAY_TRIE_SORTED: + Query(argv[1], sentence_context); + break; case lm::ngram::HASH_SORTED: default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index 4bb6b1b8..fd371cc8 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -43,6 +43,7 @@ void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector(0); static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } @@ -108,7 +108,7 @@ class SeparatelyQuantize { }; public: - static const ModelType kModelType = QUANT_TRIE_SORTED; + static const ModelType kModelTypeAdd = kQuantAdd; static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config); diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 060a97ea..455bc4ba 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -31,15 +31,15 @@ const char kBinaryMagic[] = "mmap lm http://kheafield.com/code"; void ReadARPACounts(util::FilePiece &in, std::vector &number) { number.clear(); StringPiece line; - if (!IsEntirelyWhiteSpace(line = in.ReadLine())) { + while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} + if (line != "\\data\\") { if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast(line.data()[1]) == 0x8b)) { UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); } if (static_cast(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic) UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?"); - UTIL_THROW(FormatLoadException, "First line was \"" << line.data() << "\" not blank"); + UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\."); } - if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\."); while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); // So strtol doesn't go off the end of line. diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index c56ba7b8..82c53ec8 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -98,7 +98,7 @@ template uint8_t *TemplateHashedSearch template void TemplateHashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. - SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config); + SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config); PositiveProbWarn warn(config.positive_log_probability); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index f3acdefc..c62985e4 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -52,12 +52,11 @@ struct HashedSearch { Unigram unigram; - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { const ProbBackoff &entry = unigram.Lookup(word); prob = entry.prob; backoff = entry.backoff; next = static_cast(word); - return true; } }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 91f87f1c..05059ffb 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -1,6 +1,7 @@ /* This is where the trie is built. It's on-disk. */ #include "lm/search_trie.hh" +#include "lm/bhiksha.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" @@ -543,8 +544,8 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector appears. - size_t extra_count = counts[0] + 1; - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff)); + size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); CheckSpecials(config, vocab); if (!vocab.SawUnk()) ++counts[0]; @@ -610,9 +611,9 @@ class JustCount { }; // Phase to actually write n-grams to the trie. -template class WriteEntries { +template class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -649,7 +650,7 @@ template class WriteEntries { private: ContextReader *contexts_; UnigramValue *const unigrams_; - BitPackedMiddle *const middle_; + BitPackedMiddle *const middle_; BitPackedLongest &longest_; BitPacked &bigram_pack_; }; @@ -821,7 +822,7 @@ template void TrainProbQuantizer(uint8_t order, uint64_t count, So } // namespace -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing) { +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { std::vector inputs(counts.size() - 1); std::vector contexts(counts.size() - 1); @@ -846,7 +847,7 @@ template void BuildTrie(const std::string &file_prefix, std::vecto SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); if (Quant::kTrain) { util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); @@ -863,7 +864,7 @@ template void BuildTrie(const std::string &file_prefix, std::vecto UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -901,14 +902,14 @@ template void BuildTrie(const std::string &file_prefix, std::vecto /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. if (out.middle_begin_ != out.middle_end_) { - for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { - i->FinishedLoading((i+1)->InsertIndex()); + for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { + i->FinishedLoading((i+1)->InsertIndex(), config); } - (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex()); + (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config); } } -template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { +template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { quant_.SetupMemory(start, config); start += Quant::Size(counts.size(), config); unigram.Init(start); @@ -919,22 +920,24 @@ template uint8_t *TrieSearch::SetupMemory(uint8_t *start, c std::vector middle_starts(counts.size() - 2); for (unsigned char i = 2; i < counts.size(); ++i) { middle_starts[i-2] = start; - start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config); } - // Crazy backwards thing so we initialize in the correct order. + // Crazy backwards thing so we initialize using pointers to ones that have already been initialized for (unsigned char i = counts.size() - 1; i >= 2; --i) { new (middle_begin_ + i - 2) Middle( middle_starts[i-2], quant_.Mid(i), + counts[i-1], counts[0], counts[i], - (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1])); + (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1]), + config); } longest.Init(start, quant_.Long(counts.size()), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -template void TrieSearch::LoadedBinary() { +template void TrieSearch::LoadedBinary() { unigram.LoadedBinary(); for (Middle *i = middle_begin_; i != middle_end_; ++i) { i->LoadedBinary(); @@ -942,7 +945,7 @@ template void TrieSearch::LoadedBinary() { longest.LoadedBinary(); } -template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; @@ -966,14 +969,16 @@ template void TrieSearch::InitializeFromARPA(const char *fi // At least 1MB sorting memory. ARPAToSortedFiles(config, f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory, counts, config, *this, quant_, backing); + BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } } -template class TrieSearch; -template class TrieSearch; +template class TrieSearch; +template class TrieSearch; +template class TrieSearch; +template class TrieSearch; } // namespace trie } // namespace ngram diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0a52acb5..2f39c09f 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,31 +13,33 @@ struct Backing; class SortedVocabulary; namespace trie { -template class TrieSearch; -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); +template class TrieSearch; +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); -template class TrieSearch { +template class TrieSearch { public: typedef NodeRange Node; typedef ::lm::ngram::trie::Unigram Unigram; Unigram unigram; - typedef trie::BitPackedMiddle Middle; + typedef trie::BitPackedMiddle Middle; typedef trie::BitPackedLongest Longest; Longest longest; - static const ModelType kModelType = Quant::kModelType; + static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); + AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); + Bhiksha::UpdateConfigFromBinary(fd, config); } static std::size_t Size(const std::vector &counts, const Config &config) { std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); for (unsigned char i = 1; i < counts.size() - 1; ++i) { - ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); + ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config); } return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } @@ -55,8 +57,8 @@ template class TrieSearch { void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - return unigram.Find(word, prob, backoff, node); + void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + unigram.Find(word, prob, backoff, node); } bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { @@ -83,7 +85,7 @@ template class TrieSearch { } private: - friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); + friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); // Middles are managed manually so we can delay construction and they don't have to be copyable. void FreeMiddles() { diff --git a/klm/lm/test_nounk.arpa b/klm/lm/test_nounk.arpa new file mode 100644 index 00000000..060733d9 --- /dev/null +++ b/klm/lm/test_nounk.arpa @@ -0,0 +1,120 @@ + +\data\ +ngram 1=36 +ngram 2=45 +ngram 3=10 +ngram 4=6 +ngram 5=4 + +\1-grams: +-1.383514 , -0.30103 +-1.139057 . -0.845098 +-1.029493 +-99 -0.4149733 +-1.285941 a -0.69897 +-1.687872 also -0.30103 +-1.687872 beyond -0.30103 +-1.687872 biarritz -0.30103 +-1.687872 call -0.30103 +-1.687872 concerns -0.30103 +-1.687872 consider -0.30103 +-1.687872 considering -0.30103 +-1.687872 for -0.30103 +-1.509559 higher -0.30103 +-1.687872 however -0.30103 +-1.687872 i -0.30103 +-1.687872 immediate -0.30103 +-1.687872 in -0.30103 +-1.687872 is -0.30103 +-1.285941 little -0.69897 +-1.383514 loin -0.30103 +-1.687872 look -0.30103 +-1.285941 looking -0.4771212 +-1.206319 more -0.544068 +-1.509559 on -0.4771212 +-1.509559 screening -0.4771212 +-1.687872 small -0.30103 +-1.687872 the -0.30103 +-1.687872 to -0.30103 +-1.687872 watch -0.30103 +-1.687872 watching -0.30103 +-1.687872 what -0.30103 +-1.687872 would -0.30103 +-3.141592 foo +-2.718281 bar 3.0 +-6.535897 baz -0.0 + +\2-grams: +-0.6925742 , . +-0.7522095 , however +-0.7522095 , is +-0.0602359 . +-0.4846522 looking -0.4771214 +-1.051485 screening +-1.07153 the +-1.07153 watching +-1.07153 what +-0.09132547 a little -0.69897 +-0.2922095 also call +-0.2922095 beyond immediate +-0.2705918 biarritz . +-0.2922095 call for +-0.2922095 concerns in +-0.2922095 consider watch +-0.2922095 considering consider +-0.2834328 for , +-0.5511513 higher more +-0.5845945 higher small +-0.2834328 however , +-0.2922095 i would +-0.2922095 immediate concerns +-0.2922095 in biarritz +-0.2922095 is to +-0.09021038 little more -0.1998621 +-0.7273645 loin , +-0.6925742 loin . +-0.6708385 loin +-0.2922095 look beyond +-0.4638903 looking higher +-0.4638903 looking on -0.4771212 +-0.5136299 more . -0.4771212 +-0.3561665 more loin +-0.1649931 on a -0.4771213 +-0.1649931 screening a -0.4771213 +-0.2705918 small . +-0.287799 the screening +-0.2922095 to look +-0.2622373 watch +-0.2922095 watching considering +-0.2922095 what i +-0.2922095 would also +-2 also would -6 +-6 foo bar + +\3-grams: +-0.01916512 more . +-0.0283603 on a little -0.4771212 +-0.0283603 screening a little -0.4771212 +-0.01660496 a little more -0.09409451 +-0.3488368 looking higher +-0.3488368 looking on -0.4771212 +-0.1892331 little more loin +-0.04835128 looking on a -0.4771212 +-3 also would consider -7 +-7 to look good + +\4-grams: +-0.009249173 looking on a little -0.4771212 +-0.005464747 on a little more -0.4771212 +-0.005464747 screening a little more +-0.1453306 a little more loin +-0.01552657 looking on a -0.4771212 +-4 also would consider higher -8 + +\5-grams: +-0.003061223 looking on a little +-0.001813953 looking on a little more +-0.0432557 on a little more loin +-5 also would consider higher looking + +\end\ diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 63c2a612..8c536e66 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,5 +1,6 @@ #include "lm/trie.hh" +#include "lm/bhiksha.hh" #include "lm/quantize.hh" #include "util/bit_packing.hh" #include "util/exception.hh" @@ -57,16 +58,21 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) max_vocab_ = max_vocab; } -template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { - return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr)); +template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { + return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config)); } -template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) { - if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); - BaseInit(base, max_vocab, quant.TotalBits() + next_bits_); +template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : + BitPacked(), + quant_(quant), + // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary. + bhiksha_(base, entries + 1, max_next, config), + next_source_(&next_source) { + if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); + BaseInit(reinterpret_cast(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits()); } -template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { +template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { assert(word <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; @@ -75,47 +81,42 @@ template void BitPackedMiddle::Insert(WordIndex word, float quant_.Write(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); uint64_t next = next_source_->InsertIndex(); - assert(next <= next_mask_); - util::WriteInt57(base_, at_pointer, next_bits_, next); + bhiksha_.WriteNext(base_, at_pointer, insert_index_, next); ++insert_index_; } -template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } + uint64_t index = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; quant_.Read(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); - range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); - // Read the next entry's pointer. - at_pointer += total_bits_; - range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); + bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); + return true; } -template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { - uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; - at_pointer *= total_bits_; +template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { + uint64_t index; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false; + uint64_t at_pointer = index * total_bits_; at_pointer += word_bits_; quant_.ReadBackoff(base_, at_pointer, backoff); at_pointer += quant_.TotalBits(); - range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); - // Read the next entry's pointer. - at_pointer += total_bits_; - range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); + bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); return true; } -template void BitPackedMiddle::FinishedLoading(uint64_t next_end) { - assert(next_end <= next_mask_); - uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; - util::WriteInt57(base_, last_next_write, next_bits_, next_end); +template void BitPackedMiddle::FinishedLoading(uint64_t next_end, const Config &config) { + uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits(); + bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end); + bhiksha_.FinishedLoading(config); } template void BitPackedLongest::Insert(WordIndex index, float prob) { @@ -135,8 +136,10 @@ template bool BitPackedLongest::Find(WordIndex word, float return true; } -template class BitPackedMiddle; -template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; template class BitPackedLongest; template class BitPackedLongest; diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 8fa21aaf..53612064 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -10,6 +10,7 @@ namespace lm { namespace ngram { +class Config; namespace trie { struct NodeRange { @@ -46,13 +47,12 @@ class Unigram { void LoadedBinary() {} - bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { + void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { UnigramValue *val = unigram_ + word; prob = val->weights.prob; backoff = val->weights.backoff; next.begin = val->next; next.end = (val+1)->next; - return true; } private: @@ -67,8 +67,6 @@ class BitPacked { return insert_index_; } - void LoadedBinary() {} - protected: static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); @@ -83,30 +81,30 @@ class BitPacked { uint64_t insert_index_, max_vocab_; }; -template class BitPackedMiddle : public BitPacked { +template class BitPackedMiddle : public BitPacked { public: - static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); // next_source need not be initialized. - BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); + BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); void Insert(WordIndex word, float prob, float backoff); + void FinishedLoading(uint64_t next_end, const Config &config); + + void LoadedBinary() { bhiksha_.LoadedBinary(); } + bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; - void FinishedLoading(uint64_t next_end); - private: Quant quant_; - uint8_t next_bits_; - uint64_t next_mask_; + Bhiksha bhiksha_; const BitPacked *next_source_; }; - template class BitPackedLongest : public BitPacked { public: static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { @@ -120,6 +118,8 @@ template class BitPackedLongest : public BitPacked { BaseInit(base, max_vocab, quant_.TotalBits()); } + void LoadedBinary() {} + void Insert(WordIndex word, float prob); bool Find(WordIndex word, float &prob, const NodeRange &node) const; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 7defd5c1..04979d51 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -37,14 +37,14 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { WordIndex index = 0; while (true) { ssize_t got = read(fd, &buf[0], kInitialRead); - if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); + UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words"); if (got == 0) return index; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; ssize_t ret = read(fd, &next_char, 1); - if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); - if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word."); + UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words"); + UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word."); buf.push_back(next_char); } // Ok now we have null terminated strings. diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index c92518e4..9d218fff 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -61,6 +61,7 @@ class SortedVocabulary : public base::Vocabulary { } } + // Size for purposes of file writing static size_t Size(std::size_t entries, const Config &config); // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. @@ -77,6 +78,9 @@ class SortedVocabulary : public base::Vocabulary { // Reorders reorder_vocab so that the IDs are sorted. void FinishedLoading(ProbBackoff *reorder_vocab); + // Trie stores the correct counts including in the header. If this was previously sized based on a count exluding , padding with 8 bytes will make it the correct size based on a count including . + std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); } + bool SawUnk() const { return saw_unk_; } void LoadedBinary(int fd, EnumerateVocab *to); diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index b35d80c8..9f47d559 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -107,9 +107,20 @@ void BitPackingSanity(); uint8_t RequiredBits(uint64_t max_value); struct BitsMask { + static BitsMask ByMax(uint64_t max_value) { + BitsMask ret; + ret.FromMax(max_value); + return ret; + } + static BitsMask ByBits(uint8_t bits) { + BitsMask ret; + ret.bits = bits; + ret.mask = (1ULL << bits) - 1; + return ret; + } void FromMax(uint64_t max_value) { bits = RequiredBits(max_value); - mask = (1 << bits) - 1; + mask = (1ULL << bits) - 1; } uint8_t bits; uint64_t mask; diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index d58a0727..fec47fd9 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -1,129 +1,129 @@ -/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All - * code is released to the public domain. For business purposes, Murmurhash is - * under the MIT license." - * This is modified from the original: - * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. - * length changed to unsigned int. - * placed in namespace util - * add MurmurHashNative - * default option = 0 for seed - */ - -#include "util/murmur_hash.hh" - -namespace util { - -//----------------------------------------------------------------------------- -// MurmurHash2, 64-bit versions, by Austin Appleby - -// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment -// and endian-ness issues if used across multiple platforms. - -// 64-bit hash for 64-bit platforms - -uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) -{ - const uint64_t m = 0xc6a4a7935bd1e995ULL; - const int r = 47; - - uint64_t h = seed ^ (len * m); - - const uint64_t * data = (const uint64_t *)key; - const uint64_t * end = data + (len/8); - - while(data != end) - { - uint64_t k = *data++; - - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - - const unsigned char * data2 = (const unsigned char*)data; - - switch(len & 7) - { - case 7: h ^= uint64_t(data2[6]) << 48; - case 6: h ^= uint64_t(data2[5]) << 40; - case 5: h ^= uint64_t(data2[4]) << 32; - case 4: h ^= uint64_t(data2[3]) << 24; - case 3: h ^= uint64_t(data2[2]) << 16; - case 2: h ^= uint64_t(data2[1]) << 8; - case 1: h ^= uint64_t(data2[0]); - h *= m; - }; - - h ^= h >> r; - h *= m; - h ^= h >> r; - - return h; -} - - -// 64-bit hash for 32-bit platforms - -uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) -{ - const unsigned int m = 0x5bd1e995; - const int r = 24; - - unsigned int h1 = seed ^ len; - unsigned int h2 = 0; - - const unsigned int * data = (const unsigned int *)key; - - while(len >= 8) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - - unsigned int k2 = *data++; - k2 *= m; k2 ^= k2 >> r; k2 *= m; - h2 *= m; h2 ^= k2; - len -= 4; - } - - if(len >= 4) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - } - - switch(len) - { - case 3: h2 ^= ((unsigned char*)data)[2] << 16; - case 2: h2 ^= ((unsigned char*)data)[1] << 8; - case 1: h2 ^= ((unsigned char*)data)[0]; - h2 *= m; - }; - - h1 ^= h2 >> 18; h1 *= m; - h2 ^= h1 >> 22; h2 *= m; - h1 ^= h2 >> 17; h1 *= m; - h2 ^= h1 >> 19; h2 *= m; - - uint64_t h = h1; - - h = (h << 32) | h2; - - return h; -} - -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { - if (sizeof(int) == 4) { - return MurmurHash64B(key, len, seed); - } else { - return MurmurHash64A(key, len, seed); - } -} - -} // namespace util +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All + * code is released to the public domain. For business purposes, Murmurhash is + * under the MIT license." + * This is modified from the original: + * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. + * length changed to unsigned int. + * placed in namespace util + * add MurmurHashNative + * default option = 0 for seed + */ + +#include "util/murmur_hash.hh" + +namespace util { + +//----------------------------------------------------------------------------- +// MurmurHash2, 64-bit versions, by Austin Appleby + +// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment +// and endian-ness issues if used across multiple platforms. + +// 64-bit hash for 64-bit platforms + +uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +{ + const uint64_t m = 0xc6a4a7935bd1e995ULL; + const int r = 47; + + uint64_t h = seed ^ (len * m); + + const uint64_t * data = (const uint64_t *)key; + const uint64_t * end = data + (len/8); + + while(data != end) + { + uint64_t k = *data++; + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + const unsigned char * data2 = (const unsigned char*)data; + + switch(len & 7) + { + case 7: h ^= uint64_t(data2[6]) << 48; + case 6: h ^= uint64_t(data2[5]) << 40; + case 5: h ^= uint64_t(data2[4]) << 32; + case 4: h ^= uint64_t(data2[3]) << 24; + case 3: h ^= uint64_t(data2[2]) << 16; + case 2: h ^= uint64_t(data2[1]) << 8; + case 1: h ^= uint64_t(data2[0]); + h *= m; + }; + + h ^= h >> r; + h *= m; + h ^= h >> r; + + return h; +} + + +// 64-bit hash for 32-bit platforms + +uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +{ + const unsigned int m = 0x5bd1e995; + const int r = 24; + + unsigned int h1 = seed ^ len; + unsigned int h2 = 0; + + const unsigned int * data = (const unsigned int *)key; + + while(len >= 8) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + + unsigned int k2 = *data++; + k2 *= m; k2 ^= k2 >> r; k2 *= m; + h2 *= m; h2 ^= k2; + len -= 4; + } + + if(len >= 4) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + } + + switch(len) + { + case 3: h2 ^= ((unsigned char*)data)[2] << 16; + case 2: h2 ^= ((unsigned char*)data)[1] << 8; + case 1: h2 ^= ((unsigned char*)data)[0]; + h2 *= m; + }; + + h1 ^= h2 >> 18; h1 *= m; + h2 ^= h1 >> 22; h2 *= m; + h1 ^= h2 >> 17; h1 *= m; + h2 ^= h1 >> 19; h2 *= m; + + uint64_t h = h1; + + h = (h << 32) | h2; + + return h; +} + +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { + if (sizeof(int) == 4) { + return MurmurHash64B(key, len, seed); + } else { + return MurmurHash64A(key, len, seed); + } +} + +} // namespace util diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 00be0ed7..2ec342a6 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -57,7 +57,7 @@ template class IdentityAccessor { public: typedef T Key; - T operator()(const uint64_t *in) const { return *in; } + T operator()(const T *in) const { return *in; } }; struct Pivot64 { @@ -101,6 +101,27 @@ template bool SortedUniformFind(co return BoundedSortedUniformFind(accessor, begin, below, end, above, key, out); } +// May return begin - 1. +template Iterator BinaryBelow( + const Accessor &accessor, + Iterator begin, + Iterator end, + const typename Accessor::Key key) { + while (end > begin) { + Iterator pivot(begin + (end - begin) / 2); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + begin = pivot + 1; + } else if (mid > key) { + end = pivot; + } else { + for (++pivot; (pivot < end) && accessor(pivot) == mid; ++pivot) {} + return pivot - 1; + } + } + return begin - 1; +} + // To use this template, you need to define a Pivot function to match Key. template class SortedUniformMap { public: -- cgit v1.2.3 From 51e8f4a5b9ffc96f3486ede77fe4511918156cf4 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 8 Sep 2011 13:24:11 +0200 Subject: fix viterbi to work with non prob_t types --- decoder/viterbi.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'decoder') diff --git a/decoder/viterbi.h b/decoder/viterbi.h index ac0b9a11..daee3d7a 100644 --- a/decoder/viterbi.h +++ b/decoder/viterbi.h @@ -25,7 +25,7 @@ typename WeightFunction::Weight Viterbi(const Hypergraph& hg, typedef typename WeightFunction::Weight WeightType; const int num_nodes = hg.nodes_.size(); std::vector vit_result(num_nodes); - std::vector vit_weight(num_nodes, WeightType::Zero()); + std::vector vit_weight(num_nodes, WeightType()); for (int i = 0; i < num_nodes; ++i) { const Hypergraph::Node& cur_node = hg.nodes_[i]; -- cgit v1.2.3 From 9f7a0765905e2906c43fbb5359d00ccdac38ca7f Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 9 Sep 2011 10:15:56 +0200 Subject: rule feature refactoring --- decoder/Makefile.am | 1 + decoder/cdec_ff.cc | 2 + decoder/ff_rules.cc | 107 ++++++++++++++++++++++++++++++++++++++++++++++++++++ decoder/ff_rules.h | 40 ++++++++++++++++++++ decoder/ff_spans.cc | 39 ------------------- decoder/ff_spans.h | 15 -------- 6 files changed, 150 insertions(+), 54 deletions(-) create mode 100644 decoder/ff_rules.cc create mode 100644 decoder/ff_rules.h (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index d884c431..e5f7505f 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -61,6 +61,7 @@ libcdec_a_SOURCES = \ phrasetable_fst.cc \ trule.cc \ ff.cc \ + ff_rules.cc \ ff_wordset.cc \ ff_charset.cc \ ff_lm.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 1ef76a05..588842f1 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -9,6 +9,7 @@ #include "ff_wordalign.h" #include "ff_tagger.h" #include "ff_factory.h" +#include "ff_rules.h" #include "ff_ruleshape.h" #include "ff_bleu.h" #include "ff_lm_fsa.h" @@ -53,6 +54,7 @@ void register_feature_functions() { #endif ff_registry.Register("SpanFeatures", new FFFactory()); ff_registry.Register("NgramFeatures", new FFFactory()); + ff_registry.Register("RuleIdentityFeatures", new FFFactory()); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); diff --git a/decoder/ff_rules.cc b/decoder/ff_rules.cc new file mode 100644 index 00000000..bd4c4cc0 --- /dev/null +++ b/decoder/ff_rules.cc @@ -0,0 +1,107 @@ +#include "ff_rules.h" + +#include +#include +#include + +#include "filelib.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "lattice.h" +#include "fdict.h" +#include "verbose.h" + +using namespace std; + +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + +RuleIdentityFeatures::RuleIdentityFeatures(const std::string& param) { +} + +void RuleIdentityFeatures::PrepareForInput(const SentenceMetadata& smeta) { +// std::map > + rule2_fid_.clear(); +} + +void RuleIdentityFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + map::iterator it = rule2_fid_.find(edge.rule_.get()); + if (it == rule2_fid_.end()) { + const TRule& rule = *edge.rule_; + ostringstream os; + os << "R:"; + if (rule.lhs_ < 0) os << TD::Convert(-rule.lhs_) << ':'; + for (unsigned i = 0; i < rule.f_.size(); ++i) { + if (i > 0) os << '_'; + WordID w = rule.f_[i]; + if (w < 0) { os << 'N'; w = -w; } + assert(w > 0); + os << TD::Convert(w); + } + os << ':'; + for (unsigned i = 0; i < rule.e_.size(); ++i) { + if (i > 0) os << '_'; + WordID w = rule.e_[i]; + if (w <= 0) { + os << 'N' << (1-w); + } else { + os << TD::Convert(w); + } + } + it = rule2_fid_.insert(make_pair(&rule, FD::Convert(Escape(os.str())))).first; + } + features->add_value(it->second, 1); +} + +RuleNgramFeatures::RuleNgramFeatures(const std::string& param) { +} + +void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) { +// std::map > + rule2_feats_.clear(); +} + +void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + map >::iterator it = rule2_feats_.find(edge.rule_.get()); + if (it == rule2_feats_.end()) { + const TRule& rule = *edge.rule_; + it = rule2_feats_.insert(make_pair(&rule, SparseVector())).first; + SparseVector& f = it->second; + string prev = ""; + for (int i = 0; i < rule.f_.size(); ++i) { + WordID w = rule.f_[i]; + if (w < 0) w = -w; + assert(w > 0); + const string& cur = TD::Convert(w); + ostringstream os; + os << "RB:" << prev << '_' << cur; + const int fid = FD::Convert(Escape(os.str())); + if (fid <= 0) return; + f.add_value(fid, 1.0); + prev = cur; + } + ostringstream os; + os << "RB:" << prev << '_' << ""; + f.set_value(FD::Convert(Escape(os.str())), 1.0); + } + (*features) += it->second; +} + diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h new file mode 100644 index 00000000..48d8bd05 --- /dev/null +++ b/decoder/ff_rules.h @@ -0,0 +1,40 @@ +#ifndef _FF_RULES_H_ +#define _FF_RULES_H_ + +#include +#include +#include "ff.h" +#include "array2d.h" +#include "wordid.h" + +class RuleIdentityFeatures : public FeatureFunction { + public: + RuleIdentityFeatures(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + mutable std::map rule2_fid_; +}; + +class RuleNgramFeatures : public FeatureFunction { + public: + RuleNgramFeatures(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + mutable std::map > rule2_feats_; +}; + +#endif diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc index bc23974d..0483517b 100644 --- a/decoder/ff_spans.cc +++ b/decoder/ff_spans.cc @@ -193,45 +193,6 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) { } } -RuleNgramFeatures::RuleNgramFeatures(const std::string& param) { -} - -void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) { -// std::map > - rule2_feats_.clear(); -} - -void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const vector& ant_contexts, - SparseVector* features, - SparseVector* estimated_features, - void* context) const { - map >::iterator it = rule2_feats_.find(edge.rule_.get()); - if (it == rule2_feats_.end()) { - const TRule& rule = *edge.rule_; - it = rule2_feats_.insert(make_pair(&rule, SparseVector())).first; - SparseVector& f = it->second; - string prev = ""; - for (int i = 0; i < rule.f_.size(); ++i) { - WordID w = rule.f_[i]; - if (w < 0) w = -w; - assert(w > 0); - const string& cur = TD::Convert(w); - ostringstream os; - os << "RB:" << prev << '_' << cur; - const int fid = FD::Convert(Escape(os.str())); - if (fid <= 0) return; - f.add_value(fid, 1.0); - prev = cur; - } - ostringstream os; - os << "RB:" << prev << '_' << ""; - f.set_value(FD::Convert(Escape(os.str())), 1.0); - } - (*features) += it->second; -} - inline bool IsArity2RuleReordered(const TRule& rule) { const vector& e = rule.e_; for (int i = 0; i < e.size(); ++i) { diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h index b22c4d03..24e0dede 100644 --- a/decoder/ff_spans.h +++ b/decoder/ff_spans.h @@ -44,21 +44,6 @@ class SpanFeatures : public FeatureFunction { WordID oov_; }; -class RuleNgramFeatures : public FeatureFunction { - public: - RuleNgramFeatures(const std::string& param); - protected: - virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector& ant_contexts, - SparseVector* features, - SparseVector* estimated_features, - void* context) const; - virtual void PrepareForInput(const SentenceMetadata& smeta); - private: - mutable std::map > rule2_feats_; -}; - class CMR2008ReorderingFeatures : public FeatureFunction { public: CMR2008ReorderingFeatures(const std::string& param); -- cgit v1.2.3 From 700b2abf48bf0a455064d6cf08754cbfd4e3a383 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 12 Sep 2011 19:22:59 +0100 Subject: source syntax features ~ blunsom emnlp 2008 --- decoder/Makefile.am | 1 + decoder/cdec_ff.cc | 2 + decoder/ff_source_syntax.cc | 157 ++++++++++++++++++++++++++++++++++++++++++++ decoder/ff_source_syntax.h | 24 +++++++ utils/stringlib.cc | 7 +- 5 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 decoder/ff_source_syntax.cc create mode 100644 decoder/ff_source_syntax.h (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index e5f7505f..ede1cff0 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -72,6 +72,7 @@ libcdec_a_SOURCES = \ ff_wordalign.cc \ ff_csplit.cc \ ff_tagger.cc \ + ff_source_syntax.cc \ ff_bleu.cc \ ff_factory.cc \ freqdict.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 588842f1..d562bc3a 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -14,6 +14,7 @@ #include "ff_bleu.h" #include "ff_lm_fsa.h" #include "ff_sample_fsa.h" +#include "ff_source_syntax.h" #include "ff_register.h" #include "ff_charset.h" #include "ff_wordset.h" @@ -55,6 +56,7 @@ void register_feature_functions() { ff_registry.Register("SpanFeatures", new FFFactory()); ff_registry.Register("NgramFeatures", new FFFactory()); ff_registry.Register("RuleIdentityFeatures", new FFFactory()); + ff_registry.Register("SourceSyntaxFeatures", new FFFactory); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc new file mode 100644 index 00000000..99acbd87 --- /dev/null +++ b/decoder/ff_source_syntax.cc @@ -0,0 +1,157 @@ +#include "ff_source_syntax.h" + +#include +#include + +#include "sentence_metadata.h" +#include "array2d.h" +#include "filelib.h" + +using namespace std; + +// implements the source side syntax features described in Blunsom et al. (EMNLP 2008) +// source trees must be represented in Penn Treebank format, e.g. +// (S (NP John) (VP (V left))) + +struct SourceSyntaxFeaturesImpl { + SourceSyntaxFeaturesImpl() {} + + void InitializeGrids(const string& tree, unsigned src_len) { + assert(tree.size() > 0); + fids_cat.clear(); + fids_fonly.clear(); + fids_ef.clear(); + src_tree.clear(); + fids_cat.resize(src_len, src_len + 1); + fids_fonly.resize(src_len, src_len + 1); + fids_ef.resize(src_len, src_len + 1); + src_tree.resize(src_len, src_len + 1, TD::Convert("XX")); + ParseTreeString(tree, src_len); + } + + void ParseTreeString(const string& tree, unsigned src_len) { + stack > stk; // first = i, second = category + pair cur_cat; cur_cat.first = -1; + unsigned i = 0; + unsigned p = 0; + while(p < tree.size()) { + const char cur = tree[p]; + if (cur == '(') { + stk.push(cur_cat); + ++p; + unsigned k = p + 1; + while (k < tree.size() && tree[k] != ' ') { ++k; } + cur_cat.first = i; + cur_cat.second = TD::Convert(tree.substr(p, k - p)); + // cerr << "NT: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n"; + p = k + 1; + } else if (cur == ')') { + unsigned k = p; + while (k < tree.size() && tree[k] == ')') { ++k; } + const unsigned num_closes = k - p; + for (unsigned ci = 0; ci < num_closes; ++ci) { + // cur_cat.second spans from cur_cat.first to i + // cerr << TD::Convert(cur_cat.second) << " from " << cur_cat.first << " to " << i << endl; + // NOTE: unary rule chains end up being labeled with the top-most category + src_tree(cur_cat.first, i) = cur_cat.second; + cur_cat = stk.top(); + stk.pop(); + } + p = k; + while (p < tree.size() && (tree[p] == ' ' || tree[p] == '\t')) { ++p; } + } else if (cur == ' ' || cur == '\t') { + cerr << "Unexpected whitespace in: " << tree << endl; + abort(); + } else { // terminal symbol + unsigned k = p + 1; + do { + while (k < tree.size() && tree[k] != ')' && tree[k] != ' ') { ++k; } + // cerr << "TERM: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n"; + ++i; + assert(i <= src_len); + while (k < tree.size() && tree[k] == ' ') { ++k; } + p = k; + } while (p < tree.size() && tree[p] != ')'); + } + } + // cerr << "i=" << i << " src_len=" << src_len << endl; + assert(i == src_len); // make sure tree specified in src_tree is + // the same length as the source sentence + } + + WordID FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector* feats) { + //cerr << "fire features: " << rule.AsString() << " for " << i << "," << j << endl; + const WordID lhs = src_tree(i,j); + int& fid_cat = fids_cat(i,j); + int& fid_fonly = fids_fonly(i,j)[&rule]; + int& fid_ef = fids_ef(i,j)[&rule]; + if (fid_ef <= 0) { + ostringstream os; + os << "SYN:" << TD::Convert(lhs); + fid_cat = FD::Convert(os.str()); + os << ':'; + unsigned ntc = 0; + for (unsigned k = 0; k < rule.f_.size(); ++k) { + if (k > 0) os << '_'; + int fj = rule.f_[k]; + if (fj <= 0) { + os << '[' << TD::Convert(ants[ntc++]) << ']'; + } else { + os << TD::Convert(fj); + } + } + fid_fonly = FD::Convert(os.str()); + os << ':'; + for (unsigned k = 0; k < rule.e_.size(); ++k) { + const int ei = rule.e_[k]; + if (k > 0) os << '_'; + if (ei <= 0) + os << '[' << (1-ei) << ']'; + else + os << TD::Convert(ei); + } + fid_ef = FD::Convert(os.str()); + } + if (fid_cat > 0) + feats->set_value(fid_cat, 1.0); + if (fid_fonly > 0) + feats->set_value(fid_fonly, 1.0); + if (fid_ef > 0) + feats->set_value(fid_ef, 1.0); + return lhs; + } + + Array2D src_tree; // src_tree(i,j) NT = type + mutable Array2D fids_cat; // fires for an LHS match + mutable Array2D > fids_fonly; // fires for an f-string + mutable Array2D > fids_ef; // fires for fully lexicalized +}; + +SourceSyntaxFeatures::SourceSyntaxFeatures(const string& param) : + FeatureFunction(sizeof(WordID)) { + impl = new SourceSyntaxFeaturesImpl; +} + +SourceSyntaxFeatures::~SourceSyntaxFeatures() { + delete impl; + impl = NULL; +} + +void SourceSyntaxFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + WordID ants[8]; + for (unsigned i = 0; i < ant_contexts.size(); ++i) + ants[i] = *static_cast(ant_contexts[i]); + + *static_cast(context) = + impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features); +} + +void SourceSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) { + impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength()); +} + diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h new file mode 100644 index 00000000..1e890736 --- /dev/null +++ b/decoder/ff_source_syntax.h @@ -0,0 +1,24 @@ +#ifndef _FF_SOURCE_TOOLS_H_ +#define _FF_SOURCE_TOOLS_H_ + +#include "ff.h" + +struct SourceSyntaxFeaturesImpl; + +class SourceSyntaxFeatures : public FeatureFunction { + public: + SourceSyntaxFeatures(const std::string& param); + ~SourceSyntaxFeatures(); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + SourceSyntaxFeaturesImpl* impl; +}; + +#endif diff --git a/utils/stringlib.cc b/utils/stringlib.cc index 7aaee9f0..ade02ca9 100644 --- a/utils/stringlib.cc +++ b/utils/stringlib.cc @@ -32,7 +32,12 @@ void ParseTranslatorInput(const string& line, string* input, string* ref) { void ProcessAndStripSGML(string* pline, map* out) { map& meta = *out; string& line = *pline; - string lline = LowercaseString(line); + string lline = *pline; + if (lline.find(" must be lowercase!\n"; + cerr << " " << *pline << endl; + abort(); + } if (lline.find(""); if (close == string::npos) return; // error -- cgit v1.2.3 From b09ca8a5e6f5e8c1840e51a93c9f8e6b8c4bcc33 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 09:45:01 +0100 Subject: add one more source syntax feature --- decoder/ff_source_syntax.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) (limited to 'decoder') diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index 99acbd87..5b7c16f6 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -13,6 +13,13 @@ using namespace std; // source trees must be represented in Penn Treebank format, e.g. // (S (NP John) (VP (V left))) +// log transform to make long spans cluster together +// but preserve differences +inline int SpanSizeTransform(unsigned span_size) { + if (!span_size) return 0; + return static_cast(log(span_size+1) / log(1.39)) - 1; +} + struct SourceSyntaxFeaturesImpl { SourceSyntaxFeaturesImpl() {} @@ -87,8 +94,10 @@ struct SourceSyntaxFeaturesImpl { int& fid_ef = fids_ef(i,j)[&rule]; if (fid_ef <= 0) { ostringstream os; + ostringstream os2; os << "SYN:" << TD::Convert(lhs); - fid_cat = FD::Convert(os.str()); + os2 << "SYN:" << TD::Convert(lhs) << '_' << SpanSizeTransform(j - i); + fid_cat = FD::Convert(os2.str()); os << ':'; unsigned ntc = 0; for (unsigned k = 0; k < rule.f_.size(); ++k) { -- cgit v1.2.3 From 38a5bee71f6b49515cd105a9467ff602ff9dee64 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 13:25:46 +0100 Subject: optional support for doing perfect hashing of feature strings to save lots of memory --- decoder/decoder.cc | 22 ++++++++- utils/Makefile.am | 9 +++- utils/fdict.cc | 4 ++ utils/fdict.h | 36 ++++++++++++++ utils/perfect_hash.cc | 37 ++++++++++++++ utils/perfect_hash.h | 24 +++++++++ utils/phmt.cc | 44 +++++++++++++++++ utils/weights.cc | 132 ++++++++++++++++++++++++++++++++++---------------- utils/weights.h | 14 +++--- 9 files changed, 269 insertions(+), 53 deletions(-) create mode 100644 utils/perfect_hash.cc create mode 100644 utils/perfect_hash.h create mode 100644 utils/phmt.cc (limited to 'decoder') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 76f31352..25eb2de4 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -328,6 +328,7 @@ struct DecoderImpl { bool write_gradient; // TODO Observer bool feature_expectations; // TODO Observer bool output_training_vector; // TODO Observer + bool remove_intersected_rule_annotations; static void ConvertSV(const SparseVector& src, SparseVector* trg) { for (SparseVector::const_iterator it = src.begin(); it != src.end(); ++it) @@ -361,6 +362,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("grammar,g",po::value >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") ("per_sentence_grammar_file", po::value(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset") ("list_feature_functions,L","List available feature functions") +#ifdef HAVE_CMPH + ("cmph_perfect_feature_hash,h", po::value(), "Load perfect hash function for features") +#endif ("weights,w",po::value(),"Feature weights file (initial forest / pass 1)") ("feature_function,F",po::value >()->composing(), "Pass 1 additional feature function(s) (-L for list)") @@ -433,7 +437,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") ("vector_format",po::value()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") - ("forest_output,O",po::value(),"Directory to write forests to"); + ("forest_output,O",po::value(),"Directory to write forests to") + ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)"); // ob.AddOptions(&opts); #ifdef FSA_RESCORING @@ -443,7 +448,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream po::options_description clo("Command line options"); clo.add_options() ("config,c", po::value >(&cfg_files), "Configuration file(s) - latest has priority") - ("help,h", "Print this help message and exit") + ("help,?", "Print this help message and exit") ("usage,u", po::value(), "Describe a feature function type") ("compgen", "Print just option names suitable for bash command line completion builtin 'compgen'") ; @@ -645,6 +650,12 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream FD::Freeze(); // this means we can't see the feature names of not-weighted features } + if (conf.count("cmph_perfect_feature_hash")) { + cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as() << " ...\n"; + FD::EnableHash(conf["cmph_perfect_feature_hash"].as()); + cerr << " " << FD::NumFeats() << " features in map\n"; + } + // set up translation back end if (formalism == "scfg") translator.reset(new SCFGTranslator(conf)); @@ -695,6 +706,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream unique_kbest = conf.count("unique_k_best"); get_oracle_forest = conf.count("get_oracle_forest"); oracle.show_derivation=conf.count("show_derivations"); + remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations"); #ifdef FSA_RESCORING cfg_options.Validate(); @@ -1010,6 +1022,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { // if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n"; // for (int i = 0; i < forest.edges_.size(); ++i) // forest.edges_[i].edge_prob_=prob_t::One(); } + if (remove_intersected_rule_annotations) { + for (unsigned i = 0; i < forest.edges_.size(); ++i) + if (forest.edges_[i].rule_ && + forest.edges_[i].rule_->parent_rule_) + forest.edges_[i].rule_ = forest.edges_[i].rule_->parent_rule_; + } forest.Reweight(last_weights); if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,oracle.show_derivation); if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; diff --git a/utils/Makefile.am b/utils/Makefile.am index 94f9be30..c50747bf 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -1,5 +1,5 @@ -noinst_PROGRAMS = ts -TESTS = ts +noinst_PROGRAMS = ts phmt +TESTS = ts phmt if HAVE_GTEST noinst_PROGRAMS += \ @@ -27,6 +27,11 @@ libutils_a_SOURCES = \ verbose.cc \ weights.cc +if HAVE_CMPH + libutils_a_SOURCES += perfect_hash.cc +endif + +phmt_SOURCES = phmt.cc ts_SOURCES = ts.cc dict_test_SOURCES = dict_test.cc dict_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) diff --git a/utils/fdict.cc b/utils/fdict.cc index baa0b552..676c951c 100644 --- a/utils/fdict.cc +++ b/utils/fdict.cc @@ -9,6 +9,10 @@ using namespace std; Dict FD::dict_; bool FD::frozen_ = false; +#ifdef HAVE_CMPH +PerfectHashFunction* FD::hash_ = NULL; +#endif + std::string FD::Convert(std::vector const& v) { return Convert(&*v.begin(),&*v.end()); } diff --git a/utils/fdict.h b/utils/fdict.h index f9673023..771e8b91 100644 --- a/utils/fdict.h +++ b/utils/fdict.h @@ -1,23 +1,56 @@ #ifndef _FDICT_H_ #define _FDICT_H_ +#include "config.h" + +#include #include #include #include "dict.h" +#ifdef HAVE_CMPH +#include "perfect_hash.h" +#include "string_to.h" +#endif + struct FD { // once the FD is frozen, new features not already in the // dictionary will return 0 static void Freeze() { frozen_ = true; } + static bool UsingPerfectHashFunction() { +#ifdef HAVE_CMPH + return hash_; +#else + return false; +#endif + } + static void EnableHash(const std::string& cmph_file) { +#ifdef HAVE_CMPH + hash_ = new PerfectHashFunction(cmph_file); +#endif + } static inline int NumFeats() { +#ifdef HAVE_CMPH + if (hash_) return hash_->number_of_keys(); +#endif return dict_.max() + 1; } static inline WordID Convert(const std::string& s) { +#ifdef HAVE_CMPH + if (hash_) return (*hash_)(s); +#endif return dict_.Convert(s, frozen_); } static inline const std::string& Convert(const WordID& w) { +#ifdef HAVE_CMPH + if (hash_) { + static std::string tls; + tls = to_string(w); + return tls; + } +#endif return dict_.Convert(w); } static std::string Convert(WordID const *i,WordID const* e); @@ -29,6 +62,9 @@ struct FD { static Dict dict_; private: static bool frozen_; +#ifdef HAVE_CMPH + static PerfectHashFunction* hash_; +#endif }; #endif diff --git a/utils/perfect_hash.cc b/utils/perfect_hash.cc new file mode 100644 index 00000000..706e2741 --- /dev/null +++ b/utils/perfect_hash.cc @@ -0,0 +1,37 @@ +#include "config.h" + +#ifdef HAVE_CMPH + +#include "perfect_hash.h" + +#include +#include + +using namespace std; + +PerfectHashFunction::~PerfectHashFunction() { + cmph_destroy(mphf_); +} + +PerfectHashFunction::PerfectHashFunction(const string& fname) { + FILE* f = fopen(fname.c_str(), "r"); + if (!f) { + cerr << "Failed to open file " << fname << " for reading: cannot load hash function.\n"; + abort(); + } + mphf_ = cmph_load(f); + if (!mphf_) { + cerr << "cmph_load failed on " << fname << "!\n"; + abort(); + } +} + +size_t PerfectHashFunction::operator()(const string& key) const { + return cmph_search(mphf_, &key[0], key.size()); +} + +size_t PerfectHashFunction::number_of_keys() const { + return cmph_size(mphf_); +} + +#endif diff --git a/utils/perfect_hash.h b/utils/perfect_hash.h new file mode 100644 index 00000000..8ac11f18 --- /dev/null +++ b/utils/perfect_hash.h @@ -0,0 +1,24 @@ +#ifndef _PERFECT_HASH_MAP_H_ +#define _PERFECT_HASH_MAP_H_ + +#include "config.h" + +#ifndef HAVE_CMPH +#error libcmph is required to use PerfectHashFunction +#endif + +#include +#include +#include "cmph.h" + +class PerfectHashFunction : boost::noncopyable { + public: + explicit PerfectHashFunction(const std::string& fname); + ~PerfectHashFunction(); + size_t operator()(const std::string& key) const; + size_t number_of_keys() const; + private: + cmph_t *mphf_; +}; + +#endif diff --git a/utils/phmt.cc b/utils/phmt.cc new file mode 100644 index 00000000..1f59afaf --- /dev/null +++ b/utils/phmt.cc @@ -0,0 +1,44 @@ +#include "config.h" + +#ifndef HAVE_CMPH +int main() { + return 0; +} +#else + +#include +#include "weights.h" +#include "fdict.h" + +using namespace std; + +int main(int argc, char** argv) { + if (argc != 2) { cerr << "Usage: " << argv[0] << " file.mphf\n"; return 1; } + FD::EnableHash(argv[1]); + cerr << "Number of keys: " << FD::NumFeats() << endl; + cerr << "LexFE = " << FD::Convert("LexFE") << endl; + cerr << "LexEF = " << FD::Convert("LexEF") << endl; + { + Weights w; + vector v(FD::NumFeats()); + v[FD::Convert("LexFE")] = 1.0; + v[FD::Convert("LexEF")] = 0.5; + w.InitFromVector(v); + cerr << "Writing...\n"; + w.WriteToFile("weights.bin"); + cerr << "Done.\n"; + } + { + Weights w; + vector v(FD::NumFeats()); + cerr << "Reading...\n"; + w.InitFromFile("weights.bin"); + cerr << "Done.\n"; + w.InitVector(&v); + assert(v[FD::Convert("LexFE")] == 1.0); + assert(v[FD::Convert("LexEF")] == 0.5); + } +} + +#endif + diff --git a/utils/weights.cc b/utils/weights.cc index b994a2fe..0916b72a 100644 --- a/utils/weights.cc +++ b/utils/weights.cc @@ -13,40 +13,75 @@ void Weights::InitFromFile(const std::string& filename, vector* feature_ ReadFile in_file(filename); istream& in = *in_file.stream(); assert(in); - int weight_count = 0; - bool fl = false; - string buf; - double val = 0; - while (in) { - getline(in, buf); - if (buf.size() == 0) continue; - if (buf[0] == '#') continue; - for (int i = 0; i < buf.size(); ++i) - if (buf[i] == '=') buf[i] = ' '; - int start = 0; - while(start < buf.size() && buf[start] == ' ') ++start; - int end = 0; - while(end < buf.size() && buf[end] != ' ') ++end; - const int fid = FD::Convert(buf.substr(start, end - start)); - while(end < buf.size() && buf[end] == ' ') ++end; - val = strtod(&buf.c_str()[end], NULL); - if (isnan(val)) { - cerr << FD::Convert(fid) << " has weight NaN!\n"; - abort(); + + bool read_text = true; + if (1) { + ReadFile hdrrf(filename); + istream& hi = *hdrrf.stream(); + assert(hi); + char buf[10]; + hi.get(buf, 6); + assert(hi.good()); + if (strncmp(buf, "_PHWf", 5) == 0) { + read_text = false; + } + } + + if (read_text) { + int weight_count = 0; + bool fl = false; + string buf; + weight_t val = 0; + while (in) { + getline(in, buf); + if (buf.size() == 0) continue; + if (buf[0] == '#') continue; + if (buf[0] == ' ') { + cerr << "Weights file lines may not start with whitespace.\n" << buf << endl; + abort(); + } + for (int i = buf.size() - 1; i > 0; --i) + if (buf[i] == '=' || buf[i] == '\t') { buf[i] = ' '; break; } + int start = 0; + while(start < buf.size() && buf[start] == ' ') ++start; + int end = 0; + while(end < buf.size() && buf[end] != ' ') ++end; + const int fid = FD::Convert(buf.substr(start, end - start)); + while(end < buf.size() && buf[end] == ' ') ++end; + val = strtod(&buf.c_str()[end], NULL); + if (isnan(val)) { + cerr << FD::Convert(fid) << " has weight NaN!\n"; + abort(); + } + if (wv_.size() <= fid) + wv_.resize(fid + 1); + wv_[fid] = val; + if (feature_list) { feature_list->push_back(FD::Convert(fid)); } + ++weight_count; + if (!SILENT) { + if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; } + if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; } + } } - if (wv_.size() <= fid) - wv_.resize(fid + 1); - wv_[fid] = val; - if (feature_list) { feature_list->push_back(FD::Convert(fid)); } - ++weight_count; if (!SILENT) { - if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; } - if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; } + if (fl) { cerr << endl; } + cerr << "Loaded " << weight_count << " feature weights\n"; + } + } else { // !read_text + char buf[6]; + in.get(buf, 6); + size_t num_keys[2]; + in.get(reinterpret_cast(&num_keys[0]), sizeof(size_t) + 1); + if (num_keys[0] != FD::NumFeats()) { + cerr << "Hash function reports " << FD::NumFeats() << " keys but weights file contains " << num_keys[0] << endl; + abort(); + } + wv_.resize(num_keys[0]); + in.get(reinterpret_cast(&wv_[0]), num_keys[0] * sizeof(weight_t)); + if (!in.good()) { + cerr << "Error loading weights!\n"; + abort(); } - } - if (!SILENT) { - if (fl) { cerr << endl; } - cerr << "Loaded " << weight_count << " feature weights\n"; } } @@ -54,37 +89,48 @@ void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_feature WriteFile out(fname); ostream& o = *out.stream(); assert(o); - if (extra) { o << "# " << *extra << endl; } - o.precision(17); - const int num_feats = FD::NumFeats(); - for (int i = 1; i < num_feats; ++i) { - const double val = (i < wv_.size() ? wv_[i] : 0.0); - if (hide_zero_value_features && val == 0.0) continue; - o << FD::Convert(i) << ' ' << val << endl; + bool write_text = !FD::UsingPerfectHashFunction(); + + if (write_text) { + if (extra) { o << "# " << *extra << endl; } + o.precision(17); + const int num_feats = FD::NumFeats(); + for (int i = 1; i < num_feats; ++i) { + const weight_t val = (i < wv_.size() ? wv_[i] : 0.0); + if (hide_zero_value_features && val == 0.0) continue; + o << FD::Convert(i) << ' ' << val << endl; + } + } else { + o.write("_PHWf", 5); + const size_t keys = FD::NumFeats(); + assert(keys <= wv_.size()); + o.write(reinterpret_cast(&keys), sizeof(keys)); + o.write(reinterpret_cast(&wv_[0]), keys * sizeof(weight_t)); } } -void Weights::InitVector(std::vector* w) const { +void Weights::InitVector(std::vector* w) const { *w = wv_; } -void Weights::InitSparseVector(SparseVector* w) const { +void Weights::InitSparseVector(SparseVector* w) const { for (int i = 1; i < wv_.size(); ++i) { - const double& weight = wv_[i]; + const weight_t& weight = wv_[i]; if (weight) w->set_value(i, weight); } } -void Weights::InitFromVector(const std::vector& w) { +void Weights::InitFromVector(const std::vector& w) { wv_ = w; if (wv_.size() > FD::NumFeats()) cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n"; wv_.resize(FD::NumFeats(), 0); } -void Weights::InitFromVector(const SparseVector& w) { +void Weights::InitFromVector(const SparseVector& w) { wv_.clear(); wv_.resize(FD::NumFeats(), 0.0); for (int i = 1; i < FD::NumFeats(); ++i) wv_[i] = w.value(i); } + diff --git a/utils/weights.h b/utils/weights.h index cc20283c..7664810b 100644 --- a/utils/weights.h +++ b/utils/weights.h @@ -2,21 +2,23 @@ #define _WEIGHTS_H_ #include -#include #include #include "sparse_vector.h" +// warning: in the future this will become float +typedef double weight_t; + class Weights { public: Weights() {} void InitFromFile(const std::string& fname, std::vector* feature_list = NULL); void WriteToFile(const std::string& fname, bool hide_zero_value_features = true, const std::string* extra = NULL) const; - void InitVector(std::vector* w) const; - void InitSparseVector(SparseVector* w) const; - void InitFromVector(const std::vector& w); - void InitFromVector(const SparseVector& w); + void InitVector(std::vector* w) const; + void InitSparseVector(SparseVector* w) const; + void InitFromVector(const std::vector& w); + void InitFromVector(const SparseVector& w); private: - std::vector wv_; + std::vector wv_; }; #endif -- cgit v1.2.3 From 251da4347ea356f799e6c227ac8cf541c0cef2f2 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 17:36:23 +0100 Subject: get rid of bad Weights class so it no longer keeps a copy of a vector inside it --- decoder/decoder.cc | 64 ++++++++--------- decoder/decoder.h | 9 ++- mira/kbest_mira.cc | 62 ++++------------- pro-train/mr_pro_map.cc | 8 +-- pro-train/mr_pro_reduce.cc | 16 ++--- training/Makefile.am | 8 --- training/augment_grammar.cc | 4 +- training/collapse_weights.cc | 6 +- training/compute_cllh.cc | 23 +++--- training/grammar_convert.cc | 8 +-- training/mpi_batch_optimize.cc | 127 ++++++++-------------------------- training/mpi_online_optimize.cc | 69 +++++++----------- training/mr_optimize_reduce.cc | 19 ++--- utils/fdict.h | 2 + utils/phmt.cc | 8 +-- utils/weights.cc | 75 ++++++++++++-------- utils/weights.h | 22 +++--- vest/mr_vest_generate_mapper_input.cc | 6 +- 18 files changed, 201 insertions(+), 335 deletions(-) (limited to 'decoder') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 25eb2de4..4d4b6245 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -159,8 +159,7 @@ struct RescoringPass { shared_ptr models; shared_ptr inter_conf; vector ffs; - shared_ptr w; // null == use previous weights - vector weight_vector; + shared_ptr > weight_vector; int fid_summary; // 0 == no summary feature double density_prune; // 0 == don't density prune double beam_prune; // 0 == don't beam prune @@ -169,7 +168,7 @@ struct RescoringPass { ostream& operator<<(ostream& os, const RescoringPass& rp) { os << "[num_fn=" << rp.ffs.size(); if (rp.inter_conf) { os << " int_alg=" << *rp.inter_conf; } - if (rp.w) os << " new_weights"; + //if (rp.weight_vector.size() > 0) os << " new_weights"; if (rp.fid_summary) os << " summary_feature=" << FD::Convert(rp.fid_summary); if (rp.density_prune) os << " density_prune=" << rp.density_prune; if (rp.beam_prune) os << " beam_prune=" << rp.beam_prune; @@ -181,13 +180,8 @@ struct DecoderImpl { DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg); ~DecoderImpl(); bool Decode(const string& input, DecoderObserver*); - void SetWeights(const vector& weights) { - init_weights = weights; - for (int i = 0; i < rescoring_passes.size(); ++i) { - if (rescoring_passes[i].models) - rescoring_passes[i].models->SetWeights(weights); - rescoring_passes[i].weight_vector = weights; - } + vector& CurrentWeightVector() { + return *rescoring_passes.back().weight_vector; } void SetId(int next_sent_id) { sent_id = next_sent_id - 1; } @@ -300,8 +294,7 @@ struct DecoderImpl { OracleBleu oracle; string formalism; shared_ptr translator; - Weights w_init_weights; // used with initial parse - vector init_weights; // weights used with initial parse + shared_ptr > init_weights; // weights used with initial parse vector > pffs; #ifdef FSA_RESCORING CFGOptions cfg_options; @@ -557,13 +550,18 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream exit(1); } - // load initial feature weights (and possibly freeze feature set) - if (conf.count("weights")) { - w_init_weights.InitFromFile(str("weights",conf)); - w_init_weights.InitVector(&init_weights); - init_weights.resize(FD::NumFeats()); + // load perfect hash function for features + if (conf.count("cmph_perfect_feature_hash")) { + cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as() << " ...\n"; + FD::EnableHash(conf["cmph_perfect_feature_hash"].as()); + cerr << " " << FD::NumFeats() << " features in map\n"; } + // load initial feature weights (and possibly freeze feature set) + init_weights.reset(new vector); + if (conf.count("weights")) + Weights::InitFromFile(str("weights",conf), init_weights.get()); + // cube pruning pop-limit: we may want to configure this on a per-pass basis pop_limit = conf["cubepruning_pop_limit"].as(); @@ -582,9 +580,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream RescoringPass& rp = rescoring_passes.back(); // only configure new weights if pass > 0, otherwise we reuse the initial chart weights if (nth_pass_condition && conf.count(ws)) { - rp.w.reset(new Weights); - rp.w->InitFromFile(str(ws.c_str(), conf)); - rp.w->InitVector(&rp.weight_vector); + rp.weight_vector.reset(new vector()); + Weights::InitFromFile(str(ws.c_str(), conf), rp.weight_vector.get()); } bool has_stateful = false; if (conf.count(ff)) { @@ -624,11 +621,15 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } // set up weight vectors since later phases may reuse weights from earlier phases - const vector* prev = &init_weights; + shared_ptr > prev_weights = init_weights; for (int pass = 0; pass < rescoring_passes.size(); ++pass) { RescoringPass& rp = rescoring_passes[pass]; - if (!rp.w) { rp.weight_vector = *prev; } else { prev = &rp.weight_vector; } - rp.models.reset(new ModelSet(rp.weight_vector, rp.ffs)); + if (!rp.weight_vector) { + rp.weight_vector = prev_weights; + } else { + prev_weights = rp.weight_vector; + } + rp.models.reset(new ModelSet(*rp.weight_vector, rp.ffs)); string ps = "Pass1 "; ps[4] += pass; if (!SILENT) show_models(conf,*rp.models,ps.c_str()); } @@ -650,12 +651,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream FD::Freeze(); // this means we can't see the feature names of not-weighted features } - if (conf.count("cmph_perfect_feature_hash")) { - cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as() << " ...\n"; - FD::EnableHash(conf["cmph_perfect_feature_hash"].as()); - cerr << " " << FD::NumFeats() << " features in map\n"; - } - // set up translation back end if (formalism == "scfg") translator.reset(new SCFGTranslator(conf)); @@ -685,7 +680,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } if (!fsa_ffs.empty()) { cerr<<"FSA: "; - show_all_features(fsa_ffs,init_weights,cerr,cerr,true,true); + show_all_features(fsa_ffs,*init_weights,cerr,cerr,true,true); } #endif @@ -733,7 +728,8 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) { if (del) delete o; return res; } -void Decoder::SetWeights(const vector& weights) { pimpl_->SetWeights(weights); } +vector& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); } +const vector& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); } void Decoder::SetSupplementalGrammar(const std::string& grammar_string) { assert(pimpl_->translator->GetDecoderType() == "SCFG"); static_cast(*pimpl_->translator).SetSupplementalGrammar(grammar_string); @@ -774,7 +770,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { translator->ProcessMarkupHints(smeta.sgml_); Timer t("Translation"); const bool translation_successful = - translator->Translate(to_translate, &smeta, init_weights, &forest); + translator->Translate(to_translate, &smeta, *init_weights, &forest); translator->SentenceComplete(); if (!translation_successful) { @@ -812,7 +808,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { for (int pass = 0; pass < rescoring_passes.size(); ++pass) { const RescoringPass& rp = rescoring_passes[pass]; - const vector& cur_weights = rp.weight_vector; + const vector& cur_weights = *rp.weight_vector; if (!SILENT) cerr << endl << " RESCORING PASS #" << (pass+1) << " " << rp << endl; #ifdef FSA_RESCORING cfg_options.maybe_output_source(forest); @@ -933,7 +929,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { #endif } - const vector& last_weights = (rescoring_passes.empty() ? init_weights : rescoring_passes.back().weight_vector); + const vector& last_weights = (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector); // Oracle Rescoring if(get_oracle_forest) { diff --git a/decoder/decoder.h b/decoder/decoder.h index 5491369f..9d009ffa 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -7,6 +7,8 @@ #include #include +#include "weights.h" // weight_t + #undef CP_TIME //#define CP_TIME #ifdef CP_TIME @@ -39,7 +41,12 @@ struct Decoder { Decoder(int argc, char** argv); Decoder(std::istream* config_file); bool Decode(const std::string& input, DecoderObserver* observer = NULL); - void SetWeights(const std::vector& weights); + + // access this to either *read* or *write* to the decoder's last + // weight vector (i.e., the weights of the finest past) + std::vector& CurrentWeightVector(); + const std::vector& CurrentWeightVector() const; + void SetId(int id); ~Decoder(); const boost::program_options::variables_map& GetConf() const { return conf; } diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc index 6918a9a1..459a5e6f 100644 --- a/mira/kbest_mira.cc +++ b/mira/kbest_mira.cc @@ -32,21 +32,6 @@ namespace po = boost::program_options; bool invert_score; boost::shared_ptr rng; -void SanityCheck(const vector& w) { - for (int i = 0; i < w.size(); ++i) { - assert(!isnan(w[i])); - assert(!isinf(w[i])); - } -} - -struct FComp { - const vector& w_; - FComp(const vector& w) : w_(w) {} - bool operator()(int a, int b) const { - return fabs(w_[a]) > fabs(w_[b]); - } -}; - void RandomPermutation(int len, vector* p_ids) { vector& ids = *p_ids; ids.resize(len); @@ -58,21 +43,6 @@ void RandomPermutation(int len, vector* p_ids) { } } -void ShowLargestFeatures(const vector& w) { - vector fnums(w.size()); - for (int i = 0; i < w.size(); ++i) - fnums[i] = i; - vector::iterator mid = fnums.begin(); - mid += (w.size() > 10 ? 10 : w.size()); - partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); - cerr << "TOP FEATURES:"; - --mid; - for (vector::iterator i = fnums.begin(); i != mid; ++i) { - cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; - } - cerr << endl; -} - bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() @@ -209,14 +179,16 @@ int main(int argc, char** argv) { cerr << "Mismatched number of references (" << ds.size() << ") and sources (" << corpus.size() << ")\n"; return 1; } - // load initial weights - Weights weights; - weights.InitFromFile(conf["input_weights"].as()); - SparseVector lambdas; - weights.InitSparseVector(&lambdas); ReadFile ini_rf(conf["decoder_config"].as()); Decoder decoder(ini_rf.stream()); + + // load initial weights + vector& dense_weights = decoder.CurrentWeightVector(); + SparseVector lambdas; + Weights::InitFromFile(conf["input_weights"].as(), &dense_weights); + Weights::InitSparseVector(dense_weights, &lambdas); + const double max_step_size = conf["max_step_size"].as(); const double mt_metric_scale = conf["mt_metric_scale"].as(); @@ -230,7 +202,6 @@ int main(int argc, char** argv) { double tot_loss = 0; int dots = 0; int cur_pass = 0; - vector dense_weights; SparseVector tot; tot += lambdas; // initial weights normalizer++; // count for initial weights @@ -240,27 +211,22 @@ int main(int argc, char** argv) { vector order; RandomPermutation(corpus.size(), &order); while (lcount <= max_iteration) { - dense_weights.clear(); - weights.InitFromVector(lambdas); - weights.InitVector(&dense_weights); - decoder.SetWeights(dense_weights); + lambdas.init_vector(&dense_weights); if ((cur_sent * 40 / corpus.size()) > dots) { ++dots; cerr << '.'; } if (corpus.size() == cur_sent) { cerr << " [AVG METRIC LAST PASS=" << (tot_loss / corpus.size()) << "]\n"; - ShowLargestFeatures(dense_weights); + Weights::ShowLargestFeatures(dense_weights); cur_sent = 0; tot_loss = 0; dots = 0; ostringstream os; os << "weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << ".gz"; - weights.WriteToFile(os.str(), true, &msg); SparseVector x = tot; x /= normalizer; ostringstream sa; sa << "weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "-avg.gz"; - Weights ww; - ww.InitFromVector(x); - ww.WriteToFile(sa.str(), true, &msga); + x.init_vector(&dense_weights); + Weights::WriteToFile(os.str(), dense_weights, true, &msg); ++cur_pass; RandomPermutation(corpus.size(), &order); } @@ -294,11 +260,11 @@ int main(int argc, char** argv) { ++cur_sent; } cerr << endl; - weights.WriteToFile("weights.mira-final.gz", true, &msg); + Weights::WriteToFile("weights.mira-final.gz", dense_weights, true, &msg); tot /= normalizer; - weights.InitFromVector(tot); + tot.init_vector(dense_weights); msg = "# MIRA tuned weights (averaged vector)"; - weights.WriteToFile("weights.mira-final-avg.gz", true, &msg); + Weights::WriteToFile("weights.mira-final-avg.gz", dense_weights, true, &msg); cerr << "Optimization complete.\nAVERAGED WEIGHTS: weights.mira-final-avg.gz\n"; return 0; } diff --git a/pro-train/mr_pro_map.cc b/pro-train/mr_pro_map.cc index 4324e8de..bc59285b 100644 --- a/pro-train/mr_pro_map.cc +++ b/pro-train/mr_pro_map.cc @@ -301,12 +301,8 @@ int main(int argc, char** argv) { const unsigned gamma = conf["candidate_pairs"].as(); const unsigned xi = conf["best_pairs"].as(); string weightsf = conf["weights"].as(); - vector weights; - { - Weights w; - w.InitFromFile(weightsf); - w.InitVector(&weights); - } + vector weights; + Weights::InitFromFile(weightsf, &weights); string kbest_repo = conf["kbest_repository"].as(); MkDirP(kbest_repo); while(in) { diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 9b422f33..9caaa1d1 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -194,7 +194,7 @@ int main(int argc, char** argv) { InitCommandLine(argc, argv, &conf); string line; vector > > training, testing; - SparseVector old_weights; + SparseVector old_weights; const bool tune_regularizer = conf.count("tune_regularizer"); if (tune_regularizer && !conf.count("testset")) { cerr << "--tune_regularizer requires --testset to be set\n"; @@ -210,9 +210,9 @@ int main(int argc, char** argv) { const double psi = conf["interpolation"].as(); if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; } if (conf.count("weights")) { - Weights w; - w.InitFromFile(conf["weights"].as()); - w.InitSparseVector(&old_weights); + vector dt; + Weights::InitFromFile(conf["weights"].as(), &dt); + Weights::InitSparseVector(dt, &old_weights); } ReadCorpus(&cin, &training); if (conf.count("testset")) { @@ -220,8 +220,8 @@ int main(int argc, char** argv) { ReadCorpus(rf.stream(), &testing); } cerr << "Number of features: " << FD::NumFeats() << endl; - vector x(FD::NumFeats(), 0.0); // x[0] is bias - for (SparseVector::const_iterator it = old_weights.begin(); + vector x(FD::NumFeats(), 0.0); // x[0] is bias + for (SparseVector::const_iterator it = old_weights.begin(); it != old_weights.end(); ++it) x[it->first] = it->second; double tppl = 0.0; @@ -257,7 +257,6 @@ int main(int argc, char** argv) { sigsq = sp[best_i].first; tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as(), &x); } - Weights w; if (conf.count("weights")) { for (int i = 1; i < x.size(); ++i) x[i] = (x[i] * psi) + old_weights.get(i) * (1.0 - psi); @@ -271,7 +270,6 @@ int main(int argc, char** argv) { cout << "# " << sp[i].first << "\t" << sp[i].second << "\t" << smoothed[i] << endl; } } - w.InitFromVector(x); - w.WriteToFile("-"); + Weights::WriteToFile("-", x); return 0; } diff --git a/training/Makefile.am b/training/Makefile.am index e075e417..6e2c06f5 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -12,9 +12,7 @@ bin_PROGRAMS = \ cllh_filter_grammar \ mpi_online_optimize \ mpi_batch_optimize \ - mpi_em_optimize \ compute_cllh \ - feature_expectations \ augment_grammar noinst_PROGRAMS = \ @@ -29,12 +27,6 @@ mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -feature_expectations_SOURCES = feature_expectations.cc -feature_expectations_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -mpi_em_optimize_SOURCES = mpi_em_optimize.cc optimize.cc -mpi_em_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - compute_cllh_SOURCES = compute_cllh.cc compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz diff --git a/training/augment_grammar.cc b/training/augment_grammar.cc index df8d4ee8..e89a92d5 100644 --- a/training/augment_grammar.cc +++ b/training/augment_grammar.cc @@ -134,9 +134,7 @@ int main(int argc, char** argv) { } else { ngram = NULL; } extra_feature = conf.count("extra_lex_feature") > 0; if (conf.count("collapse_weights")) { - Weights w; - w.InitFromFile(conf["collapse_weights"].as()); - w.InitVector(&col_weights); + Weights::InitFromFile(conf["collapse_weights"].as(), &col_weights); } clear_features = conf.count("clear_features_after_collapse") > 0; gather_rules = false; diff --git a/training/collapse_weights.cc b/training/collapse_weights.cc index 4fb742fb..dc480f6c 100644 --- a/training/collapse_weights.cc +++ b/training/collapse_weights.cc @@ -59,10 +59,8 @@ int main(int argc, char** argv) { InitCommandLine(argc, argv, &conf); const string wfile = conf["weights"].as(); const string gfile = conf["grammar"].as(); - Weights wm; - wm.InitFromFile(wfile); - vector w; - wm.InitVector(&w); + vector w; + Weights::InitFromFile(wfile, &w); MarginalMap e_tots; MarginalMap f_tots; prob_t tot; diff --git a/training/compute_cllh.cc b/training/compute_cllh.cc index 332f6d0c..b496d196 100644 --- a/training/compute_cllh.cc +++ b/training/compute_cllh.cc @@ -148,15 +148,6 @@ int main(int argc, char** argv) { if (!InitCommandLine(argc, argv, &conf)) return false; - // load initial weights - Weights weights; - if (conf.count("weights")) - weights.InitFromFile(conf["weights"].as()); - - // freeze feature set - //const bool freeze_feature_set = conf.count("freeze_feature_set"); - //if (freeze_feature_set) FD::Freeze(); - // load cdec.ini and set up decoder ReadFile ini_rf(conf["decoder_config"].as()); Decoder decoder(ini_rf.stream()); @@ -165,17 +156,22 @@ int main(int argc, char** argv) { abort(); } + // load weights + vector& weights = decoder.CurrentWeightVector(); + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &weights); + + // freeze feature set + //const bool freeze_feature_set = conf.count("freeze_feature_set"); + //if (freeze_feature_set) FD::Freeze(); + vector corpus; vector ids; ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); assert(corpus.size() > 0); assert(corpus.size() == ids.size()); - vector wv; - weights.InitVector(&wv); - decoder.SetWeights(wv); TrainingObserver observer; double objective = 0; - bool converged = false; observer.Reset(); if (rank == 0) @@ -197,3 +193,4 @@ int main(int argc, char** argv) { return 0; } + diff --git a/training/grammar_convert.cc b/training/grammar_convert.cc index 8d292f8a..bf8abb26 100644 --- a/training/grammar_convert.cc +++ b/training/grammar_convert.cc @@ -251,12 +251,10 @@ int main(int argc, char **argv) { const bool is_split_input = (conf["format"].as() == "split"); const bool is_json_input = is_split_input || (conf["format"].as() == "json"); const bool collapse_weights = conf.count("collapse_weights"); - Weights wts; vector w; - if (conf.count("weights")) { - wts.InitFromFile(conf["weights"].as()); - wts.InitVector(&w); - } + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &w); + if (collapse_weights && !w.size()) { cerr << "--collapse_weights requires a weights file to be specified!\n"; exit(1); diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index 39a8af7d..cc5953f6 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -31,42 +31,12 @@ using namespace std; using boost::shared_ptr; namespace po = boost::program_options; -void SanityCheck(const vector& w) { - for (int i = 0; i < w.size(); ++i) { - assert(!isnan(w[i])); - assert(!isinf(w[i])); - } -} - -struct FComp { - const vector& w_; - FComp(const vector& w) : w_(w) {} - bool operator()(int a, int b) const { - return fabs(w_[a]) > fabs(w_[b]); - } -}; - -void ShowLargestFeatures(const vector& w) { - vector fnums(w.size()); - for (int i = 0; i < w.size(); ++i) - fnums[i] = i; - vector::iterator mid = fnums.begin(); - mid += (w.size() > 10 ? 10 : w.size()); - partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); - cerr << "TOP FEATURES:"; - for (vector::iterator i = fnums.begin(); i != mid; ++i) { - cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; - } - cerr << endl; -} - bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("input_weights,w",po::value(),"Input feature weights file") ("training_data,t",po::value(),"Training data") ("decoder_config,d",po::value(),"Decoder configuration file") - ("sharded_input,s",po::value(), "Corpus and grammar files are 'sharded' so each processor loads its own input and grammar file. Argument is the directory containing the shards.") ("output_weights,o",po::value()->default_value("-"),"Output feature weights file") ("optimization_method,m", po::value()->default_value("lbfgs"), "Optimization method (sgd, lbfgs, rprop)") ("correction_buffers,M", po::value()->default_value(10), "Number of gradients for LBFGS to maintain in memory") @@ -88,14 +58,10 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { } po::notify(*conf); - if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data") | conf->count("sharded_input")) || !conf->count("decoder_config")) { + if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data")) || !conf->count("decoder_config")) { cerr << dcmdline_options << endl; return false; } - if (conf->count("training_data") && conf->count("sharded_input")) { - cerr << "Cannot specify both --training_data and --sharded_input\n"; - return false; - } return true; } @@ -236,42 +202,9 @@ int main(int argc, char** argv) { po::variables_map conf; if (!InitCommandLine(argc, argv, &conf)) return 1; - string shard_dir; - if (conf.count("sharded_input")) { - shard_dir = conf["sharded_input"].as(); - if (!DirectoryExists(shard_dir)) { - if (rank == 0) cerr << "Can't find shard directory: " << shard_dir << endl; - return 1; - } - if (rank == 0) - cerr << "Shard directory: " << shard_dir << endl; - } - - // load initial weights - Weights weights; - if (rank == 0) { cerr << "Loading weights...\n"; } - weights.InitFromFile(conf["input_weights"].as()); - if (rank == 0) { cerr << "Done loading weights.\n"; } - - // freeze feature set (should be optional?) - const bool freeze_feature_set = true; - if (freeze_feature_set) FD::Freeze(); - // load cdec.ini and set up decoder vector cdec_ini; ReadConfig(conf["decoder_config"].as(), &cdec_ini); - if (shard_dir.size()) { - if (rank == 0) { - for (int i = 0; i < cdec_ini.size(); ++i) { - if (cdec_ini[i].find("grammar=") == 0) { - cerr << "!!! using sharded input and " << conf["decoder_config"].as() << " contains a grammar specification:\n" << cdec_ini[i] << "\n VERIFY THAT THIS IS CORRECT!\n"; - } - } - } - ostringstream g; - g << "grammar=" << shard_dir << "/grammar." << rank << "_of_" << size << ".gz"; - cdec_ini.push_back(g.str()); - } istringstream ini; StoreConfig(cdec_ini, &ini); if (rank == 0) cerr << "Loading grammar...\n"; @@ -282,22 +215,28 @@ int main(int argc, char** argv) { } if (rank == 0) cerr << "Done loading grammar!\n"; + // load initial weights + if (rank == 0) { cerr << "Loading weights...\n"; } + vector& lambdas = decoder->CurrentWeightVector(); + Weights::InitFromFile(conf["input_weights"].as(), &lambdas); + if (rank == 0) { cerr << "Done loading weights.\n"; } + + // freeze feature set (should be optional?) + const bool freeze_feature_set = true; + if (freeze_feature_set) FD::Freeze(); + const int num_feats = FD::NumFeats(); if (rank == 0) cerr << "Number of features: " << num_feats << endl; + lambdas.resize(num_feats); + const bool gaussian_prior = conf.count("gaussian_prior"); - vector means(num_feats, 0); + vector means(num_feats, 0); if (conf.count("means")) { if (!gaussian_prior) { cerr << "Don't use --means without --gaussian_prior!\n"; exit(1); } - Weights wm; - wm.InitFromFile(conf["means"].as()); - if (num_feats != FD::NumFeats()) { - cerr << "[ERROR] Means file had unexpected features!\n"; - exit(1); - } - wm.InitVector(&means); + Weights::InitFromFile(conf["means"].as(), &means); } shared_ptr o; if (rank == 0) { @@ -309,26 +248,13 @@ int main(int argc, char** argv) { cerr << "Optimizer: " << o->Name() << endl; } double objective = 0; - vector lambdas(num_feats, 0.0); - weights.InitVector(&lambdas); - if (lambdas.size() != num_feats) { - cerr << "Initial weights file did not have all features specified!\n feats=" - << num_feats << "\n weights file=" << lambdas.size() << endl; - lambdas.resize(num_feats, 0.0); - } vector gradient(num_feats, 0.0); - vector rcv_grad(num_feats, 0.0); + vector rcv_grad; + rcv_grad.clear(); bool converged = false; vector corpus; - if (shard_dir.size()) { - ostringstream os; os << shard_dir << "/corpus." << rank << "_of_" << size; - ReadTrainingCorpus(os.str(), 0, 1, &corpus); - cerr << os.str() << " has " << corpus.size() << " training examples. " << endl; - if (corpus.size() > 500) { corpus.resize(500); cerr << " TRUNCATING\n"; } - } else { - ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); - } + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); assert(corpus.size() > 0); TrainingObserver observer; @@ -341,19 +267,20 @@ int main(int argc, char** argv) { if (rank == 0) { cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n"; } - decoder->SetWeights(lambdas); for (int i = 0; i < corpus.size(); ++i) decoder->Decode(corpus[i], &observer); cerr << " process " << rank << '/' << size << " done\n"; fill(gradient.begin(), gradient.end(), 0); - fill(rcv_grad.begin(), rcv_grad.end(), 0); observer.SetLocalGradientAndObjective(&gradient, &objective); double to = 0; #ifdef HAVE_MPI + rcv_grad.resize(num_feats, 0.0); mpi::reduce(world, &gradient[0], gradient.size(), &rcv_grad[0], plus(), 0); - mpi::reduce(world, objective, to, plus(), 0); swap(gradient, rcv_grad); + rcv_grad.clear(); + + mpi::reduce(world, objective, to, plus(), 0); objective = to; #endif @@ -378,7 +305,7 @@ int main(int argc, char** argv) { for (int i = 0; i < gradient.size(); ++i) gnorm += gradient[i] * gradient[i]; cerr << " GNORM=" << sqrt(gnorm) << endl; - vector old = lambdas; + vector old = lambdas; int c = 0; while (old == lambdas) { ++c; @@ -387,9 +314,8 @@ int main(int argc, char** argv) { assert(c < 5); } old.clear(); - SanityCheck(lambdas); - ShowLargestFeatures(lambdas); - weights.InitFromVector(lambdas); + Weights::SanityCheck(lambdas); + Weights::ShowLargestFeatures(lambdas); converged = o->HasConverged(); if (converged) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } @@ -399,7 +325,7 @@ int main(int argc, char** argv) { ostringstream vv; vv << "Objective = " << objective << " (eval count=" << o->EvaluationCount() << ")"; const string svv = vv.str(); - weights.WriteToFile(fname, true, &svv); + Weights::WriteToFile(fname, lambdas, true, &svv); } // rank == 0 int cint = converged; #ifdef HAVE_MPI @@ -411,3 +337,4 @@ int main(int argc, char** argv) { } return 0; } + diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 32033c19..2ef4a2e7 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -31,35 +31,6 @@ namespace mpi = boost::mpi; using namespace std; namespace po = boost::program_options; -void SanityCheck(const vector& w) { - for (int i = 0; i < w.size(); ++i) { - assert(!isnan(w[i])); - assert(!isinf(w[i])); - } -} - -struct FComp { - const vector& w_; - FComp(const vector& w) : w_(w) {} - bool operator()(int a, int b) const { - return fabs(w_[a]) > fabs(w_[b]); - } -}; - -void ShowLargestFeatures(const vector& w) { - vector fnums(w.size()); - for (int i = 0; i < w.size(); ++i) - fnums[i] = i; - vector::iterator mid = fnums.begin(); - mid += (w.size() > 10 ? 10 : w.size()); - partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); - cerr << "TOP FEATURES:"; - for (vector::iterator i = fnums.begin(); i != mid; ++i) { - cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; - } - cerr << endl; -} - bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() @@ -250,10 +221,25 @@ int main(int argc, char** argv) { if (!InitCommandLine(argc, argv, &conf)) return 1; + vector > agenda; + if (!LoadAgenda(conf["training_agenda"].as(), &agenda)) + return 1; + if (rank == 0) + cerr << "Loaded agenda defining " << agenda.size() << " training epochs\n"; + + assert(agenda.size() > 0); + + if (1) { // hack to load the feature hash functions -- TODO this should not be in cdec.ini + const string& cur_config = agenda[0].first; + const unsigned max_iteration = agenda[0].second; + ReadFile ini_rf(cur_config); + Decoder decoder(ini_rf.stream()); + } + // load initial weights - Weights weights; + vector init_weights; if (conf.count("input_weights")) - weights.InitFromFile(conf["input_weights"].as()); + Weights::InitFromFile(conf["input_weights"].as(), &init_weights); vector frozen_fids; if (conf.count("frozen_features")) { @@ -310,19 +296,12 @@ int main(int argc, char** argv) { rng.reset(new MT19937); SparseVector x; - weights.InitSparseVector(&x); + Weights::InitSparseVector(init_weights, &x); TrainingObserver observer; int write_weights_every_ith = 100; // TODO configure int titer = -1; - vector > agenda; - if (!LoadAgenda(conf["training_agenda"].as(), &agenda)) - return 1; - if (rank == 0) - cerr << "Loaded agenda defining " << agenda.size() << " training epochs\n"; - - vector lambdas; for (int ai = 0; ai < agenda.size(); ++ai) { const string& cur_config = agenda[ai].first; const unsigned max_iteration = agenda[ai].second; @@ -331,6 +310,8 @@ int main(int argc, char** argv) { // load cdec.ini and set up decoder ReadFile ini_rf(cur_config); Decoder decoder(ini_rf.stream()); + vector& lambdas = decoder.CurrentWeightVector(); + if (ai == 0) { lambdas.swap(init_weights); init_weights.clear(); } if (rank == 0) o->ResetEpoch(); // resets the learning rate-- TODO is this good? @@ -341,15 +322,13 @@ int main(int argc, char** argv) { #ifdef HAVE_MPI mpi::timer timer; #endif - weights.InitFromVector(x); - weights.InitVector(&lambdas); + x.init_vector(&lambdas); ++iter; ++titer; observer.Reset(); - decoder.SetWeights(lambdas); if (rank == 0) { converged = (iter == max_iteration); - SanityCheck(lambdas); - ShowLargestFeatures(lambdas); + Weights::SanityCheck(lambdas); + Weights::ShowLargestFeatures(lambdas); string fname = "weights.cur.gz"; if (iter % write_weights_every_ith == 0) { ostringstream o; o << "weights.epoch_" << (ai+1) << '.' << iter << ".gz"; @@ -360,7 +339,7 @@ int main(int argc, char** argv) { vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast(corpus.size())) << " eta=" << lr->eta(titer); const string svv = vv.str(); cerr << svv << endl; - weights.WriteToFile(fname, true, &svv); + Weights::WriteToFile(fname, lambdas, true, &svv); } for (int i = 0; i < size_per_proc; ++i) { diff --git a/training/mr_optimize_reduce.cc b/training/mr_optimize_reduce.cc index b931991d..15e28fa1 100644 --- a/training/mr_optimize_reduce.cc +++ b/training/mr_optimize_reduce.cc @@ -88,25 +88,19 @@ int main(int argc, char** argv) { const bool use_b64 = conf["input_format"].as() == "b64"; - Weights weights; - weights.InitFromFile(conf["input_weights"].as()); + vector lambdas; + Weights::InitFromFile(conf["input_weights"].as(), &lambdas); const string s_obj = "**OBJ**"; int num_feats = FD::NumFeats(); cerr << "Number of features: " << num_feats << endl; const bool gaussian_prior = conf.count("gaussian_prior"); - vector means(num_feats, 0); + vector means(num_feats, 0); if (conf.count("means")) { if (!gaussian_prior) { cerr << "Don't use --means without --gaussian_prior!\n"; exit(1); } - Weights wm; - wm.InitFromFile(conf["means"].as()); - if (num_feats != FD::NumFeats()) { - cerr << "[ERROR] Means file had unexpected features!\n"; - exit(1); - } - wm.InitVector(&means); + Weights::InitFromFile(conf["means"].as(), &means); } shared_ptr o; const string omethod = conf["optimization_method"].as(); @@ -124,8 +118,6 @@ int main(int argc, char** argv) { cerr << "No state file found, assuming ITERATION 1\n"; } - vector lambdas(num_feats, 0); - weights.InitVector(&lambdas); double objective = 0; vector gradient(num_feats, 0); // 0**OBJ**=12.2;Feat1=2.3;Feat2=-0.2; @@ -223,8 +215,7 @@ int main(int argc, char** argv) { old.clear(); SanityCheck(lambdas); ShowLargestFeatures(lambdas); - weights.InitFromVector(lambdas); - weights.WriteToFile(conf["output_weights"].as(), false); + Weights::WriteToFile(conf["output_weights"].as(), lambdas, false); const bool conv = o->HasConverged(); if (conv) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } diff --git a/utils/fdict.h b/utils/fdict.h index 771e8b91..f0871b9a 100644 --- a/utils/fdict.h +++ b/utils/fdict.h @@ -28,6 +28,8 @@ struct FD { } static void EnableHash(const std::string& cmph_file) { #ifdef HAVE_CMPH + assert(dict_.max() == 0); // dictionary must not have + // been added to hash_ = new PerfectHashFunction(cmph_file); #endif } diff --git a/utils/phmt.cc b/utils/phmt.cc index 1f59afaf..48d9f093 100644 --- a/utils/phmt.cc +++ b/utils/phmt.cc @@ -19,22 +19,18 @@ int main(int argc, char** argv) { cerr << "LexFE = " << FD::Convert("LexFE") << endl; cerr << "LexEF = " << FD::Convert("LexEF") << endl; { - Weights w; vector v(FD::NumFeats()); v[FD::Convert("LexFE")] = 1.0; v[FD::Convert("LexEF")] = 0.5; - w.InitFromVector(v); cerr << "Writing...\n"; - w.WriteToFile("weights.bin"); + Weights::WriteToFile("weights.bin", v); cerr << "Done.\n"; } { - Weights w; vector v(FD::NumFeats()); cerr << "Reading...\n"; - w.InitFromFile("weights.bin"); + Weights::InitFromFile("weights.bin", &v); cerr << "Done.\n"; - w.InitVector(&v); assert(v[FD::Convert("LexFE")] == 1.0); assert(v[FD::Convert("LexEF")] == 0.5); } diff --git a/utils/weights.cc b/utils/weights.cc index 0916b72a..c49000be 100644 --- a/utils/weights.cc +++ b/utils/weights.cc @@ -8,7 +8,10 @@ using namespace std; -void Weights::InitFromFile(const std::string& filename, vector* feature_list) { +void Weights::InitFromFile(const string& filename, + vector* pweights, + vector* feature_list) { + vector& weights = *pweights; if (!SILENT) cerr << "Reading weights from " << filename << endl; ReadFile in_file(filename); istream& in = *in_file.stream(); @@ -47,16 +50,16 @@ void Weights::InitFromFile(const std::string& filename, vector* feature_ int end = 0; while(end < buf.size() && buf[end] != ' ') ++end; const int fid = FD::Convert(buf.substr(start, end - start)); + if (feature_list) { feature_list->push_back(buf.substr(start, end - start)); } while(end < buf.size() && buf[end] == ' ') ++end; val = strtod(&buf.c_str()[end], NULL); if (isnan(val)) { cerr << FD::Convert(fid) << " has weight NaN!\n"; abort(); } - if (wv_.size() <= fid) - wv_.resize(fid + 1); - wv_[fid] = val; - if (feature_list) { feature_list->push_back(FD::Convert(fid)); } + if (weights.size() <= fid) + weights.resize(fid + 1); + weights[fid] = val; ++weight_count; if (!SILENT) { if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; } @@ -76,8 +79,8 @@ void Weights::InitFromFile(const std::string& filename, vector* feature_ cerr << "Hash function reports " << FD::NumFeats() << " keys but weights file contains " << num_keys[0] << endl; abort(); } - wv_.resize(num_keys[0]); - in.get(reinterpret_cast(&wv_[0]), num_keys[0] * sizeof(weight_t)); + weights.resize(num_keys[0]); + in.get(reinterpret_cast(&weights[0]), num_keys[0] * sizeof(weight_t)); if (!in.good()) { cerr << "Error loading weights!\n"; abort(); @@ -85,7 +88,10 @@ void Weights::InitFromFile(const std::string& filename, vector* feature_ } } -void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_features, const string* extra) const { +void Weights::WriteToFile(const string& fname, + const vector& weights, + bool hide_zero_value_features, + const string* extra) { WriteFile out(fname); ostream& o = *out.stream(); assert(o); @@ -96,41 +102,54 @@ void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_feature o.precision(17); const int num_feats = FD::NumFeats(); for (int i = 1; i < num_feats; ++i) { - const weight_t val = (i < wv_.size() ? wv_[i] : 0.0); + const weight_t val = (i < weights.size() ? weights[i] : 0.0); if (hide_zero_value_features && val == 0.0) continue; o << FD::Convert(i) << ' ' << val << endl; } } else { o.write("_PHWf", 5); const size_t keys = FD::NumFeats(); - assert(keys <= wv_.size()); + assert(keys <= weights.size()); o.write(reinterpret_cast(&keys), sizeof(keys)); - o.write(reinterpret_cast(&wv_[0]), keys * sizeof(weight_t)); + o.write(reinterpret_cast(&weights[0]), keys * sizeof(weight_t)); } } -void Weights::InitVector(std::vector* w) const { - *w = wv_; +void Weights::InitSparseVector(const vector& dv, + SparseVector* sv) { + sv->clear(); + for (unsigned i = 1; i < dv.size(); ++i) { + if (dv[i]) sv->set_value(i, dv[i]); + } } -void Weights::InitSparseVector(SparseVector* w) const { - for (int i = 1; i < wv_.size(); ++i) { - const weight_t& weight = wv_[i]; - if (weight) w->set_value(i, weight); +void Weights::SanityCheck(const vector& w) { + for (int i = 0; i < w.size(); ++i) { + assert(!isnan(w[i])); + assert(!isinf(w[i])); } } -void Weights::InitFromVector(const std::vector& w) { - wv_ = w; - if (wv_.size() > FD::NumFeats()) - cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n"; - wv_.resize(FD::NumFeats(), 0); -} +struct FComp { + const vector& w_; + FComp(const vector& w) : w_(w) {} + bool operator()(int a, int b) const { + return fabs(w_[a]) > fabs(w_[b]); + } +}; -void Weights::InitFromVector(const SparseVector& w) { - wv_.clear(); - wv_.resize(FD::NumFeats(), 0.0); - for (int i = 1; i < FD::NumFeats(); ++i) - wv_[i] = w.value(i); +void Weights::ShowLargestFeatures(const vector& w) { + vector fnums(w.size()); + for (int i = 0; i < w.size(); ++i) + fnums[i] = i; + vector::iterator mid = fnums.begin(); + mid += (w.size() > 10 ? 10 : w.size()); + partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); + cerr << "TOP FEATURES:"; + for (vector::iterator i = fnums.begin(); i != mid; ++i) { + cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; + } + cerr << endl; } + diff --git a/utils/weights.h b/utils/weights.h index 7664810b..30f71db0 100644 --- a/utils/weights.h +++ b/utils/weights.h @@ -10,15 +10,21 @@ typedef double weight_t; class Weights { public: - Weights() {} - void InitFromFile(const std::string& fname, std::vector* feature_list = NULL); - void WriteToFile(const std::string& fname, bool hide_zero_value_features = true, const std::string* extra = NULL) const; - void InitVector(std::vector* w) const; - void InitSparseVector(SparseVector* w) const; - void InitFromVector(const std::vector& w); - void InitFromVector(const SparseVector& w); + static void InitFromFile(const std::string& fname, + std::vector* weights, + std::vector* feature_list = NULL); + static void WriteToFile(const std::string& fname, + const std::vector& weights, + bool hide_zero_value_features = true, + const std::string* extra = NULL); + static void InitSparseVector(const std::vector& dv, + SparseVector* sv); + // check for infinities, NaNs, etc + static void SanityCheck(const std::vector& w); + // write weights with largest magnitude to cerr + static void ShowLargestFeatures(const std::vector& w); private: - std::vector wv_; + Weights(); }; #endif diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc index b84c44bc..0c094fd5 100644 --- a/vest/mr_vest_generate_mapper_input.cc +++ b/vest/mr_vest_generate_mapper_input.cc @@ -223,16 +223,16 @@ struct oracle_directions { cerr << "Forest repo: " << forest_repository << endl; assert(DirectoryExists(forest_repository)); vector features; - weights.InitFromFile(weights_file, &features); + vector dorigin; + Weights::InitFromFile(weights_file, &dorigin, &features); if (optimize_features.size()) features=optimize_features; - weights.InitSparseVector(&origin); + Weights::InitSparseVector(dorigin, &origin); fids.clear(); AddFeatureIds(features); oracles.resize(dev_set_size); } - Weights weights; void AddFeatureIds(vector const& features) { int i = fids.size(); fids.resize(fids.size()+features.size()); -- cgit v1.2.3 From bff9f7f6e3ed777c9379c0373657eeaf43a6a213 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 17:57:32 +0100 Subject: fix for crash with no rescoring --- decoder/decoder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'decoder') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 4d4b6245..45404c47 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -181,7 +181,7 @@ struct DecoderImpl { ~DecoderImpl(); bool Decode(const string& input, DecoderObserver*); vector& CurrentWeightVector() { - return *rescoring_passes.back().weight_vector; + return (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector); } void SetId(int next_sent_id) { sent_id = next_sent_id - 1; } -- cgit v1.2.3 From ddc38ce211d4b38f66e56dfa072856a4e9de2c17 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 18:46:33 +0100 Subject: remove features that are overfitting --- decoder/ff_source_syntax.cc | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index 5b7c16f6..ffe07f03 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -25,12 +25,10 @@ struct SourceSyntaxFeaturesImpl { void InitializeGrids(const string& tree, unsigned src_len) { assert(tree.size() > 0); - fids_cat.clear(); - fids_fonly.clear(); + //fids_cat.clear(); fids_ef.clear(); src_tree.clear(); - fids_cat.resize(src_len, src_len + 1); - fids_fonly.resize(src_len, src_len + 1); + //fids_cat.resize(src_len, src_len + 1); fids_ef.resize(src_len, src_len + 1); src_tree.resize(src_len, src_len + 1, TD::Convert("XX")); ParseTreeString(tree, src_len); @@ -89,15 +87,14 @@ struct SourceSyntaxFeaturesImpl { WordID FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector* feats) { //cerr << "fire features: " << rule.AsString() << " for " << i << "," << j << endl; const WordID lhs = src_tree(i,j); - int& fid_cat = fids_cat(i,j); - int& fid_fonly = fids_fonly(i,j)[&rule]; + //int& fid_cat = fids_cat(i,j); int& fid_ef = fids_ef(i,j)[&rule]; if (fid_ef <= 0) { ostringstream os; - ostringstream os2; + //ostringstream os2; os << "SYN:" << TD::Convert(lhs); - os2 << "SYN:" << TD::Convert(lhs) << '_' << SpanSizeTransform(j - i); - fid_cat = FD::Convert(os2.str()); + //os2 << "SYN:" << TD::Convert(lhs) << '_' << SpanSizeTransform(j - i); + //fid_cat = FD::Convert(os2.str()); os << ':'; unsigned ntc = 0; for (unsigned k = 0; k < rule.f_.size(); ++k) { @@ -109,7 +106,6 @@ struct SourceSyntaxFeaturesImpl { os << TD::Convert(fj); } } - fid_fonly = FD::Convert(os.str()); os << ':'; for (unsigned k = 0; k < rule.e_.size(); ++k) { const int ei = rule.e_[k]; @@ -121,18 +117,15 @@ struct SourceSyntaxFeaturesImpl { } fid_ef = FD::Convert(os.str()); } - if (fid_cat > 0) - feats->set_value(fid_cat, 1.0); - if (fid_fonly > 0) - feats->set_value(fid_fonly, 1.0); + //if (fid_cat > 0) + // feats->set_value(fid_cat, 1.0); if (fid_ef > 0) feats->set_value(fid_ef, 1.0); return lhs; } Array2D src_tree; // src_tree(i,j) NT = type - mutable Array2D fids_cat; // fires for an LHS match - mutable Array2D > fids_fonly; // fires for an f-string + // mutable Array2D fids_cat; // this tends to overfit baddly mutable Array2D > fids_ef; // fires for fully lexicalized }; -- cgit v1.2.3 From 08f1814923005f702300d661c4d67f4635fc901c Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Thu, 15 Sep 2011 12:52:59 +0100 Subject: script to filter reachable sentences, weight cleanup --- decoder/apply_models.cc | 3 +- decoder/hg.h | 8 +- training/Makefile.am | 10 +- training/cllh_filter_grammar.cc | 197 -------------------------------------- training/mpi_extract_reachable.cc | 163 +++++++++++++++++++++++++++++++ utils/feature_vector.h | 4 +- 6 files changed, 174 insertions(+), 211 deletions(-) delete mode 100644 training/cllh_filter_grammar.cc create mode 100644 training/mpi_extract_reachable.cc (limited to 'decoder') diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 26cdb881..40fd27e4 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -276,8 +276,7 @@ public: make_heap(cand.begin(), cand.end(), HeapCandCompare()); State2Node state2node; // "buf" in Figure 2 int pops = 0; - int pop_limit_eff=max(1,int(v.promise*pop_limit_)); - while(!cand.empty() && pops < pop_limit_eff) { + while(!cand.empty() && pops < pop_limit_) { pop_heap(cand.begin(), cand.end(), HeapCandCompare()); Candidate* item = cand.back(); cand.pop_back(); diff --git a/decoder/hg.h b/decoder/hg.h index e5ef05f8..f0ddbb76 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -49,16 +49,14 @@ public: // TODO get rid of cat_? // TODO keep cat_ and add span and/or state? :) struct Node { - Node() : id_(), cat_(), promise(1) {} + Node() : id_(), cat_() {} int id_; // equal to this object's position in the nodes_ vector WordID cat_; // non-terminal category if <0, 0 if not set WordID NT() const { return -cat_; } EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_ EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_ - double promise; // set in global pruning; in [0,infty) so that mean is 1. use: e.g. scale cube poplimit. //TODO: appears to be useless, compile without this? on the other hand, pretty cheap. void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting cat_=o.cat_; - promise=o.promise; } void copy_reindex(Node const& o,indices_after const& n2,indices_after const& e2) { copy_fixed(o); @@ -81,7 +79,7 @@ public: int head_node_; // refers to a position in nodes_ TailNodeVector tail_nodes_; // contents refer to positions in nodes_ TRulePtr rule_; - FeatureVector feature_values_; + SparseVector feature_values_; prob_t edge_prob_; // dot product of weights and feat_values int id_; // equal to this object's position in the edges_ vector @@ -468,7 +466,7 @@ public: /// drop edge i if edge_margin[i] < prune_below, unless preserve_mask[i] void MarginPrune(EdgeProbs const& edge_margin,prob_t prune_below,EdgeMask const* preserve_mask=0,bool safe_inside=false,bool verbose=false); - //TODO: in my opinion, looking at the ratio of logprobs (features \dot weights) rather than the absolute difference generalizes more nicely across sentence lengths and weight vectors that are constant multiples of one another. at least make that an option. i worked around this a little in cdec by making "beam alpha per source word" but that's not helping with different tuning runs. this would also make me more comfortable about allocating Node.promise + //TODO: in my opinion, looking at the ratio of logprobs (features \dot weights) rather than the absolute difference generalizes more nicely across sentence lengths and weight vectors that are constant multiples of one another. at least make that an option. i worked around this a little in cdec by making "beam alpha per source word" but that's not helping with different tuning runs. // beam_alpha=0 means don't beam prune, otherwise drop things that are e^beam_alpha times worse than best - // prunes any edge whose prob_t on the best path taking that edge is more than e^alpha times //density=0 means don't density prune: // for density>=1.0, keep this many times the edges needed for the 1best derivation diff --git a/training/Makefile.am b/training/Makefile.am index 7ceeda34..5752859e 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -9,9 +9,9 @@ bin_PROGRAMS = \ atools \ plftools \ collapse_weights \ - cllh_filter_grammar \ - mpi_online_optimize \ + mpi_extract_reachable \ mpi_extract_features \ + mpi_online_optimize \ mpi_batch_optimize \ compute_cllh \ augment_grammar @@ -25,6 +25,9 @@ TESTS = lbfgs_test optimize_test mpi_online_optimize_SOURCES = mpi_online_optimize.cc online_optimizer.cc mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc +mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + mpi_extract_features_SOURCES = mpi_extract_features.cc mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz @@ -34,9 +37,6 @@ mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/ compute_cllh_SOURCES = compute_cllh.cc compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -cllh_filter_grammar_SOURCES = cllh_filter_grammar.cc -cllh_filter_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - augment_grammar_SOURCES = augment_grammar.cc augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz diff --git a/training/cllh_filter_grammar.cc b/training/cllh_filter_grammar.cc deleted file mode 100644 index 6998ec2b..00000000 --- a/training/cllh_filter_grammar.cc +++ /dev/null @@ -1,197 +0,0 @@ -#include -#include -#include -#include // fork -#include // waitpid - -#include -#include - -#include "tdict.h" -#include "ff_register.h" -#include "verbose.h" -#include "hg.h" -#include "decoder.h" -#include "filelib.h" - -using namespace std; -namespace po = boost::program_options; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("training_data,t",po::value(),"Training data corpus") - ("decoder_config,c",po::value(),"Decoder configuration file") - ("shards,s",po::value()->default_value(1),"Number of shards") - ("starting_shard,S",po::value()->default_value(0), "In this invocation only process shards >= S") - ("work_limit,l",po::value()->default_value(9999), "Process maximially this many shards") - ("ncpus,C",po::value()->default_value(1),"Number of CPUs to use"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* ids) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - int lc = 0; - assert(size > 0); - assert(rank < size); - while(in) { - getline(in, line); - if (!in) break; - if (lc % size == rank) { - c->push_back(line); - ids->push_back(lc); - } - ++lc; - } -} - -struct TrainingObserver : public DecoderObserver { - TrainingObserver() : s_lhs(-TD::Convert("S")), goal_lhs(-TD::Convert("Goal")) {} - - void Reset() { - total_complete = 0; - } - - virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { - state = 1; - used.clear(); - failed = true; - } - - virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 1); - for (int i = 0; i < hg->edges_.size(); ++i) { - const TRule* rule = hg->edges_[i].rule_.get(); - if (rule->lhs_ == s_lhs || rule->lhs_ == goal_lhs) // fragile hack to filter out glue rules - continue; - used.insert(rule); - } - state = 2; - } - - virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 2); - state = 3; - } - - virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { - if (state == 3) { - failed = false; - } else { - failed = true; - } - } - - set used; - - const int s_lhs; - const int goal_lhs; - bool failed; - int total_complete; - int state; -}; - -void work(const string& fname, int rank, int size, Decoder* decoder) { - cerr << "Worker " << rank << '/' << size << " starting.\n"; - vector corpus; - vector ids; - ReadTrainingCorpus(fname, rank, size, &corpus, &ids); - assert(corpus.size() > 0); - assert(corpus.size() == ids.size()); - cerr << " " << rank << '/' << size << ": has " << corpus.size() << " sentences to process\n"; - ostringstream oc; oc << "corpus." << rank << "_of_" << size; - WriteFile foc(oc.str()); - ostringstream og; og << "grammar." << rank << "_of_" << size << ".gz"; - WriteFile fog(og.str()); - - set all_used; - TrainingObserver observer; - for (int i = 0; i < corpus.size(); ++i) { - const int sent_id = ids[i]; - const string& input = corpus[i]; - decoder->SetId(sent_id); - decoder->Decode(input, &observer); - if (observer.failed) { - // do nothing - } else { - (*foc.stream()) << input << endl; - for (set::iterator it = observer.used.begin(); it != observer.used.end(); ++it) { - if (all_used.insert(*it).second) - (*fog.stream()) << **it << endl; - } - } - } -} - -int main(int argc, char** argv) { - register_feature_functions(); - - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - const string fname = conf["training_data"].as(); - const unsigned ncpus = conf["ncpus"].as(); - const unsigned shards = conf["shards"].as(); - const unsigned start = conf["starting_shard"].as(); - const unsigned work_limit = conf["work_limit"].as(); - const unsigned eff_shards = min(start + work_limit, shards); - cerr << "Processing shards " << start << "/" << shards << " to " << eff_shards << "/" << shards << endl; - assert(ncpus > 0); - ReadFile ini_rf(conf["decoder_config"].as()); - Decoder decoder(ini_rf.stream()); - if (decoder.GetConf()["input"].as() != "-") { - cerr << "cdec.ini must not set an input file\n"; - abort(); - } - SetSilent(true); // turn off verbose decoder output - cerr << "Forking " << ncpus << " time(s)\n"; - vector children; - for (int i = 0; i < ncpus; ++i) { - pid_t pid = fork(); - if (pid < 0) { - cerr << "Fork failed!\n"; - exit(1); - } - if (pid > 0) { - children.push_back(pid); - } else { - for (int j = start; j < eff_shards; ++j) { - if (j % ncpus == i) { - cerr << " CPU " << i << " processing shard " << j << endl; - work(fname, j, shards, &decoder); - cerr << " Shard " << j << "/" << shards << " finished.\n"; - } - } - _exit(0); - } - } - for (int i = 0; i < children.size(); ++i) { - int status; - int w = waitpid(children[i], &status, 0); - if (w < 0) { cerr << "Error while waiting for children!"; return 1; } - if (WIFSIGNALED(status)) { - cerr << "Child " << i << " received signal " << WTERMSIG(status) << endl; - if (WTERMSIG(status) == 11) { cerr << " this is a SEGV- you may be trying to print temporarily created rules\n"; } - } - } - return 0; -} diff --git a/training/mpi_extract_reachable.cc b/training/mpi_extract_reachable.cc new file mode 100644 index 00000000..2a7c2b9d --- /dev/null +++ b/training/mpi_extract_reachable.cc @@ -0,0 +1,163 @@ +#include +#include +#include +#include + +#include "config.h" +#ifdef HAVE_MPI +#include +#endif +#include +#include + +#include "ff_register.h" +#include "verbose.h" +#include "filelib.h" +#include "fdict.h" +#include "decoder.h" +#include "weights.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("training_data,t",po::value(),"Training data corpus") + ("decoder_config,c",po::value(),"Decoder configuration file") + ("weights,w", po::value(), "(Optional) weights file; weights may affect what features are encountered in pruning configurations") + ("output_prefix,o",po::value()->default_value("reachable"),"Output path prefix"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { + cerr << "Decode an input set (optionally in parallel using MPI) and write\nout the inputs that produce reachable parallel parses.\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct ReachabilityObserver : public DecoderObserver { + + virtual void NotifyDecodingStart(const SentenceMetadata&) { + reachable = false; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + reachable = true; + } + + bool reachable; +}; + +#ifdef HAVE_MPI +namespace mpi = boost::mpi; +#endif + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return false; + + // load cdec.ini and set up decoder + ReadFile ini_rf(conf["decoder_config"].as()); + Decoder decoder(ini_rf.stream()); + if (decoder.GetConf()["input"].as() != "-") { + cerr << "cdec.ini must not set an input file\n"; + abort(); + } + + if (FD::UsingPerfectHashFunction()) { + cerr << "Your configuration file has enabled a cmph hash function. Please disable.\n"; + return 1; + } + + // load optional weights + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &decoder.CurrentWeightVector()); + + vector corpus; + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); + assert(corpus.size() > 0); + + + if (rank == 0) + cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; + + size_t num_reached = 0; + { + ostringstream os; + os << conf["output_prefix"].as() << '.' << rank << "_of_" << size; + WriteFile wf(os.str()); + ostream& out = *wf.stream(); + ReachabilityObserver observer; + for (int i = 0; i < corpus.size(); ++i) { + decoder.Decode(corpus[i], &observer); + if (observer.reachable) { + out << corpus[i] << endl; + ++num_reached; + } + corpus[i].clear(); + } + cerr << "Shard " << rank << '/' << size << " finished, wrote " + << num_reached << " instances to " << os.str() << endl; + } + + size_t total = 0; +#ifdef HAVE_MPI + reduce(world, num_reached, total, std::plus(), 0); +#else + total = num_reached; +#endif + if (rank == 0) { + cerr << "-----------------------------------------\n"; + cerr << "TOTAL = " << total << " instances\n"; + } + return 0; +} + diff --git a/utils/feature_vector.h b/utils/feature_vector.h index 733aa99e..a7b61a66 100755 --- a/utils/feature_vector.h +++ b/utils/feature_vector.h @@ -3,9 +3,9 @@ #include #include "sparse_vector.h" -#include "fdict.h" +#include "weights.h" -typedef double Featval; +typedef weight_t Featval; typedef SparseVector FeatureVector; typedef SparseVector WeightVector; typedef std::vector DenseWeightVector; -- cgit v1.2.3 From 10cfa1082059db646148af1884117082335a48e7 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Sep 2011 17:06:40 +0100 Subject: source span size features --- decoder/ff_source_syntax.cc | 62 +++++++++++++++++++++++++++++++++++++++++++++ decoder/ff_source_syntax.h | 17 +++++++++++++ 2 files changed, 79 insertions(+) (limited to 'decoder') diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index ffe07f03..2df31c3a 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -157,3 +157,65 @@ void SourceSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) { impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength()); } +struct SourceSpanSizeFeaturesImpl { + SourceSpanSizeFeaturesImpl() {} + + void InitializeGrids(unsigned src_len) { + fids.clear(); + fids.resize(src_len, src_len + 1); + } + + int FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector* feats) { + int& fid = fids(i,j)[&rule]; + if (fid <= 0) { + ostringstream os; + os << "SSS:"; + unsigned ntc = 0; + for (unsigned k = 0; k < rule.f_.size(); ++k) { + if (k > 0) os << '_'; + int fj = rule.f_[k]; + if (fj <= 0) { + os << '[' << TD::Convert(-fj) << ants[ntc++] << ']'; + } else { + os << TD::Convert(fj); + } + } + fid = FD::Convert(os.str()); + } + if (fid > 0) + feats->set_value(fid, 1.0); + return SpanSizeTransform(j - i); + } + + mutable Array2D > fids; +}; + +SourceSpanSizeFeatures::SourceSpanSizeFeatures(const string& param) : + FeatureFunction(sizeof(char)) { + impl = new SourceSpanSizeFeaturesImpl; +} + +SourceSpanSizeFeatures::~SourceSpanSizeFeatures() { + delete impl; + impl = NULL; +} + +void SourceSpanSizeFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + int ants[8]; + for (unsigned i = 0; i < ant_contexts.size(); ++i) + ants[i] = *static_cast(ant_contexts[i]); + + *static_cast(context) = + impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features); +} + +void SourceSpanSizeFeatures::PrepareForInput(const SentenceMetadata& smeta) { + impl->InitializeGrids(smeta.GetSourceLength()); +} + + diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h index 1e890736..279563e1 100644 --- a/decoder/ff_source_syntax.h +++ b/decoder/ff_source_syntax.h @@ -21,4 +21,21 @@ class SourceSyntaxFeatures : public FeatureFunction { SourceSyntaxFeaturesImpl* impl; }; +struct SourceSpanSizeFeaturesImpl; +class SourceSpanSizeFeatures : public FeatureFunction { + public: + SourceSpanSizeFeatures(const std::string& param); + ~SourceSpanSizeFeatures(); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + SourceSpanSizeFeaturesImpl* impl; +}; + #endif -- cgit v1.2.3 From e7d2352ed630d16a790113223cd8a80155f61615 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Sep 2011 17:11:55 +0100 Subject: enable sss features --- decoder/cdec_ff.cc | 1 + 1 file changed, 1 insertion(+) (limited to 'decoder') diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index d562bc3a..69f40c93 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -57,6 +57,7 @@ void register_feature_functions() { ff_registry.Register("NgramFeatures", new FFFactory()); ff_registry.Register("RuleIdentityFeatures", new FFFactory()); ff_registry.Register("SourceSyntaxFeatures", new FFFactory); + ff_registry.Register("SourceSpanSizeFeatures", new FFFactory); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); -- cgit v1.2.3 From 5d7ac6050aab3eac5121a2168fe9bd81453d118a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Sep 2011 22:38:42 +0100 Subject: arity > 0 rules only for sss features --- decoder/ff_source_syntax.cc | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index 2df31c3a..fc341bb0 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -166,24 +166,26 @@ struct SourceSpanSizeFeaturesImpl { } int FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector* feats) { - int& fid = fids(i,j)[&rule]; - if (fid <= 0) { - ostringstream os; - os << "SSS:"; - unsigned ntc = 0; - for (unsigned k = 0; k < rule.f_.size(); ++k) { - if (k > 0) os << '_'; - int fj = rule.f_[k]; - if (fj <= 0) { - os << '[' << TD::Convert(-fj) << ants[ntc++] << ']'; - } else { - os << TD::Convert(fj); + if (rule.Arity() > 0) { + int& fid = fids(i,j)[&rule]; + if (fid <= 0) { + ostringstream os; + os << "SSS:"; + unsigned ntc = 0; + for (unsigned k = 0; k < rule.f_.size(); ++k) { + if (k > 0) os << '_'; + int fj = rule.f_[k]; + if (fj <= 0) { + os << '[' << TD::Convert(-fj) << ants[ntc++] << ']'; + } else { + os << TD::Convert(fj); + } } + fid = FD::Convert(os.str()); } - fid = FD::Convert(os.str()); + if (fid > 0) + feats->set_value(fid, 1.0); } - if (fid > 0) - feats->set_value(fid, 1.0); return SpanSizeTransform(j - i); } -- cgit v1.2.3 From 388081290e99fdd6eacc9d761ebfdea69647fa72 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Sep 2011 22:42:19 +0100 Subject: add target side for sss features --- decoder/ff_source_syntax.cc | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'decoder') diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index fc341bb0..035132b4 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -181,6 +181,15 @@ struct SourceSpanSizeFeaturesImpl { os << TD::Convert(fj); } } + os << ':'; + for (unsigned k = 0; k < rule.e_.size(); ++k) { + const int ei = rule.e_[k]; + if (k > 0) os << '_'; + if (ei <= 0) + os << '[' << (1-ei) << ']'; + else + os << TD::Convert(ei); + } fid = FD::Convert(os.str()); } if (fid > 0) -- cgit v1.2.3 From e1b61419329c83709018ca397a29d069e4294bd1 Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Fri, 23 Sep 2011 15:44:35 +0100 Subject: make show_partition work even in absence of feature functions --- decoder/decoder.cc | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'decoder') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 45404c47..c4fe3c4d 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -794,6 +794,11 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl; } + if (conf.count("show_partition")) { + const prob_t z = Inside(forest); + cerr << " Partition log(Z): " << log(z) << endl; + } + SummaryFeature summary_feature_type = kNODE_RISK; if (conf["summary_feature_type"].as() == "edge_risk") summary_feature_type = kEDGE_RISK; -- cgit v1.2.3 From 8ecf63852d730f99e7c1bbacfbffdf518d5a0c3f Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Fri, 23 Sep 2011 20:49:43 +0100 Subject: stub work to talk to new kenlm --- decoder/ff_klm.cc | 349 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 349 insertions(+) (limited to 'decoder') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 24dcb9c3..016aad26 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,6 +12,353 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" +#undef NEW_KENLM +#ifdef NEW_KENLM + +#include "lm/left.hh" + +using namespace std; + +// -x : rules include and +// -n NAME : feature id is NAME +bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) { + vector const& argv=SplitOnWhitespace(in); + *explicit_markers = false; + *featname="LanguageModel"; + *mapfile = ""; +#define LMSPEC_NEXTARG if (i==argv.end()) { \ + cerr << "Missing argument for "<<*last<<". "; goto usage; \ + } else { ++i; } + + for (vector::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { + string const& s=*i; + if (s[0]=='-') { + if (s.size()>2) goto fail; + switch (s[1]) { + case 'x': + *explicit_markers = true; + break; + case 'm': + LMSPEC_NEXTARG; *mapfile=*i; + break; + case 'n': + LMSPEC_NEXTARG; *featname=*i; + break; +#undef LMSPEC_NEXTARG + default: + fail: + cerr<<"Unknown KLanguageModel option "<empty()) + *filename=s; + else { + cerr<<"More than one filename provided. "; + goto usage; + } + } + } + if (!filename->empty()) + return true; +usage: + cerr << "KLanguageModel is incorrect!\n"; + return false; +} + +template +string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { + return "KLanguageModel"; +} + +struct VMapper : public lm::ngram::EnumerateVocab { + VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_->size()) + out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); + (*out_)[cdec_id] = index; + } + vector* out_; + const lm::WordIndex kLM_UNKNOWN_TOKEN; +}; + +template +class KLanguageModelImpl { + + static inline const lm::ngram::ChartState& RemnantLMState(const void* state) { + return *static_cast(state); + } + + inline void SetRemnantLMState(const lm::ngram::ChartState& lmstate, void* state) const { + // if we were clever, we could use the memory pointed to by state to do all + // the work, avoiding this copy + memcpy(state, &lmstate, ngram_->StateSize()); + } + + public: + double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { + double sum = 0.0; + if (oovs) *oovs = 0; + const vector& e = rule.e(); + lm::ngram::ChartState state; + lm::ngram::RuleScore ruleScore(*ngram_, state); + unsigned i = 0; + if (e.size()) { + if (e[i] == kCDEC_SOS) { + ++i; + ruleScore.BeginSentence(); + } else if (e[i] <= 0) { // special case for left-edge NT + const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]); + ruleScore.BeginNonTerminal(prevState, 0.0f); // TODO + ++i; + } + } + for (; i < e.size(); ++i) { + if (e[i] <= 0) { + const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]); + ruleScore.NonTerminal(prevState, 0.0f); // TODO + } else { + const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, + // maybe handle emission + const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id + const bool is_oov = (cur_word == 0); + if (is_oov) (*oovs) += 1.0; + ruleScore.Terminal(cur_word); + } + } + if (remnant) SetRemnantLMState(state, remnant); + return ruleScore.Finish(); + } + + // this assumes no target words on final unary -> goal rule. is that ok? + // for (n-1 left words) and (n-1 right words) + double FinalTraversalCost(const void* state, double* oovs) { + if (add_sos_eos_) { // rules do not produce , so do it here + lm::ngram::ChartState cstate; + lm::ngram::RuleScore ruleScore(*ngram_, cstate); + ruleScore.BeginSentence(); + SetRemnantLMState(cstate, dummy_state_); + dummy_ants_[1] = state; + *oovs = 0; + return LookupWords(*dummy_rule_, dummy_ants_, oovs, NULL); + } else { // rules DO produce ... + double p = 0; + cerr << "not implemented"; abort(); // TODO + //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } + //if (UnscoredSize(state) > 0) { // are there unscored words + // if (kSOS_ != IthUnscoredWord(0, state)) { + // p -= 100 * UnscoredSize(state); + // } + //} + return p; + } + } + + // if this is not a class-based LM, returns w untransformed, + // otherwise returns a word class mapping of w, + // returns TD::Convert("") if there is no mapping for w + WordID ClassifyWordIfNecessary(WordID w) const { + if (word2class_map_.empty()) return w; + if (w >= word2class_map_.size()) + return kCDEC_UNK; + else + return word2class_map_[w]; + } + + // converts to cdec word id's to KenLM's id space, OOVs and end up at 0 + lm::WordIndex MapWord(WordID w) const { + if (w >= cdec2klm_map_.size()) + return 0; + else + return cdec2klm_map_[w]; + } + + public: + KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : + kCDEC_UNK(TD::Convert("")) , + kCDEC_SOS(TD::Convert("")) , + add_sos_eos_(!explicit_markers) { + { + VMapper vm(&cdec2klm_map_); + lm::ngram::Config conf; + conf.enumerate_vocab = &vm; + ngram_ = new Model(filename.c_str(), conf); + } + order_ = ngram_->Order(); + cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; + state_size_ = sizeof(lm::ngram::ChartState); + + // special handling of beginning / ending sentence markers + dummy_state_ = new char[state_size_]; + memset(dummy_state_, 0, state_size_); + dummy_ants_.push_back(dummy_state_); + dummy_ants_.push_back(NULL); + dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); + kSOS_ = MapWord(kCDEC_SOS); + assert(kSOS_ > 0); + kEOS_ = MapWord(TD::Convert("")); + assert(kEOS_ > 0); + assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant + + // handle class-based LMs (unambiguous word->class mapping reqd.) + if (mapfile.size()) + LoadWordClasses(mapfile); + } + + void LoadWordClasses(const string& file) { + ReadFile rf(file); + istream& in = *rf.stream(); + string line; + vector dummy; + int lc = 0; + cerr << " Loading word classes from " << file << " ...\n"; + AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); + AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); + while(in) { + getline(in, line); + if (!in) continue; + dummy.clear(); + TD::ConvertSentence(line, &dummy); + ++lc; + if (dummy.size() != 2) { + cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; + abort(); + } + AddWordToClassMapping_(dummy[0], dummy[1]); + } + } + + void AddWordToClassMapping_(WordID word, WordID cls) { + if (word2class_map_.size() <= word) { + word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); + assert(word2class_map_.size() > word); + } + if(word2class_map_[word] != kCDEC_UNK) { + cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; + abort(); + } + word2class_map_[word] = cls; + } + + ~KLanguageModelImpl() { + delete ngram_; + delete[] dummy_state_; + } + + int ReserveStateSize() const { return state_size_; } + + private: + const WordID kCDEC_UNK; + const WordID kCDEC_SOS; + lm::WordIndex kSOS_; // - requires special handling. + lm::WordIndex kEOS_; // + Model* ngram_; + const bool add_sos_eos_; // flag indicating whether the hypergraph produces and + // if this is true, FinalTransitionFeatures will "add" and + // if false, FinalTransitionFeatures will score anything with the + // markers in the right place (i.e., the beginning and end of + // the sentence) with 0, and anything else with -100 + + int order_; + int state_size_; + char* dummy_state_; + vector dummy_ants_; + vector cdec2klm_map_; + vector word2class_map_; // if this is a class-based LM, this is the word->class mapping + TRulePtr dummy_rule_; +}; + +template +KLanguageModel::KLanguageModel(const string& param) { + string filename, mapfile, featname; + bool explicit_markers; + if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { + abort(); + } + try { + pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); + } + fid_ = FD::Convert(featname); + oov_fid_ = FD::Convert(featname+"_OOV"); + // cerr << "FID: " << oov_fid_ << endl; + SetStateSize(pimpl_->ReserveStateSize()); +} + +template +Features KLanguageModel::features() const { + return single_feature(fid_); +} + +template +KLanguageModel::~KLanguageModel() { + delete pimpl_; +} + +template +void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + double est = 0; + double oovs = 0; + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state)); + if (oovs && oov_fid_) + features->set_value(oov_fid_, oovs); +} + +template +void KLanguageModel::FinalTraversalFeatures(const void* ant_state, + SparseVector* features) const { + double oovs = 0; + double lm = pimpl_->FinalTraversalCost(ant_state, &oovs); + features->set_value(fid_, lm); + if (oov_fid_ && oovs) + features->set_value(oov_fid_, oovs); +} + +template boost::shared_ptr CreateModel(const std::string ¶m) { + KLanguageModel *ret = new KLanguageModel(param); + ret->Init(); + return boost::shared_ptr(ret); +} + +boost::shared_ptr KLanguageModelFactory::Create(std::string param) const { + using namespace lm::ngram; + std::string filename, ignored_map; + bool ignored_markers; + std::string ignored_featname; + ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); + ModelType m; + if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + + switch (m) { + case HASH_PROBING: + return CreateModel(param); + case TRIE_SORTED: + return CreateModel(param); + case ARRAY_TRIE_SORTED: + return CreateModel(param); + case QUANT_TRIE_SORTED: + return CreateModel(param); + case QUANT_ARRAY_TRIE_SORTED: + return CreateModel(param); + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); + } +} + +std::string KLanguageModelFactory::usage(bool params,bool verbose) const { + return KLanguageModel::usage(params, verbose); +} + +#else + using namespace std; static const unsigned char HAS_FULL_CONTEXT = 1; @@ -469,3 +816,5 @@ boost::shared_ptr KLanguageModelFactory::Create(std::string par std::string KLanguageModelFactory::usage(bool params,bool verbose) const { return KLanguageModel::usage(params, verbose); } + +#endif -- cgit v1.2.3 From d71c74f3924e6c207f3ebfab470b9a30e2551dde Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 23 Sep 2011 16:38:01 -0400 Subject: Go through ff_klm and try to fix it for the new version. --- decoder/ff_klm.cc | 36 +++++++++--------------------------- 1 file changed, 9 insertions(+), 27 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 016aad26..3b2113ad 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -90,19 +90,12 @@ class KLanguageModelImpl { return *static_cast(state); } - inline void SetRemnantLMState(const lm::ngram::ChartState& lmstate, void* state) const { - // if we were clever, we could use the memory pointed to by state to do all - // the work, avoiding this copy - memcpy(state, &lmstate, ngram_->StateSize()); - } - public: double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { - double sum = 0.0; if (oovs) *oovs = 0; const vector& e = rule.e(); lm::ngram::ChartState state; - lm::ngram::RuleScore ruleScore(*ngram_, state); + lm::ngram::RuleScore ruleScore(*ngram_, remnant ? *static_cast(remnant) : state); unsigned i = 0; if (e.size()) { if (e[i] == kCDEC_SOS) { @@ -123,12 +116,13 @@ class KLanguageModelImpl { // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id const bool is_oov = (cur_word == 0); - if (is_oov) (*oovs) += 1.0; + if (is_oov && oovs) (*oovs) += 1.0; ruleScore.Terminal(cur_word); } } - if (remnant) SetRemnantLMState(state, remnant); - return ruleScore.Finish(); + double ret = ruleScore.Finish(); + state.ZeroRemaining(); + return ret; } // this assumes no target words on final unary -> goal rule. is that ok? @@ -138,10 +132,9 @@ class KLanguageModelImpl { lm::ngram::ChartState cstate; lm::ngram::RuleScore ruleScore(*ngram_, cstate); ruleScore.BeginSentence(); - SetRemnantLMState(cstate, dummy_state_); - dummy_ants_[1] = state; - *oovs = 0; - return LookupWords(*dummy_rule_, dummy_ants_, oovs, NULL); + ruleScore.NonTerminal(RemnantLMState(state), 0.0f); + ruleScore.Terminal(kEOS_); + return ruleScore.Finish(); } else { // rules DO produce ... double p = 0; cerr << "not implemented"; abort(); // TODO @@ -187,14 +180,8 @@ class KLanguageModelImpl { } order_ = ngram_->Order(); cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; - state_size_ = sizeof(lm::ngram::ChartState); // special handling of beginning / ending sentence markers - dummy_state_ = new char[state_size_]; - memset(dummy_state_, 0, state_size_); - dummy_ants_.push_back(dummy_state_); - dummy_ants_.push_back(NULL); - dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); kSOS_ = MapWord(kCDEC_SOS); assert(kSOS_ > 0); kEOS_ = MapWord(TD::Convert("")); @@ -243,10 +230,9 @@ class KLanguageModelImpl { ~KLanguageModelImpl() { delete ngram_; - delete[] dummy_state_; } - int ReserveStateSize() const { return state_size_; } + int ReserveStateSize() const { return sizeof(lm::ngram::ChartState); } private: const WordID kCDEC_UNK; @@ -261,12 +247,8 @@ class KLanguageModelImpl { // the sentence) with 0, and anything else with -100 int order_; - int state_size_; - char* dummy_state_; - vector dummy_ants_; vector cdec2klm_map_; vector word2class_map_; // if this is a class-based LM, this is the word->class mapping - TRulePtr dummy_rule_; }; template -- cgit v1.2.3 From 747309fcb0e0b1c6d060a68286ba1cf5ed1fbfa4 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sat, 24 Sep 2011 11:33:22 -0400 Subject: Chris says remnant and oovs should not be null, so stop checking. Also, we were not properly doing ZeroRemaining, sorry. --- decoder/ff_klm.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 3b2113ad..6d9aca54 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -92,10 +92,9 @@ class KLanguageModelImpl { public: double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { - if (oovs) *oovs = 0; + *oovs = 0; const vector& e = rule.e(); - lm::ngram::ChartState state; - lm::ngram::RuleScore ruleScore(*ngram_, remnant ? *static_cast(remnant) : state); + lm::ngram::RuleScore ruleScore(*ngram_, *static_cast(remnant)); unsigned i = 0; if (e.size()) { if (e[i] == kCDEC_SOS) { @@ -115,13 +114,12 @@ class KLanguageModelImpl { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id - const bool is_oov = (cur_word == 0); - if (is_oov && oovs) (*oovs) += 1.0; + if (cur_word == 0) (*oovs) += 1.0; ruleScore.Terminal(cur_word); } } double ret = ruleScore.Finish(); - state.ZeroRemaining(); + static_cast(remnant)->ZeroRemaining(); return ret; } -- cgit v1.2.3 From b77d23a3032f42be3705e88ae1734bae779fb9a3 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 28 Sep 2011 16:19:09 +0100 Subject: test fixes --- decoder/grammar_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'decoder') diff --git a/decoder/grammar_test.cc b/decoder/grammar_test.cc index 62b8f958..cde00efa 100644 --- a/decoder/grammar_test.cc +++ b/decoder/grammar_test.cc @@ -15,12 +15,12 @@ using namespace std; class GrammarTest : public testing::Test { public: GrammarTest() { - wts.InitFromFile("test_data/weights.gt"); + Weights::InitFromFile("test_data/weights.gt", &wts); } protected: virtual void SetUp() { } virtual void TearDown() { } - Weights wts; + vector wts; }; TEST_F(GrammarTest,TestTextGrammar) { -- cgit v1.2.3 From 0af7d663194beddcde420349bbd91430e0b2e423 Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Tue, 11 Oct 2011 16:16:53 +0100 Subject: remove implicit conversion-to-double operator from LogVal that caused overflow errors, clean up some pf code --- decoder/aligner.cc | 2 +- decoder/cfg.cc | 2 +- decoder/cfg_format.h | 2 +- decoder/decoder.cc | 10 ++++---- decoder/hg.cc | 4 ++-- decoder/rule_lexer.l | 2 ++ decoder/trule.h | 15 +++++++++++- gi/pf/brat.cc | 11 --------- gi/pf/cbgi.cc | 10 -------- gi/pf/dpnaive.cc | 12 ---------- gi/pf/itg.cc | 11 --------- gi/pf/pfbrat.cc | 11 --------- gi/pf/pfdist.cc | 11 --------- gi/pf/pfnaive.cc | 11 --------- mteval/mbr_kbest.cc | 4 ++-- phrasinator/ccrp_nt.h | 24 +++++++++++++++---- training/mpi_batch_optimize.cc | 2 +- training/mpi_compute_cllh.cc | 51 +++++++++++++++++++---------------------- training/mpi_online_optimize.cc | 4 ++-- utils/logval.h | 10 ++++---- 20 files changed, 78 insertions(+), 131 deletions(-) (limited to 'decoder') diff --git a/decoder/aligner.cc b/decoder/aligner.cc index 292ee123..53e059fb 100644 --- a/decoder/aligner.cc +++ b/decoder/aligner.cc @@ -165,7 +165,7 @@ inline void WriteProbGrid(const Array2D& m, ostream* pos) { if (m(i,j) == prob_t::Zero()) { os << "\t---X---"; } else { - snprintf(b, 1024, "%0.5f", static_cast(m(i,j))); + snprintf(b, 1024, "%0.5f", m(i,j).as_float()); os << '\t' << b; } } diff --git a/decoder/cfg.cc b/decoder/cfg.cc index 651978d2..cd7e66e9 100755 --- a/decoder/cfg.cc +++ b/decoder/cfg.cc @@ -639,7 +639,7 @@ void CFG::Print(std::ostream &o,CFGFormat const& f) const { o << '['<& src, SparseVector* trg) { for (SparseVector::const_iterator it = src.begin(); it != src.end(); ++it) - trg->set_value(it->first, it->second); + trg->set_value(it->first, it->second.as_float()); } }; @@ -788,10 +788,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { const bool show_tree_structure=conf.count("show_tree_structure"); if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,oracle.show_derivation); if (conf.count("show_expected_length")) { - const PRPair res = - Inside, - PRWeightFunction >(forest); - cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl; + const PRPair res = + Inside, + PRWeightFunction >(forest); + cerr << " Expected length (words): " << (res.r / res.p).as_float() << "\t" << res << endl; } if (conf.count("show_partition")) { diff --git a/decoder/hg.cc b/decoder/hg.cc index 3ad17f1a..180986d7 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -157,14 +157,14 @@ prob_t Hypergraph::ComputeEdgePosteriors(double scale, vector* posts) co const ScaledEdgeProb weight(scale); const ScaledTransitionEventWeightFunction w2(scale); SparseVector pv; - const double inside = InsideOutside, ScaledTransitionEventWeightFunction>(*this, &pv, weight, w2); posts->resize(edges_.size()); for (int i = 0; i < edges_.size(); ++i) (*posts)[i] = prob_t(pv.value(i)); - return prob_t(inside); + return inside; } prob_t Hypergraph::ComputeBestPathThroughEdges(vector* post) const { diff --git a/decoder/rule_lexer.l b/decoder/rule_lexer.l index 9331d8ed..083a5bb1 100644 --- a/decoder/rule_lexer.l +++ b/decoder/rule_lexer.l @@ -220,6 +220,8 @@ NT [^\t \[\],]+ std::cerr << "Line " << lex_line << ": LHS and RHS arity mismatch!\n"; abort(); } + // const bool ignore_grammar_features = false; + // if (ignore_grammar_features) scfglex_num_feats = 0; TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity, scfglex_als, scfglex_num_als)); check_and_update_ctf_stack(rp); TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top()); diff --git a/decoder/trule.h b/decoder/trule.h index 4df4ec90..8eb2a059 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -5,7 +5,9 @@ #include #include #include -#include + +#include "boost/shared_ptr.hpp" +#include "boost/functional/hash.hpp" #include "sparse_vector.h" #include "wordid.h" @@ -162,4 +164,15 @@ class TRule { bool SanityCheck() const; }; +inline size_t hash_value(const TRule& r) { + size_t h = boost::hash_value(r.e_); + boost::hash_combine(h, -r.lhs_); + boost::hash_combine(h, boost::hash_value(r.f_)); + return h; +} + +inline bool operator==(const TRule& a, const TRule& b) { + return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); +} + #endif diff --git a/gi/pf/brat.cc b/gi/pf/brat.cc index 4c6ba3ef..7b60ef23 100644 --- a/gi/pf/brat.cc +++ b/gi/pf/brat.cc @@ -25,17 +25,6 @@ static unsigned kMAX_SRC_PHRASE; static unsigned kMAX_TRG_PHRASE; struct FSTState; -size_t hash_value(const TRule& r) { - size_t h = 2 - r.lhs_; - boost::hash_combine(h, boost::hash_value(r.e_)); - boost::hash_combine(h, boost::hash_value(r.f_)); - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} - double log_poisson(unsigned x, const double& lambda) { assert(lambda > 0.0); return log(lambda) * x - lgamma(x + 1) - lambda; diff --git a/gi/pf/cbgi.cc b/gi/pf/cbgi.cc index 20204e8a..97f1ba34 100644 --- a/gi/pf/cbgi.cc +++ b/gi/pf/cbgi.cc @@ -27,16 +27,6 @@ double log_decay(unsigned x, const double& b) { return log(b - 1) - x * log(b); } -size_t hash_value(const TRule& r) { - // TODO fix hash function - size_t h = boost::hash_value(r.e_) * boost::hash_value(r.f_) * r.lhs_; - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} - struct SimpleBase { SimpleBase(unsigned esize, unsigned fsize, unsigned ntsize = 144) : uniform_e(-log(esize)), diff --git a/gi/pf/dpnaive.cc b/gi/pf/dpnaive.cc index 582d1be7..608f73d5 100644 --- a/gi/pf/dpnaive.cc +++ b/gi/pf/dpnaive.cc @@ -20,18 +20,6 @@ namespace po = boost::program_options; static unsigned kMAX_SRC_PHRASE; static unsigned kMAX_TRG_PHRASE; -struct FSTState; - -size_t hash_value(const TRule& r) { - size_t h = 2 - r.lhs_; - boost::hash_combine(h, boost::hash_value(r.e_)); - boost::hash_combine(h, boost::hash_value(r.f_)); - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); diff --git a/gi/pf/itg.cc b/gi/pf/itg.cc index 2c2a86f9..ac3c16a3 100644 --- a/gi/pf/itg.cc +++ b/gi/pf/itg.cc @@ -27,17 +27,6 @@ ostream& operator<<(ostream& os, const vector& p) { return os << ']'; } -size_t hash_value(const TRule& r) { - size_t h = boost::hash_value(r.e_); - boost::hash_combine(h, -r.lhs_); - boost::hash_combine(h, boost::hash_value(r.f_)); - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} - double log_poisson(unsigned x, const double& lambda) { assert(lambda > 0.0); return log(lambda) * x - lgamma(x + 1) - lambda; diff --git a/gi/pf/pfbrat.cc b/gi/pf/pfbrat.cc index 4c6ba3ef..7b60ef23 100644 --- a/gi/pf/pfbrat.cc +++ b/gi/pf/pfbrat.cc @@ -25,17 +25,6 @@ static unsigned kMAX_SRC_PHRASE; static unsigned kMAX_TRG_PHRASE; struct FSTState; -size_t hash_value(const TRule& r) { - size_t h = 2 - r.lhs_; - boost::hash_combine(h, boost::hash_value(r.e_)); - boost::hash_combine(h, boost::hash_value(r.f_)); - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} - double log_poisson(unsigned x, const double& lambda) { assert(lambda > 0.0); return log(lambda) * x - lgamma(x + 1) - lambda; diff --git a/gi/pf/pfdist.cc b/gi/pf/pfdist.cc index 18dfd03b..81abd61b 100644 --- a/gi/pf/pfdist.cc +++ b/gi/pf/pfdist.cc @@ -24,17 +24,6 @@ namespace po = boost::program_options; shared_ptr prng; -size_t hash_value(const TRule& r) { - size_t h = boost::hash_value(r.e_); - boost::hash_combine(h, -r.lhs_); - boost::hash_combine(h, boost::hash_value(r.f_)); - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} - void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() diff --git a/gi/pf/pfnaive.cc b/gi/pf/pfnaive.cc index 43c604c3..c30e7c4f 100644 --- a/gi/pf/pfnaive.cc +++ b/gi/pf/pfnaive.cc @@ -24,17 +24,6 @@ namespace po = boost::program_options; shared_ptr prng; -size_t hash_value(const TRule& r) { - size_t h = boost::hash_value(r.e_); - boost::hash_combine(h, -r.lhs_); - boost::hash_combine(h, boost::hash_value(r.f_)); - return h; -} - -bool operator==(const TRule& a, const TRule& b) { - return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); -} - void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc index 2867b36b..64a6a8bf 100644 --- a/mteval/mbr_kbest.cc +++ b/mteval/mbr_kbest.cc @@ -32,7 +32,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } struct LossComparer { - bool operator()(const pair, double>& a, const pair, double>& b) const { + bool operator()(const pair, prob_t>& a, const pair, prob_t>& b) const { return a.second < b.second; } }; @@ -108,7 +108,7 @@ int main(int argc, char** argv) { ScoreP s = scorer->ScoreCandidate(list[j].first); double loss = 1.0 - s->ComputeScore(); if (type == TER || type == AER) loss = 1.0 - loss; - double weighted_loss = loss * (joints[j] / marginal); + double weighted_loss = loss * (joints[j] / marginal).as_float(); wl_acc += weighted_loss; if ((!output_list) && wl_acc > mbr_loss) break; } diff --git a/phrasinator/ccrp_nt.h b/phrasinator/ccrp_nt.h index 163b643a..811bce73 100644 --- a/phrasinator/ccrp_nt.h +++ b/phrasinator/ccrp_nt.h @@ -50,15 +50,26 @@ class CCRP_NoTable { return it->second; } - void increment(const Dish& dish) { - ++custs_[dish]; + int increment(const Dish& dish) { + int table_diff = 0; + if (++custs_[dish] == 1) + table_diff = 1; ++num_customers_; + return table_diff; } - void decrement(const Dish& dish) { - if ((--custs_[dish]) == 0) + int decrement(const Dish& dish) { + int table_diff = 0; + int nc = --custs_[dish]; + if (nc == 0) { custs_.erase(dish); + table_diff = -1; + } else if (nc < 0) { + std::cerr << "Dish counts dropped below zero for: " << dish << std::endl; + abort(); + } --num_customers_; + return table_diff; } double prob(const Dish& dish, const double& p0) const { @@ -66,6 +77,11 @@ class CCRP_NoTable { return (at_table + p0 * concentration_) / (num_customers_ + concentration_); } + double logprob(const Dish& dish, const double& logp0) const { + const unsigned at_table = num_customers(dish); + return log(at_table + exp(logp0 + log(concentration_))) - log(num_customers_ + concentration_); + } + double log_crp_prob() const { return log_crp_prob(concentration_); } diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index 0ba8c530..046e921c 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -92,7 +92,7 @@ struct TrainingObserver : public DecoderObserver { void SetLocalGradientAndObjective(vector* g, double* o) const { *o = acc_obj; for (SparseVector::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) - (*g)[it->first] = it->second; + (*g)[it->first] = it->second.as_float(); } virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { diff --git a/training/mpi_compute_cllh.cc b/training/mpi_compute_cllh.cc index b496d196..d5caa745 100644 --- a/training/mpi_compute_cllh.cc +++ b/training/mpi_compute_cllh.cc @@ -1,6 +1,4 @@ -#include #include -#include #include #include #include @@ -12,6 +10,7 @@ #include #include +#include "sentence_metadata.h" #include "verbose.h" #include "hg.h" #include "prob.h" @@ -52,7 +51,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { return true; } -void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* ids) { +void ReadInstances(const string& fname, int rank, int size, vector* c) { + assert(fname != "-"); ReadFile rf(fname); istream& in = *rf.stream(); string line; @@ -60,20 +60,16 @@ void ReadTrainingCorpus(const string& fname, int rank, int size, vector* while(in) { getline(in, line); if (!in) break; - if (lc % size == rank) { - c->push_back(line); - ids->push_back(lc); - } + if (lc % size == rank) c->push_back(line); ++lc; } } static const double kMINUS_EPSILON = -1e-6; -struct TrainingObserver : public DecoderObserver { - void Reset() { - acc_obj = 0; - } +struct ConditionalLikelihoodObserver : public DecoderObserver { + + ConditionalLikelihoodObserver() : trg_words(), acc_obj(), cur_obj() {} virtual void NotifyDecodingStart(const SentenceMetadata&) { cur_obj = 0; @@ -120,8 +116,10 @@ struct TrainingObserver : public DecoderObserver { } assert(!isnan(log_ref_z)); acc_obj += (cur_obj - log_ref_z); + trg_words += smeta.GetReference().size(); } + unsigned trg_words; double acc_obj; double cur_obj; int state; @@ -161,35 +159,32 @@ int main(int argc, char** argv) { if (conf.count("weights")) Weights::InitFromFile(conf["weights"].as(), &weights); - // freeze feature set - //const bool freeze_feature_set = conf.count("freeze_feature_set"); - //if (freeze_feature_set) FD::Freeze(); - - vector corpus; vector ids; - ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); + vector corpus; + ReadInstances(conf["training_data"].as(), rank, size, &corpus); assert(corpus.size() > 0); - assert(corpus.size() == ids.size()); - - TrainingObserver observer; - double objective = 0; - observer.Reset(); if (rank == 0) - cerr << "Each processor is decoding " << corpus.size() << " training examples...\n"; + cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; - for (int i = 0; i < corpus.size(); ++i) { - decoder.SetId(ids[i]); + ConditionalLikelihoodObserver observer; + for (int i = 0; i < corpus.size(); ++i) decoder.Decode(corpus[i], &observer); - } + double objective = 0; + unsigned total_words = 0; #ifdef HAVE_MPI reduce(world, observer.acc_obj, objective, std::plus(), 0); + reduce(world, observer.trg_words, total_words, std::plus(), 0); #else objective = observer.acc_obj; #endif - if (rank == 0) - cout << "OBJECTIVE: " << objective << endl; + if (rank == 0) { + cout << "CONDITIONAL LOG_e LIKELIHOOD: " << objective << endl; + cout << "CONDITIONAL LOG_2 LIKELIHOOD: " << (objective/log(2)) << endl; + cout << " CONDITIONAL ENTROPY: " << (objective/log(2) / total_words) << endl; + cout << " PERPLEXITY: " << pow(2, (objective/log(2) / total_words)) << endl; + } return 0; } diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 2ef4a2e7..f87b7274 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -94,7 +94,7 @@ struct TrainingObserver : public DecoderObserver { void SetLocalGradientAndObjective(vector* g, double* o) const { *o = acc_obj; for (SparseVector::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) - (*g)[it->first] = it->second; + (*g)[it->first] = it->second.as_float(); } virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { @@ -158,7 +158,7 @@ struct TrainingObserver : public DecoderObserver { void GetGradient(SparseVector* g) const { g->clear(); for (SparseVector::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) - g->set_value(it->first, it->second); + g->set_value(it->first, it->second.as_float()); } int total_complete; diff --git a/utils/logval.h b/utils/logval.h index 6fdc2c42..8a59d0b1 100644 --- a/utils/logval.h +++ b/utils/logval.h @@ -25,12 +25,13 @@ class LogVal { typedef LogVal Self; LogVal() : s_(), v_(LOGVAL_LOG0) {} - explicit LogVal(double x) : s_(std::signbit(x)), v_(s_ ? std::log(-x) : std::log(x)) {} + LogVal(double x) : s_(std::signbit(x)), v_(s_ ? std::log(-x) : std::log(x)) {} + const Self& operator=(double x) { s_ = std::signbit(x); v_ = s_ ? std::log(-x) : std::log(x); return *this; } LogVal(init_minus_1) : s_(true),v_(0) { } LogVal(init_1) : s_(),v_(0) { } LogVal(init_0) : s_(),v_(LOGVAL_LOG0) { } - LogVal(int x) : s_(x<0), v_(s_ ? std::log(-x) : std::log(x)) {} - LogVal(unsigned x) : s_(0), v_(std::log(x)) { } + explicit LogVal(int x) : s_(x<0), v_(s_ ? std::log(-x) : std::log(x)) {} + explicit LogVal(unsigned x) : s_(0), v_(std::log(x)) { } LogVal(double lnx,bool sign) : s_(sign),v_(lnx) {} LogVal(double lnx,init_lnx) : s_(),v_(lnx) {} static Self exp(T lnx) { return Self(lnx,false); } @@ -141,9 +142,6 @@ class LogVal { return pow(1/root); } - operator T() const { - if (s_) return -std::exp(v_); else return std::exp(v_); - } T as_float() const { if (s_) return -std::exp(v_); else return std::exp(v_); } -- cgit v1.2.3 From 171027795ba3a01ba2ed82d7036610ac397e1fe8 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 14 Oct 2011 11:51:12 +0100 Subject: remove FSA integration code. will have to be resurrected another day --- decoder/Makefile.am | 1 - decoder/apply_fsa_models.cc | 798 ---------------------------------------- decoder/cdec_ff.cc | 13 - decoder/feature_accum.h | 129 ------- decoder/ff_factory.h | 2 - decoder/ff_from_fsa.h | 304 --------------- decoder/ff_fsa.h | 401 -------------------- decoder/ff_fsa_data.h | 131 ------- decoder/ff_fsa_dynamic.h | 208 ----------- decoder/ff_lm.cc | 48 --- decoder/ff_lm_fsa.h | 140 ------- decoder/ff_register.h | 38 -- decoder/hg_test.cc | 16 +- training/mpi_online_optimize.cc | 2 + 14 files changed, 10 insertions(+), 2221 deletions(-) delete mode 100755 decoder/apply_fsa_models.cc delete mode 100755 decoder/feature_accum.h delete mode 100755 decoder/ff_from_fsa.h delete mode 100755 decoder/ff_fsa.h delete mode 100755 decoder/ff_fsa_data.h delete mode 100755 decoder/ff_fsa_dynamic.h delete mode 100755 decoder/ff_lm_fsa.h (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index ede1cff0..6b9360d8 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -42,7 +42,6 @@ libcdec_a_SOURCES = \ cfg.cc \ dwarf.cc \ ff_dwarf.cc \ - apply_fsa_models.cc \ rule_lexer.cc \ fst_translator.cc \ csplit.cc \ diff --git a/decoder/apply_fsa_models.cc b/decoder/apply_fsa_models.cc deleted file mode 100755 index 3e93cadd..00000000 --- a/decoder/apply_fsa_models.cc +++ /dev/null @@ -1,798 +0,0 @@ -//see apply_fsa_models.README for notes on the l2r earley fsa+cfg intersection -//implementation in this file (also some comments in this file) -#define SAFE_VALGRIND 1 - -#include "apply_fsa_models.h" -#include -#include -#include -#include - -#include "writer.h" -#include "hg.h" -#include "ff_fsa_dynamic.h" -#include "ff_from_fsa.h" -#include "feature_vector.h" -#include "stringlib.h" -#include "apply_models.h" -#include "cfg.h" -#include "hg_cfg.h" -#include "utoa.h" -#include "hash.h" -#include "value_array.h" -#include "d_ary_heap.h" -#include "agenda.h" -#include "show.h" -#include "string_to.h" - - -#define DFSA(x) x -//fsa earley chart - -#define DPFSA(x) x -//prefix trie - -#define DBUILDTRIE(x) - -#define PRINT_PREFIX 1 -#if PRINT_PREFIX -# define IF_PRINT_PREFIX(x) x -#else -# define IF_PRINT_PREFIX(x) -#endif -// keep backpointers in prefix trie so you can print a meaningful node id - -static const unsigned FSA_AGENDA_RESERVE=10; // TODO: increase to 1<<24 (16M) - -using namespace std; - -//impl details (not exported). flat namespace for my ease. - -typedef CFG::RHS RHS; -typedef CFG::BinRhs BinRhs; -typedef CFG::NTs NTs; -typedef CFG::NT NT; -typedef CFG::NTHandle NTHandle; -typedef CFG::Rules Rules; -typedef CFG::Rule Rule; -typedef CFG::RuleHandle RuleHandle; - -namespace { - -/* - -1) A -> x . * (trie) - -this is somewhat nice. cost pushed for best first, of course. similar benefit as left-branching binarization without the explicit predict/complete steps? - -vs. just - -2) * -> x . y - -here you have to potentially list out all A -> . x y as items * -> . x y immediately, and shared rhs seqs won't be shared except at the usual single-NT predict/complete. of course, the prediction of items -> . x y can occur lazy best-first. - -vs. - -3) * -> x . * - -with 3, we predict all sorts of useless items - that won't give us our goal A and may not partcipate in any parse. this is not a good option at all. - -I'm using option 1. -*/ - -// if we don't greedy-binarize, we want to encode recognized prefixes p (X -> p . rest) efficiently. if we're doing this, we may as well also push costs so we can best-first select rules in a lazy fashion. this is effectively left-branching binarization, of course. - -template -struct fsa_map_type { - typedef std::map type; // change to HASH_MAP ? -}; -//template typedef - and macro to make it less painful -#define FSA_MAP(k,v) fsa_map_type >::type - -struct PrefixTrieNode; -typedef PrefixTrieNode *NodeP; -typedef PrefixTrieNode const *NodePc; - -// for debugging prints only -struct TrieBackP { - WordID w; - NodePc from; - TrieBackP(WordID w=0,NodePc from=0) : w(w),from(from) { } -}; - -FsaFeatureFunction const* print_fsa=0; -CFG const* print_cfg=0; -inline ostream& print_cfg_rhs(std::ostream &o,WordID w,CFG const*pcfg=print_cfg) { - if (pcfg) - pcfg->print_rhs_name(o,w); - else - CFG::static_print_rhs_name(o,w); - return o; -} - -inline std::string nt_name(WordID n,CFG const*pcfg=print_cfg) { - if (pcfg) return pcfg->nt_name(n); - return CFG::static_nt_name(n); -} - -template -ostream& print_by_nt(std::ostream &o,V const& v,CFG const*pcfg=print_cfg,char const* header="\nNT -> X\n") { - o< "< -ostream& print_map_by_nt(std::ostream &o,V const& v,CFG const*pcfg=print_cfg,char const* header="\nNT -> X\n") { - o<first,pcfg) << " -> "<second<<"\n"; - } - return o; -} - -struct PrefixTrieEdge { - PrefixTrieEdge() - // : dest(0),w(TD::max_wordid) - {} - PrefixTrieEdge(WordID w,NodeP dest) - : dest(dest),w(w) - {} -// explicit PrefixTrieEdge(best_t p) : p(p),dest(0) { } - - best_t p;// viterbi additional prob, i.e. product over path incl. p_final = total rule prob. note: for final edge, set this. - //DPFSA() - // we can probably just store deltas, but for debugging remember the full p - // best_t delta; // - NodeP dest; - bool is_final() const { return dest==0; } - best_t p_dest() const; - WordID w; // for root and and is_final(), this will be (negated) NTHandle. - - // for sorting most probable first in adj; actually >(p) - inline bool operator <(PrefixTrieEdge const& o) const { - return o.p"< BPs; - void back_vec(BPs &ns) const { - IF_PRINT_PREFIX(if(backp.from) { ns.push_back(backp); backp.from->back_vec(ns); }) - } - - BPs back_vec() const { - BPs ret; - back_vec(ret); - return ret; - } - - unsigned size() const { - unsigned a=adj.size(); - unsigned e=edge_for.size(); - return a>e?a:e; - } - - void print_back_str(std::ostream &o) const { - BPs back=back_vec(); - unsigned i=back.size(); - if (!i) { - o<<"PrefixTrieNode@"<<(uintptr_t)this; - return; - } - bool first=true; - while (i--<=0) { - if (!first) o<<','; - first=false; - WordID w=back[i].w; - print_cfg_rhs(o,w); - } - } - std::string back_str() const { - std::ostringstream o; - print_back_str(o); - return o.str(); - } - -// best_t p_final; // additional prob beyond what we already paid. while building, this is the total prob -// instead of storing final, we'll say that an edge with a NULL dest is a final edge. this way it gets sorted into the list of adj. - - // instead of completed map, we have trie start w/ lhs. - NTHandle lhs; // nonneg. - instead of storing this in Item. - IF_PRINT_PREFIX(BP backp;) - - enum { ROOT=-1 }; - explicit PrefixTrieNode(NTHandle lhs=ROOT,best_t p=1) : p(p),lhs(lhs),IF_PRINT_PREFIX(backp()) { - //final=false; - } - bool is_root() const { return lhs==ROOT; } // means adj are the nonneg lhs indices, and we have the index edge_for still available - - // outgoing edges will be ordered highest p to worst p - - typedef FSA_MAP(WordID,PrefixTrieEdge) PrefixTrieEdgeFor; -public: - PrefixTrieEdgeFor edge_for; //TODO: move builder elsewhere? then need 2nd hash or edge include pointer to builder. just clear this later - bool have_adj() const { - return adj.size()>=edge_for.size(); - } - bool no_adj() const { - return adj.empty(); - } - - void index_adj() { - index_adj(edge_for); - } - template - void index_adj(M &m) { - assert(have_adj()); - m.clear(); - for (int i=0;i - void index_lhs(PV &v) { - for (int i=0,e=adj.size();i!=e;++i) { - PrefixTrieEdge const& edge=adj[i]; - // assert(edge.p.is_1()); // actually, after done_building, e will have telescoped dest->p/p. - NTHandle n=-edge.w; - assert(n>=0); -// SHOWM3(DPFSA,"index_lhs",i,edge,n); - v[n]=edge.dest; - } - } - - template - void done_root(PV &v) { - assert(is_root()); - SHOWM1(DBUILDTRIE,"done_root",OSTRF1(print_map_by_nt,edge_for)); - done_building_r(); //sets adj - SHOWM1(DBUILDTRIE,"done_root",OSTRF1(print_by_nt,adj)); -// SHOWM1(DBUILDTRIE,done_root,adj); -// index_adj(); // we want an index for the root node?. don't think so - index_lhs handles it. also we stopped clearing edge_for. - index_lhs(v); // uses adj - } - - // call only once. - void done_building_r() { - done_building(); - for (int i=0;idone_building_r(); - } - - // for done_building; compute incremental (telescoped) edge p - PrefixTrieEdge /*const&*/ operator()(PrefixTrieEdgeFor::value_type & pair) const { - PrefixTrieEdge &e=pair.second;//const_cast(pair.second); - e.p=e.p_dest()/p; - return e; - } - - // call only once. - void done_building() { - SHOWM3(DBUILDTRIE,"done_building",edge_for.size(),adj.size(),1); -#if 1 - adj.reinit_map(edge_for,*this); -#else - adj.reinit(edge_for.size()); - SHOWM3(DBUILDTRIE,"done_building_reinit",edge_for.size(),adj.size(),2); - Adj::iterator o=adj.begin(); - for (PrefixTrieEdgeFor::iterator i=edge_for.begin(),e=edge_for.end();i!=e;++i) { - SHOWM3(DBUILDTRIE,"edge_for",o-adj.begin(),i->first,i->second); - PrefixTrieEdge &edge=i->second; - edge.p=(edge.dest->p)/p; - *o++=edge; -// (*this)(*i); - } -#endif - SHOWM1(DBUILDTRIE,"done building adj",prange(adj.begin(),adj.end(),true)); - assert(adj.size()==edge_for.size()); -// if (final) p_final/=p; - std::sort(adj.begin(),adj.end()); - //TODO: store adjacent differences on edges (compared to - } - - typedef ValueArray Adj; -// typedef vector Adj; - Adj adj; - - typedef WordID W; - - // let's compute p_min so that every rule reachable from the created node has p at least this low. - NodeP improve_edge(PrefixTrieEdge const& e,best_t rulep) { - NodeP d=e.dest; - maybe_improve(d->p,rulep); - return d; - } - - inline NodeP build(W w,best_t rulep) { - return build(lhs,w,rulep); - } - inline NodeP build_lhs(NTHandle n,best_t rulep) { - return build(n,-n,rulep); - } - - NodeP build(NTHandle lhs_,W w,best_t rulep) { - PrefixTrieEdgeFor::iterator i=edge_for.find(w); - if (i!=edge_for.end()) - return improve_edge(i->second,rulep); - NodeP r=new PrefixTrieNode(lhs_,rulep); - IF_PRINT_PREFIX(r->backp=BP(w,this)); -// edge_for.insert(i,PrefixTrieEdgeFor::value_type(w,PrefixTrieEdge(w,r))); - add(edge_for,w,PrefixTrieEdge(w,r)); - SHOWM4(DBUILDTRIE,"built node",this,w,*r,r); - return r; - } - - void set_final(NTHandle lhs_,best_t pf) { - assert(no_adj()); -// final=true; - PrefixTrieEdge &e=edge_for[null_wordid]; - e.p=pf; - e.dest=0; - e.w=lhs_; - maybe_improve(p,pf); - } - -private: - void destroy_children() { - assert(adj.size()>=edge_for.size()); - for (int i=0,e=adj.size();i" << p; - o << ',' << size() << ','; - print_back_str(o); - } - PRINT_SELF(PrefixTrieNode) -}; - -inline best_t PrefixTrieEdge::p_dest() const { - return dest ? dest->p : p; // for final edge, p was set (no sentinel node) -} - - -//Trie starts with lhs (nonneg index), then continues w/ rhs (mixed >0 word, else NT) -// trie ends with final edge, which points to a per-lhs prefix node -struct PrefixTrie { - void print(std::ostream &o) const { - o << cfgp << ' ' << root; - } - PRINT_SELF(PrefixTrie); - CFG *cfgp; - Rules const* rulesp; - Rules const& rules() const { return *rulesp; } - CFG const& cfg() const { return *cfgp; } - PrefixTrieNode root; - typedef std::vector LhsToTrie; // will have to check lhs2[lhs].p for best cost of some rule with that lhs, then use edge deltas after? they're just caching a very cheap computation, really - LhsToTrie lhs2; // no reason to use a map or hash table; every NT in the CFG will have some rule rhses. lhs_to_trie[i]=root.edge_for[i], i.e. we still have a root trie node conceptually, we just access through this since it's faster. - typedef LhsToTrie LhsToComplete; - LhsToComplete lhs2complete; // the sentinel "we're completing" node (dot at end) for that lhs. special case of suffix-set=same trie minimization (aka right branching binarization) // these will be used to track kbest completions, along with a l state (r state will be in the list) - PrefixTrie(CFG &cfg) : cfgp(&cfg),rulesp(&cfg.rules),lhs2(cfg.nts.size(),0),lhs2complete(cfg.nts.size()) { -// cfg.SortLocalBestFirst(); // instead we'll sort in done_building_r - print_cfg=cfgp; - SHOWM2(DBUILDTRIE,"PrefixTrie()",rulesp->size(),lhs2.size()); - cfg.VisitRuleIds(*this); - root.done_root(lhs2); - SHOWM3(DBUILDTRIE,"done w/ PrefixTrie: ",root,root.adj.size(),lhs2.size()); - DBUILDTRIE(print_by_nt(cerr,lhs2,cfgp)); - SHOWM1(DBUILDTRIE,"lhs2",OSTRF2(print_by_nt,lhs2,cfgp)); - } - - void operator()(int ri) { - Rule const& r=rules()[ri]; - NTHandle lhs=r.lhs; - best_t p=r.p; -// NodeP n=const_cast(root).build_lhs(lhs,p); - NodeP n=root.build_lhs(lhs,p); - SHOWM4(DBUILDTRIE,"Prefixtrie rule id, root",ri,root,p,*n); - for (RHS::const_iterator i=r.rhs.begin(),e=r.rhs.end();;++i) { - SHOWM2(DBUILDTRIE,"PrefixTrie build or final",i-r.rhs.begin(),*n); - if (i==e) { - n->set_final(lhs,p); - break; - } - n=n->build(*i,p); - SHOWM2(DBUILDTRIE,"PrefixTrie built",*i,*n); - } -// root.build(lhs,r.p)->build(r.rhs,r.p); - } - inline NodeP lhs2_ex(NTHandle n) const { - NodeP r=lhs2[n]; - if (!r) throw std::runtime_error("PrefixTrie: no CFG rule w/ lhs "+cfgp->nt_name(n)); - return r; - } -private: - PrefixTrie(PrefixTrie const& o); -}; - - - -typedef std::size_t ItemHash; - - -struct ItemKey { - explicit ItemKey(NodeP start,Bytes const& start_state) : dot(start),q(start_state),r(start_state) { } - explicit ItemKey(NodeP dot) : dot(dot) { } - NodeP dot; // dot is a function of the stuff already recognized, and gives a set of suffixes y to complete to finish a rhs for lhs() -> dot y. for a lhs A -> . *, this will point to lh2[A] - Bytes q,r; // (q->r are the fsa states; if r is empty it means - bool operator==(ItemKey const& o) const { - return dot==o.dot && q==o.q && r==o.r; - } - inline ItemHash hash() const { - ItemHash h=GOLDEN_MEAN_FRACTION*(ItemHash)(dot-NULL); // i.e. lower order bits of ptr are nonrandom - using namespace boost; - hash_combine(h,q); - hash_combine(h,r); - return h; - } - template - void print(O &o) const { - o<<"lhs="<print_back_str(o); - if (print_fsa) { - o<<'/'; - print_fsa->print_state(o,&q[0]); - o<<"->"; - print_fsa->print_state(o,&r[0]); - } - } - NTHandle lhs() const { return dot->lhs; } - PRINT_SELF(ItemKey) -}; -inline ItemHash hash_value(ItemKey const& x) { - return x.hash(); -} -ItemKey null_item((PrefixTrieNode*)0); - -struct Item; -typedef Item *ItemP; - -/* we use a single type of item so it can live in a single best-first queue. we hold them by pointer so they can have mutable state, e.g. priority/location, but also lists of predictions and kbest completions (i.e. completions[L,r] = L -> * (r,s), by 1best for each possible s. we may discover more s later. we could use different subtypes since we hold by pointer, but for now everything will be packed as variants of Item */ -#undef INIT_LOCATION -#if D_ARY_TRACK_OUT_OF_HEAP -# define INIT_LOCATION , location(D_ARY_HEAP_NULL_INDEX) -#elif !defined(NDEBUG) || SAFE_VALGRIND - // avoid spurious valgrind warning - FIXME: still complains??? -# define INIT_LOCATION , location() -#else -# define INIT_LOCATION -#endif - -// these should go in a global best-first queue -struct ItemPrio { - // NOTE: sum = viterbi (max) - ItemPrio() : priority(init_0()),inner(init_0()) { } - explicit ItemPrio(best_t priority) : priority(priority),inner(init_0()) { } - best_t priority; // includes inner prob. (forward) - /* The forward probability alpha_i(X[k]->x.y) is the sum of the probabilities of all - constrained paths of length i that end in state X[k]->x.y*/ - best_t inner; - /* The inner probability beta_i(X[k]->x.y) is the sum of the probabilities of all - paths of length i-k that start in state X[k,k]->.xy and end in X[k,i]->x.y, and generate the input symbols x[k,...,i-1] */ - template - void print(O &o) const { - o<=0; - } - explicit Item(FFState const& state,NodeP dot,best_t prio,int next=0) : ItemPrio(prio),ItemKey(dot,state),trienext(next),from(0) - INIT_LOCATION - { -// t=ADJ; -// if (dot->adj.size()) - dot->p_delta(next,priority); -// SHOWM1(DFSA,"Item(state,dot,prio)",prio); - } - typedef std::queue Predicted; -// Predicted predicted; // this is empty, unless this is a predicted L -> .asdf item, or a to-complete L -> asdf . - int trienext; // index of dot->adj to complete (if dest==0), or predict (if NT), or scan (if word). note: we could store pointer inside adj since it and trie are @ fixed addrs. less pointer arith, more space. - ItemP from; //backpointer - 0 for L -> . asdf for the rest; L -> a .sdf, it's the L -> .asdf item. - ItemP predicted_from() const { - ItemP p=(ItemP)this; - while(p->from) p=p->from; - return p; - } - template - void print(O &o) const { - o<< '['; - o< -struct ApplyFsa { - ApplyFsa(HgCFG &i, - const SentenceMetadata& smeta, - const FsaFeatureFunction& fsa, - DenseWeightVector const& weights, - ApplyFsaBy const& by, - Hypergraph* oh - ) - :hgcfg(i),smeta(smeta),fsa(fsa),weights(weights),by(by),oh(oh) - { - stateless=!fsa.state_bytes(); - } - void Compute() { - if (by.IsBottomUp() || stateless) - ApplyBottomUp(); - else - ApplyEarley(); - } - void ApplyBottomUp(); - void ApplyEarley(); - CFG const& GetCFG(); -private: - CFG cfg; - HgCFG &hgcfg; - SentenceMetadata const& smeta; - FsaFF const& fsa; -// WeightVector weight_vector; - DenseWeightVector weights; - ApplyFsaBy by; - Hypergraph* oh; - std::string cfg_out; - bool stateless; -}; - -template -void ApplyFsa::ApplyBottomUp() -{ - assert(by.IsBottomUp()); - FeatureFunctionFromFsa buff(&fsa); - buff.Init(); // mandatory to call this (normally factory would do it) - vector ffs(1,&buff); - ModelSet models(weights, ffs); - IntersectionConfiguration i(stateless ? BU_FULL : by.BottomUpAlgorithm(),by.pop_limit); - ApplyModelSet(hgcfg.ih,smeta,models,i,oh); -} - -template -void ApplyFsa::ApplyEarley() -{ - hgcfg.GiveCFG(cfg); - print_cfg=&cfg; - print_fsa=&fsa; - Chart chart(cfg,smeta,fsa); - // don't need to uniq - option to do that already exists in cfg_options - //TODO: - chart.best_first(); - *oh=hgcfg.ih; -} - - -void ApplyFsaModels(HgCFG &i, - const SentenceMetadata& smeta, - const FsaFeatureFunction& fsa, - DenseWeightVector const& weight_vector, - ApplyFsaBy const& by, - Hypergraph* oh) -{ - ApplyFsa a(i,smeta,fsa,weight_vector,by,oh); - a.Compute(); -} - -/* -namespace { -char const* anames[]={ - "BU_CUBE", - "BU_FULL", - "EARLEY", - 0 -}; -} -*/ - -//TODO: named enum type in boost? - -std::string ApplyFsaBy::name() const { -// return anames[algorithm]; - return GetName(algorithm); -} - -std::string ApplyFsaBy::all_names() { - return FsaByNames(" "); - /* - std::ostringstream o; - for (int i=0;i=N_ALGORITHMS) - throw std::runtime_error("Unknown ApplyFsaBy type id: "+itos(i)+" - legal types: "+all_names()); -*/ - GetName(i); // checks validity - algorithm=i; -} - -int ApplyFsaBy::BottomUpAlgorithm() const { - assert(IsBottomUp()); - return algorithm==BU_CUBE ? - IntersectionConfiguration::CUBE - :IntersectionConfiguration::FULL; -} - -void ApplyFsaModels(Hypergraph const& ih, - const SentenceMetadata& smeta, - const FsaFeatureFunction& fsa, - DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this) - ApplyFsaBy const& cfg, - Hypergraph* out) -{ - HgCFG i(ih); - ApplyFsaModels(i,smeta,fsa,weights,cfg,out); -} diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 69f40c93..4ce5749e 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -12,8 +12,6 @@ #include "ff_rules.h" #include "ff_ruleshape.h" #include "ff_bleu.h" -#include "ff_lm_fsa.h" -#include "ff_sample_fsa.h" #include "ff_source_syntax.h" #include "ff_register.h" #include "ff_charset.h" @@ -31,15 +29,6 @@ void register_feature_functions() { } registered = true; - //TODO: these are worthless example target FSA ffs. remove later - RegisterFsaImpl(true); - RegisterFsaImpl(true); - RegisterFsaImpl(true); -// ff_registry.Register("LanguageModelFsaDynamic",new FFFactory > >); // to test correctness of FsaFeatureFunctionDynamic erasure - RegisterFsaDynToFF(); - RegisterFsaImpl(true); // same as LM but using fsa wrapper - RegisterFsaDynToFF(); - RegisterFF(); RegisterFF(); @@ -47,8 +36,6 @@ void register_feature_functions() { RegisterFF(); RegisterFF(); - ff_registry.Register(new FFFactory); // same as WordPenalty, but implemented using ff_fsa - //TODO: use for all features the new Register which requires static FF::usage(false,false) give name #ifdef HAVE_RANDLM ff_registry.Register("RandLM", new FFFactory); diff --git a/decoder/feature_accum.h b/decoder/feature_accum.h deleted file mode 100755 index 4b8338eb..00000000 --- a/decoder/feature_accum.h +++ /dev/null @@ -1,129 +0,0 @@ -#ifndef FEATURE_ACCUM_H -#define FEATURE_ACCUM_H - -#include "ff.h" -#include "sparse_vector.h" -#include "value_array.h" - -struct SparseFeatureAccumulator : public FeatureVector { - typedef FeatureVector State; - SparseFeatureAccumulator() { assert(!"this code is disabled"); } - template - FeatureVector const& describe(FF const& ) { return *this; } - void Store(FeatureVector *fv) const { -//NO fv->set_from(*this); - } - template - void Store(FF const& /* ff */,FeatureVector *fv) const { -//NO fv->set_from(*this); - } - template - void Add(FF const& /* ff */,FeatureVector const& fv) { - (*this)+=fv; - } - void Add(FeatureVector const& fv) { - (*this)+=fv; - } - /* - SparseFeatureAccumulator(FeatureVector const& fv) : State(fv) {} - FeatureAccumulator(Features const& fids) {} - FeatureAccumulator(Features const& fids,FeatureVector const& fv) : State(fv) {} - void Add(Features const& fids,FeatureVector const& fv) { - *this += fv; - } - */ - void Add(int i,Featval v) { -//NO (*this)[i]+=v; - } - void Add(Features const& fids,int i,Featval v) { -//NO (*this)[i]+=v; - } -}; - -struct SingleFeatureAccumulator { - typedef Featval State; - typedef SingleFeatureAccumulator Self; - State v; - /* - void operator +=(State const& o) { - v+=o; - } - */ - void operator +=(Self const& s) { - v+=s.v; - } - SingleFeatureAccumulator() : v() {} - template - State const& describe(FF const& ) const { return v; } - - template - void Store(FF const& ff,FeatureVector *fv) const { - fv->set_value(ff.fid_,v); - } - void Store(Features const& fids,FeatureVector *fv) const { - assert(fids.size()==1); - fv->set_value(fids[0],v); - } - /* - SingleFeatureAccumulator(Features const& fids) { assert(fids.size()==1); } - SingleFeatureAccumulator(Features const& fids,FeatureVector const& fv) - { - assert(fids.size()==1); - v=fv.get_singleton(); - } - */ - - template - void Add(FF const& ff,FeatureVector const& fv) { - v+=fv.get(ff.fid_); - } - void Add(FeatureVector const& fv) { - v+=fv.get_singleton(); - } - - void Add(Features const& fids,FeatureVector const& fv) { - v += fv.get(fids[0]); - } - void Add(Featval dv) { - v+=dv; - } - void Add(int,Featval dv) { - v+=dv; - } - void Add(FeatureVector const& fids,int i,Featval dv) { - assert(fids.size()==1 && i==0); - v+=dv; - } -}; - - -#if 0 -// omitting this so we can default construct an accum. might be worth resurrecting in the future -struct ArrayFeatureAccumulator : public ValueArray { - typedef ValueArray State; - template - ArrayFeatureAccumulator(Fsa const& fsa) : State(fsa.features_.size()) { } - ArrayFeatureAccumulator(Features const& fids) : State(fids.size()) { } - ArrayFeatureAccumulator(Features const& fids) : State(fids.size()) { } - ArrayFeatureAccumulator(Features const& fids,FeatureVector const& fv) : State(fids.size()) { - for (int i=0,e=iset_value(fids[i],(*this)[i]); - } - void Add(Features const& fids,FeatureVector const& fv) { - for (int i=0,e=i -#include "ff_fsa_dynamic.h" - class FeatureFunction; class FsaFeatureFunction; diff --git a/decoder/ff_from_fsa.h b/decoder/ff_from_fsa.h deleted file mode 100755 index f8d79e03..00000000 --- a/decoder/ff_from_fsa.h +++ /dev/null @@ -1,304 +0,0 @@ -#ifndef FF_FROM_FSA_H -#define FF_FROM_FSA_H - -#include "ff_fsa.h" - -#ifndef TD__none -// replacing dependency on SRILM -#define TD__none -1 -#endif - -#ifndef FSA_FF_DEBUG -# define FSA_FF_DEBUG 0 -#endif -#if FSA_FF_DEBUG -# define FSAFFDBG(e,x) FSADBGif(debug(),e,x) -# define FSAFFDBGnl(e) FSADBGif_nl(debug(),e) -#else -# define FSAFFDBG(e,x) -# define FSAFFDBGnl(e) -#endif - -/* regular bottom up scorer from Fsa feature - uses guarantee about markov order=N to score ASAP - encoding of state: if less than N-1 (ctxlen) words - - usage: - typedef FeatureFunctionFromFsa LanguageModelFromFsa; -*/ - -template -class FeatureFunctionFromFsa : public FeatureFunction { - typedef void const* SP; - typedef WordID *W; - typedef WordID const* WP; -public: - template - FeatureFunctionFromFsa(I const& param) : ff(param) { - debug_=true; // because factory won't set until after we construct. - } - template - FeatureFunctionFromFsa(I & param) : ff(param) { - debug_=true; // because factory won't set until after we construct. - } - - static std::string usage(bool args,bool verbose) { - return Impl::usage(args,verbose); - } - void init_name_debug(std::string const& n,bool debug) { - FeatureFunction::init_name_debug(n,debug); - ff.init_name_debug(n,debug); - } - - // this should override - Features features() const { - DBGINIT("FeatureFunctionFromFsa features() name="<=1) - for (int j=0,ee=e.size();;++j) { // items in target side of rule - for(;;++j) { - if (j>=ee) goto rhs_done; // j may go 1 past ee due to k possibly getting to end - if (RHS_WORD(j)) break; - } - // word @j - int k=j; - while(k{"<") - FSAFFDBG(edge," end="<{"< -# define FSADBG(e,x) FSADBGif(d().debug(),e,x) -# define FSADBGnl(e) FSADBGif_nl(d().debug(),e,x) -#else -# define FSADBG(e,x) -# define FSADBGnl(e) -#endif - -#include "fast_lexical_cast.hpp" -#include -#include -#include "ff.h" -#include "sparse_vector.h" -#include "tdict.h" -#include "hg.h" -#include "ff_fsa_data.h" - -/* -usage: see ff_sample_fsa.h or ff_lm_fsa.h - - then, to decode, see ff_from_fsa.h (or TODO: left->right target-earley style rescoring) - - */ - - -template -struct FsaFeatureFunctionBase : public FsaFeatureFunctionData { - Impl const& d() const { return static_cast(*this); } - Impl & d() { return static_cast(*this); } - - // this will get called by factory - override if you have multiple or dynamically named features. note: may be called repeatedly - void Init() { - Init(name()); - DBGINIT("base (single feature) FsaFeatureFunctionBase::Init name="<set_value(fid,val) possibly with duplicates. state and next_state will never be the same memory. - //TODO: decide if we want to require you to support dest same as src, since that's how we use it most often in ff_from_fsa bottom-up wrapper (in l->r scoring, however, distinct copies will be the rule), and it probably wouldn't be too hard for most people to support. however, it's good to hide the complexity here, once (see overly clever FsaScan loop that swaps src/dest addresses repeatedly to scan a sequence by effectively swapping) - -protected: - // overrides have different name because of inheritance method hiding; - - // simple/common case; 1 fid. these need not be overriden if you have multiple feature ids - Featval Scan1(WordID w,void const* state,void *next_state) const { - assert(0); - return 0; - } - Featval Scan1Meta(SentenceMetadata const& /* smeta */,Hypergraph::Edge const& /* edge */, - WordID w,void const* state,void *next_state) const { - return d().Scan1(w,state,next_state); - } -public: - - // must override this or Scan1Meta or Scan1 - template - inline void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID w,void const* state,void *next_state,Accum *a) const { - Add(d().Scan1Meta(smeta,edge,w,state,next_state),a); - } - - // bounce back and forth between two state vars starting at cs, returning end state location. if we required src=dest addr safe state updating, this concept wouldn't need to exist. - // required that you override this if you score phrases differently than word-by-word, however, you can just use the SCAN_PHRASE_ACCUM_OVERRIDE macro to do that in terms of ScanPhraseAccum - template - void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *accum) const { - // extra code - IT'S FOR EFFICIENCY, MAN! IT'S OK! definitely no bugs here. - if (!ssz) { - for (;io - odd: - d().ScanAccum(smeta,edge,i[0],os,es,accum); // o->e - } - return es; - } - - - static const bool simple_phrase_score=true; // if d().simple_phrase_score_, then you should expect different Phrase scores for phrase length > M. so, set this false if you provide ScanPhraseAccum (SCAN_PHRASE_ACCUM_OVERRIDE macro does this) - - // override this (and use SCAN_PHRASE_ACCUM_OVERRIDE ) if you want e.g. maximum possible order ngram scores with markov_order < n-1. in the future SparseFeatureAccumulator will probably be the only option for type-erased FSA ffs. - // note you'll still have to override ScanAccum - template - void ScanPhraseAccum(SentenceMetadata const& smeta,Hypergraph::Edge const & edge, - WordID const* i, WordID const* end, - void const* state,void *next_state,Accum *accum) const { - if (!ssz) { - for (;i \ - void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *accum) const { \ - ScanPhraseAccum(smeta,edge,i,end,cs,ns,accum); \ - return ns; \ - } \ - template \ - void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, \ - WordID const* i, WordID const* end, \ - void const* state,Accum *accum) const { \ - char s2[ssz]; ScanPhraseAccum(smeta,edge,i,end,state,(void*)s2,accum); \ - } - - // override this or bounce along with above. note: you can just call ScanPhraseAccum - // doesn't set state (for heuristic in ff_from_fsa) - template - void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID const* i, WordID const* end, - void const* state,Accum *accum) const { - char s1[ssz]; - char s2[ssz]; - state_copy(s1,state); - d().ScanPhraseAccumBounce(smeta,edge,i,end,(void*)s1,(void*)s2,accum); - } - - // for single-feat only. but will work for different accums - template - inline void Add(Featval v,Accum *a) const { - a->Add(fid_,v); - } - inline void set_feat(FeatureVector *features,Featval v) const { - features->set_value(fid_,v); - } - - // don't set state-bytes etc. in ctor because it may depend on parsing param string - FsaFeatureFunctionBase(int statesz=0,Sentence const& end_sentence_phrase=Sentence()) - : FsaFeatureFunctionData(statesz,end_sentence_phrase) - { - name_=name(); // should allow FsaDynamic wrapper to get name copied to it with sync - } - -}; - -template -struct MultipleFeatureFsa : public FsaFeatureFunctionBase { - typedef SparseFeatureAccumulator Accum; -}; - - - - -// if State is pod. sets state size and allocs start, h_start -// usage: -// struct ShorterThanPrev : public FsaTypedBase -// i.e. Impl is a CRTP -template -struct FsaTypedBase : public FsaFeatureFunctionBase { - Impl const& d() const { return static_cast(*this); } - Impl & d() { return static_cast(*this); } -protected: - typedef FsaFeatureFunctionBase Base; - typedef St State; - static inline State & state(void *state) { - return *(State*)state; - } - static inline State const& state(void const* state) { - return *(State const*)state; - } - void set_starts(State const& s,State const& heuristic_s) { - if (0) { // already in ctor - Base::start.resize(sizeof(State)); - Base::h_start.resize(sizeof(State)); - } - assert(Base::start.size()==sizeof(State)); - assert(Base::h_start.size()==sizeof(State)); - state(Base::start.begin())=s; - state(Base::h_start.begin())=heuristic_s; - } - FsaTypedBase(St const& start_st=St() - ,St const& h_start_st=St() - ,Sentence const& end_sentence_phrase=Sentence()) - : Base(sizeof(State),end_sentence_phrase) { - set_starts(start_st,h_start_st); - } -public: - void print_state(std::ostream &o,void const*st) const { - o< - inline void ScanT(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID w,St const& prev_st,St &new_st,Accum *a) const { - Add(d().ScanT1(smeta,edge,w,prev_st,new_st),a); - } - - // note: you're on your own when it comes to Phrase overrides. see FsaFeatureFunctionBase. sorry. - - template - inline void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID w,void const* st,void *next_state,Accum *a) const { - Impl const& im=d(); - FSADBG(edge,"Scan "<describe(im)<<" "<"< -struct FsaScanner { -// enum {ALIGN=8}; - static const int ALIGN=8; - FF const& ff; - SentenceMetadata const& smeta; - int ssz; - Bytes states; // first is at begin, second is at (char*)begin+stride - void *st0; // states - void *st1; // states+stride - void *cs; // initially st0, alternates between st0 and st1 - inline void *nexts() const { - return (cs==st0)?st1:st0; - } - Hypergraph::Edge const& edge; - FsaScanner(FF const& ff,SentenceMetadata const& smeta,Hypergraph::Edge const& edge) : ff(ff),smeta(smeta),edge(edge) - { - ssz=ff.state_bytes(); - int stride=((ssz+ALIGN-1)/ALIGN)*ALIGN; // round up to multiple of ALIGN - states.resize(stride+ssz); - st0=states.begin(); - st1=(char*)st0+stride; -// for (int i=0;i<2;++i) st[i]=cs+(i*stride); - } - void reset(void const* state) { - cs=st0; - std::memcpy(st0,state,ssz); - } - template - void scan(WordID w,Accum *a) { - void *ns=nexts(); - ff.ScanAccum(smeta,edge,w,cs,ns,a); - cs=ns; - } - template - void scan(WordID const* i,WordID const* end,Accum *a) { - // faster. and allows greater-order excursions - cs=ff.ScanPhraseAccumBounce(smeta,edge,i,end,cs,nexts(),a); - } -}; - - -//TODO: combine 2 FsaFeatures typelist style (can recurse for more) - - - - -#endif diff --git a/decoder/ff_fsa_data.h b/decoder/ff_fsa_data.h deleted file mode 100755 index d215e940..00000000 --- a/decoder/ff_fsa_data.h +++ /dev/null @@ -1,131 +0,0 @@ -#ifndef FF_FSA_DATA_H -#define FF_FSA_DATA_H - -#include //C99 -#include -#include "sentences.h" -#include "feature_accum.h" -#include "value_array.h" -#include "ff.h" //debug -typedef ValueArray Bytes; - -// stuff I see no reason to have virtual. but because it's impossible (w/o virtual inheritance to have dynamic fsa ff know where the impl's data starts, implemented a sync (copy) method that needs to be called. init_name_debug was already necessary to keep state in sync between ff and ff_from_fsa, so no sync should be needed after it. supposing all modifications were through setters, then no explicit sync call would ever be needed; updates could be mirrored. -struct FsaFeatureFunctionData -{ - void init_name_debug(std::string const& n,bool debug) { - name_=n; - debug_=debug; - } - //HACK for diamond inheritance (w/o costing performance) - FsaFeatureFunctionData *sync_to_; - - void sync() const { // call this if you modify any fields after your constructor is done - if (sync_to_) { - DBGINIT("sync to "<<*sync_to_); - *sync_to_=*this; - DBGINIT("synced result="<<*sync_to_<< " from this="<<*this); - } else { - DBGINIT("nobody to sync to - from FeatureFunctionData this="<<*this); - } - } - - friend std::ostream &operator<<(std::ostream &o,FsaFeatureFunctionData const& d) { - o << "[FSA "< - static inline T* state_as(void *p) { return (T*)p; } - template - static inline T const* state_as(void const* p) { return (T*)p; } - std::string describe_features(FeatureVector const& feats) { - std::ostringstream o; - o<" for lm. -protected: - int ssz; // don't forget to set this. default 0 (it may depend on params of course) - // this can be called instead or after constructor (also set bytes and end_phrase_) - void set_state_bytes(int sb=0) { - if (start.size()!=sb) start.resize(sb); - if (h_start.size()!=sb) h_start.resize(sb); - ssz=sb; - } - void set_end_phrase(WordID single) { - end_phrase_=singleton_sentence(single); - } - - inline void static to_state(void *state,char const* begin,char const* end) { - std::memcpy(state,begin,end-begin); - } - inline void static to_state(void *state,char const* begin,int n) { - std::memcpy(state,begin,n); - } - template - inline void static to_state(void *state,T const* begin,int n=1) { - to_state(state,(char const*)begin,n*sizeof(T)); - } - template - inline void static to_state(void *state,T const* begin,T const* end) { - to_state(state,(char const*)begin,(char const*)end); - } - inline static char hexdigit(int i) { - int j=i-10; - return j>=0?'a'+j:'0'+i; - } - inline static void print_hex_byte(std::ostream &o,unsigned c) { - o<>4); - o<Add(v); - } - -}; - -#endif diff --git a/decoder/ff_fsa_dynamic.h b/decoder/ff_fsa_dynamic.h deleted file mode 100755 index 6f75bbe5..00000000 --- a/decoder/ff_fsa_dynamic.h +++ /dev/null @@ -1,208 +0,0 @@ -#ifndef FF_FSA_DYNAMIC_H -#define FF_FSA_DYNAMIC_H - -struct SentenceMetadata; - -#include "ff_fsa_data.h" -#include "hg.h" // can't forward declare nested Hypergraph::Edge class -#include - -// the type-erased interface - -//FIXME: diamond inheritance problem. make a copy of the fixed data? or else make the dynamic version not wrap but rather be templated CRTP base (yuck) -struct FsaFeatureFunction : public FsaFeatureFunctionData { - static const bool simple_phrase_score=false; - virtual int markov_order() const = 0; - - // see ff_fsa.h - FsaFeatureFunctionBase gives you reasonable impls of these if you override just ScanAccum - virtual void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID w,void const* state,void *next_state,Accum *a) const = 0; - virtual void ScanPhraseAccum(SentenceMetadata const& smeta,Hypergraph::Edge const & edge, - WordID const* i, WordID const* end, - void const* state,void *next_state,Accum *accum) const = 0; - virtual void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID const* i, WordID const* end, - void const* state,Accum *accum) const = 0; - virtual void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *accum) const = 0; - - virtual int early_score_words(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,Accum *accum) const { return 0; } - // called after constructor, before use - virtual void Init() = 0; - virtual std::string usage_v(bool param,bool verbose) const { - return FeatureFunction::usage_helper("unnamed_dynamic_fsa_feature","","",param,verbose); - } - virtual void init_name_debug(std::string const& n,bool debug) { - FsaFeatureFunctionData::init_name_debug(n,debug); - } - - virtual void print_state(std::ostream &o,void const*state) const { - FsaFeatureFunctionData::print_state(o,state); - } - virtual std::string describe() const { return "[FSA unnamed_dynamic_fsa_feature]"; } - - //end_phrase() - virtual ~FsaFeatureFunction() {} - - // no need to override: - std::string describe_state(void const* state) const { - std::ostringstream o; - print_state(o,state); - return o.str(); - } -}; - -// conforming to above interface, type erases FsaImpl -// you might be wondering: why do this? answer: it's cool, and it means that the bottom-up ff over ff_fsa wrapper doesn't go through multiple layers of dynamic dispatch -// usage: typedef FsaFeatureFunctionDynamic MyFsaDyn; -template -struct FsaFeatureFunctionDynamic : public FsaFeatureFunction { - static const bool simple_phrase_score=Impl::simple_phrase_score; - Impl& d() { return impl;//static_cast(*this); - } - Impl const& d() const { return impl; - //static_cast(*this); - } - int markov_order() const { return d().markov_order(); } - - std::string describe() const { - return d().describe(); - } - - virtual void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID w,void const* state,void *next_state,Accum *a) const { - return d().ScanAccum(smeta,edge,w,state,next_state,a); - } - - virtual void ScanPhraseAccum(SentenceMetadata const& smeta,Hypergraph::Edge const & edge, - WordID const* i, WordID const* end, - void const* state,void *next_state,Accum *a) const { - return d().ScanPhraseAccum(smeta,edge,i,end,state,next_state,a); - } - - virtual void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID const* i, WordID const* end, - void const* state,Accum *a) const { - return d().ScanPhraseAccumOnly(smeta,edge,i,end,state,a); - } - - virtual void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *a) const { - return d().ScanPhraseAccumBounce(smeta,edge,i,end,cs,ns,a); - } - - virtual int early_score_words(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,Accum *accum) const { - return d().early_score_words(smeta,edge,i,end,accum); - } - - static std::string usage(bool param,bool verbose) { - return Impl::usage(param,verbose); - } - - std::string usage_v(bool param,bool verbose) const { - return Impl::usage(param,verbose); - } - - virtual void print_state(std::ostream &o,void const*state) const { - return d().print_state(o,state); - } - - void init_name_debug(std::string const& n,bool debug) { - FsaFeatureFunction::init_name_debug(n,debug); - d().init_name_debug(n,debug); - } - - virtual void Init() { - d().sync_to_=(FsaFeatureFunctionData*)this; - d().Init(); - d().sync(); - } - - template - FsaFeatureFunctionDynamic(I const& param) : impl(param) { - Init(); - } -private: - Impl impl; -}; - -// constructor takes ptr or shared_ptr to Impl, otherwise same as above - note: not virtual -template -struct FsaFeatureFunctionPimpl : public FsaFeatureFunctionData { - typedef boost::shared_ptr Pimpl; - static const bool simple_phrase_score=Impl::simple_phrase_score; - Impl const& d() const { return *p_; } - int markov_order() const { return d().markov_order(); } - - std::string describe() const { - return d().describe(); - } - - void ScanAccum(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID w,void const* state,void *next_state,Accum *a) const { - return d().ScanAccum(smeta,edge,w,state,next_state,a); - } - - void ScanPhraseAccum(SentenceMetadata const& smeta,Hypergraph::Edge const & edge, - WordID const* i, WordID const* end, - void const* state,void *next_state,Accum *a) const { - return d().ScanPhraseAccum(smeta,edge,i,end,state,next_state,a); - } - - void ScanPhraseAccumOnly(SentenceMetadata const& smeta,Hypergraph::Edge const& edge, - WordID const* i, WordID const* end, - void const* state,Accum *a) const { - return d().ScanPhraseAccumOnly(smeta,edge,i,end,state,a); - } - - void *ScanPhraseAccumBounce(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,void *cs,void *ns,Accum *a) const { - return d().ScanPhraseAccumBounce(smeta,edge,i,end,cs,ns,a); - } - - int early_score_words(SentenceMetadata const& smeta,Hypergraph::Edge const& edge,WordID const* i, WordID const* end,Accum *accum) const { - return d().early_score_words(smeta,edge,i,end,accum); - } - - static std::string usage(bool param,bool verbose) { - return Impl::usage(param,verbose); - } - - std::string usage_v(bool param,bool verbose) const { - return Impl::usage(param,verbose); - } - - void print_state(std::ostream &o,void const*state) const { - return d().print_state(o,state); - } - -#if 0 - // this and Init() don't touch p_ because we want to leave the original alone. - void init_name_debug(std::string const& n,bool debug) { - FsaFeatureFunctionData::init_name_debug(n,debug); - } -#endif - void Init() { - p_=hold_pimpl_.get(); -#if 0 - d().sync_to_=static_cast(this); - d().Init(); -#endif - *static_cast(this)=d(); - } - - FsaFeatureFunctionPimpl(Impl const* const p) : hold_pimpl_(p,null_deleter()) { - Init(); - } - FsaFeatureFunctionPimpl(Pimpl const& p) : hold_pimpl_(p) { - Init(); - } -private: - Impl const* p_; - Pimpl hold_pimpl_; -}; - -typedef FsaFeatureFunctionPimpl FsaFeatureFunctionFwd; // allow ff_from_fsa for an existing dynamic-type ff (as opposed to usual register a wrapped known-type FSA in ff_register, which is more efficient) -//typedef FsaFeatureFunctionDynamic DynamicFsaFeatureFunctionFwd; //if you really need to have a dynamic fsa facade that's also a dynamic fsa - -//TODO: combine 2 (or N) FsaFeatureFunction (type erased) - - -#endif diff --git a/decoder/ff_lm.cc b/decoder/ff_lm.cc index afa36b96..5e16d4e3 100644 --- a/decoder/ff_lm.cc +++ b/decoder/ff_lm.cc @@ -46,7 +46,6 @@ char const* usage_verbose="-n determines the name of the feature (and its weight #endif #include "ff_lm.h" -#include "ff_lm_fsa.h" #include #include @@ -69,10 +68,6 @@ char const* usage_verbose="-n determines the name of the feature (and its weight using namespace std; -string LanguageModelFsa::usage(bool param,bool verbose) { - return FeatureFunction::usage_helper("LanguageModelFsa",usage_short,usage_verbose,param,verbose); -} - string LanguageModel::usage(bool param,bool verbose) { return FeatureFunction::usage_helper(usage_name,usage_short,usage_verbose,param,verbose); } @@ -524,49 +519,6 @@ LanguageModel::LanguageModel(const string& param) { SetStateSize(LanguageModelImpl::OrderToStateSize(order)); } -//TODO: decide whether to waste a word of space so states are always none-terminated for SRILM. otherwise we have to copy -void LanguageModelFsa::set_ngram_order(int i) { - assert(i>0); - ngram_order_=i; - ctxlen_=i-1; - set_state_bytes(ctxlen_*sizeof(WordID)); - WordID *ss=(WordID*)start.begin(); - WordID *hs=(WordID*)h_start.begin(); - if (ctxlen_) { // avoid segfault in case of unigram lm (0 state) - set_end_phrase(TD::Convert("")); -// se is pretty boring in unigram case, just adds constant prob. check that this is what we want - ss[0]=TD::Convert(""); // start-sentence context (length 1) - hs[0]=0; // empty context - for (int i=1;ifloor_; - set_ngram_order(lmorder); -} - -void LanguageModelFsa::print_state(ostream &o,void const* st) const { - WordID const *wst=(WordID const*)st; - o<<'['; - bool sp=false; - for (int i=ctxlen_;i>0;sp=true) { - --i; - WordID w=wst[i]; - if (w==0) continue; - if (sp) o<<' '; - o << TD::Convert(w); - } - o<<']'; -} - Features LanguageModel::features() const { return single_feature(fid_); } diff --git a/decoder/ff_lm_fsa.h b/decoder/ff_lm_fsa.h deleted file mode 100755 index 85b7ef44..00000000 --- a/decoder/ff_lm_fsa.h +++ /dev/null @@ -1,140 +0,0 @@ -#ifndef FF_LM_FSA_H -#define FF_LM_FSA_H - -//FIXME: when FSA_LM_PHRASE 1, 3gram fsa has differences, especially with unk words, in about the 4th decimal digit (about .05%), compared to regular ff_lm. this is USUALLY a bug (there's way more actual precision in there). this was with #define LM_FSA_SHORTEN_CONTEXT 1 and 0 (so it's not that). also, LM_FSA_SHORTEN_CONTEXT gives identical scores with FSA_LM_PHRASE 0 - -// enabling for now - retest unigram+ more, solve above puzzle - -// some impls in ff_lm.cc - -#define FSA_LM_PHRASE 1 - -#define FSA_LM_DEBUG 0 -#if FSA_LM_DEBUG -# define FSALMDBG(e,x) FSADBGif(debug(),e,x) -# define FSALMDBGnl(e) FSADBGif_nl(debug(),e) -#else -# define FSALMDBG(e,x) -# define FSALMDBGnl(e) -#endif - -#include "ff_fsa.h" -#include "ff_lm.h" - -#ifndef TD__none -// replacing dependency on SRILM -#define TD__none -1 -#endif - -namespace { -WordID empty_context=TD__none; -} - -struct LanguageModelFsa : public FsaFeatureFunctionBase { - typedef WordID * W; - typedef WordID const* WP; - - // overrides; implementations in ff_lm.cc - typedef SingleFeatureAccumulator Accum; - static std::string usage(bool,bool); - LanguageModelFsa(std::string const& param); - int markov_order() const { return ctxlen_; } - void print_state(std::ostream &,void const *) const; - inline Featval floored(Featval p) const { - return pleft;--e) - if (e[-1]!=TD__none) break; - //post: [left,e] are the seen left words - return e; - } - - template - void ScanAccum(SentenceMetadata const& /* smeta */,Hypergraph::Edge const& edge,WordID w,void const* old_st,void *new_st,Accum *a) const { -#if USE_INFO_EDGE - Hypergraph::Edge &de=(Hypergraph::Edge &)edge; -#endif - if (!ctxlen_) { - Add(floored(pimpl_->WordProb(w,&empty_context)),a); - } else { - WordID ctx[ngram_order_]; //alloca if you don't have C99 - state_copy(ctx,old_st); - ctx[ctxlen_]=TD__none; - Featval p=floored(pimpl_->WordProb(w,ctx)); - FSALMDBG(de,"p("<ShortenContext(nst,ctxlen_); -#endif - Add(p,a); - } - } - -#if FSA_LM_PHRASE - //FIXME: there is a bug in here somewhere, or else the 3gram LM we use gives different scores for phrases (impossible? BOW nonzero when shortening context past what LM has?) - template - void ScanPhraseAccum(SentenceMetadata const& /* smeta */,const Hypergraph::Edge&edge,WordID const* begin,WordID const* end,void const* old_st,void *new_st,Accum *a) const { - Hypergraph::Edge &de=(Hypergraph::Edge &)edge;(void)de; - if (begin==end) return; // otherwise w/ shortening it's possible to end up with no words at all. - /* // this is forcing unigram prob always. we will instead build the phrase - if (!ctxlen_) { - Featval p=0; - for (;iWordProb(*i,e&mpty_context)); - Add(p,a); - return; - } */ - int nw=end-begin; - WP st=(WP)old_st; - WP st_end=st+ctxlen_; // may include some null already (or none if full) - int nboth=nw+ctxlen_; - WordID ctx[nboth+1]; - ctx[nboth]=TD__none; - // reverse order - state at very end of context, then [i,end) in rev order ending at ctx[0] - W ctx_score_end=wordcpy_reverse(ctx,begin,end); - wordcpy(ctx_score_end,st,st_end); // st already reversed. - assert(ctx_score_end==ctx+nw); - // we could just copy the filled state words, but it probably doesn't save much time (and might cost some to scan to find the nones. most contexts are full except for the shortest source spans. - FSALMDBG(de," scan.r->l("<ctx;--ctx_score_end) - p+=floored(pimpl_->WordProb(ctx_score_end[-1],ctx_score_end)); - //TODO: look for score discrepancy - - // i had some idea that maybe shortencontext would return a different prob if the length provided was > ctxlen_; however, since the same disagreement happens with LM_FSA_SHORTEN_CONTEXT 0 anyway, it's not that. perhaps look to SCAN_PHRASE_ACCUM_OVERRIDE - make sure they do the right thing. -#if LM_FSA_SHORTEN_CONTEXT - p+=pimpl_->ShortenContext(ctx,nboth - need to use factory rather than ctor. -#if 0 -template -inline void RegisterFsa(bool ff_also=true,bool fsa_prefix_ff=true) { - assert(!ff_also); -// global_fsa_ff_registry->RegisterFsa(); -//if (ff_also) ff_registry.RegisterFF >(prefix_fsa(DynFsa::usage(false,false)),fsa_prefix_ff); -} -#endif - -//TODO: ff from fsa that uses pointer to fsa impl? e.g. in LanguageModel we share underlying lm file by recognizing same param, but without that effort, otherwise stateful ff may duplicate state if we enable both fsa and ff_from_fsa -template -inline void RegisterFsaImpl(bool ff_also=true,bool fsa_prefix_ff=false) { - typedef FsaFeatureFunctionDynamic DynFsa; - typedef FeatureFunctionFromFsa FFFrom; - std::string name=FsaImpl::usage(false,false); - fsa_ff_registry.Register(new FsaFactory); - if (ff_also) - ff_registry.Register(prefix_fsa(name,fsa_prefix_ff),new FFFactory); -} template inline void RegisterFF() { ff_registry.Register(new FFFactory); } -template -inline void RegisterFsaDynToFF(std::string name,bool prefix=true) { - typedef FsaFeatureFunctionDynamic DynFsa; - ff_registry.Register(prefix?"DynamicFsa"+name:name,new FFFactory >); -} - -template -inline void RegisterFsaDynToFF(bool prefix=true) { - RegisterFsaDynToFF(FsaImpl::usage(false,false),prefix); -} - void register_feature_functions(); #endif diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 3be5b82d..5d1910fb 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -57,7 +57,7 @@ TEST_F(HGTest,Union) { c3 = ViterbiESentence(hg1, &t3); int l3 = ViterbiPathLength(hg1); cerr << c3 << "\t" << TD::GetString(t3) << endl; - EXPECT_FLOAT_EQ(c2, c3); + EXPECT_FLOAT_EQ(c2.as_float(), c3.as_float()); EXPECT_EQ(TD::GetString(t2), TD::GetString(t3)); EXPECT_EQ(l2, l3); @@ -117,7 +117,7 @@ TEST_F(HGTest,InsideScore) { cerr << "cost: " << cost << "\n"; hg.PrintGraphviz(); prob_t inside = Inside(hg); - EXPECT_FLOAT_EQ(1.7934048, inside); // computed by hand + EXPECT_FLOAT_EQ(1.7934048, inside.as_float()); // computed by hand vector post; inside = hg.ComputeBestPathThroughEdges(&post); EXPECT_FLOAT_EQ(-0.3, log(inside)); // computed by hand @@ -282,13 +282,13 @@ TEST_F(HGTest, TestGenericInside) { hg.Reweight(wts); vector inside; prob_t ins = Inside(hg, &inside); - EXPECT_FLOAT_EQ(1.7934048, ins); // computed by hand + EXPECT_FLOAT_EQ(1.7934048, ins.as_float()); // computed by hand vector outside; Outside(hg, inside, &outside); EXPECT_EQ(3, outside.size()); - EXPECT_FLOAT_EQ(1.7934048, outside[0]); - EXPECT_FLOAT_EQ(1.3114071, outside[1]); - EXPECT_FLOAT_EQ(1.0, outside[2]); + EXPECT_FLOAT_EQ(1.7934048, outside[0].as_float()); + EXPECT_FLOAT_EQ(1.3114071, outside[1].as_float()); + EXPECT_FLOAT_EQ(1.0, outside[2].as_float()); } TEST_F(HGTest,TestGenericInside2) { @@ -327,8 +327,8 @@ TEST_F(HGTest,TestAddExpectations) { SparseVector feat_exps; prob_t z = InsideOutside, EdgeFeaturesAndProbWeightFunction>(hg, &feat_exps); - EXPECT_FLOAT_EQ(-2.5439765, feat_exps.value(FD::Convert("f1")) / z); - EXPECT_FLOAT_EQ(-2.6357865, feat_exps.value(FD::Convert("f2")) / z); + EXPECT_FLOAT_EQ(-2.5439765, (feat_exps.value(FD::Convert("f1")) / z).as_float()); + EXPECT_FLOAT_EQ(-2.6357865, (feat_exps.value(FD::Convert("f2")) / z).as_float()); cerr << feat_exps << endl; cerr << "Z=" << z << endl; } diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index f87b7274..993627f0 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -9,6 +9,7 @@ #include #include +#include "stringlib.h" #include "verbose.h" #include "hg.h" #include "prob.h" @@ -204,6 +205,7 @@ bool LoadAgenda(const string& file, vector >* a) { } int main(int argc, char** argv) { + cerr << "THIS SOFTWARE IS DEPRECATED YOU SHOULD USE mpi_flex_optimize\n"; #ifdef HAVE_MPI mpi::environment env(argc, argv); mpi::communicator world; -- cgit v1.2.3 From eb4b8a6ca070794db1a01b04570e9aaf346881ae Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 14 Oct 2011 11:53:48 +0100 Subject: one more to remove --- decoder/cdec-fsa.ini | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100755 decoder/cdec-fsa.ini (limited to 'decoder') diff --git a/decoder/cdec-fsa.ini b/decoder/cdec-fsa.ini deleted file mode 100755 index 05aaefd4..00000000 --- a/decoder/cdec-fsa.ini +++ /dev/null @@ -1,10 +0,0 @@ -cubepruning_pop_limit=200 -feature_function=WordPenalty -feature_function=ArityPenalty -feature_function=WordPenaltyFsa -#feature_function=LongerThanPrev -feature_function=ShorterThanPrev debug -add_pass_through_rules=true -formalism=scfg -grammar=mt09.grammar.gz -weights=weights-fsa -- cgit v1.2.3 From 957d90991b4ec80b9877126c736bd60768b094aa Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 17 Oct 2011 16:58:26 +0100 Subject: Chris, I'd like you to review this for use with your rules that contain and . --- decoder/ff_klm.cc | 72 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 23 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 6d9aca54..658aef80 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -71,6 +71,8 @@ string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; } +namespace { + struct VMapper : public lm::ngram::EnumerateVocab { VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } void Add(lm::WordIndex index, const StringPiece &str) { @@ -83,66 +85,90 @@ struct VMapper : public lm::ngram::EnumerateVocab { const lm::WordIndex kLM_UNKNOWN_TOKEN; }; -template -class KLanguageModelImpl { +#pragma pack(push) +#pragma pack(1) - static inline const lm::ngram::ChartState& RemnantLMState(const void* state) { - return *static_cast(state); +struct BoundaryAnnotatedState { + lm::ngram::ChartState state; + bool seen_bos, seen_eos; +}; + +#pragma pack(pop) + +void BoundaryCheck(bool &annotated, bool sub, double &ret) { + if (!sub) return; + if (annotated) { + ret -= 100.0; + } else { + annotated = true; } +} +} // namespace + +template +class KLanguageModelImpl { public: double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { *oovs = 0; const vector& e = rule.e(); - lm::ngram::RuleScore ruleScore(*ngram_, *static_cast(remnant)); + BoundaryAnnotatedState &annotated = *static_cast(remnant); + lm::ngram::RuleScore ruleScore(*ngram_, annotated.state); + annotated.seen_bos = false; + annotated.seen_eos = false; unsigned i = 0; + double ret = 0.0; if (e.size()) { if (e[i] == kCDEC_SOS) { ++i; ruleScore.BeginSentence(); + annotated.seen_bos = true; } else if (e[i] <= 0) { // special case for left-edge NT - const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]); - ruleScore.BeginNonTerminal(prevState, 0.0f); // TODO + const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[0]]); + ruleScore.BeginNonTerminal(sub.state, 0.0f); + annotated.seen_bos = sub.seen_bos; + annotated.seen_eos = sub.seen_eos; ++i; } } for (; i < e.size(); ++i) { if (e[i] <= 0) { - const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]); - ruleScore.NonTerminal(prevState, 0.0f); // TODO + const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[i]]); + ruleScore.NonTerminal(sub.state, 0.0f); + BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret); + BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret); } else { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id if (cur_word == 0) (*oovs) += 1.0; + BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret); ruleScore.Terminal(cur_word); } } - double ret = ruleScore.Finish(); - static_cast(remnant)->ZeroRemaining(); + ret += ruleScore.Finish(); + annotated.state.ZeroRemaining(); return ret; } // this assumes no target words on final unary -> goal rule. is that ok? // for (n-1 left words) and (n-1 right words) - double FinalTraversalCost(const void* state, double* oovs) { + double FinalTraversalCost(const void* state_void, double* oovs) { + const BoundaryAnnotatedState &annotated = *static_cast(state_void); if (add_sos_eos_) { // rules do not produce , so do it here + assert(!annotated.seen_bos); + assert(!annotated.seen_eos); lm::ngram::ChartState cstate; lm::ngram::RuleScore ruleScore(*ngram_, cstate); ruleScore.BeginSentence(); - ruleScore.NonTerminal(RemnantLMState(state), 0.0f); + ruleScore.NonTerminal(annotated.state, 0.0f); ruleScore.Terminal(kEOS_); return ruleScore.Finish(); } else { // rules DO produce ... - double p = 0; - cerr << "not implemented"; abort(); // TODO - //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } - //if (UnscoredSize(state) > 0) { // are there unscored words - // if (kSOS_ != IthUnscoredWord(0, state)) { - // p -= 100 * UnscoredSize(state); - // } - //} - return p; + double ret = 0.0; + if (!annotated.seen_bos) ret -= 100.0; + if (!annotated.seen_eos) ret -= 100.0; + return ret; } } @@ -230,7 +256,7 @@ class KLanguageModelImpl { delete ngram_; } - int ReserveStateSize() const { return sizeof(lm::ngram::ChartState); } + int ReserveStateSize() const { return sizeof(BoundaryAnnotatedState); } private: const WordID kCDEC_UNK; -- cgit v1.2.3 From 3d1ed02a4e5d81aace80b0e004e96351d116630f Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 18 Oct 2011 10:25:56 +0100 Subject: Revised and handling --- decoder/ff_klm.cc | 84 ++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 26 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 658aef80..3c941fbf 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,8 +12,8 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" +#define NEW_KENLM #undef NEW_KENLM -#ifdef NEW_KENLM #include "lm/left.hh" @@ -95,14 +95,58 @@ struct BoundaryAnnotatedState { #pragma pack(pop) -void BoundaryCheck(bool &annotated, bool sub, double &ret) { - if (!sub) return; - if (annotated) { - ret -= 100.0; - } else { - annotated = true; - } -} +template class BoundaryRuleScore { + public: + BoundaryRuleScore(const Model &m, BoundaryAnnotatedState &state) : + back_(m, state.state), + bos_(state.seen_bos), + eos_(state.seen_eos), + penalty_(0.0), + end_sentence_(m.GetVocabulary().EndSentence()) { + bos_ = false; + eos_ = false; + } + + void BeginSentence() { + back_.BeginSentence(); + bos_ = true; + } + + void BeginNonTerminal(const BoundaryAnnotatedState &sub) { + back_.BeginNonTerminal(sub.state, 0.0f); + bos_ = sub.seen_bos; + eos_ = sub.seen_eos; + } + + void NonTerminal(const BoundaryAnnotatedState &sub) { + back_.NonTerminal(sub.state, 0.0f); + // cdec only calls this if there's content. + if (sub.seen_bos) { + bos_ = true; + penalty_ -= 100.0f; + } + if (eos_) penalty_ -= 100.0f; + eos_ |= sub.seen_eos; + } + + void Terminal(lm::WordIndex word) { + back_.Terminal(word); + if (eos_) penalty_ -= 100.0f; + if (word == end_sentence_) eos_ = true; + } + + float Finish() { + return penalty_ + back_.Finish(); + } + + private: + lm::ngram::RuleScore back_; + bool &bos_, &eos_; + + float penalty_; + + lm::WordIndex end_sentence_; +}; } // namespace @@ -112,42 +156,30 @@ class KLanguageModelImpl { double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { *oovs = 0; const vector& e = rule.e(); - BoundaryAnnotatedState &annotated = *static_cast(remnant); - lm::ngram::RuleScore ruleScore(*ngram_, annotated.state); - annotated.seen_bos = false; - annotated.seen_eos = false; + BoundaryRuleScore ruleScore(*ngram_, *static_cast(remnant)); unsigned i = 0; - double ret = 0.0; if (e.size()) { if (e[i] == kCDEC_SOS) { ++i; ruleScore.BeginSentence(); - annotated.seen_bos = true; } else if (e[i] <= 0) { // special case for left-edge NT - const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[0]]); - ruleScore.BeginNonTerminal(sub.state, 0.0f); - annotated.seen_bos = sub.seen_bos; - annotated.seen_eos = sub.seen_eos; + ruleScore.BeginNonTerminal(*static_cast(ant_states[-e[0]])); ++i; } } for (; i < e.size(); ++i) { if (e[i] <= 0) { - const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[i]]); - ruleScore.NonTerminal(sub.state, 0.0f); - BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret); - BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret); + ruleScore.NonTerminal(*static_cast(ant_states[-e[i]])); } else { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id if (cur_word == 0) (*oovs) += 1.0; - BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret); ruleScore.Terminal(cur_word); } } - ret += ruleScore.Finish(); - annotated.state.ZeroRemaining(); + double ret = ruleScore.Finish(); + static_cast(remnant)->state.ZeroRemaining(); return ret; } -- cgit v1.2.3 From 04e38a57b19ea012895ac2efb39382c2e77833a9 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 18 Oct 2011 14:19:09 +0100 Subject: incorporate kenneth's fixes --- decoder/ff_klm.cc | 464 ------------------------------------------------------ 1 file changed, 464 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 3c941fbf..ed6f731e 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,9 +12,6 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" -#define NEW_KENLM -#undef NEW_KENLM - #include "lm/left.hh" using namespace std; @@ -395,464 +392,3 @@ std::string KLanguageModelFactory::usage(bool params,bool verbose) const { return KLanguageModel::usage(params, verbose); } -#else - -using namespace std; - -static const unsigned char HAS_FULL_CONTEXT = 1; -static const unsigned char HAS_EOS_ON_RIGHT = 2; -static const unsigned char MASK = 7; - -// -x : rules include and -// -n NAME : feature id is NAME -bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) { - vector const& argv=SplitOnWhitespace(in); - *explicit_markers = false; - *featname="LanguageModel"; - *mapfile = ""; -#define LMSPEC_NEXTARG if (i==argv.end()) { \ - cerr << "Missing argument for "<<*last<<". "; goto usage; \ - } else { ++i; } - - for (vector::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { - string const& s=*i; - if (s[0]=='-') { - if (s.size()>2) goto fail; - switch (s[1]) { - case 'x': - *explicit_markers = true; - break; - case 'm': - LMSPEC_NEXTARG; *mapfile=*i; - break; - case 'n': - LMSPEC_NEXTARG; *featname=*i; - break; -#undef LMSPEC_NEXTARG - default: - fail: - cerr<<"Unknown KLanguageModel option "<empty()) - *filename=s; - else { - cerr<<"More than one filename provided. "; - goto usage; - } - } - } - if (!filename->empty()) - return true; -usage: - cerr << "KLanguageModel is incorrect!\n"; - return false; -} - -template -string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { - return "KLanguageModel"; -} - -struct VMapper : public lm::ngram::EnumerateVocab { - VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } - void Add(lm::WordIndex index, const StringPiece &str) { - const WordID cdec_id = TD::Convert(str.as_string()); - if (cdec_id >= out_->size()) - out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); - (*out_)[cdec_id] = index; - } - vector* out_; - const lm::WordIndex kLM_UNKNOWN_TOKEN; -}; - -template -class KLanguageModelImpl { - - // returns the number of unscored words at the left edge of a span - inline int UnscoredSize(const void* state) const { - return *(static_cast(state) + unscored_size_offset_); - } - - inline void SetUnscoredSize(int size, void* state) const { - *(static_cast(state) + unscored_size_offset_) = size; - } - - static inline const lm::ngram::State& RemnantLMState(const void* state) { - return *static_cast(state); - } - - inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const { - // if we were clever, we could use the memory pointed to by state to do all - // the work, avoiding this copy - memcpy(state, &lmstate, ngram_->StateSize()); - } - - lm::WordIndex IthUnscoredWord(int i, const void* state) const { - const lm::WordIndex* const mem = reinterpret_cast(static_cast(state) + unscored_words_offset_); - return mem[i]; - } - - void SetIthUnscoredWord(int i, lm::WordIndex index, void *state) const { - lm::WordIndex* mem = reinterpret_cast(static_cast(state) + unscored_words_offset_); - mem[i] = index; - } - - inline bool GetFlag(const void *state, unsigned char flag) const { - return (*(static_cast(state) + is_complete_offset_) & flag); - } - - inline void SetFlag(bool on, unsigned char flag, void *state) const { - if (on) { - *(static_cast(state) + is_complete_offset_) |= flag; - } else { - *(static_cast(state) + is_complete_offset_) &= (MASK ^ flag); - } - } - - inline bool HasFullContext(const void *state) const { - return GetFlag(state, HAS_FULL_CONTEXT); - } - - inline void SetHasFullContext(bool flag, void *state) const { - SetFlag(flag, HAS_FULL_CONTEXT, state); - } - - public: - double LookupWords(const TRule& rule, const vector& ant_states, double* pest_sum, double* oovs, double* est_oovs, void* remnant) { - double sum = 0.0; - double est_sum = 0.0; - int num_scored = 0; - int num_estimated = 0; - if (oovs) *oovs = 0; - if (est_oovs) *est_oovs = 0; - bool saw_eos = false; - bool has_some_history = false; - lm::ngram::State state = ngram_->NullContextState(); - const vector& e = rule.e(); - bool context_complete = false; - for (int j = 0; j < e.size(); ++j) { - if (e[j] < 1) { // handle non-terminal substitution - const void* astate = (ant_states[-e[j]]); - int unscored_ant_len = UnscoredSize(astate); - for (int k = 0; k < unscored_ant_len; ++k) { - const lm::WordIndex cur_word = IthUnscoredWord(k, astate); - const bool is_oov = (cur_word == 0); - double p = 0; - if (cur_word == kSOS_) { - state = ngram_->BeginSentenceState(); - if (has_some_history) { // this is immediately fully scored, and bad - p = -100; - context_complete = true; - } else { // this might be a real - num_scored = max(0, order_ - 2); - } - } else { - const lm::ngram::State scopy(state); - p = ngram_->Score(scopy, cur_word, state); - if (saw_eos) { p = -100; } - saw_eos = (cur_word == kEOS_); - } - has_some_history = true; - ++num_scored; - if (!context_complete) { - if (num_scored >= order_) context_complete = true; - } - if (context_complete) { - sum += p; - if (oovs && is_oov) (*oovs)++; - } else { - if (remnant) - SetIthUnscoredWord(num_estimated, cur_word, remnant); - ++num_estimated; - est_sum += p; - if (est_oovs && is_oov) (*est_oovs)++; - } - } - saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); - if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 - state = RemnantLMState(astate); - context_complete = true; - } - } else { // handle terminal - const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[j]); // in future, - // maybe handle emission - const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id - double p = 0; - const bool is_oov = (cur_word == 0); - if (cur_word == kSOS_) { - state = ngram_->BeginSentenceState(); - if (has_some_history) { // this is immediately fully scored, and bad - p = -100; - context_complete = true; - } else { // this might be a real - num_scored = max(0, order_ - 2); - } - } else { - const lm::ngram::State scopy(state); - p = ngram_->Score(scopy, cur_word, state); - if (saw_eos) { p = -100; } - saw_eos = (cur_word == kEOS_); - } - has_some_history = true; - ++num_scored; - if (!context_complete) { - if (num_scored >= order_) context_complete = true; - } - if (context_complete) { - sum += p; - if (oovs && is_oov) (*oovs)++; - } else { - if (remnant) - SetIthUnscoredWord(num_estimated, cur_word, remnant); - ++num_estimated; - est_sum += p; - if (est_oovs && is_oov) (*est_oovs)++; - } - } - } - if (pest_sum) *pest_sum = est_sum; - if (remnant) { - state.ZeroRemaining(); - SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); - SetRemnantLMState(state, remnant); - SetUnscoredSize(num_estimated, remnant); - SetHasFullContext(context_complete || (num_scored >= order_), remnant); - } - return sum; - } - - // this assumes no target words on final unary -> goal rule. is that ok? - // for (n-1 left words) and (n-1 right words) - double FinalTraversalCost(const void* state, double* oovs) { - if (add_sos_eos_) { // rules do not produce , so do it here - SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); - SetHasFullContext(1, dummy_state_); - SetUnscoredSize(0, dummy_state_); - dummy_ants_[1] = state; - *oovs = 0; - return LookupWords(*dummy_rule_, dummy_ants_, NULL, oovs, NULL, NULL); - } else { // rules DO produce ... - double p = 0; - if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } - if (UnscoredSize(state) > 0) { // are there unscored words - if (kSOS_ != IthUnscoredWord(0, state)) { - p -= 100 * UnscoredSize(state); - } - } - return p; - } - } - - // if this is not a class-based LM, returns w untransformed, - // otherwise returns a word class mapping of w, - // returns TD::Convert("") if there is no mapping for w - WordID ClassifyWordIfNecessary(WordID w) const { - if (word2class_map_.empty()) return w; - if (w >= word2class_map_.size()) - return kCDEC_UNK; - else - return word2class_map_[w]; - } - - // converts to cdec word id's to KenLM's id space, OOVs and end up at 0 - lm::WordIndex MapWord(WordID w) const { - if (w >= cdec2klm_map_.size()) - return 0; - else - return cdec2klm_map_[w]; - } - - public: - KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : - kCDEC_UNK(TD::Convert("")) , - add_sos_eos_(!explicit_markers) { - { - VMapper vm(&cdec2klm_map_); - lm::ngram::Config conf; - conf.enumerate_vocab = &vm; - ngram_ = new Model(filename.c_str(), conf); - } - order_ = ngram_->Order(); - cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; - state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); - unscored_size_offset_ = ngram_->StateSize(); - is_complete_offset_ = unscored_size_offset_ + 1; - unscored_words_offset_ = is_complete_offset_ + 1; - - // special handling of beginning / ending sentence markers - dummy_state_ = new char[state_size_]; - memset(dummy_state_, 0, state_size_); - dummy_ants_.push_back(dummy_state_); - dummy_ants_.push_back(NULL); - dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); - kSOS_ = MapWord(TD::Convert("")); - assert(kSOS_ > 0); - kEOS_ = MapWord(TD::Convert("")); - assert(kEOS_ > 0); - assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant - - // handle class-based LMs (unambiguous word->class mapping reqd.) - if (mapfile.size()) - LoadWordClasses(mapfile); - } - - void LoadWordClasses(const string& file) { - ReadFile rf(file); - istream& in = *rf.stream(); - string line; - vector dummy; - int lc = 0; - cerr << " Loading word classes from " << file << " ...\n"; - AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); - AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); - while(in) { - getline(in, line); - if (!in) continue; - dummy.clear(); - TD::ConvertSentence(line, &dummy); - ++lc; - if (dummy.size() != 2) { - cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; - abort(); - } - AddWordToClassMapping_(dummy[0], dummy[1]); - } - } - - void AddWordToClassMapping_(WordID word, WordID cls) { - if (word2class_map_.size() <= word) { - word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); - assert(word2class_map_.size() > word); - } - if(word2class_map_[word] != kCDEC_UNK) { - cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; - abort(); - } - word2class_map_[word] = cls; - } - - ~KLanguageModelImpl() { - delete ngram_; - delete[] dummy_state_; - } - - int ReserveStateSize() const { return state_size_; } - - private: - const WordID kCDEC_UNK; - lm::WordIndex kSOS_; // - requires special handling. - lm::WordIndex kEOS_; // - Model* ngram_; - const bool add_sos_eos_; // flag indicating whether the hypergraph produces and - // if this is true, FinalTransitionFeatures will "add" and - // if false, FinalTransitionFeatures will score anything with the - // markers in the right place (i.e., the beginning and end of - // the sentence) with 0, and anything else with -100 - - int order_; - int state_size_; - int unscored_size_offset_; - int is_complete_offset_; - int unscored_words_offset_; - char* dummy_state_; - vector dummy_ants_; - vector cdec2klm_map_; - vector word2class_map_; // if this is a class-based LM, this is the word->class mapping - TRulePtr dummy_rule_; -}; - -template -KLanguageModel::KLanguageModel(const string& param) { - string filename, mapfile, featname; - bool explicit_markers; - if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { - abort(); - } - try { - pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); - } catch (std::exception &e) { - std::cerr << e.what() << std::endl; - abort(); - } - fid_ = FD::Convert(featname); - oov_fid_ = FD::Convert(featname+"_OOV"); - cerr << "FID: " << oov_fid_ << endl; - SetStateSize(pimpl_->ReserveStateSize()); -} - -template -Features KLanguageModel::features() const { - return single_feature(fid_); -} - -template -KLanguageModel::~KLanguageModel() { - delete pimpl_; -} - -template -void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, - const Hypergraph::Edge& edge, - const vector& ant_states, - SparseVector* features, - SparseVector* estimated_features, - void* state) const { - double est = 0; - double oovs = 0; - double est_oovs = 0; - features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, &oovs, &est_oovs, state)); - estimated_features->set_value(fid_, est); - if (oov_fid_) { - if (oovs) features->set_value(oov_fid_, oovs); - if (est_oovs) estimated_features->set_value(oov_fid_, est_oovs); - } -} - -template -void KLanguageModel::FinalTraversalFeatures(const void* ant_state, - SparseVector* features) const { - double oovs = 0; - double lm = pimpl_->FinalTraversalCost(ant_state, &oovs); - features->set_value(fid_, lm); - if (oov_fid_ && oovs) - features->set_value(oov_fid_, oovs); -} - -template boost::shared_ptr CreateModel(const std::string ¶m) { - KLanguageModel *ret = new KLanguageModel(param); - ret->Init(); - return boost::shared_ptr(ret); -} - -boost::shared_ptr KLanguageModelFactory::Create(std::string param) const { - using namespace lm::ngram; - std::string filename, ignored_map; - bool ignored_markers; - std::string ignored_featname; - ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); - ModelType m; - if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; - - switch (m) { - case HASH_PROBING: - return CreateModel(param); - case TRIE_SORTED: - return CreateModel(param); - case ARRAY_TRIE_SORTED: - return CreateModel(param); - case QUANT_TRIE_SORTED: - return CreateModel(param); - case QUANT_ARRAY_TRIE_SORTED: - return CreateModel(param); - default: - UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); - } -} - -std::string KLanguageModelFactory::usage(bool params,bool verbose) const { - return KLanguageModel::usage(params, verbose); -} - -#endif -- cgit v1.2.3