summaryrefslogtreecommitdiff
path: root/klm/lm/builder/print.hh
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2013-01-21 12:29:43 +0100
committerPatrick Simianer <p@simianer.de>2013-01-21 12:29:43 +0100
commit0d23f8aecbfaf982cd165ebfc2a1611cefcc7275 (patch)
tree8eafa6ea43224ff70635cadd4d6f027d28f4986f /klm/lm/builder/print.hh
parentdbc66cd3944321961c5e11d5254fd914f05a98ad (diff)
parent7cac43b858f3b681555bf0578f54b1f822c43207 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'klm/lm/builder/print.hh')
-rw-r--r--klm/lm/builder/print.hh102
1 files changed, 102 insertions, 0 deletions
diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh
new file mode 100644
index 00000000..aa932e75
--- /dev/null
+++ b/klm/lm/builder/print.hh
@@ -0,0 +1,102 @@
+#ifndef LM_BUILDER_PRINT__
+#define LM_BUILDER_PRINT__
+
+#include "lm/builder/ngram.hh"
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/header_info.hh"
+#include "util/file.hh"
+#include "util/mmap.hh"
+#include "util/string_piece.hh"
+
+#include <ostream>
+
+#include <assert.h>
+
+// Warning: print routines read all unigrams before all bigrams before all
+// trigrams etc. So if other parts of the chain move jointly, you'll have to
+// buffer.
+
+namespace lm { namespace builder {
+
+class VocabReconstitute {
+ public:
+ // fd must be alive for life of this object; does not take ownership.
+ explicit VocabReconstitute(int fd);
+
+ const char *Lookup(WordIndex index) const {
+ assert(index < map_.size() - 1);
+ return map_[index];
+ }
+
+ StringPiece LookupPiece(WordIndex index) const {
+ return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]);
+ }
+
+ std::size_t Size() const {
+ // There's an extra entry to support StringPiece lengths.
+ return map_.size() - 1;
+ }
+
+ private:
+ util::scoped_memory memory_;
+ std::vector<const char*> map_;
+};
+
+// Not defined, only specialized.
+template <class T> void PrintPayload(std::ostream &to, const Payload &payload);
+template <> inline void PrintPayload<uint64_t>(std::ostream &to, const Payload &payload) {
+ to << payload.count;
+}
+template <> inline void PrintPayload<Uninterpolated>(std::ostream &to, const Payload &payload) {
+ to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma);
+}
+template <> inline void PrintPayload<ProbBackoff>(std::ostream &to, const Payload &payload) {
+ to << payload.complete.prob << ' ' << payload.complete.backoff;
+}
+
+// template parameter is the type stored.
+template <class V> class Print {
+ public:
+ explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {}
+
+ void Run(const ChainPositions &chains) {
+ NGramStreams streams(chains);
+ for (NGramStream *s = streams.begin(); s != streams.end(); ++s) {
+ DumpStream(*s);
+ }
+ }
+
+ void Run(const util::stream::ChainPosition &position) {
+ NGramStream stream(position);
+ DumpStream(stream);
+ }
+
+ private:
+ void DumpStream(NGramStream &stream) {
+ for (; stream; ++stream) {
+ PrintPayload<V>(to_, stream->Value());
+ for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) {
+ to_ << ' ' << vocab_.Lookup(*w) << '=' << *w;
+ }
+ to_ << '\n';
+ }
+ }
+
+ const VocabReconstitute &vocab_;
+ std::ostream &to_;
+};
+
+class PrintARPA {
+ public:
+ // header_info may be NULL to disable the header
+ explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd);
+
+ void Run(const ChainPositions &positions);
+
+ private:
+ const VocabReconstitute &vocab_;
+ int out_fd_;
+};
+
+}} // namespaces
+#endif // LM_BUILDER_PRINT__