summaryrefslogtreecommitdiff
path: root/klm/lm/read_arpa.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/read_arpa.hh')
-rw-r--r--klm/lm/read_arpa.hh14
1 files changed, 10 insertions, 4 deletions
diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh
index ab996bde..234d130c 100644
--- a/klm/lm/read_arpa.hh
+++ b/klm/lm/read_arpa.hh
@@ -16,7 +16,13 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number);
void ReadNGramHeader(util::FilePiece &in, unsigned int length);
void ReadBackoff(util::FilePiece &in, Prob &weights);
-void ReadBackoff(util::FilePiece &in, ProbBackoff &weights);
+void ReadBackoff(util::FilePiece &in, float &backoff);
+inline void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) {
+ ReadBackoff(in, weights.backoff);
+}
+inline void ReadBackoff(util::FilePiece &in, RestWeights &weights) {
+ ReadBackoff(in, weights.backoff);
+}
void ReadEnd(util::FilePiece &in);
@@ -35,7 +41,7 @@ class PositiveProbWarn {
WarningAction action_;
};
-template <class Voc> void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff *unigrams, PositiveProbWarn &warn) {
+template <class Voc, class Weights> void Read1Gram(util::FilePiece &f, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) {
try {
float prob = f.ReadFloat();
if (prob > 0.0) {
@@ -43,7 +49,7 @@ template <class Voc> void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff
prob = 0.0;
}
if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability");
- ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))];
+ Weights &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))];
value.prob = prob;
ReadBackoff(f, value);
} catch(util::Exception &e) {
@@ -53,7 +59,7 @@ template <class Voc> void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff
}
// Return true if a positive log probability came out.
-template <class Voc> void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, ProbBackoff *unigrams, PositiveProbWarn &warn) {
+template <class Voc, class Weights> void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) {
ReadNGramHeader(f, 1);
for (std::size_t i = 0; i < count; ++i) {
Read1Gram(f, vocab, unigrams, warn);