summaryrefslogtreecommitdiff
path: root/klm/lm/builder/print.cc
blob: b0323221a834934cfbf9dfb0734c947af548fb72 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include "lm/builder/print.hh"

#include "util/double-conversion/double-conversion.h"
#include "util/double-conversion/utils.h"
#include "util/file.hh"
#include "util/mmap.hh"
#include "util/scoped.hh"
#include "util/stream/timer.hh"

#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE
#include <boost/lexical_cast.hpp>

#include <sstream>

#include <string.h>

namespace lm { namespace builder {

VocabReconstitute::VocabReconstitute(int fd) {
  uint64_t size = util::SizeOrThrow(fd);
  util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_);
  const char *const start = static_cast<const char*>(memory_.get());
  const char *i;
  for (i = start; i != start + size; i += strlen(i) + 1) {
    map_.push_back(i);
  }
  // Last one for LookupPiece.
  map_.push_back(i);
}

namespace {
class OutputManager {
  public:
    static const std::size_t kOutBuf = 1048576;

    // Does not take ownership of out.
    explicit OutputManager(int out)
      : buf_(util::MallocOrThrow(kOutBuf)),
        builder_(static_cast<char*>(buf_.get()), kOutBuf),
        // Mostly the default but with inf instead.  And no flags.
        convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0),
        fd_(out) {}

    ~OutputManager() {
      Flush();
    }

    OutputManager &operator<<(float value) {
      // Odd, but this is the largest number found in the comments.
      EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
      convert_.ToShortestSingle(value, &builder_);
      return *this;
    }

    OutputManager &operator<<(StringPiece str) {
      if (str.size() > kOutBuf) {
        Flush();
        util::WriteOrThrow(fd_, str.data(), str.size());
      } else {
        EnsureRemaining(str.size());
        builder_.AddSubstring(str.data(), str.size());
      }
      return *this;
    }

    // Inefficient!
    OutputManager &operator<<(unsigned val) {
      return *this << boost::lexical_cast<std::string>(val);
    }

    OutputManager &operator<<(char c) {
      EnsureRemaining(1);
      builder_.AddCharacter(c);
      return *this;
    }

    void Flush() {
      util::WriteOrThrow(fd_, buf_.get(), builder_.position());
      builder_.Reset();
    }

  private:
    void EnsureRemaining(std::size_t amount) {
      if (static_cast<std::size_t>(builder_.size() - builder_.position()) < amount) {
        Flush();
      }
    }

    util::scoped_malloc buf_;
    double_conversion::StringBuilder builder_;
    double_conversion::DoubleToStringConverter convert_;
    int fd_;
};
} // namespace

PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd) 
  : vocab_(vocab), out_fd_(out_fd) {
  std::stringstream stream;

  if (header_info) {
    stream << "# Input file: " << header_info->input_file << '\n';
    stream << "# Token count: " << header_info->token_count << '\n';
    stream << "# Smoothing: Modified Kneser-Ney" << '\n';
  }
  stream << "\\data\\\n";
  for (size_t i = 0; i < counts.size(); ++i) {
    stream << "ngram " << (i+1) << '=' << counts[i] << '\n';
  }
  stream << '\n';
  std::string as_string(stream.str());
  util::WriteOrThrow(out_fd, as_string.data(), as_string.size());
}

void PrintARPA::Run(const ChainPositions &positions) {
  UTIL_TIMER("(%w s) Wrote ARPA file\n");
  OutputManager out(out_fd_);
  for (unsigned order = 1; order <= positions.size(); ++order) {
    out << "\\" << order << "-grams:" << '\n';
    for (NGramStream stream(positions[order - 1]); stream; ++stream) {
      // Correcting for numerical precision issues.  Take that IRST.  
      out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin());
      for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
        out << ' ' << vocab_.Lookup(*i);
      }
      float backoff = stream->Value().complete.backoff;
      if (backoff != 0.0)
        out << '\t' << backoff;
      out << '\n';
    }
    out << '\n';
  }
  out << "\\end\\\n";
}

}} // namespaces