summaryrefslogtreecommitdiff
path: root/klm/lm/builder/print.cc
blob: 84bd81cad51d2962f1e284d12deb2a12bbc8926c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include "lm/builder/print.hh"

#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/mmap.hh"
#include "util/scoped.hh"
#include "util/stream/timer.hh"

#include <sstream>

#include <string.h>

namespace lm { namespace builder {

VocabReconstitute::VocabReconstitute(int fd) {
  uint64_t size = util::SizeOrThrow(fd);
  util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_);
  const char *const start = static_cast<const char*>(memory_.get());
  const char *i;
  for (i = start; i != start + size; i += strlen(i) + 1) {
    map_.push_back(i);
  }
  // Last one for LookupPiece.
  map_.push_back(i);
}

PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd) 
  : vocab_(vocab), out_fd_(out_fd) {
  std::stringstream stream;

  if (header_info) {
    stream << "# Input file: " << header_info->input_file << '\n';
    stream << "# Token count: " << header_info->token_count << '\n';
    stream << "# Smoothing: Modified Kneser-Ney" << '\n';
  }
  stream << "\\data\\\n";
  for (size_t i = 0; i < counts.size(); ++i) {
    stream << "ngram " << (i+1) << '=' << counts[i] << '\n';
  }
  stream << '\n';
  std::string as_string(stream.str());
  util::WriteOrThrow(out_fd, as_string.data(), as_string.size());
}

void PrintARPA::Run(const ChainPositions &positions) {
  util::scoped_fd closer(out_fd_);
  UTIL_TIMER("(%w s) Wrote ARPA file\n");
  util::FakeOFStream out(out_fd_);
  for (unsigned order = 1; order <= positions.size(); ++order) {
    out << "\\" << order << "-grams:" << '\n';
    for (NGramStream stream(positions[order - 1]); stream; ++stream) {
      // Correcting for numerical precision issues.  Take that IRST.  
      out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin());
      for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
        out << ' ' << vocab_.Lookup(*i);
      }
      float backoff = stream->Value().complete.backoff;
      if (backoff != 0.0)
        out << '\t' << backoff;
      out << '\n';
    }
    out << '\n';
  }
  out << "\\end\\\n";
}

}} // namespaces