diff options
Diffstat (limited to 'klm/lm')
| -rw-r--r-- | klm/lm/bhiksha.hh | 5 | ||||
| -rw-r--r-- | klm/lm/build_binary.cc | 2 | ||||
| -rw-r--r-- | klm/lm/left.hh | 39 | ||||
| -rw-r--r-- | klm/lm/vocab.cc | 1 | ||||
| -rw-r--r-- | klm/lm/vocab.hh | 1 | 
5 files changed, 31 insertions, 17 deletions
| diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index bc705959..3df43dda 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -10,6 +10,9 @@   *  Currently only used for next pointers.     */ +#ifndef LM_BHIKSHA__ +#define LM_BHIKSHA__ +  #include <inttypes.h>  #include <assert.h> @@ -108,3 +111,5 @@ class ArrayBhiksha {  } // namespace trie  } // namespace ngram  } // namespace lm + +#endif // LM_BHIKSHA__ diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index b7aee4de..fdb62a71 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,7 +15,7 @@ namespace ngram {  namespace {  void Usage(const char *name) { -  std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n" +  std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"  "-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"  "   Default is -100.  The ARPA file will always take precedence.\n"  "-s allows models to be built even if they do not have <s> and </s>.\n" diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 15464c82..41f71f84 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -175,24 +175,14 @@ template <class M> class RuleScore {        float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1];        float *back = backoffs, *back2 = backoffs2; -      unsigned char next_use; +      unsigned char next_use = out_.right.length;        // First word -      ProcessRet(model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use)); -      if (!next_use) { -        left_done_ = true; -        out_.right = in.right; -        return; -      } +      if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return; +        // Words after the first, so extending a bigram to begin with -      unsigned char extend_length = 2; -      for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) { -        ProcessRet(model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use)); -        if (!next_use) { -          left_done_ = true; -          out_.right = in.right; -          return; -        } +      for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { +        if (ExtendLeft(in, next_use, extend_length, back, back2)) return;          std::swap(back, back2);        } @@ -228,6 +218,25 @@ template <class M> class RuleScore {      }    private: +    bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { +      ProcessRet(model_.ExtendLeft( +            out_.right.words, out_.right.words + next_use, // Words to extend into +            back_in, // Backoffs to use +            in.left.pointers[extend_length - 1], extend_length, // Words to be extended +            back_out, // Backoffs for the next score +            next_use)); // Length of n-gram to use in next scoring.   +      if (next_use != out_.right.length) { +        left_done_ = true; +        if (!next_use) { +          out_.right = in.right; +          // Early exit.   +          return true; +        } +      } +      // Continue scoring.   +      return false; +    } +      void ProcessRet(const FullScoreReturn &ret) {        prob_ += ret.prob;        if (left_done_) return; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 03b0767a..ffec41ca 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -135,6 +135,7 @@ void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {    end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);    ReadWords(fd, to);    SetSpecial(Index("<s>"), Index("</s>"), 0); +  bound_ = end_ - begin_ + 1;  }  namespace { diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 4cf68196..3c3414fb 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -66,7 +66,6 @@ class SortedVocabulary : public base::Vocabulary {      static size_t Size(std::size_t entries, const Config &config);      // Vocab words are [0, Bound())  Only valid after FinishedLoading/LoadedBinary.   -    // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases.        WordIndex Bound() const { return bound_; }      // Everything else is for populating.  I'm too lazy to hide and friend these, but you'll only get a const reference anyway. | 
