summaryrefslogtreecommitdiff
path: root/klm/lm/builder/pipeline.cc
diff options
context:
space:
mode:
authorWu, Ke <wuke@cs.umd.edu>2014-12-17 16:11:38 -0500
committerWu, Ke <wuke@cs.umd.edu>2014-12-17 16:11:38 -0500
commit1613f1fc44ca67820afd7e7b21eb54b316c8ce55 (patch)
treee02b77084f28a18df6b854f87a986124db44d717 /klm/lm/builder/pipeline.cc
parentbd9308e22b5434aa220cc57d82ee867464a011f1 (diff)
parent796768086a687d3f1856fef6489c34fe4d373642 (diff)
Merge with upstream
Diffstat (limited to 'klm/lm/builder/pipeline.cc')
-rw-r--r--klm/lm/builder/pipeline.cc103
1 files changed, 61 insertions, 42 deletions
diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc
index 44a2313c..21064ab3 100644
--- a/klm/lm/builder/pipeline.cc
+++ b/klm/lm/builder/pipeline.cc
@@ -2,6 +2,7 @@
#include "lm/builder/adjust_counts.hh"
#include "lm/builder/corpus_count.hh"
+#include "lm/builder/hash_gamma.hh"
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/interpolate.hh"
#include "lm/builder/print.hh"
@@ -20,10 +21,13 @@
namespace lm { namespace builder {
namespace {
-void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts) {
+void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts) {
std::cerr << "Statistics:\n";
for (size_t i = 0; i < counts.size(); ++i) {
- std::cerr << (i + 1) << ' ' << counts[i];
+ std::cerr << (i + 1) << ' ' << counts_pruned[i];
+ if(counts[i] != counts_pruned[i])
+ std::cerr << "/" << counts[i];
+
for (size_t d = 1; d <= 3; ++d)
std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d];
std::cerr << '\n';
@@ -39,7 +43,7 @@ class Master {
const PipelineConfig &Config() const { return config_; }
- Chains &MutableChains() { return chains_; }
+ util::stream::Chains &MutableChains() { return chains_; }
template <class T> Master &operator>>(const T &worker) {
chains_ >> worker;
@@ -64,7 +68,7 @@ class Master {
}
// For initial probabilities, but this is generic.
- void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, Chains &second, util::stream::ChainConfig second_config) {
+ void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, util::stream::Chains &second, util::stream::ChainConfig second_config) {
// Do merge first before allocating chain memory.
for (std::size_t i = 1; i < config_.order; ++i) {
sorts[i - 1].Merge(0);
@@ -198,9 +202,9 @@ class Master {
PipelineConfig config_;
- Chains chains_;
+ util::stream::Chains chains_;
// Often only unigrams, but sometimes all orders.
- FixedArray<util::stream::FileBuffer> files_;
+ util::FixedArray<util::stream::FileBuffer> files_;
};
void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) {
@@ -221,7 +225,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
WordIndex type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName();
- CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action);
chain >> boost::ref(counter);
util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
@@ -231,21 +235,22 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
master.InitForAdjust(sorter, type_count);
}
-void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) {
+void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary,
+ util::FixedArray<util::stream::FileBuffer> &gammas, const std::vector<uint64_t> &prune_thresholds) {
const PipelineConfig &config = master.Config();
- Chains second(config.order);
+ util::stream::Chains second(config.order);
{
Sorts<ContextOrder> sorts;
master.SetupSorts(sorts);
- PrintStatistics(counts, discounts);
- lm::ngram::ShowSizes(counts);
+ PrintStatistics(counts, counts_pruned, discounts);
+ lm::ngram::ShowSizes(counts_pruned);
std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl;
- master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in);
+ master.SortAndReadTwice(counts_pruned, sorts, second, config.initial_probs.adder_in);
}
- Chains gamma_chains(config.order);
- InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains);
+ util::stream::Chains gamma_chains(config.order);
+ InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds);
// Don't care about gamma for 0.
gamma_chains[0] >> util::stream::kRecycle;
gammas.Init(config.order - 1);
@@ -257,19 +262,25 @@ void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector
master.SetupSorts(primary);
}
-void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) {
+void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas) {
std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl;
const PipelineConfig &config = master.Config();
master.MaximumLazyInput(counts, primary);
- Chains gamma_chains(config.order - 1);
- util::stream::ChainConfig read_backoffs(config.read_backoffs);
- read_backoffs.entry_size = sizeof(float);
+ util::stream::Chains gamma_chains(config.order - 1);
for (std::size_t i = 0; i < config.order - 1; ++i) {
+ util::stream::ChainConfig read_backoffs(config.read_backoffs);
+
+ // Add 1 because here we are skipping unigrams
+ if(config.prune_thresholds[i + 1] > 0)
+ read_backoffs.entry_size = sizeof(HashGamma);
+ else
+ read_backoffs.entry_size = sizeof(float);
+
gamma_chains.push_back(read_backoffs);
gamma_chains.back() >> gammas[i].Source();
}
- master >> Interpolate(counts[0], ChainPositions(gamma_chains));
+ master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.output_q);
gamma_chains >> util::stream::kRecycle;
master.BufferFinal(counts);
}
@@ -291,32 +302,40 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
"Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size.");
UTIL_TIMER("(%w s) Total wall time elapsed\n");
- Master master(config);
-
- util::scoped_fd vocab_file(config.vocab_file.empty() ?
- util::MakeTemp(config.TempPrefix()) :
- util::CreateOrThrow(config.vocab_file.c_str()));
- uint64_t token_count;
- std::string text_file_name;
- CountText(text_file, vocab_file.get(), master, token_count, text_file_name);
- std::vector<uint64_t> counts;
- std::vector<Discount> discounts;
- master >> AdjustCounts(counts, discounts);
+ Master master(config);
+ // master's destructor will wait for chains. But they might be deadlocked if
+ // this thread dies because e.g. it ran out of memory.
+ try {
+ util::scoped_fd vocab_file(config.vocab_file.empty() ?
+ util::MakeTemp(config.TempPrefix()) :
+ util::CreateOrThrow(config.vocab_file.c_str()));
+ uint64_t token_count;
+ std::string text_file_name;
+ CountText(text_file, vocab_file.get(), master, token_count, text_file_name);
+
+ std::vector<uint64_t> counts;
+ std::vector<uint64_t> counts_pruned;
+ std::vector<Discount> discounts;
+ master >> AdjustCounts(config.prune_thresholds, counts, counts_pruned, config.discount, discounts);
+
+ {
+ util::FixedArray<util::stream::FileBuffer> gammas;
+ Sorts<SuffixOrder> primary;
+ InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds);
+ InterpolateProbabilities(counts_pruned, master, primary, gammas);
+ }
- {
- FixedArray<util::stream::FileBuffer> gammas;
- Sorts<SuffixOrder> primary;
- InitialProbabilities(counts, discounts, master, primary, gammas);
- InterpolateProbabilities(counts, master, primary, gammas);
+ std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
+ VocabReconstitute vocab(vocab_file.get());
+ UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?");
+ HeaderInfo header_info(text_file_name, token_count);
+ master >> PrintARPA(vocab, counts_pruned, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle;
+ master.MutableChains().Wait(true);
+ } catch (const util::Exception &e) {
+ std::cerr << e.what() << std::endl;
+ abort();
}
-
- std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
- VocabReconstitute vocab(vocab_file.get());
- UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?");
- HeaderInfo header_info(text_file_name, token_count);
- master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle;
- master.MutableChains().Wait(true);
}
}} // namespaces