diff options
Diffstat (limited to 'klm/lm/builder/pipeline.cc')
-rw-r--r-- | klm/lm/builder/pipeline.cc | 103 |
1 files changed, 61 insertions, 42 deletions
diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc index 44a2313c..21064ab3 100644 --- a/klm/lm/builder/pipeline.cc +++ b/klm/lm/builder/pipeline.cc @@ -2,6 +2,7 @@ #include "lm/builder/adjust_counts.hh" #include "lm/builder/corpus_count.hh" +#include "lm/builder/hash_gamma.hh" #include "lm/builder/initial_probabilities.hh" #include "lm/builder/interpolate.hh" #include "lm/builder/print.hh" @@ -20,10 +21,13 @@ namespace lm { namespace builder { namespace { -void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts) { +void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts) { std::cerr << "Statistics:\n"; for (size_t i = 0; i < counts.size(); ++i) { - std::cerr << (i + 1) << ' ' << counts[i]; + std::cerr << (i + 1) << ' ' << counts_pruned[i]; + if(counts[i] != counts_pruned[i]) + std::cerr << "/" << counts[i]; + for (size_t d = 1; d <= 3; ++d) std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d]; std::cerr << '\n'; @@ -39,7 +43,7 @@ class Master { const PipelineConfig &Config() const { return config_; } - Chains &MutableChains() { return chains_; } + util::stream::Chains &MutableChains() { return chains_; } template <class T> Master &operator>>(const T &worker) { chains_ >> worker; @@ -64,7 +68,7 @@ class Master { } // For initial probabilities, but this is generic. - void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, Chains &second, util::stream::ChainConfig second_config) { + void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, util::stream::Chains &second, util::stream::ChainConfig second_config) { // Do merge first before allocating chain memory. for (std::size_t i = 1; i < config_.order; ++i) { sorts[i - 1].Merge(0); @@ -198,9 +202,9 @@ class Master { PipelineConfig config_; - Chains chains_; + util::stream::Chains chains_; // Often only unigrams, but sometimes all orders. - FixedArray<util::stream::FileBuffer> files_; + util::FixedArray<util::stream::FileBuffer> files_; }; void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) { @@ -221,7 +225,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m WordIndex type_count = config.vocab_estimate; util::FilePiece text(text_file, NULL, &std::cerr); text_file_name = text.FileName(); - CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize()); + CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action); chain >> boost::ref(counter); util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner()); @@ -231,21 +235,22 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m master.InitForAdjust(sorter, type_count); } -void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) { +void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, + util::FixedArray<util::stream::FileBuffer> &gammas, const std::vector<uint64_t> &prune_thresholds) { const PipelineConfig &config = master.Config(); - Chains second(config.order); + util::stream::Chains second(config.order); { Sorts<ContextOrder> sorts; master.SetupSorts(sorts); - PrintStatistics(counts, discounts); - lm::ngram::ShowSizes(counts); + PrintStatistics(counts, counts_pruned, discounts); + lm::ngram::ShowSizes(counts_pruned); std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl; - master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in); + master.SortAndReadTwice(counts_pruned, sorts, second, config.initial_probs.adder_in); } - Chains gamma_chains(config.order); - InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains); + util::stream::Chains gamma_chains(config.order); + InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds); // Don't care about gamma for 0. gamma_chains[0] >> util::stream::kRecycle; gammas.Init(config.order - 1); @@ -257,19 +262,25 @@ void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector master.SetupSorts(primary); } -void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) { +void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas) { std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl; const PipelineConfig &config = master.Config(); master.MaximumLazyInput(counts, primary); - Chains gamma_chains(config.order - 1); - util::stream::ChainConfig read_backoffs(config.read_backoffs); - read_backoffs.entry_size = sizeof(float); + util::stream::Chains gamma_chains(config.order - 1); for (std::size_t i = 0; i < config.order - 1; ++i) { + util::stream::ChainConfig read_backoffs(config.read_backoffs); + + // Add 1 because here we are skipping unigrams + if(config.prune_thresholds[i + 1] > 0) + read_backoffs.entry_size = sizeof(HashGamma); + else + read_backoffs.entry_size = sizeof(float); + gamma_chains.push_back(read_backoffs); gamma_chains.back() >> gammas[i].Source(); } - master >> Interpolate(counts[0], ChainPositions(gamma_chains)); + master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.output_q); gamma_chains >> util::stream::kRecycle; master.BufferFinal(counts); } @@ -291,32 +302,40 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) { "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size."); UTIL_TIMER("(%w s) Total wall time elapsed\n"); - Master master(config); - - util::scoped_fd vocab_file(config.vocab_file.empty() ? - util::MakeTemp(config.TempPrefix()) : - util::CreateOrThrow(config.vocab_file.c_str())); - uint64_t token_count; - std::string text_file_name; - CountText(text_file, vocab_file.get(), master, token_count, text_file_name); - std::vector<uint64_t> counts; - std::vector<Discount> discounts; - master >> AdjustCounts(counts, discounts); + Master master(config); + // master's destructor will wait for chains. But they might be deadlocked if + // this thread dies because e.g. it ran out of memory. + try { + util::scoped_fd vocab_file(config.vocab_file.empty() ? + util::MakeTemp(config.TempPrefix()) : + util::CreateOrThrow(config.vocab_file.c_str())); + uint64_t token_count; + std::string text_file_name; + CountText(text_file, vocab_file.get(), master, token_count, text_file_name); + + std::vector<uint64_t> counts; + std::vector<uint64_t> counts_pruned; + std::vector<Discount> discounts; + master >> AdjustCounts(config.prune_thresholds, counts, counts_pruned, config.discount, discounts); + + { + util::FixedArray<util::stream::FileBuffer> gammas; + Sorts<SuffixOrder> primary; + InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds); + InterpolateProbabilities(counts_pruned, master, primary, gammas); + } - { - FixedArray<util::stream::FileBuffer> gammas; - Sorts<SuffixOrder> primary; - InitialProbabilities(counts, discounts, master, primary, gammas); - InterpolateProbabilities(counts, master, primary, gammas); + std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; + VocabReconstitute vocab(vocab_file.get()); + UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?"); + HeaderInfo header_info(text_file_name, token_count); + master >> PrintARPA(vocab, counts_pruned, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; + master.MutableChains().Wait(true); + } catch (const util::Exception &e) { + std::cerr << e.what() << std::endl; + abort(); } - - std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; - VocabReconstitute vocab(vocab_file.get()); - UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?"); - HeaderInfo header_info(text_file_name, token_count); - master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; - master.MutableChains().Wait(true); } }} // namespaces |