From d884099e0db8b4510847ec106b59ef7dca3c245b Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 18 Jan 2013 17:12:51 +0000 Subject: KenLM dffafbf with lmplz source (but not built) --- klm/util/double-conversion/strtod.cc | 554 +++++++++++++++++++++++++++++++++++ 1 file changed, 554 insertions(+) create mode 100644 klm/util/double-conversion/strtod.cc (limited to 'klm/util/double-conversion/strtod.cc') diff --git a/klm/util/double-conversion/strtod.cc b/klm/util/double-conversion/strtod.cc new file mode 100644 index 00000000..9758989f --- /dev/null +++ b/klm/util/double-conversion/strtod.cc @@ -0,0 +1,554 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include + +#include "strtod.h" +#include "bignum.h" +#include "cached-powers.h" +#include "ieee.h" + +namespace double_conversion { + +// 2^53 = 9007199254740992. +// Any integer with at most 15 decimal digits will hence fit into a double +// (which has a 53bit significand) without loss of precision. +static const int kMaxExactDoubleIntegerDecimalDigits = 15; +// 2^64 = 18446744073709551616 > 10^19 +static const int kMaxUint64DecimalDigits = 19; + +// Max double: 1.7976931348623157 x 10^308 +// Min non-zero double: 4.9406564584124654 x 10^-324 +// Any x >= 10^309 is interpreted as +infinity. +// Any x <= 10^-324 is interpreted as 0. +// Note that 2.5e-324 (despite being smaller than the min double) will be read +// as non-zero (equal to the min non-zero double). +static const int kMaxDecimalPower = 309; +static const int kMinDecimalPower = -324; + +// 2^64 = 18446744073709551616 +static const uint64_t kMaxUint64 = UINT64_2PART_C(0xFFFFFFFF, FFFFFFFF); + + +static const double exact_powers_of_ten[] = { + 1.0, // 10^0 + 10.0, + 100.0, + 1000.0, + 10000.0, + 100000.0, + 1000000.0, + 10000000.0, + 100000000.0, + 1000000000.0, + 10000000000.0, // 10^10 + 100000000000.0, + 1000000000000.0, + 10000000000000.0, + 100000000000000.0, + 1000000000000000.0, + 10000000000000000.0, + 100000000000000000.0, + 1000000000000000000.0, + 10000000000000000000.0, + 100000000000000000000.0, // 10^20 + 1000000000000000000000.0, + // 10^22 = 0x21e19e0c9bab2400000 = 0x878678326eac9 * 2^22 + 10000000000000000000000.0 +}; +static const int kExactPowersOfTenSize = ARRAY_SIZE(exact_powers_of_ten); + +// Maximum number of significant digits in the decimal representation. +// In fact the value is 772 (see conversions.cc), but to give us some margin +// we round up to 780. +static const int kMaxSignificantDecimalDigits = 780; + +static Vector TrimLeadingZeros(Vector buffer) { + for (int i = 0; i < buffer.length(); i++) { + if (buffer[i] != '0') { + return buffer.SubVector(i, buffer.length()); + } + } + return Vector(buffer.start(), 0); +} + + +static Vector TrimTrailingZeros(Vector buffer) { + for (int i = buffer.length() - 1; i >= 0; --i) { + if (buffer[i] != '0') { + return buffer.SubVector(0, i + 1); + } + } + return Vector(buffer.start(), 0); +} + + +static void CutToMaxSignificantDigits(Vector buffer, + int exponent, + char* significant_buffer, + int* significant_exponent) { + for (int i = 0; i < kMaxSignificantDecimalDigits - 1; ++i) { + significant_buffer[i] = buffer[i]; + } + // The input buffer has been trimmed. Therefore the last digit must be + // different from '0'. + ASSERT(buffer[buffer.length() - 1] != '0'); + // Set the last digit to be non-zero. This is sufficient to guarantee + // correct rounding. + significant_buffer[kMaxSignificantDecimalDigits - 1] = '1'; + *significant_exponent = + exponent + (buffer.length() - kMaxSignificantDecimalDigits); +} + + +// Trims the buffer and cuts it to at most kMaxSignificantDecimalDigits. +// If possible the input-buffer is reused, but if the buffer needs to be +// modified (due to cutting), then the input needs to be copied into the +// buffer_copy_space. +static void TrimAndCut(Vector buffer, int exponent, + char* buffer_copy_space, int space_size, + Vector* trimmed, int* updated_exponent) { + Vector left_trimmed = TrimLeadingZeros(buffer); + Vector right_trimmed = TrimTrailingZeros(left_trimmed); + exponent += left_trimmed.length() - right_trimmed.length(); + if (right_trimmed.length() > kMaxSignificantDecimalDigits) { + ASSERT(space_size >= kMaxSignificantDecimalDigits); + CutToMaxSignificantDigits(right_trimmed, exponent, + buffer_copy_space, updated_exponent); + *trimmed = Vector(buffer_copy_space, + kMaxSignificantDecimalDigits); + } else { + *trimmed = right_trimmed; + *updated_exponent = exponent; + } +} + + +// Reads digits from the buffer and converts them to a uint64. +// Reads in as many digits as fit into a uint64. +// When the string starts with "1844674407370955161" no further digit is read. +// Since 2^64 = 18446744073709551616 it would still be possible read another +// digit if it was less or equal than 6, but this would complicate the code. +static uint64_t ReadUint64(Vector buffer, + int* number_of_read_digits) { + uint64_t result = 0; + int i = 0; + while (i < buffer.length() && result <= (kMaxUint64 / 10 - 1)) { + int digit = buffer[i++] - '0'; + ASSERT(0 <= digit && digit <= 9); + result = 10 * result + digit; + } + *number_of_read_digits = i; + return result; +} + + +// Reads a DiyFp from the buffer. +// The returned DiyFp is not necessarily normalized. +// If remaining_decimals is zero then the returned DiyFp is accurate. +// Otherwise it has been rounded and has error of at most 1/2 ulp. +static void ReadDiyFp(Vector buffer, + DiyFp* result, + int* remaining_decimals) { + int read_digits; + uint64_t significand = ReadUint64(buffer, &read_digits); + if (buffer.length() == read_digits) { + *result = DiyFp(significand, 0); + *remaining_decimals = 0; + } else { + // Round the significand. + if (buffer[read_digits] >= '5') { + significand++; + } + // Compute the binary exponent. + int exponent = 0; + *result = DiyFp(significand, exponent); + *remaining_decimals = buffer.length() - read_digits; + } +} + + +static bool DoubleStrtod(Vector trimmed, + int exponent, + double* result) { +#if !defined(DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS) + // On x86 the floating-point stack can be 64 or 80 bits wide. If it is + // 80 bits wide (as is the case on Linux) then double-rounding occurs and the + // result is not accurate. + // We know that Windows32 uses 64 bits and is therefore accurate. + // Note that the ARM simulator is compiled for 32bits. It therefore exhibits + // the same problem. + return false; +#endif + if (trimmed.length() <= kMaxExactDoubleIntegerDecimalDigits) { + int read_digits; + // The trimmed input fits into a double. + // If the 10^exponent (resp. 10^-exponent) fits into a double too then we + // can compute the result-double simply by multiplying (resp. dividing) the + // two numbers. + // This is possible because IEEE guarantees that floating-point operations + // return the best possible approximation. + if (exponent < 0 && -exponent < kExactPowersOfTenSize) { + // 10^-exponent fits into a double. + *result = static_cast(ReadUint64(trimmed, &read_digits)); + ASSERT(read_digits == trimmed.length()); + *result /= exact_powers_of_ten[-exponent]; + return true; + } + if (0 <= exponent && exponent < kExactPowersOfTenSize) { + // 10^exponent fits into a double. + *result = static_cast(ReadUint64(trimmed, &read_digits)); + ASSERT(read_digits == trimmed.length()); + *result *= exact_powers_of_ten[exponent]; + return true; + } + int remaining_digits = + kMaxExactDoubleIntegerDecimalDigits - trimmed.length(); + if ((0 <= exponent) && + (exponent - remaining_digits < kExactPowersOfTenSize)) { + // The trimmed string was short and we can multiply it with + // 10^remaining_digits. As a result the remaining exponent now fits + // into a double too. + *result = static_cast(ReadUint64(trimmed, &read_digits)); + ASSERT(read_digits == trimmed.length()); + *result *= exact_powers_of_ten[remaining_digits]; + *result *= exact_powers_of_ten[exponent - remaining_digits]; + return true; + } + } + return false; +} + + +// Returns 10^exponent as an exact DiyFp. +// The given exponent must be in the range [1; kDecimalExponentDistance[. +static DiyFp AdjustmentPowerOfTen(int exponent) { + ASSERT(0 < exponent); + ASSERT(exponent < PowersOfTenCache::kDecimalExponentDistance); + // Simply hardcode the remaining powers for the given decimal exponent + // distance. + ASSERT(PowersOfTenCache::kDecimalExponentDistance == 8); + switch (exponent) { + case 1: return DiyFp(UINT64_2PART_C(0xa0000000, 00000000), -60); + case 2: return DiyFp(UINT64_2PART_C(0xc8000000, 00000000), -57); + case 3: return DiyFp(UINT64_2PART_C(0xfa000000, 00000000), -54); + case 4: return DiyFp(UINT64_2PART_C(0x9c400000, 00000000), -50); + case 5: return DiyFp(UINT64_2PART_C(0xc3500000, 00000000), -47); + case 6: return DiyFp(UINT64_2PART_C(0xf4240000, 00000000), -44); + case 7: return DiyFp(UINT64_2PART_C(0x98968000, 00000000), -40); + default: + UNREACHABLE(); + return DiyFp(0, 0); + } +} + + +// If the function returns true then the result is the correct double. +// Otherwise it is either the correct double or the double that is just below +// the correct double. +static bool DiyFpStrtod(Vector buffer, + int exponent, + double* result) { + DiyFp input; + int remaining_decimals; + ReadDiyFp(buffer, &input, &remaining_decimals); + // Since we may have dropped some digits the input is not accurate. + // If remaining_decimals is different than 0 than the error is at most + // .5 ulp (unit in the last place). + // We don't want to deal with fractions and therefore keep a common + // denominator. + const int kDenominatorLog = 3; + const int kDenominator = 1 << kDenominatorLog; + // Move the remaining decimals into the exponent. + exponent += remaining_decimals; + int error = (remaining_decimals == 0 ? 0 : kDenominator / 2); + + int old_e = input.e(); + input.Normalize(); + error <<= old_e - input.e(); + + ASSERT(exponent <= PowersOfTenCache::kMaxDecimalExponent); + if (exponent < PowersOfTenCache::kMinDecimalExponent) { + *result = 0.0; + return true; + } + DiyFp cached_power; + int cached_decimal_exponent; + PowersOfTenCache::GetCachedPowerForDecimalExponent(exponent, + &cached_power, + &cached_decimal_exponent); + + if (cached_decimal_exponent != exponent) { + int adjustment_exponent = exponent - cached_decimal_exponent; + DiyFp adjustment_power = AdjustmentPowerOfTen(adjustment_exponent); + input.Multiply(adjustment_power); + if (kMaxUint64DecimalDigits - buffer.length() >= adjustment_exponent) { + // The product of input with the adjustment power fits into a 64 bit + // integer. + ASSERT(DiyFp::kSignificandSize == 64); + } else { + // The adjustment power is exact. There is hence only an error of 0.5. + error += kDenominator / 2; + } + } + + input.Multiply(cached_power); + // The error introduced by a multiplication of a*b equals + // error_a + error_b + error_a*error_b/2^64 + 0.5 + // Substituting a with 'input' and b with 'cached_power' we have + // error_b = 0.5 (all cached powers have an error of less than 0.5 ulp), + // error_ab = 0 or 1 / kDenominator > error_a*error_b/ 2^64 + int error_b = kDenominator / 2; + int error_ab = (error == 0 ? 0 : 1); // We round up to 1. + int fixed_error = kDenominator / 2; + error += error_b + error_ab + fixed_error; + + old_e = input.e(); + input.Normalize(); + error <<= old_e - input.e(); + + // See if the double's significand changes if we add/subtract the error. + int order_of_magnitude = DiyFp::kSignificandSize + input.e(); + int effective_significand_size = + Double::SignificandSizeForOrderOfMagnitude(order_of_magnitude); + int precision_digits_count = + DiyFp::kSignificandSize - effective_significand_size; + if (precision_digits_count + kDenominatorLog >= DiyFp::kSignificandSize) { + // This can only happen for very small denormals. In this case the + // half-way multiplied by the denominator exceeds the range of an uint64. + // Simply shift everything to the right. + int shift_amount = (precision_digits_count + kDenominatorLog) - + DiyFp::kSignificandSize + 1; + input.set_f(input.f() >> shift_amount); + input.set_e(input.e() + shift_amount); + // We add 1 for the lost precision of error, and kDenominator for + // the lost precision of input.f(). + error = (error >> shift_amount) + 1 + kDenominator; + precision_digits_count -= shift_amount; + } + // We use uint64_ts now. This only works if the DiyFp uses uint64_ts too. + ASSERT(DiyFp::kSignificandSize == 64); + ASSERT(precision_digits_count < 64); + uint64_t one64 = 1; + uint64_t precision_bits_mask = (one64 << precision_digits_count) - 1; + uint64_t precision_bits = input.f() & precision_bits_mask; + uint64_t half_way = one64 << (precision_digits_count - 1); + precision_bits *= kDenominator; + half_way *= kDenominator; + DiyFp rounded_input(input.f() >> precision_digits_count, + input.e() + precision_digits_count); + if (precision_bits >= half_way + error) { + rounded_input.set_f(rounded_input.f() + 1); + } + // If the last_bits are too close to the half-way case than we are too + // inaccurate and round down. In this case we return false so that we can + // fall back to a more precise algorithm. + + *result = Double(rounded_input).value(); + if (half_way - error < precision_bits && precision_bits < half_way + error) { + // Too imprecise. The caller will have to fall back to a slower version. + // However the returned number is guaranteed to be either the correct + // double, or the next-lower double. + return false; + } else { + return true; + } +} + + +// Returns +// - -1 if buffer*10^exponent < diy_fp. +// - 0 if buffer*10^exponent == diy_fp. +// - +1 if buffer*10^exponent > diy_fp. +// Preconditions: +// buffer.length() + exponent <= kMaxDecimalPower + 1 +// buffer.length() + exponent > kMinDecimalPower +// buffer.length() <= kMaxDecimalSignificantDigits +static int CompareBufferWithDiyFp(Vector buffer, + int exponent, + DiyFp diy_fp) { + ASSERT(buffer.length() + exponent <= kMaxDecimalPower + 1); + ASSERT(buffer.length() + exponent > kMinDecimalPower); + ASSERT(buffer.length() <= kMaxSignificantDecimalDigits); + // Make sure that the Bignum will be able to hold all our numbers. + // Our Bignum implementation has a separate field for exponents. Shifts will + // consume at most one bigit (< 64 bits). + // ln(10) == 3.3219... + ASSERT(((kMaxDecimalPower + 1) * 333 / 100) < Bignum::kMaxSignificantBits); + Bignum buffer_bignum; + Bignum diy_fp_bignum; + buffer_bignum.AssignDecimalString(buffer); + diy_fp_bignum.AssignUInt64(diy_fp.f()); + if (exponent >= 0) { + buffer_bignum.MultiplyByPowerOfTen(exponent); + } else { + diy_fp_bignum.MultiplyByPowerOfTen(-exponent); + } + if (diy_fp.e() > 0) { + diy_fp_bignum.ShiftLeft(diy_fp.e()); + } else { + buffer_bignum.ShiftLeft(-diy_fp.e()); + } + return Bignum::Compare(buffer_bignum, diy_fp_bignum); +} + + +// Returns true if the guess is the correct double. +// Returns false, when guess is either correct or the next-lower double. +static bool ComputeGuess(Vector trimmed, int exponent, + double* guess) { + if (trimmed.length() == 0) { + *guess = 0.0; + return true; + } + if (exponent + trimmed.length() - 1 >= kMaxDecimalPower) { + *guess = Double::Infinity(); + return true; + } + if (exponent + trimmed.length() <= kMinDecimalPower) { + *guess = 0.0; + return true; + } + + if (DoubleStrtod(trimmed, exponent, guess) || + DiyFpStrtod(trimmed, exponent, guess)) { + return true; + } + if (*guess == Double::Infinity()) { + return true; + } + return false; +} + +double Strtod(Vector buffer, int exponent) { + char copy_buffer[kMaxSignificantDecimalDigits]; + Vector trimmed; + int updated_exponent; + TrimAndCut(buffer, exponent, copy_buffer, kMaxSignificantDecimalDigits, + &trimmed, &updated_exponent); + exponent = updated_exponent; + + double guess; + bool is_correct = ComputeGuess(trimmed, exponent, &guess); + if (is_correct) return guess; + + DiyFp upper_boundary = Double(guess).UpperBoundary(); + int comparison = CompareBufferWithDiyFp(trimmed, exponent, upper_boundary); + if (comparison < 0) { + return guess; + } else if (comparison > 0) { + return Double(guess).NextDouble(); + } else if ((Double(guess).Significand() & 1) == 0) { + // Round towards even. + return guess; + } else { + return Double(guess).NextDouble(); + } +} + +float Strtof(Vector buffer, int exponent) { + char copy_buffer[kMaxSignificantDecimalDigits]; + Vector trimmed; + int updated_exponent; + TrimAndCut(buffer, exponent, copy_buffer, kMaxSignificantDecimalDigits, + &trimmed, &updated_exponent); + exponent = updated_exponent; + + double double_guess; + bool is_correct = ComputeGuess(trimmed, exponent, &double_guess); + + float float_guess = static_cast(double_guess); + if (float_guess == double_guess) { + // This shortcut triggers for integer values. + return float_guess; + } + + // We must catch double-rounding. Say the double has been rounded up, and is + // now a boundary of a float, and rounds up again. This is why we have to + // look at previous too. + // Example (in decimal numbers): + // input: 12349 + // high-precision (4 digits): 1235 + // low-precision (3 digits): + // when read from input: 123 + // when rounded from high precision: 124. + // To do this we simply look at the neigbors of the correct result and see + // if they would round to the same float. If the guess is not correct we have + // to look at four values (since two different doubles could be the correct + // double). + + double double_next = Double(double_guess).NextDouble(); + double double_previous = Double(double_guess).PreviousDouble(); + + float f1 = static_cast(double_previous); + float f2 = float_guess; + float f3 = static_cast(double_next); + float f4; + if (is_correct) { + f4 = f3; + } else { + double double_next2 = Double(double_next).NextDouble(); + f4 = static_cast(double_next2); + } + ASSERT(f1 <= f2 && f2 <= f3 && f3 <= f4); + + // If the guess doesn't lie near a single-precision boundary we can simply + // return its float-value. + if (f1 == f4) { + return float_guess; + } + + ASSERT((f1 != f2 && f2 == f3 && f3 == f4) || + (f1 == f2 && f2 != f3 && f3 == f4) || + (f1 == f2 && f2 == f3 && f3 != f4)); + + // guess and next are the two possible canditates (in the same way that + // double_guess was the lower candidate for a double-precision guess). + float guess = f1; + float next = f4; + DiyFp upper_boundary; + if (guess == 0.0f) { + float min_float = 1e-45f; + upper_boundary = Double(static_cast(min_float) / 2).AsDiyFp(); + } else { + upper_boundary = Single(guess).UpperBoundary(); + } + int comparison = CompareBufferWithDiyFp(trimmed, exponent, upper_boundary); + if (comparison < 0) { + return guess; + } else if (comparison > 0) { + return next; + } else if ((Single(guess).Significand() & 1) == 0) { + // Round towards even. + return guess; + } else { + return next; + } +} + +} // namespace double_conversion -- cgit v1.2.3 From b35a7f3a96ff8ae42e15922dd6949bf9f5d15501 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 22 Jan 2013 21:37:49 +0000 Subject: KenLM 58da338b --- klm/lm/Makefile.am | 4 +- klm/lm/build_binary.cc | 228 -------------------------------- klm/lm/build_binary_main.cc | 228 ++++++++++++++++++++++++++++++++ klm/lm/builder/Makefile.am | 2 +- klm/lm/builder/discount.hh | 2 +- klm/lm/builder/lmplz_main.cc | 94 +++++++++++++ klm/lm/builder/main.cc | 94 ------------- klm/lm/filter/filter_main.cc | 248 ++++++++++++++++++++++++++++++++++ klm/lm/filter/main.cc | 249 ----------------------------------- klm/lm/filter/phrase.hh | 1 + klm/lm/filter/vocab.hh | 1 + klm/lm/fragment.cc | 37 ------ klm/lm/fragment_main.cc | 37 ++++++ klm/lm/kenlm_max_order_main.cc | 6 + klm/lm/max_order.cc | 6 - klm/lm/ngram_query.cc | 47 ------- klm/lm/query_main.cc | 47 +++++++ klm/util/Makefile.am | 1 + klm/util/double-conversion/strtod.cc | 4 + klm/util/file.cc | 47 +++++-- klm/util/file_piece.cc | 22 +++- klm/util/file_piece.hh | 10 ++ klm/util/file_piece_test.cc | 14 ++ klm/util/have.hh | 4 - klm/util/read_compressed.cc | 28 +++- klm/util/read_compressed.hh | 7 + klm/util/read_compressed_test.cc | 55 +++++--- klm/util/stream/io.cc | 8 +- klm/util/stream/sort.hh | 12 +- klm/util/string_piece.cc | 3 +- klm/util/string_piece.hh | 41 ------ klm/util/string_piece_hash.hh | 43 ++++++ klm/util/usage.cc | 2 +- 33 files changed, 875 insertions(+), 757 deletions(-) delete mode 100644 klm/lm/build_binary.cc create mode 100644 klm/lm/build_binary_main.cc create mode 100644 klm/lm/builder/lmplz_main.cc delete mode 100644 klm/lm/builder/main.cc create mode 100644 klm/lm/filter/filter_main.cc delete mode 100644 klm/lm/filter/main.cc delete mode 100644 klm/lm/fragment.cc create mode 100644 klm/lm/fragment_main.cc create mode 100644 klm/lm/kenlm_max_order_main.cc delete mode 100644 klm/lm/max_order.cc delete mode 100644 klm/lm/ngram_query.cc create mode 100644 klm/lm/query_main.cc create mode 100644 klm/util/string_piece_hash.hh (limited to 'klm/util/double-conversion/strtod.cc') diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 45f40c43..48b0ba34 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -1,9 +1,9 @@ bin_PROGRAMS = build_binary ngram_query -build_binary_SOURCES = build_binary.cc +build_binary_SOURCES = build_binary_main.cc build_binary_LDADD = libklm.a ../util/libklm_util.a ../util/double-conversion/libklm_util_double.a -lz -ngram_query_SOURCES = ngram_query.cc +ngram_query_SOURCES = query_main.cc ngram_query_LDADD = libklm.a ../util/libklm_util.a ../util/double-conversion/libklm_util_double.a -lz #noinst_PROGRAMS = \ diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc deleted file mode 100644 index ab2c0c32..00000000 --- a/klm/lm/build_binary.cc +++ /dev/null @@ -1,228 +0,0 @@ -#include "lm/model.hh" -#include "lm/sizes.hh" -#include "util/file_piece.hh" -#include "util/usage.hh" - -#include -#include -#include -#include -#include -#include - -#include -#include - -#ifdef WIN32 -#include "util/getopt.hh" -#else -#include -#endif - -namespace lm { -namespace ngram { -namespace { - -void Usage(const char *name, const char *default_mem) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" -"-u sets the log10 probability for if the ARPA file does not have one.\n" -" Default is -100. The ARPA file will always take precedence.\n" -"-s allows models to be built even if they do not have and .\n" -"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" -"-w mmap|after determines how writing is done.\n" -" mmap maps the binary file and writes to it. Default for trie.\n" -" after allocates anonymous memory, builds, and writes. Default for probing.\n" -"-r \"order1.arpa order2 order3 order4\" adds lower-order rest costs from these\n" -" model files. order1.arpa must be an ARPA file. All others may be ARPA or\n" -" the same data structure as being built. All files must have the same\n" -" vocabulary. For probing, the unigrams must be in the same order.\n\n" -"type is either probing or trie. Default is probing.\n\n" -"probing uses a probing hash table. It is the fastest but uses the most memory.\n" -"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" -"trie is a straightforward trie with bit-level packing. It uses the least\n" -"memory and is still faster than SRI or IRST. Building the trie format uses an\n" -"on-disk sort to save memory.\n" -"-T is the temporary directory prefix. Default is the output file name.\n" -"-S determines memory use for sorting. Default is " << default_mem << ". This is compatible\n" -" with GNU sort. The number is followed by a unit: \% for percent of physical\n" -" memory, b for bytes, K for Kilobytes, M for megabytes, then G,T,P,E,Z,Y. \n" -" Default unit is K for Kilobytes.\n" -"-q turns quantization on and sets the number of bits (e.g. -q 8).\n" -"-b sets backoff quantization bits. Requires -q and defaults to that value.\n" -"-a compresses pointers using an array of offsets. The parameter is the\n" -" maximum number of bits encoded by the array. Memory is minimized subject\n" -" to the maximum, so pick 255 to minimize memory.\n\n" -"Get a memory estimate by passing an ARPA file without an output file name.\n"; - exit(1); -} - -// I could really use boost::lexical_cast right about now. -float ParseFloat(const char *from) { - char *end; - float ret = strtod(from, &end); - if (*end) throw util::ParseNumberException(from); - return ret; -} -unsigned long int ParseUInt(const char *from) { - char *end; - unsigned long int ret = strtoul(from, &end, 10); - if (*end) throw util::ParseNumberException(from); - return ret; -} - -uint8_t ParseBitCount(const char *from) { - unsigned long val = ParseUInt(from); - if (val > 25) { - util::ParseNumberException e(from); - e << " bit counts are limited to 25."; - } - return val; -} - -void ParseFileList(const char *from, std::vector &to) { - to.clear(); - while (true) { - const char *i; - for (i = from; *i && *i != ' '; ++i) {} - to.push_back(std::string(from, i - from)); - if (!*i) break; - from = i + 1; - } -} - -void ProbingQuantizationUnsupported() { - std::cerr << "Quantization is only implemented in the trie data structure." << std::endl; - exit(1); -} - -} // namespace ngram -} // namespace lm -} // namespace - -int main(int argc, char *argv[]) { - using namespace lm::ngram; - - const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G"; - - try { - bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false; - lm::ngram::Config config; - config.building_memory = util::ParseSize(default_mem); - int opt; - while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) { - switch(opt) { - case 'q': - config.prob_bits = ParseBitCount(optarg); - if (!set_backoff_bits) config.backoff_bits = config.prob_bits; - quantize = true; - break; - case 'b': - config.backoff_bits = ParseBitCount(optarg); - set_backoff_bits = true; - break; - case 'a': - config.pointer_bhiksha_bits = ParseBitCount(optarg); - bhiksha = true; - break; - case 'u': - config.unknown_missing_logprob = ParseFloat(optarg); - break; - case 'p': - config.probing_multiplier = ParseFloat(optarg); - break; - case 't': // legacy - case 'T': - config.temporary_directory_prefix = optarg; - break; - case 'm': // legacy - config.building_memory = ParseUInt(optarg) * 1048576; - break; - case 'S': - config.building_memory = std::min(static_cast(std::numeric_limits::max()), util::ParseSize(optarg)); - break; - case 'w': - set_write_method = true; - if (!strcmp(optarg, "mmap")) { - config.write_method = Config::WRITE_MMAP; - } else if (!strcmp(optarg, "after")) { - config.write_method = Config::WRITE_AFTER; - } else { - Usage(argv[0], default_mem); - } - break; - case 's': - config.sentence_marker_missing = lm::SILENT; - break; - case 'i': - config.positive_log_probability = lm::SILENT; - break; - case 'r': - rest = true; - ParseFileList(optarg, config.rest_lower_files); - config.rest_function = Config::REST_LOWER; - break; - default: - Usage(argv[0], default_mem); - } - } - if (!quantize && set_backoff_bits) { - std::cerr << "You specified backoff quantization (-b) but not probability quantization (-q)" << std::endl; - abort(); - } - if (optind + 1 == argc) { - ShowSizes(argv[optind], config); - return 0; - } - const char *model_type; - const char *from_file; - - if (optind + 2 == argc) { - model_type = "probing"; - from_file = argv[optind]; - config.write_mmap = argv[optind + 1]; - } else if (optind + 3 == argc) { - model_type = argv[optind]; - from_file = argv[optind + 1]; - config.write_mmap = argv[optind + 2]; - } else { - Usage(argv[0], default_mem); - } - if (!strcmp(model_type, "probing")) { - if (!set_write_method) config.write_method = Config::WRITE_AFTER; - if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); - if (rest) { - RestProbingModel(from_file, config); - } else { - ProbingModel(from_file, config); - } - } else if (!strcmp(model_type, "trie")) { - if (rest) { - std::cerr << "Rest + trie is not supported yet." << std::endl; - return 1; - } - if (!set_write_method) config.write_method = Config::WRITE_MMAP; - if (quantize) { - if (bhiksha) { - QuantArrayTrieModel(from_file, config); - } else { - QuantTrieModel(from_file, config); - } - } else { - if (bhiksha) { - ArrayTrieModel(from_file, config); - } else { - TrieModel(from_file, config); - } - } - } else { - Usage(argv[0], default_mem); - } - } - catch (const std::exception &e) { - std::cerr << e.what() << std::endl; - std::cerr << "ERROR" << std::endl; - return 1; - } - std::cerr << "SUCCESS" << std::endl; - return 0; -} diff --git a/klm/lm/build_binary_main.cc b/klm/lm/build_binary_main.cc new file mode 100644 index 00000000..ab2c0c32 --- /dev/null +++ b/klm/lm/build_binary_main.cc @@ -0,0 +1,228 @@ +#include "lm/model.hh" +#include "lm/sizes.hh" +#include "util/file_piece.hh" +#include "util/usage.hh" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#ifdef WIN32 +#include "util/getopt.hh" +#else +#include +#endif + +namespace lm { +namespace ngram { +namespace { + +void Usage(const char *name, const char *default_mem) { + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" +"-u sets the log10 probability for if the ARPA file does not have one.\n" +" Default is -100. The ARPA file will always take precedence.\n" +"-s allows models to be built even if they do not have and .\n" +"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" +"-w mmap|after determines how writing is done.\n" +" mmap maps the binary file and writes to it. Default for trie.\n" +" after allocates anonymous memory, builds, and writes. Default for probing.\n" +"-r \"order1.arpa order2 order3 order4\" adds lower-order rest costs from these\n" +" model files. order1.arpa must be an ARPA file. All others may be ARPA or\n" +" the same data structure as being built. All files must have the same\n" +" vocabulary. For probing, the unigrams must be in the same order.\n\n" +"type is either probing or trie. Default is probing.\n\n" +"probing uses a probing hash table. It is the fastest but uses the most memory.\n" +"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" +"trie is a straightforward trie with bit-level packing. It uses the least\n" +"memory and is still faster than SRI or IRST. Building the trie format uses an\n" +"on-disk sort to save memory.\n" +"-T is the temporary directory prefix. Default is the output file name.\n" +"-S determines memory use for sorting. Default is " << default_mem << ". This is compatible\n" +" with GNU sort. The number is followed by a unit: \% for percent of physical\n" +" memory, b for bytes, K for Kilobytes, M for megabytes, then G,T,P,E,Z,Y. \n" +" Default unit is K for Kilobytes.\n" +"-q turns quantization on and sets the number of bits (e.g. -q 8).\n" +"-b sets backoff quantization bits. Requires -q and defaults to that value.\n" +"-a compresses pointers using an array of offsets. The parameter is the\n" +" maximum number of bits encoded by the array. Memory is minimized subject\n" +" to the maximum, so pick 255 to minimize memory.\n\n" +"Get a memory estimate by passing an ARPA file without an output file name.\n"; + exit(1); +} + +// I could really use boost::lexical_cast right about now. +float ParseFloat(const char *from) { + char *end; + float ret = strtod(from, &end); + if (*end) throw util::ParseNumberException(from); + return ret; +} +unsigned long int ParseUInt(const char *from) { + char *end; + unsigned long int ret = strtoul(from, &end, 10); + if (*end) throw util::ParseNumberException(from); + return ret; +} + +uint8_t ParseBitCount(const char *from) { + unsigned long val = ParseUInt(from); + if (val > 25) { + util::ParseNumberException e(from); + e << " bit counts are limited to 25."; + } + return val; +} + +void ParseFileList(const char *from, std::vector &to) { + to.clear(); + while (true) { + const char *i; + for (i = from; *i && *i != ' '; ++i) {} + to.push_back(std::string(from, i - from)); + if (!*i) break; + from = i + 1; + } +} + +void ProbingQuantizationUnsupported() { + std::cerr << "Quantization is only implemented in the trie data structure." << std::endl; + exit(1); +} + +} // namespace ngram +} // namespace lm +} // namespace + +int main(int argc, char *argv[]) { + using namespace lm::ngram; + + const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G"; + + try { + bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false; + lm::ngram::Config config; + config.building_memory = util::ParseSize(default_mem); + int opt; + while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) { + switch(opt) { + case 'q': + config.prob_bits = ParseBitCount(optarg); + if (!set_backoff_bits) config.backoff_bits = config.prob_bits; + quantize = true; + break; + case 'b': + config.backoff_bits = ParseBitCount(optarg); + set_backoff_bits = true; + break; + case 'a': + config.pointer_bhiksha_bits = ParseBitCount(optarg); + bhiksha = true; + break; + case 'u': + config.unknown_missing_logprob = ParseFloat(optarg); + break; + case 'p': + config.probing_multiplier = ParseFloat(optarg); + break; + case 't': // legacy + case 'T': + config.temporary_directory_prefix = optarg; + break; + case 'm': // legacy + config.building_memory = ParseUInt(optarg) * 1048576; + break; + case 'S': + config.building_memory = std::min(static_cast(std::numeric_limits::max()), util::ParseSize(optarg)); + break; + case 'w': + set_write_method = true; + if (!strcmp(optarg, "mmap")) { + config.write_method = Config::WRITE_MMAP; + } else if (!strcmp(optarg, "after")) { + config.write_method = Config::WRITE_AFTER; + } else { + Usage(argv[0], default_mem); + } + break; + case 's': + config.sentence_marker_missing = lm::SILENT; + break; + case 'i': + config.positive_log_probability = lm::SILENT; + break; + case 'r': + rest = true; + ParseFileList(optarg, config.rest_lower_files); + config.rest_function = Config::REST_LOWER; + break; + default: + Usage(argv[0], default_mem); + } + } + if (!quantize && set_backoff_bits) { + std::cerr << "You specified backoff quantization (-b) but not probability quantization (-q)" << std::endl; + abort(); + } + if (optind + 1 == argc) { + ShowSizes(argv[optind], config); + return 0; + } + const char *model_type; + const char *from_file; + + if (optind + 2 == argc) { + model_type = "probing"; + from_file = argv[optind]; + config.write_mmap = argv[optind + 1]; + } else if (optind + 3 == argc) { + model_type = argv[optind]; + from_file = argv[optind + 1]; + config.write_mmap = argv[optind + 2]; + } else { + Usage(argv[0], default_mem); + } + if (!strcmp(model_type, "probing")) { + if (!set_write_method) config.write_method = Config::WRITE_AFTER; + if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); + if (rest) { + RestProbingModel(from_file, config); + } else { + ProbingModel(from_file, config); + } + } else if (!strcmp(model_type, "trie")) { + if (rest) { + std::cerr << "Rest + trie is not supported yet." << std::endl; + return 1; + } + if (!set_write_method) config.write_method = Config::WRITE_MMAP; + if (quantize) { + if (bhiksha) { + QuantArrayTrieModel(from_file, config); + } else { + QuantTrieModel(from_file, config); + } + } else { + if (bhiksha) { + ArrayTrieModel(from_file, config); + } else { + TrieModel(from_file, config); + } + } + } else { + Usage(argv[0], default_mem); + } + } + catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + std::cerr << "ERROR" << std::endl; + return 1; + } + std::cerr << "SUCCESS" << std::endl; + return 0; +} diff --git a/klm/lm/builder/Makefile.am b/klm/lm/builder/Makefile.am index b5c147fd..317e03ce 100644 --- a/klm/lm/builder/Makefile.am +++ b/klm/lm/builder/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = builder builder_SOURCES = \ - main.cc \ + lmplz_main.cc \ adjust_counts.cc \ adjust_counts.hh \ corpus_count.cc \ diff --git a/klm/lm/builder/discount.hh b/klm/lm/builder/discount.hh index 754fb20d..4d0aa4fd 100644 --- a/klm/lm/builder/discount.hh +++ b/klm/lm/builder/discount.hh @@ -3,7 +3,7 @@ #include -#include +#include namespace lm { namespace builder { diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc new file mode 100644 index 00000000..90b9dca2 --- /dev/null +++ b/klm/lm/builder/lmplz_main.cc @@ -0,0 +1,94 @@ +#include "lm/builder/pipeline.hh" +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/usage.hh" + +#include + +#include + +namespace { +class SizeNotify { + public: + SizeNotify(std::size_t &out) : behind_(out) {} + + void operator()(const std::string &from) { + behind_ = util::ParseSize(from); + } + + private: + std::size_t &behind_; +}; + +boost::program_options::typed_value *SizeOption(std::size_t &to, const char *default_value) { + return boost::program_options::value()->notifier(SizeNotify(to))->default_value(default_value); +} + +} // namespace + +int main(int argc, char *argv[]) { + try { + namespace po = boost::program_options; + po::options_description options("Language model building options"); + lm::builder::PipelineConfig pipeline; + + options.add_options() + ("order,o", po::value(&pipeline.order)->required(), "Order of the model") + ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)") + ("temp_prefix,T", po::value(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix") + ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") + ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step") + ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow") + ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)") + ("block_count", po::value(&pipeline.block_count)->default_value(2), "Block count (per order)") + ("vocab_file", po::value(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") + ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc."); + if (argc == 1) { + std::cerr << + "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n" + "Please cite:\n" + "@inproceedings{kenlm,\n" + "author = {Kenneth Heafield},\n" + "title = {{KenLM}: Faster and Smaller Language Model Queries},\n" + "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n" + "month = {July}, year={2011},\n" + "address = {Edinburgh, UK},\n" + "publisher = {Association for Computational Linguistics},\n" + "}\n\n" + "Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n" + "the model (-o) is the only mandatory option. As this is an on-disk program,\n" + "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n" + "Memory sizes are specified like GNU sort: a number followed by a unit character.\n" + "Valid units are \% for percentage of memory (supported platforms only) and (in\n" + "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n"; + std::cerr << options << std::endl; + return 1; + } + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, options), vm); + po::notify(vm); + + util::NormalizeTempPrefix(pipeline.sort.temp_prefix); + + lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; + // TODO: evaluate options for these. + initial.adder_in.total_memory = 32768; + initial.adder_in.block_count = 2; + initial.adder_out.total_memory = 32768; + initial.adder_out.block_count = 2; + pipeline.read_backoffs = initial.adder_out; + + // Read from stdin + try { + lm::builder::Pipeline(pipeline, 0, 1); + } catch (const util::MallocException &e) { + std::cerr << e.what() << std::endl; + std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as() << std::endl; + return 1; + } + util::PrintUsage(std::cerr); + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; + } +} diff --git a/klm/lm/builder/main.cc b/klm/lm/builder/main.cc deleted file mode 100644 index 90b9dca2..00000000 --- a/klm/lm/builder/main.cc +++ /dev/null @@ -1,94 +0,0 @@ -#include "lm/builder/pipeline.hh" -#include "util/file.hh" -#include "util/file_piece.hh" -#include "util/usage.hh" - -#include - -#include - -namespace { -class SizeNotify { - public: - SizeNotify(std::size_t &out) : behind_(out) {} - - void operator()(const std::string &from) { - behind_ = util::ParseSize(from); - } - - private: - std::size_t &behind_; -}; - -boost::program_options::typed_value *SizeOption(std::size_t &to, const char *default_value) { - return boost::program_options::value()->notifier(SizeNotify(to))->default_value(default_value); -} - -} // namespace - -int main(int argc, char *argv[]) { - try { - namespace po = boost::program_options; - po::options_description options("Language model building options"); - lm::builder::PipelineConfig pipeline; - - options.add_options() - ("order,o", po::value(&pipeline.order)->required(), "Order of the model") - ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)") - ("temp_prefix,T", po::value(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix") - ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") - ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step") - ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow") - ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)") - ("block_count", po::value(&pipeline.block_count)->default_value(2), "Block count (per order)") - ("vocab_file", po::value(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") - ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc."); - if (argc == 1) { - std::cerr << - "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n" - "Please cite:\n" - "@inproceedings{kenlm,\n" - "author = {Kenneth Heafield},\n" - "title = {{KenLM}: Faster and Smaller Language Model Queries},\n" - "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n" - "month = {July}, year={2011},\n" - "address = {Edinburgh, UK},\n" - "publisher = {Association for Computational Linguistics},\n" - "}\n\n" - "Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n" - "the model (-o) is the only mandatory option. As this is an on-disk program,\n" - "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n" - "Memory sizes are specified like GNU sort: a number followed by a unit character.\n" - "Valid units are \% for percentage of memory (supported platforms only) and (in\n" - "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n"; - std::cerr << options << std::endl; - return 1; - } - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, options), vm); - po::notify(vm); - - util::NormalizeTempPrefix(pipeline.sort.temp_prefix); - - lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; - // TODO: evaluate options for these. - initial.adder_in.total_memory = 32768; - initial.adder_in.block_count = 2; - initial.adder_out.total_memory = 32768; - initial.adder_out.block_count = 2; - pipeline.read_backoffs = initial.adder_out; - - // Read from stdin - try { - lm::builder::Pipeline(pipeline, 0, 1); - } catch (const util::MallocException &e) { - std::cerr << e.what() << std::endl; - std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as() << std::endl; - return 1; - } - util::PrintUsage(std::cerr); - } catch (const std::exception &e) { - std::cerr << e.what() << std::endl; - return 1; - } -} diff --git a/klm/lm/filter/filter_main.cc b/klm/lm/filter/filter_main.cc new file mode 100644 index 00000000..1a4ba84f --- /dev/null +++ b/klm/lm/filter/filter_main.cc @@ -0,0 +1,248 @@ +#include "lm/filter/arpa_io.hh" +#include "lm/filter/format.hh" +#include "lm/filter/phrase.hh" +#ifndef NTHREAD +#include "lm/filter/thread.hh" +#endif +#include "lm/filter/vocab.hh" +#include "lm/filter/wrapper.hh" +#include "util/file_piece.hh" + +#include + +#include +#include +#include +#include + +namespace lm { +namespace { + +void DisplayHelp(const char *name) { + std::cerr + << "Usage: " << name << " mode [context] [phrase] [raw|arpa] [threads:m] [batch_size:m] (vocab|model):input_file output_file\n\n" + "copy mode just copies, but makes the format nicer for e.g. irstlm's broken\n" + " parser.\n" + "single mode treats the entire input as a single sentence.\n" + "multiple mode filters to multiple sentences in parallel. Each sentence is on\n" + " a separate line. A separate file is created for each file by appending the\n" + " 0-indexed line number to the output file name.\n" + "union mode produces one filtered model that is the union of models created by\n" + " multiple mode.\n\n" + "context means only the context (all but last word) has to pass the filter, but\n" + " the entire n-gram is output.\n\n" + "phrase means that the vocabulary is actually tab-delimited phrases and that the\n" + " phrases can generate the n-gram when assembled in arbitrary order and\n" + " clipped. Currently works with multiple or union mode.\n\n" + "The file format is set by [raw|arpa] with default arpa:\n" + "raw means space-separated tokens, optionally followed by a tab and arbitrary\n" + " text. This is useful for ngram count files.\n" + "arpa means the ARPA file format for n-gram language models.\n\n" +#ifndef NTHREAD + "threads:m sets m threads (default: conccurrency detected by boost)\n" + "batch_size:m sets the batch size for threading. Expect memory usage from this\n" + " of 2*threads*batch_size n-grams.\n\n" +#else + "This binary was compiled with -DNTHREAD, disabling threading. If you wanted\n" + " threading, compile without this flag against Boost >=1.42.0.\n\n" +#endif + "There are two inputs: vocabulary and model. Either may be given as a file\n" + " while the other is on stdin. Specify the type given as a file using\n" + " vocab: or model: before the file name. \n\n" + "For ARPA format, the output must be seekable. For raw format, it can be a\n" + " stream i.e. /dev/stdout\n"; +} + +typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION, MODE_UNSET} FilterMode; +typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format; + +struct Config { + Config() : +#ifndef NTHREAD + batch_size(25000), + threads(boost::thread::hardware_concurrency()), +#endif + phrase(false), + context(false), + format(FORMAT_ARPA) + { +#ifndef NTHREAD + if (!threads) threads = 1; +#endif + } + +#ifndef NTHREAD + size_t batch_size; + size_t threads; +#endif + bool phrase; + bool context; + FilterMode mode; + Format format; +}; + +template void RunThreadedFilter(const Config &config, util::FilePiece &in_lm, Filter &filter, Output &output) { +#ifndef NTHREAD + if (config.threads == 1) { +#endif + Format::RunFilter(in_lm, filter, output); +#ifndef NTHREAD + } else { + typedef Controller Threaded; + Threaded threading(config.batch_size, config.threads * 2, config.threads, filter, output); + Format::RunFilter(in_lm, threading, output); + } +#endif +} + +template void RunContextFilter(const Config &config, util::FilePiece &in_lm, Filter filter, Output &output) { + if (config.context) { + ContextFilter context_filter(filter); + RunThreadedFilter, OutputBuffer, Output>(config, in_lm, context_filter, output); + } else { + RunThreadedFilter(config, in_lm, filter, output); + } +} + +template void DispatchBinaryFilter(const Config &config, util::FilePiece &in_lm, const Binary &binary, typename Format::Output &out) { + typedef BinaryFilter Filter; + RunContextFilter(config, in_lm, Filter(binary), out); +} + +template void DispatchFilterModes(const Config &config, std::istream &in_vocab, util::FilePiece &in_lm, const char *out_name) { + if (config.mode == MODE_MULTIPLE) { + if (config.phrase) { + typedef phrase::Multiple Filter; + phrase::Substrings substrings; + typename Format::Multiple out(out_name, phrase::ReadMultiple(in_vocab, substrings)); + RunContextFilter(config, in_lm, Filter(substrings), out); + } else { + typedef vocab::Multiple Filter; + boost::unordered_map > words; + typename Format::Multiple out(out_name, vocab::ReadMultiple(in_vocab, words)); + RunContextFilter(config, in_lm, Filter(words), out); + } + return; + } + + typename Format::Output out(out_name); + + if (config.mode == MODE_COPY) { + Format::Copy(in_lm, out); + return; + } + + if (config.mode == MODE_SINGLE) { + vocab::Single::Words words; + vocab::ReadSingle(in_vocab, words); + DispatchBinaryFilter(config, in_lm, vocab::Single(words), out); + return; + } + + if (config.mode == MODE_UNION) { + if (config.phrase) { + phrase::Substrings substrings; + phrase::ReadMultiple(in_vocab, substrings); + DispatchBinaryFilter(config, in_lm, phrase::Union(substrings), out); + } else { + vocab::Union::Words words; + vocab::ReadMultiple(in_vocab, words); + DispatchBinaryFilter(config, in_lm, vocab::Union(words), out); + } + return; + } +} + +} // namespace +} // namespace lm + +int main(int argc, char *argv[]) { + if (argc < 4) { + lm::DisplayHelp(argv[0]); + return 1; + } + + // I used to have boost::program_options, but some users didn't want to compile boost. + lm::Config config; + config.mode = lm::MODE_UNSET; + for (int i = 1; i < argc - 2; ++i) { + const char *str = argv[i]; + if (!std::strcmp(str, "copy")) { + config.mode = lm::MODE_COPY; + } else if (!std::strcmp(str, "single")) { + config.mode = lm::MODE_SINGLE; + } else if (!std::strcmp(str, "multiple")) { + config.mode = lm::MODE_MULTIPLE; + } else if (!std::strcmp(str, "union")) { + config.mode = lm::MODE_UNION; + } else if (!std::strcmp(str, "phrase")) { + config.phrase = true; + } else if (!std::strcmp(str, "context")) { + config.context = true; + } else if (!std::strcmp(str, "arpa")) { + config.format = lm::FORMAT_ARPA; + } else if (!std::strcmp(str, "raw")) { + config.format = lm::FORMAT_COUNT; +#ifndef NTHREAD + } else if (!std::strncmp(str, "threads:", 8)) { + config.threads = boost::lexical_cast(str + 8); + if (!config.threads) { + std::cerr << "Specify at least one thread." << std::endl; + return 1; + } + } else if (!std::strncmp(str, "batch_size:", 11)) { + config.batch_size = boost::lexical_cast(str + 11); + if (config.batch_size < 5000) { + std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; + if (!config.batch_size) return 1; + } +#endif + } else { + lm::DisplayHelp(argv[0]); + return 1; + } + } + + if (config.mode == lm::MODE_UNSET) { + lm::DisplayHelp(argv[0]); + return 1; + } + + if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { + std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; + return 1; + } + + bool cmd_is_model = true; + const char *cmd_input = argv[argc - 2]; + if (!strncmp(cmd_input, "vocab:", 6)) { + cmd_is_model = false; + cmd_input += 6; + } else if (!strncmp(cmd_input, "model:", 6)) { + cmd_input += 6; + } else if (strchr(cmd_input, ':')) { + errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input); + } else { + std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; + } + std::ifstream cmd_file; + std::istream *vocab; + if (cmd_is_model) { + vocab = &std::cin; + } else { + cmd_file.open(cmd_input, std::ios::in); + if (!cmd_file) { + err(2, "Could not open input file %s", cmd_input); + } + vocab = &cmd_file; + } + + util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); + + if (config.format == lm::FORMAT_ARPA) { + lm::DispatchFilterModes(config, *vocab, model, argv[argc - 1]); + } else if (config.format == lm::FORMAT_COUNT) { + lm::DispatchFilterModes(config, *vocab, model, argv[argc - 1]); + } + return 0; +} diff --git a/klm/lm/filter/main.cc b/klm/lm/filter/main.cc deleted file mode 100644 index c42243e2..00000000 --- a/klm/lm/filter/main.cc +++ /dev/null @@ -1,249 +0,0 @@ -#include "lm/filter/arpa_io.hh" -#include "lm/filter/format.hh" -#include "lm/filter/phrase.hh" -#ifndef NTHREAD -#include "lm/filter/thread.hh" -#endif -#include "lm/filter/vocab.hh" -#include "lm/filter/wrapper.hh" -#include "util/file_piece.hh" - -#include - -#include -#include -#include -#include - -namespace lm { -namespace { - -void DisplayHelp(const char *name) { - std::cerr - << "Usage: " << name << " mode [context] [phrase] [raw|arpa] [threads:m] [batch_size:m] (vocab|model):input_file output_file\n\n" - "copy mode just copies, but makes the format nicer for e.g. irstlm's broken\n" - " parser.\n" - "single mode treats the entire input as a single sentence.\n" - "multiple mode filters to multiple sentences in parallel. Each sentence is on\n" - " a separate line. A separate file is created for each file by appending the\n" - " 0-indexed line number to the output file name.\n" - "union mode produces one filtered model that is the union of models created by\n" - " multiple mode.\n\n" - "context means only the context (all but last word) has to pass the filter, but\n" - " the entire n-gram is output.\n\n" - "phrase means that the vocabulary is actually tab-delimited phrases and that the\n" - " phrases can generate the n-gram when assembled in arbitrary order and\n" - " clipped. Currently works with multiple or union mode.\n\n" - "The file format is set by [raw|arpa] with default arpa:\n" - "raw means space-separated tokens, optionally followed by a tab and arbitrary\n" - " text. This is useful for ngram count files.\n" - "arpa means the ARPA file format for n-gram language models.\n\n" -#ifndef NTHREAD - "threads:m sets m threads (default: conccurrency detected by boost)\n" - "batch_size:m sets the batch size for threading. Expect memory usage from this\n" - " of 2*threads*batch_size n-grams.\n\n" -#else - "This binary was compiled with -DNTHREAD, disabling threading. If you wanted\n" - " threading, compile without this flag against Boost >=1.42.0.\n\n" -#endif - "There are two inputs: vocabulary and model. Either may be given as a file\n" - " while the other is on stdin. Specify the type given as a file using\n" - " vocab: or model: before the file name. \n\n" - "For ARPA format, the output must be seekable. For raw format, it can be a\n" - " stream i.e. /dev/stdout\n"; -} - -typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION} FilterMode; -typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format; - -struct Config { - Config() : -#ifndef NTHREAD - batch_size(25000), - threads(boost::thread::hardware_concurrency()), -#endif - phrase(false), - context(false), - format(FORMAT_ARPA) - { -#ifndef NTHREAD - if (!threads) threads = 1; -#endif - } - -#ifndef NTHREAD - size_t batch_size; - size_t threads; -#endif - bool phrase; - bool context; - FilterMode mode; - Format format; -}; - -template void RunThreadedFilter(const Config &config, util::FilePiece &in_lm, Filter &filter, Output &output) { -#ifndef NTHREAD - if (config.threads == 1) { -#endif - Format::RunFilter(in_lm, filter, output); -#ifndef NTHREAD - } else { - typedef Controller Threaded; - Threaded threading(config.batch_size, config.threads * 2, config.threads, filter, output); - Format::RunFilter(in_lm, threading, output); - } -#endif -} - -template void RunContextFilter(const Config &config, util::FilePiece &in_lm, Filter filter, Output &output) { - if (config.context) { - ContextFilter context_filter(filter); - RunThreadedFilter, OutputBuffer, Output>(config, in_lm, context_filter, output); - } else { - RunThreadedFilter(config, in_lm, filter, output); - } -} - -template void DispatchBinaryFilter(const Config &config, util::FilePiece &in_lm, const Binary &binary, typename Format::Output &out) { - typedef BinaryFilter Filter; - RunContextFilter(config, in_lm, Filter(binary), out); -} - -template void DispatchFilterModes(const Config &config, std::istream &in_vocab, util::FilePiece &in_lm, const char *out_name) { - if (config.mode == MODE_MULTIPLE) { - if (config.phrase) { - typedef phrase::Multiple Filter; - phrase::Substrings substrings; - typename Format::Multiple out(out_name, phrase::ReadMultiple(in_vocab, substrings)); - RunContextFilter(config, in_lm, Filter(substrings), out); - } else { - typedef vocab::Multiple Filter; - boost::unordered_map > words; - typename Format::Multiple out(out_name, vocab::ReadMultiple(in_vocab, words)); - RunContextFilter(config, in_lm, Filter(words), out); - } - return; - } - - typename Format::Output out(out_name); - - if (config.mode == MODE_COPY) { - Format::Copy(in_lm, out); - return; - } - - if (config.mode == MODE_SINGLE) { - vocab::Single::Words words; - vocab::ReadSingle(in_vocab, words); - DispatchBinaryFilter(config, in_lm, vocab::Single(words), out); - return; - } - - if (config.mode == MODE_UNION) { - if (config.phrase) { - phrase::Substrings substrings; - phrase::ReadMultiple(in_vocab, substrings); - DispatchBinaryFilter(config, in_lm, phrase::Union(substrings), out); - } else { - vocab::Union::Words words; - vocab::ReadMultiple(in_vocab, words); - DispatchBinaryFilter(config, in_lm, vocab::Union(words), out); - } - return; - } -} - -} // namespace -} // namespace lm - -int main(int argc, char *argv[]) { - if (argc < 4) { - lm::DisplayHelp(argv[0]); - return 1; - } - - // I used to have boost::program_options, but some users didn't want to compile boost. - lm::Config config; - boost::optional mode; - for (int i = 1; i < argc - 2; ++i) { - const char *str = argv[i]; - if (!std::strcmp(str, "copy")) { - mode = lm::MODE_COPY; - } else if (!std::strcmp(str, "single")) { - mode = lm::MODE_SINGLE; - } else if (!std::strcmp(str, "multiple")) { - mode = lm::MODE_MULTIPLE; - } else if (!std::strcmp(str, "union")) { - mode = lm::MODE_UNION; - } else if (!std::strcmp(str, "phrase")) { - config.phrase = true; - } else if (!std::strcmp(str, "context")) { - config.context = true; - } else if (!std::strcmp(str, "arpa")) { - config.format = lm::FORMAT_ARPA; - } else if (!std::strcmp(str, "raw")) { - config.format = lm::FORMAT_COUNT; -#ifndef NTHREAD - } else if (!std::strncmp(str, "threads:", 8)) { - config.threads = boost::lexical_cast(str + 8); - if (!config.threads) { - std::cerr << "Specify at least one thread." << std::endl; - return 1; - } - } else if (!std::strncmp(str, "batch_size:", 11)) { - config.batch_size = boost::lexical_cast(str + 11); - if (config.batch_size < 5000) { - std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; - if (!config.batch_size) return 1; - } -#endif - } else { - lm::DisplayHelp(argv[0]); - return 1; - } - } - - if (!mode) { - lm::DisplayHelp(argv[0]); - return 1; - } - config.mode = *mode; - - if (config.phrase && config.mode != lm::MODE_UNION && mode != lm::MODE_MULTIPLE) { - std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; - return 1; - } - - bool cmd_is_model = true; - const char *cmd_input = argv[argc - 2]; - if (!strncmp(cmd_input, "vocab:", 6)) { - cmd_is_model = false; - cmd_input += 6; - } else if (!strncmp(cmd_input, "model:", 6)) { - cmd_input += 6; - } else if (strchr(cmd_input, ':')) { - errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input); - } else { - std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; - } - std::ifstream cmd_file; - std::istream *vocab; - if (cmd_is_model) { - vocab = &std::cin; - } else { - cmd_file.open(cmd_input, std::ios::in); - if (!cmd_file) { - err(2, "Could not open input file %s", cmd_input); - } - vocab = &cmd_file; - } - - util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); - - if (config.format == lm::FORMAT_ARPA) { - lm::DispatchFilterModes(config, *vocab, model, argv[argc - 1]); - } else if (config.format == lm::FORMAT_COUNT) { - lm::DispatchFilterModes(config, *vocab, model, argv[argc - 1]); - } - return 0; -} diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh index 07479dea..b4edff41 100644 --- a/klm/lm/filter/phrase.hh +++ b/klm/lm/filter/phrase.hh @@ -57,6 +57,7 @@ class Substrings { LM_FILTER_PHRASE_METHOD(Right, right) LM_FILTER_PHRASE_METHOD(Phrase, phrase) +#pragma GCC diagnostic ignored "-Wuninitialized" // end != finish so there's always an initialization // sentence_id must be non-decreasing. Iterators are over words in the phrase. template void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) { // Iterate over all substrings. diff --git a/klm/lm/filter/vocab.hh b/klm/lm/filter/vocab.hh index e2b6adff..7f0fadaa 100644 --- a/klm/lm/filter/vocab.hh +++ b/klm/lm/filter/vocab.hh @@ -5,6 +5,7 @@ #include "util/multi_intersection.hh" #include "util/string_piece.hh" +#include "util/string_piece_hash.hh" #include "util/tokenize_piece.hh" #include diff --git a/klm/lm/fragment.cc b/klm/lm/fragment.cc deleted file mode 100644 index 0267cd4e..00000000 --- a/klm/lm/fragment.cc +++ /dev/null @@ -1,37 +0,0 @@ -#include "lm/binary_format.hh" -#include "lm/model.hh" -#include "lm/left.hh" -#include "util/tokenize_piece.hh" - -template void Query(const char *name) { - Model model(name); - std::string line; - lm::ngram::ChartState ignored; - while (getline(std::cin, line)) { - lm::ngram::RuleScore scorer(model, ignored); - for (util::TokenIter i(line, ' '); i; ++i) { - scorer.Terminal(model.GetVocabulary().Index(*i)); - } - std::cout << scorer.Finish() << '\n'; - } -} - -int main(int argc, char *argv[]) { - if (argc != 2) { - std::cerr << "Expected model file name." << std::endl; - return 1; - } - const char *name = argv[1]; - lm::ngram::ModelType model_type = lm::ngram::PROBING; - lm::ngram::RecognizeBinary(name, model_type); - switch (model_type) { - case lm::ngram::PROBING: - Query(name); - break; - case lm::ngram::REST_PROBING: - Query(name); - break; - default: - std::cerr << "Model type not supported yet." << std::endl; - } -} diff --git a/klm/lm/fragment_main.cc b/klm/lm/fragment_main.cc new file mode 100644 index 00000000..0267cd4e --- /dev/null +++ b/klm/lm/fragment_main.cc @@ -0,0 +1,37 @@ +#include "lm/binary_format.hh" +#include "lm/model.hh" +#include "lm/left.hh" +#include "util/tokenize_piece.hh" + +template void Query(const char *name) { + Model model(name); + std::string line; + lm::ngram::ChartState ignored; + while (getline(std::cin, line)) { + lm::ngram::RuleScore scorer(model, ignored); + for (util::TokenIter i(line, ' '); i; ++i) { + scorer.Terminal(model.GetVocabulary().Index(*i)); + } + std::cout << scorer.Finish() << '\n'; + } +} + +int main(int argc, char *argv[]) { + if (argc != 2) { + std::cerr << "Expected model file name." << std::endl; + return 1; + } + const char *name = argv[1]; + lm::ngram::ModelType model_type = lm::ngram::PROBING; + lm::ngram::RecognizeBinary(name, model_type); + switch (model_type) { + case lm::ngram::PROBING: + Query(name); + break; + case lm::ngram::REST_PROBING: + Query(name); + break; + default: + std::cerr << "Model type not supported yet." << std::endl; + } +} diff --git a/klm/lm/kenlm_max_order_main.cc b/klm/lm/kenlm_max_order_main.cc new file mode 100644 index 00000000..94221201 --- /dev/null +++ b/klm/lm/kenlm_max_order_main.cc @@ -0,0 +1,6 @@ +#include "lm/max_order.hh" +#include + +int main(int argc, char *argv[]) { + std::cerr << "KenLM was compiled with a maximum supported n-gram order set to " << KENLM_MAX_ORDER << "." << std::endl; +} diff --git a/klm/lm/max_order.cc b/klm/lm/max_order.cc deleted file mode 100644 index 94221201..00000000 --- a/klm/lm/max_order.cc +++ /dev/null @@ -1,6 +0,0 @@ -#include "lm/max_order.hh" -#include - -int main(int argc, char *argv[]) { - std::cerr << "KenLM was compiled with a maximum supported n-gram order set to " << KENLM_MAX_ORDER << "." << std::endl; -} diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc deleted file mode 100644 index 49757d9a..00000000 --- a/klm/lm/ngram_query.cc +++ /dev/null @@ -1,47 +0,0 @@ -#include "lm/ngram_query.hh" - -int main(int argc, char *argv[]) { - if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) { - std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl; - std::cerr << "Input is wrapped in and unless null is passed." << std::endl; - return 1; - } - try { - bool sentence_context = (argc == 2); - using namespace lm::ngram; - ModelType model_type; - if (RecognizeBinary(argv[1], model_type)) { - switch(model_type) { - case PROBING: - Query(argv[1], sentence_context, std::cin, std::cout); - break; - case REST_PROBING: - Query(argv[1], sentence_context, std::cin, std::cout); - break; - case TRIE: - Query(argv[1], sentence_context, std::cin, std::cout); - break; - case QUANT_TRIE: - Query(argv[1], sentence_context, std::cin, std::cout); - break; - case ARRAY_TRIE: - Query(argv[1], sentence_context, std::cin, std::cout); - break; - case QUANT_ARRAY_TRIE: - Query(argv[1], sentence_context, std::cin, std::cout); - break; - default: - std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; - abort(); - } - } else { - Query(argv[1], sentence_context, std::cin, std::cout); - } - std::cerr << "Total time including destruction:\n"; - util::PrintUsage(std::cerr); - } catch (const std::exception &e) { - std::cerr << e.what() << std::endl; - return 1; - } - return 0; -} diff --git a/klm/lm/query_main.cc b/klm/lm/query_main.cc new file mode 100644 index 00000000..49757d9a --- /dev/null +++ b/klm/lm/query_main.cc @@ -0,0 +1,47 @@ +#include "lm/ngram_query.hh" + +int main(int argc, char *argv[]) { + if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) { + std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl; + std::cerr << "Input is wrapped in and unless null is passed." << std::endl; + return 1; + } + try { + bool sentence_context = (argc == 2); + using namespace lm::ngram; + ModelType model_type; + if (RecognizeBinary(argv[1], model_type)) { + switch(model_type) { + case PROBING: + Query(argv[1], sentence_context, std::cin, std::cout); + break; + case REST_PROBING: + Query(argv[1], sentence_context, std::cin, std::cout); + break; + case TRIE: + Query(argv[1], sentence_context, std::cin, std::cout); + break; + case QUANT_TRIE: + Query(argv[1], sentence_context, std::cin, std::cout); + break; + case ARRAY_TRIE: + Query(argv[1], sentence_context, std::cin, std::cout); + break; + case QUANT_ARRAY_TRIE: + Query(argv[1], sentence_context, std::cin, std::cout); + break; + default: + std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; + abort(); + } + } else { + Query(argv[1], sentence_context, std::cin, std::cout); + } + std::cerr << "Total time including destruction:\n"; + util::PrintUsage(std::cerr); + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; + } + return 0; +} diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 248cc844..7f873e96 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -38,6 +38,7 @@ libklm_util_a_SOURCES = \ sized_iterator.hh \ sorted_uniform.hh \ string_piece.hh \ + string_piece_hash.hh \ thread_pool.hh \ tokenize_piece.hh \ usage.hh \ diff --git a/klm/util/double-conversion/strtod.cc b/klm/util/double-conversion/strtod.cc index 9758989f..e298766a 100644 --- a/klm/util/double-conversion/strtod.cc +++ b/klm/util/double-conversion/strtod.cc @@ -506,7 +506,9 @@ float Strtof(Vector buffer, int exponent) { double double_previous = Double(double_guess).PreviousDouble(); float f1 = static_cast(double_previous); +#ifndef NDEBUG float f2 = float_guess; +#endif float f3 = static_cast(double_next); float f4; if (is_correct) { @@ -515,7 +517,9 @@ float Strtof(Vector buffer, int exponent) { double double_next2 = Double(double_next).NextDouble(); f4 = static_cast(double_next2); } +#ifndef NDEBUG ASSERT(f1 <= f2 && f2 <= f3 && f3 <= f4); +#endif // If the guess doesn't lie near a single-precision boundary we can simply // return its float-value. diff --git a/klm/util/file.cc b/klm/util/file.cc index 9a6d2e64..86d9b12d 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -22,6 +22,7 @@ #include #include #include +#include #else #include #endif @@ -99,15 +100,15 @@ uint64_t SizeOrThrow(int fd) { } void ResizeOrThrow(int fd, uint64_t to) { - UTIL_THROW_IF_ARG( #if defined(_WIN32) || defined(_WIN64) - _chsize_s + errno_t ret = _chsize_s #elif defined(OS_ANDROID) - ftruncate64 + int ret = ftruncate64 #else - ftruncate + int ret = ftruncate #endif - (fd, to), FDException, (fd), "while resizing to " << to << " bytes"); + (fd, to); + UTIL_THROW_IF_ARG(ret, FDException, (fd), "while resizing to " << to << " bytes"); } std::size_t PartialRead(int fd, void *to, std::size_t amount) { @@ -150,9 +151,21 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) { void PReadOrThrow(int fd, void *to_void, std::size_t size, uint64_t off) { uint8_t *to = static_cast(to_void); #if defined(_WIN32) || defined(_WIN64) - UTIL_THROW(Exception, "TODO: PReadOrThrow for windows using ReadFile http://stackoverflow.com/questions/766477/are-there-equivalents-to-pread-on-different-platforms"); -#else + UTIL_THROW(Exception, "This pread implementation for windows is broken. Please send me a patch that does not change the file pointer. Atomically. Or send me an implementation of pwrite that is allowed to change the file pointer but can be called concurrently with pread."); + const std::size_t kMaxDWORD = static_cast(4294967295UL); +#endif for (;size ;) { +#if defined(_WIN32) || defined(_WIN64) + /* BROKEN: changes file pointer. Even if you save it and change it back, it won't be safe to use concurrently with write() or read() which lmplz does. */ + // size_t might be 64-bit. DWORD is always 32. + DWORD reading = static_cast(std::min(kMaxDWORD, size)); + DWORD ret; + OVERLAPPED overlapped; + memset(&overlapped, 0, sizeof(OVERLAPPED)); + overlapped.Offset = static_cast(off); + overlapped.OffsetHigh = static_cast(off >> 32); + UTIL_THROW_IF(!ReadFile((HANDLE)_get_osfhandle(fd), to, reading, &ret, &overlapped), Exception, "ReadFile failed for offset " << off); +#else ssize_t ret; errno = 0; do { @@ -166,11 +179,11 @@ void PReadOrThrow(int fd, void *to_void, std::size_t size, uint64_t off) { UTIL_THROW_IF(ret == 0, EndOfFileException, " for reading " << size << " bytes at " << off << " from " << NameFromFD(fd)); UTIL_THROW_ARG(FDException, (fd), "while reading " << size << " bytes at offset " << off); } +#endif size -= ret; off += ret; to += ret; } -#endif } void WriteOrThrow(int fd, const void *data_void, std::size_t size) { @@ -218,15 +231,15 @@ typedef CheckOffT::True IgnoredType; // Can't we all just get along? void InternalSeek(int fd, int64_t off, int whence) { - UTIL_THROW_IF_ARG( + if ( #if defined(_WIN32) || defined(_WIN64) - (__int64)-1 == _lseeki64(fd, off, whence), + (__int64)-1 == _lseeki64(fd, off, whence) #elif defined(OS_ANDROID) - (off64_t)-1 == lseek64(fd, off, whence), + (off64_t)-1 == lseek64(fd, off, whence) #else - (off_t)-1 == lseek(fd, off, whence), + (off_t)-1 == lseek(fd, off, whence) #endif - FDException, (fd), "while seeking to " << off << " whence " << whence); + ) UTIL_THROW_ARG(FDException, (fd), "while seeking to " << off << " whence " << whence); } } // namespace @@ -386,7 +399,13 @@ void NormalizeTempPrefix(std::string &base) { struct stat sb; // It's fine for it to not exist. if (-1 == stat(base.c_str(), &sb)) return; - if (S_ISDIR(sb.st_mode)) base += '/'; + if ( +#if defined(_WIN32) || defined(_WIN64) + sb.st_mode & _S_IFDIR +#else + S_ISDIR(sb.st_mode) +#endif + ) base += '/'; } int MakeTemp(const std::string &base) { diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index fbfa0e0e..4d143857 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -49,6 +49,18 @@ FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, std: Initialize(NamePossiblyFind(fd, name).c_str(), show_progress, min_buffer); } +FilePiece::FilePiece(std::istream &stream, const char *name, std::size_t min_buffer) : + total_size_(kBadSize), page_(SizePage()) { + InitializeNoRead("istream", min_buffer); + + fallback_to_read_ = true; + data_.reset(MallocOrThrow(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED); + position_ = data_.begin(); + position_end_ = position_; + + fell_back_.Reset(stream); +} + FilePiece::~FilePiece() {} StringPiece FilePiece::ReadLine(char delim) { @@ -83,7 +95,8 @@ unsigned long int FilePiece::ReadULong() { return ReadNumber(); } -void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer) { +// Factored out so that istream can call this. +void FilePiece::InitializeNoRead(const char *name, std::size_t min_buffer) { file_name_ = name; default_map_size_ = page_ * std::max((min_buffer / page_ + 1), 2); @@ -91,6 +104,10 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::s position_end_ = NULL; mapped_offset_ = 0; at_end_ = false; +} + +void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer) { + InitializeNoRead(name, min_buffer); if (total_size_ == kBadSize) { // So the assertion passes. @@ -239,8 +256,7 @@ void FilePiece::TransitionToRead() { assert(!fallback_to_read_); fallback_to_read_ = true; data_.reset(); - data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED); - UTIL_THROW_IF(!data_.get(), ErrnoException, "malloc failed for " << default_map_size_); + data_.reset(MallocOrThrow(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED); position_ = data_.begin(); position_end_ = position_; diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index 53310976..c07c6011 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -9,6 +9,7 @@ #include "util/string_piece.hh" #include +#include #include #include @@ -31,6 +32,13 @@ class FilePiece { // Takes ownership of fd. name is used for messages. explicit FilePiece(int fd, const char *name = NULL, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576); + /* Read from an istream. Don't use this if you can avoid it. Raw fd IO is + * much faster. But sometimes you just have an istream like Boost's HTTP + * server and want to parse it the same way. + * name is just used for messages and FileName(). + */ + explicit FilePiece(std::istream &stream, const char *name = NULL, std::size_t min_buffer = 1048576); + ~FilePiece(); char get() { @@ -71,6 +79,8 @@ class FilePiece { const std::string &FileName() const { return file_name_; } private: + void InitializeNoRead(const char *name, std::size_t min_buffer); + // Calls InitializeNoRead, so don't call both. void Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer); template T ReadNumber(); diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc index 91e4c559..7336007d 100644 --- a/klm/util/file_piece_test.cc +++ b/klm/util/file_piece_test.cc @@ -24,6 +24,20 @@ std::string FileLocation() { return ret; } +/* istream */ +BOOST_AUTO_TEST_CASE(IStream) { + std::fstream ref(FileLocation().c_str(), std::ios::in); + std::fstream backing(FileLocation().c_str(), std::ios::in); + FilePiece test(backing); + std::string ref_line; + while (getline(ref, ref_line)) { + StringPiece test_line(test.ReadLine()); + BOOST_CHECK_EQUAL(ref_line, test_line); + } + BOOST_CHECK_THROW(test.get(), EndOfFileException); + BOOST_CHECK_THROW(test.get(), EndOfFileException); +} + /* mmap implementation */ BOOST_AUTO_TEST_CASE(MMapReadLine) { std::fstream ref(FileLocation().c_str(), std::ios::in); diff --git a/klm/util/have.hh b/klm/util/have.hh index e9a4d946..6e18529d 100644 --- a/klm/util/have.hh +++ b/klm/util/have.hh @@ -10,8 +10,4 @@ //#define HAVE_ICU #endif -#ifndef HAVE_BOOST -//#define HAVE_BOOST -#endif - #endif // UTIL_HAVE__ diff --git a/klm/util/read_compressed.cc b/klm/util/read_compressed.cc index 7a1a8fb5..b81549e4 100644 --- a/klm/util/read_compressed.cc +++ b/klm/util/read_compressed.cc @@ -320,6 +320,23 @@ class XZip : public ReadBase { }; #endif // HAVE_XZLIB +class IStreamReader : public ReadBase { + public: + explicit IStreamReader(std::istream &stream) : stream_(stream) {} + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + if (!stream_.read(static_cast(to), amount)) { + UTIL_THROW_IF(!stream_.eof(), ErrnoException, "istream error"); + amount = stream_.gcount(); + } + ReadCount(thunk) += amount; + return amount; + } + + private: + std::istream &stream_; +}; + enum MagicResult { UNKNOWN, GZIP, BZIP, XZIP }; @@ -329,7 +346,7 @@ MagicResult DetectMagic(const void *from_void) { if (header[0] == 0x1f && header[1] == 0x8b) { return GZIP; } - if (header[0] == 'B' && header[1] == 'Z') { + if (header[0] == 'B' && header[1] == 'Z' && header[2] == 'h') { return BZIP; } const uint8_t xzmagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 }; @@ -387,6 +404,10 @@ ReadCompressed::ReadCompressed(int fd) { Reset(fd); } +ReadCompressed::ReadCompressed(std::istream &in) { + Reset(in); +} + ReadCompressed::ReadCompressed() {} ReadCompressed::~ReadCompressed() {} @@ -396,6 +417,11 @@ void ReadCompressed::Reset(int fd) { internal_.reset(ReadFactory(fd, raw_amount_)); } +void ReadCompressed::Reset(std::istream &in) { + internal_.reset(); + internal_.reset(new IStreamReader(in)); +} + std::size_t ReadCompressed::Read(void *to, std::size_t amount) { return internal_->Read(to, amount, *this); } diff --git a/klm/util/read_compressed.hh b/klm/util/read_compressed.hh index 83ca9fb2..8b54c9e8 100644 --- a/klm/util/read_compressed.hh +++ b/klm/util/read_compressed.hh @@ -45,6 +45,10 @@ class ReadCompressed { // Takes ownership of fd. explicit ReadCompressed(int fd); + // Try to avoid using this. Use the fd instead. + // There is no decompression support for istreams. + explicit ReadCompressed(std::istream &in); + // Must call Reset later. ReadCompressed(); @@ -53,6 +57,9 @@ class ReadCompressed { // Takes ownership of fd. void Reset(int fd); + // Same advice as the constructor. + void Reset(std::istream &in); + std::size_t Read(void *to, std::size_t amount); uint64_t RawAmount() const { return raw_amount_; } diff --git a/klm/util/read_compressed_test.cc b/klm/util/read_compressed_test.cc index 6fd97e5e..9cb4a4b9 100644 --- a/klm/util/read_compressed_test.cc +++ b/klm/util/read_compressed_test.cc @@ -25,19 +25,34 @@ void ReadLoop(ReadCompressed &reader, void *to_void, std::size_t amount) { } } -void TestRandom(const char *compressor) { - const uint32_t kSize4 = 100000 / 4; +const uint32_t kSize4 = 100000 / 4; + +std::string WriteRandom() { char name[] = "tempXXXXXX"; + scoped_fd original(mkstemp(name)); + BOOST_REQUIRE(original.get() > 0); + for (uint32_t i = 0; i < kSize4; ++i) { + WriteOrThrow(original.get(), &i, sizeof(uint32_t)); + } + return name; +} - // Write test file. - { - scoped_fd original(mkstemp(name)); - BOOST_REQUIRE(original.get() > 0); - for (uint32_t i = 0; i < kSize4; ++i) { - WriteOrThrow(original.get(), &i, sizeof(uint32_t)); - } +void VerifyRead(ReadCompressed &reader) { + for (uint32_t i = 0; i < kSize4; ++i) { + uint32_t got; + ReadLoop(reader, &got, sizeof(uint32_t)); + BOOST_CHECK_EQUAL(i, got); } + char ignored; + BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1)); + // Test double EOF call. + BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1)); +} + +void TestRandom(const char *compressor) { + std::string name(WriteRandom()); + char gzname[] = "tempXXXXXX"; scoped_fd gzipped(mkstemp(gzname)); @@ -52,20 +67,11 @@ void TestRandom(const char *compressor) { command += "\""; BOOST_REQUIRE_EQUAL(0, system(command.c_str())); - BOOST_CHECK_EQUAL(0, unlink(name)); + BOOST_CHECK_EQUAL(0, unlink(name.c_str())); BOOST_CHECK_EQUAL(0, unlink(gzname)); ReadCompressed reader(gzipped.release()); - for (uint32_t i = 0; i < kSize4; ++i) { - uint32_t got; - ReadLoop(reader, &got, sizeof(uint32_t)); - BOOST_CHECK_EQUAL(i, got); - } - - char ignored; - BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1)); - // Test double EOF call. - BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1)); + VerifyRead(reader); } BOOST_AUTO_TEST_CASE(Uncompressed) { @@ -90,5 +96,14 @@ BOOST_AUTO_TEST_CASE(ReadXZ) { } #endif +BOOST_AUTO_TEST_CASE(IStream) { + std::string name(WriteRandom()); + std::fstream stream(name.c_str(), std::ios::in); + BOOST_CHECK_EQUAL(0, unlink(name.c_str())); + ReadCompressed reader; + reader.Reset(stream); + VerifyRead(reader); +} + } // namespace } // namespace util diff --git a/klm/util/stream/io.cc b/klm/util/stream/io.cc index c7ad2980..0459f706 100644 --- a/klm/util/stream/io.cc +++ b/klm/util/stream/io.cc @@ -29,15 +29,17 @@ void Read::Run(const ChainPosition &position) { void PRead::Run(const ChainPosition &position) { scoped_fd owner; if (own_) owner.reset(file_); - uint64_t size = SizeOrThrow(file_); + const uint64_t size = SizeOrThrow(file_); UTIL_THROW_IF(size % static_cast(position.GetChain().EntrySize()), ReadSizeException, "File size " << file_ << " size is " << size << " not a multiple of " << position.GetChain().EntrySize()); - std::size_t block_size = position.GetChain().BlockSize(); + const std::size_t block_size = position.GetChain().BlockSize(); + const uint64_t block_size64 = static_cast(block_size); Link link(position); uint64_t offset = 0; - for (; offset + block_size < size; offset += block_size, ++link) { + for (; offset + block_size64 < size; offset += block_size64, ++link) { PReadOrThrow(file_, link->Get(), block_size, offset); link->SetValidSize(block_size); } + // size - offset is <= block_size, so it casts to 32-bit fine. if (size - offset) { PReadOrThrow(file_, link->Get(), size - offset, offset); link->SetValidSize(size - offset); diff --git a/klm/util/stream/sort.hh b/klm/util/stream/sort.hh index a86f160f..16aa6a03 100644 --- a/klm/util/stream/sort.hh +++ b/klm/util/stream/sort.hh @@ -365,10 +365,14 @@ template class BlockSorter { // Record the size of each block in a separate file. offsets_->Append(link->ValidSize()); void *end = static_cast(link->Get()) + link->ValidSize(); - std::sort( - SizedIt(link->Get(), entry_size), - SizedIt(end, entry_size), - compare_); +#if defined(_WIN32) || defined(_WIN64) + std::stable_sort +#else + std::sort +#endif + (SizedIt(link->Get(), entry_size), + SizedIt(end, entry_size), + compare_); } offsets_->FinishedAppending(); } diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc index b422cefc..ec394b96 100644 --- a/klm/util/string_piece.cc +++ b/klm/util/string_piece.cc @@ -17,7 +17,8 @@ void StringPiece::CopyToString(std::string* target) const { } size_type StringPiece::find(const StringPiece& s, size_type pos) const { - if (length_ < 0 || pos > static_cast(length_)) + // Not sure why length_ < 0 was here since it's std::size_t. + if (/*length_ < 0 || */pos > static_cast(length_)) return npos; const char* result = std::search(ptr_ + pos, ptr_ + length_, diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh index 51481646..9cf4c7f6 100644 --- a/klm/util/string_piece.hh +++ b/klm/util/string_piece.hh @@ -50,10 +50,6 @@ #include "util/have.hh" -#ifdef HAVE_BOOST -#include -#endif // HAVE_BOOST - #include #include #include @@ -256,46 +252,9 @@ inline std::ostream& operator<<(std::ostream& o, const StringPiece& piece) { return o.write(piece.data(), static_cast(piece.size())); } -#ifdef HAVE_BOOST -inline size_t hash_value(const StringPiece &str) { - return boost::hash_range(str.data(), str.data() + str.length()); -} - -/* Support for lookup of StringPiece in boost::unordered_map */ -struct StringPieceCompatibleHash : public std::unary_function { - size_t operator()(const StringPiece &str) const { - return hash_value(str); - } -}; - -struct StringPieceCompatibleEquals : public std::binary_function { - bool operator()(const StringPiece &first, const StringPiece &second) const { - return first == second; - } -}; -template typename T::const_iterator FindStringPiece(const T &t, const StringPiece &key) { -#if BOOST_VERSION < 104200 - std::string temp(key.data(), key.size()); - return t.find(temp); -#else - return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); -#endif -} - -template typename T::iterator FindStringPiece(T &t, const StringPiece &key) { -#if BOOST_VERSION < 104200 - std::string temp(key.data(), key.size()); - return t.find(temp); -#else - return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); -#endif -} -#endif - #ifdef HAVE_ICU U_NAMESPACE_END using U_NAMESPACE_QUALIFIER StringPiece; #endif - #endif // BASE_STRING_PIECE_H__ diff --git a/klm/util/string_piece_hash.hh b/klm/util/string_piece_hash.hh new file mode 100644 index 00000000..f206b1d8 --- /dev/null +++ b/klm/util/string_piece_hash.hh @@ -0,0 +1,43 @@ +#ifndef UTIL_STRING_PIECE_HASH__ +#define UTIL_STRING_PIECE_HASH__ + +#include "util/string_piece.hh" + +#include +#include + +inline size_t hash_value(const StringPiece &str) { + return boost::hash_range(str.data(), str.data() + str.length()); +} + +/* Support for lookup of StringPiece in boost::unordered_map */ +struct StringPieceCompatibleHash : public std::unary_function { + size_t operator()(const StringPiece &str) const { + return hash_value(str); + } +}; + +struct StringPieceCompatibleEquals : public std::binary_function { + bool operator()(const StringPiece &first, const StringPiece &second) const { + return first == second; + } +}; +template typename T::const_iterator FindStringPiece(const T &t, const StringPiece &key) { +#if BOOST_VERSION < 104200 + std::string temp(key.data(), key.size()); + return t.find(temp); +#else + return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +#endif +} + +template typename T::iterator FindStringPiece(T &t, const StringPiece &key) { +#if BOOST_VERSION < 104200 + std::string temp(key.data(), key.size()); + return t.find(temp); +#else + return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +#endif +} + +#endif // UTIL_STRING_PIECE_HASH__ diff --git a/klm/util/usage.cc b/klm/util/usage.cc index 16a004bb..b8e125d0 100644 --- a/klm/util/usage.cc +++ b/klm/util/usage.cc @@ -81,7 +81,7 @@ template uint64_t ParseNum(const std::string &arg) { UTIL_THROW_IF_ARG(stream >> throwaway, SizeParseError, (arg), "because there was more cruft " << throwaway << " after the number."); // Silly sort, using kilobytes as your default unit. - if (after.empty()) after == "K"; + if (after.empty()) after = "K"; if (after == "%") { uint64_t mem = GuessPhysicalMemory(); UTIL_THROW_IF_ARG(!mem, SizeParseError, (arg), "because % was specified but the physical memory size could not be determined."); -- cgit v1.2.3