diff options
author | Wu, Ke <wuke@cs.umd.edu> | 2014-12-17 16:15:13 -0500 |
---|---|---|
committer | Wu, Ke <wuke@cs.umd.edu> | 2014-12-17 16:15:13 -0500 |
commit | 6829a0bc624b02ebefc79f8cf9ec89d7d64a7c30 (patch) | |
tree | 125dfb20f73342873476c793995397b26fd202dd | |
parent | b455a108a21f4ba5a58ab1bc53a8d2bf4d829067 (diff) | |
parent | 7468e8d85e99b4619442c7afaf4a0d92870111bb (diff) |
Merge branch 'const_reorder_2' into softsyn_2
315 files changed, 5726 insertions, 3598 deletions
@@ -1,3 +1,5 @@ +klm/lm/builder/dump_counts +klm/util/cat_compressed example_extff/ff_example.lo example_extff/libff_example.la mteval/meteor_jar.cc @@ -226,3 +228,4 @@ training/utils/sentserver utils/stringlib_test word-aligner/binderiv word-aligner/fast_align +test-driver diff --git a/THREADS.txt b/THREADS.txt new file mode 100644 index 00000000..4dba2403 --- /dev/null +++ b/THREADS.txt @@ -0,0 +1,5 @@ +The cdec decoder is not, in general, thread safe. There are system components +that make use of multi-threading, but the decoder may not be used from multiple +threads. If you wish to decode in parallel, independent decoder processes +must be run. + diff --git a/configure.ac b/configure.ac index aaad9a78..36cee5af 100644 --- a/configure.ac +++ b/configure.ac @@ -1,5 +1,5 @@ AC_CONFIG_MACRO_DIR([m4]) -AC_INIT([cdec],[2014-09-08]) +AC_INIT([cdec],[2014-10-12]) AC_CONFIG_SRCDIR([decoder/cdec.cc]) AM_INIT_AUTOMAKE AC_CONFIG_HEADERS(config.h) @@ -12,6 +12,7 @@ OLD_CXXFLAGS=$CXXFLAGS AC_PROG_CC AC_PROG_CXX CXXFLAGS=$OLD_CXXFLAGS +AX_PTHREAD AX_CXX_COMPILE_STDCXX_11([],[mandatory]) AC_LANG_CPLUSPLUS AC_OPENMP diff --git a/corpus/conll2cdec.pl b/corpus/conll2cdec.pl new file mode 100755 index 00000000..ee4e07db --- /dev/null +++ b/corpus/conll2cdec.pl @@ -0,0 +1,42 @@ +#!/usr/bin/perl -w +use strict; + +die "Usage: $0 file.conll\n\n Converts a CoNLL formatted labeled sequence into cdec's format.\n\n" unless scalar @ARGV == 1; +open F, "<$ARGV[0]" or die "Can't read $ARGV[0]: $!\n"; + +my @xx; +my @yy; +my @os; +my $sec = undef; +my $i = 0; +while(<F>) { + chomp; + if (/^\s*$/) { + print "<seg id=\"$i\""; + $i++; + for (my $j = 0; $j < $sec; $j++) { + my @oo = (); + for (my $k = 0; $k < scalar @xx; $k++) { + my $sym = $os[$k]->[$j]; + $sym =~ s/"/'/g; + push @oo, $sym; + } + my $zz = $j + 1; + print " feat$zz=\"@oo\""; + } + + print "> @xx ||| @yy </seg>\n"; + @xx = (); + @yy = (); + @os = (); + } else { + my ($x, @fs) = split /\s+/; + my $y = pop @fs; + if (!defined $sec) { $sec = scalar @fs; } + die unless $sec == scalar @fs; + push @xx, $x; + push @yy, $y; + push @os, \@fs; + } +} + diff --git a/corpus/tokenize-anything.sh b/corpus/tokenize-anything.sh index bca954d1..c580e88b 100755 --- a/corpus/tokenize-anything.sh +++ b/corpus/tokenize-anything.sh @@ -7,6 +7,13 @@ if [[ $# == 1 && $1 == '-u' ]] ; then NORMARGS="--batchline" SEDFLAGS="-u" else + if [[ $# != 0 ]] ; then + echo Usage: `basename $0` [-u] \< file.in \> file.out 1>&2 + echo 1>&2 + echo Tokenizes text in a reasonable way in most languages. 1>&2 + echo 1>&2 + exit 1 + fi NORMARGS="" SEDFLAGS="" fi diff --git a/decoder/JSON_parser.c b/decoder/JSON_parser.c deleted file mode 100644 index 5e392bc6..00000000 --- a/decoder/JSON_parser.c +++ /dev/null @@ -1,1012 +0,0 @@ -/* JSON_parser.c */ - -/* 2007-08-24 */ - -/* -Copyright (c) 2005 JSON.org - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -The Software shall be used for Good, not Evil. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -*/ - -/* - Callbacks, comments, Unicode handling by Jean Gressmann (jean@0x42.de), 2007-2009. - - For the added features the license above applies also. - - Changelog: - 2009-05-17 - Incorporated benrudiak@googlemail.com fix for UTF16 decoding. - - 2009-05-14 - Fixed float parsing bug related to a locale being set that didn't - use '.' as decimal point character (charles@transmissionbt.com). - - 2008-10-14 - Renamed states.IN to states.IT to avoid name clash which IN macro - defined in windef.h (alexey.pelykh@gmail.com) - - 2008-07-19 - Removed some duplicate code & debugging variable (charles@transmissionbt.com) - - 2008-05-28 - Made JSON_value structure ansi C compliant. This bug was report by - trisk@acm.jhu.edu - - 2008-05-20 - Fixed bug reported by charles@transmissionbt.com where the switching - from static to dynamic parse buffer did not copy the static parse - buffer's content. -*/ - - - -#include <assert.h> -#include <ctype.h> -#include <float.h> -#include <stddef.h> -#include <stdio.h> -#include <stdlib.h> -#include <string.h> -#include <locale.h> - -#include "JSON_parser.h" - -#ifdef _MSC_VER -# if _MSC_VER >= 1400 /* Visual Studio 2005 and up */ -# pragma warning(disable:4996) // unsecure sscanf -# endif -#endif - - -#define true 1 -#define false 0 -#define __ -1 /* the universal error code */ - -/* values chosen so that the object size is approx equal to one page (4K) */ -#ifndef JSON_PARSER_STACK_SIZE -# define JSON_PARSER_STACK_SIZE 128 -#endif - -#ifndef JSON_PARSER_PARSE_BUFFER_SIZE -# define JSON_PARSER_PARSE_BUFFER_SIZE 3500 -#endif - -typedef unsigned short UTF16; - -struct JSON_parser_struct { - JSON_parser_callback callback; - void* ctx; - signed char state, before_comment_state, type, escaped, comment, allow_comments, handle_floats_manually; - UTF16 utf16_high_surrogate; - long depth; - long top; - signed char* stack; - long stack_capacity; - char decimal_point; - char* parse_buffer; - size_t parse_buffer_capacity; - size_t parse_buffer_count; - size_t comment_begin_offset; - signed char static_stack[JSON_PARSER_STACK_SIZE]; - char static_parse_buffer[JSON_PARSER_PARSE_BUFFER_SIZE]; -}; - -#define COUNTOF(x) (sizeof(x)/sizeof(x[0])) - -/* - Characters are mapped into these character classes. This allows for - a significant reduction in the size of the state transition table. -*/ - - - -enum classes { - C_SPACE, /* space */ - C_WHITE, /* other whitespace */ - C_LCURB, /* { */ - C_RCURB, /* } */ - C_LSQRB, /* [ */ - C_RSQRB, /* ] */ - C_COLON, /* : */ - C_COMMA, /* , */ - C_QUOTE, /* " */ - C_BACKS, /* \ */ - C_SLASH, /* / */ - C_PLUS, /* + */ - C_MINUS, /* - */ - C_POINT, /* . */ - C_ZERO , /* 0 */ - C_DIGIT, /* 123456789 */ - C_LOW_A, /* a */ - C_LOW_B, /* b */ - C_LOW_C, /* c */ - C_LOW_D, /* d */ - C_LOW_E, /* e */ - C_LOW_F, /* f */ - C_LOW_L, /* l */ - C_LOW_N, /* n */ - C_LOW_R, /* r */ - C_LOW_S, /* s */ - C_LOW_T, /* t */ - C_LOW_U, /* u */ - C_ABCDF, /* ABCDF */ - C_E, /* E */ - C_ETC, /* everything else */ - C_STAR, /* * */ - NR_CLASSES -}; - -static int ascii_class[128] = { -/* - This array maps the 128 ASCII characters into character classes. - The remaining Unicode characters should be mapped to C_ETC. - Non-whitespace control characters are errors. -*/ - __, __, __, __, __, __, __, __, - __, C_WHITE, C_WHITE, __, __, C_WHITE, __, __, - __, __, __, __, __, __, __, __, - __, __, __, __, __, __, __, __, - - C_SPACE, C_ETC, C_QUOTE, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, - C_ETC, C_ETC, C_STAR, C_PLUS, C_COMMA, C_MINUS, C_POINT, C_SLASH, - C_ZERO, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, - C_DIGIT, C_DIGIT, C_COLON, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, - - C_ETC, C_ABCDF, C_ABCDF, C_ABCDF, C_ABCDF, C_E, C_ABCDF, C_ETC, - C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, - C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, - C_ETC, C_ETC, C_ETC, C_LSQRB, C_BACKS, C_RSQRB, C_ETC, C_ETC, - - C_ETC, C_LOW_A, C_LOW_B, C_LOW_C, C_LOW_D, C_LOW_E, C_LOW_F, C_ETC, - C_ETC, C_ETC, C_ETC, C_ETC, C_LOW_L, C_ETC, C_LOW_N, C_ETC, - C_ETC, C_ETC, C_LOW_R, C_LOW_S, C_LOW_T, C_LOW_U, C_ETC, C_ETC, - C_ETC, C_ETC, C_ETC, C_LCURB, C_ETC, C_RCURB, C_ETC, C_ETC -}; - - -/* - The state codes. -*/ -enum states { - GO, /* start */ - OK, /* ok */ - OB, /* object */ - KE, /* key */ - CO, /* colon */ - VA, /* value */ - AR, /* array */ - ST, /* string */ - ES, /* escape */ - U1, /* u1 */ - U2, /* u2 */ - U3, /* u3 */ - U4, /* u4 */ - MI, /* minus */ - ZE, /* zero */ - IT, /* integer */ - FR, /* fraction */ - E1, /* e */ - E2, /* ex */ - E3, /* exp */ - T1, /* tr */ - T2, /* tru */ - T3, /* true */ - F1, /* fa */ - F2, /* fal */ - F3, /* fals */ - F4, /* false */ - N1, /* nu */ - N2, /* nul */ - N3, /* null */ - C1, /* / */ - C2, /* / * */ - C3, /* * */ - FX, /* *.* *eE* */ - D1, /* second UTF-16 character decoding started by \ */ - D2, /* second UTF-16 character proceeded by u */ - NR_STATES -}; - -enum actions -{ - CB = -10, /* comment begin */ - CE = -11, /* comment end */ - FA = -12, /* false */ - TR = -13, /* false */ - NU = -14, /* null */ - DE = -15, /* double detected by exponent e E */ - DF = -16, /* double detected by fraction . */ - SB = -17, /* string begin */ - MX = -18, /* integer detected by minus */ - ZX = -19, /* integer detected by zero */ - IX = -20, /* integer detected by 1-9 */ - EX = -21, /* next char is escaped */ - UC = -22 /* Unicode character read */ -}; - - -static int state_transition_table[NR_STATES][NR_CLASSES] = { -/* - The state transition table takes the current state and the current symbol, - and returns either a new state or an action. An action is represented as a - negative number. A JSON text is accepted if at the end of the text the - state is OK and if the mode is MODE_DONE. - - white 1-9 ABCDF etc - space | { } [ ] : , " \ / + - . 0 | a b c d e f l n r s t u | E | * */ -/*start GO*/ {GO,GO,-6,__,-5,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*ok OK*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*object OB*/ {OB,OB,__,-9,__,__,__,__,SB,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*key KE*/ {KE,KE,__,__,__,__,__,__,SB,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*colon CO*/ {CO,CO,__,__,__,__,-2,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*value VA*/ {VA,VA,-6,__,-5,__,__,__,SB,__,CB,__,MX,__,ZX,IX,__,__,__,__,__,FA,__,NU,__,__,TR,__,__,__,__,__}, -/*array AR*/ {AR,AR,-6,__,-5,-7,__,__,SB,__,CB,__,MX,__,ZX,IX,__,__,__,__,__,FA,__,NU,__,__,TR,__,__,__,__,__}, -/*string ST*/ {ST,__,ST,ST,ST,ST,ST,ST,-4,EX,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST}, -/*escape ES*/ {__,__,__,__,__,__,__,__,ST,ST,ST,__,__,__,__,__,__,ST,__,__,__,ST,__,ST,ST,__,ST,U1,__,__,__,__}, -/*u1 U1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U2,U2,U2,U2,U2,U2,U2,U2,__,__,__,__,__,__,U2,U2,__,__}, -/*u2 U2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U3,U3,U3,U3,U3,U3,U3,U3,__,__,__,__,__,__,U3,U3,__,__}, -/*u3 U3*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U4,U4,U4,U4,U4,U4,U4,U4,__,__,__,__,__,__,U4,U4,__,__}, -/*u4 U4*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,UC,UC,UC,UC,UC,UC,UC,UC,__,__,__,__,__,__,UC,UC,__,__}, -/*minus MI*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,ZE,IT,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*zero ZE*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,DF,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*int IT*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,DF,IT,IT,__,__,__,__,DE,__,__,__,__,__,__,__,__,DE,__,__}, -/*frac FR*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,__,FR,FR,__,__,__,__,E1,__,__,__,__,__,__,__,__,E1,__,__}, -/*e E1*/ {__,__,__,__,__,__,__,__,__,__,__,E2,E2,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*ex E2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*exp E3*/ {OK,OK,__,-8,__,-7,__,-3,__,__,__,__,__,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*tr T1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,T2,__,__,__,__,__,__,__}, -/*tru T2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,T3,__,__,__,__}, -/*true T3*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__,__,__}, -/*fa F1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F2,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*fal F2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F3,__,__,__,__,__,__,__,__,__}, -/*fals F3*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F4,__,__,__,__,__,__}, -/*false F4*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__,__,__}, -/*nu N1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,N2,__,__,__,__}, -/*nul N2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,N3,__,__,__,__,__,__,__,__,__}, -/*null N3*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__}, -/*/ C1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,C2}, -/*/* C2*/ {C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C3}, -/** C3*/ {C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,CE,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C3}, -/*_. FX*/ {OK,OK,__,-8,__,-7,__,-3,__,__,__,__,__,__,FR,FR,__,__,__,__,E1,__,__,__,__,__,__,__,__,E1,__,__}, -/*\ D1*/ {__,__,__,__,__,__,__,__,__,D2,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*\ D2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,U1,__,__,__,__}, -}; - - -/* - These modes can be pushed on the stack. -*/ -enum modes { - MODE_ARRAY = 1, - MODE_DONE = 2, - MODE_KEY = 3, - MODE_OBJECT = 4 -}; - -static int -push(JSON_parser jc, int mode) -{ -/* - Push a mode onto the stack. Return false if there is overflow. -*/ - jc->top += 1; - if (jc->depth < 0) { - if (jc->top >= jc->stack_capacity) { - size_t bytes_to_allocate; - jc->stack_capacity *= 2; - bytes_to_allocate = jc->stack_capacity * sizeof(jc->static_stack[0]); - if (jc->stack == &jc->static_stack[0]) { - jc->stack = (signed char*)malloc(bytes_to_allocate); - memcpy(jc->stack, jc->static_stack, sizeof(jc->static_stack)); - } else { - jc->stack = (signed char*)realloc(jc->stack, bytes_to_allocate); - } - } - } else { - if (jc->top >= jc->depth) { - return false; - } - } - - jc->stack[jc->top] = mode; - return true; -} - - -static int -pop(JSON_parser jc, int mode) -{ -/* - Pop the stack, assuring that the current mode matches the expectation. - Return false if there is underflow or if the modes mismatch. -*/ - if (jc->top < 0 || jc->stack[jc->top] != mode) { - return false; - } - jc->top -= 1; - return true; -} - - -#define parse_buffer_clear(jc) \ - do {\ - jc->parse_buffer_count = 0;\ - jc->parse_buffer[0] = 0;\ - } while (0) - -#define parse_buffer_pop_back_char(jc)\ - do {\ - assert(jc->parse_buffer_count >= 1);\ - --jc->parse_buffer_count;\ - jc->parse_buffer[jc->parse_buffer_count] = 0;\ - } while (0) - -void delete_JSON_parser(JSON_parser jc) -{ - if (jc) { - if (jc->stack != &jc->static_stack[0]) { - free((void*)jc->stack); - } - if (jc->parse_buffer != &jc->static_parse_buffer[0]) { - free((void*)jc->parse_buffer); - } - free((void*)jc); - } -} - - -JSON_parser -new_JSON_parser(JSON_config* config) -{ -/* - new_JSON_parser starts the checking process by constructing a JSON_parser - object. It takes a depth parameter that restricts the level of maximum - nesting. - - To continue the process, call JSON_parser_char for each character in the - JSON text, and then call JSON_parser_done to obtain the final result. - These functions are fully reentrant. -*/ - - int depth = 0; - JSON_config default_config; - - JSON_parser jc = (JSON_parser)malloc(sizeof(struct JSON_parser_struct)); - - memset(jc, 0, sizeof(*jc)); - - - /* initialize configuration */ - init_JSON_config(&default_config); - - /* set to default configuration if none was provided */ - if (config == NULL) { - config = &default_config; - } - - depth = config->depth; - - /* We need to be able to push at least one object */ - if (depth == 0) { - depth = 1; - } - - jc->state = GO; - jc->top = -1; - - /* Do we want non-bound stack? */ - if (depth > 0) { - jc->stack_capacity = depth; - jc->depth = depth; - if (depth <= (int)COUNTOF(jc->static_stack)) { - jc->stack = &jc->static_stack[0]; - } else { - jc->stack = (signed char*)malloc(jc->stack_capacity * sizeof(jc->static_stack[0])); - } - } else { - jc->stack_capacity = COUNTOF(jc->static_stack); - jc->depth = -1; - jc->stack = &jc->static_stack[0]; - } - - /* set parser to start */ - push(jc, MODE_DONE); - - /* set up the parse buffer */ - jc->parse_buffer = &jc->static_parse_buffer[0]; - jc->parse_buffer_capacity = COUNTOF(jc->static_parse_buffer); - parse_buffer_clear(jc); - - /* set up callback, comment & float handling */ - jc->callback = config->callback; - jc->ctx = config->callback_ctx; - jc->allow_comments = config->allow_comments != 0; - jc->handle_floats_manually = config->handle_floats_manually != 0; - - /* set up decimal point */ - jc->decimal_point = *localeconv()->decimal_point; - - return jc; -} - -static void grow_parse_buffer(JSON_parser jc) -{ - size_t bytes_to_allocate; - jc->parse_buffer_capacity *= 2; - bytes_to_allocate = jc->parse_buffer_capacity * sizeof(jc->parse_buffer[0]); - if (jc->parse_buffer == &jc->static_parse_buffer[0]) { - jc->parse_buffer = (char*)malloc(bytes_to_allocate); - memcpy(jc->parse_buffer, jc->static_parse_buffer, jc->parse_buffer_count); - } else { - jc->parse_buffer = (char*)realloc(jc->parse_buffer, bytes_to_allocate); - } -} - -#define parse_buffer_push_back_char(jc, c)\ - do {\ - if (jc->parse_buffer_count + 1 >= jc->parse_buffer_capacity) grow_parse_buffer(jc);\ - jc->parse_buffer[jc->parse_buffer_count++] = c;\ - jc->parse_buffer[jc->parse_buffer_count] = 0;\ - } while (0) - -#define assert_is_non_container_type(jc) \ - assert( \ - jc->type == JSON_T_NULL || \ - jc->type == JSON_T_FALSE || \ - jc->type == JSON_T_TRUE || \ - jc->type == JSON_T_FLOAT || \ - jc->type == JSON_T_INTEGER || \ - jc->type == JSON_T_STRING) - - -static int parse_parse_buffer(JSON_parser jc) -{ - if (jc->callback) { - JSON_value value, *arg = NULL; - - if (jc->type != JSON_T_NONE) { - assert_is_non_container_type(jc); - - switch(jc->type) { - case JSON_T_FLOAT: - arg = &value; - if (jc->handle_floats_manually) { - value.vu.str.value = jc->parse_buffer; - value.vu.str.length = jc->parse_buffer_count; - } else { - /*sscanf(jc->parse_buffer, "%Lf", &value.vu.float_value);*/ - - /* not checking with end pointer b/c there may be trailing ws */ - value.vu.float_value = strtod(jc->parse_buffer, NULL); - } - break; - case JSON_T_INTEGER: - arg = &value; - sscanf(jc->parse_buffer, JSON_PARSER_INTEGER_SSCANF_TOKEN, &value.vu.integer_value); - break; - case JSON_T_STRING: - arg = &value; - value.vu.str.value = jc->parse_buffer; - value.vu.str.length = jc->parse_buffer_count; - break; - } - - if (!(*jc->callback)(jc->ctx, jc->type, arg)) { - return false; - } - } - } - - parse_buffer_clear(jc); - - return true; -} - -#define IS_HIGH_SURROGATE(uc) (((uc) & 0xFC00) == 0xD800) -#define IS_LOW_SURROGATE(uc) (((uc) & 0xFC00) == 0xDC00) -#define DECODE_SURROGATE_PAIR(hi,lo) ((((hi) & 0x3FF) << 10) + ((lo) & 0x3FF) + 0x10000) -static unsigned char utf8_lead_bits[4] = { 0x00, 0xC0, 0xE0, 0xF0 }; - -static int decode_unicode_char(JSON_parser jc) -{ - int i; - unsigned uc = 0; - char* p; - int trail_bytes; - - assert(jc->parse_buffer_count >= 6); - - p = &jc->parse_buffer[jc->parse_buffer_count - 4]; - - for (i = 12; i >= 0; i -= 4, ++p) { - unsigned x = *p; - - if (x >= 'a') { - x -= ('a' - 10); - } else if (x >= 'A') { - x -= ('A' - 10); - } else { - x &= ~0x30u; - } - - assert(x < 16); - - uc |= x << i; - } - - /* clear UTF-16 char from buffer */ - jc->parse_buffer_count -= 6; - jc->parse_buffer[jc->parse_buffer_count] = 0; - - /* attempt decoding ... */ - if (jc->utf16_high_surrogate) { - if (IS_LOW_SURROGATE(uc)) { - uc = DECODE_SURROGATE_PAIR(jc->utf16_high_surrogate, uc); - trail_bytes = 3; - jc->utf16_high_surrogate = 0; - } else { - /* high surrogate without a following low surrogate */ - return false; - } - } else { - if (uc < 0x80) { - trail_bytes = 0; - } else if (uc < 0x800) { - trail_bytes = 1; - } else if (IS_HIGH_SURROGATE(uc)) { - /* save the high surrogate and wait for the low surrogate */ - jc->utf16_high_surrogate = uc; - return true; - } else if (IS_LOW_SURROGATE(uc)) { - /* low surrogate without a preceding high surrogate */ - return false; - } else { - trail_bytes = 2; - } - } - - jc->parse_buffer[jc->parse_buffer_count++] = (char) ((uc >> (trail_bytes * 6)) | utf8_lead_bits[trail_bytes]); - - for (i = trail_bytes * 6 - 6; i >= 0; i -= 6) { - jc->parse_buffer[jc->parse_buffer_count++] = (char) (((uc >> i) & 0x3F) | 0x80); - } - - jc->parse_buffer[jc->parse_buffer_count] = 0; - - return true; -} - -static int add_escaped_char_to_parse_buffer(JSON_parser jc, int next_char) -{ - jc->escaped = 0; - /* remove the backslash */ - parse_buffer_pop_back_char(jc); - switch(next_char) { - case 'b': - parse_buffer_push_back_char(jc, '\b'); - break; - case 'f': - parse_buffer_push_back_char(jc, '\f'); - break; - case 'n': - parse_buffer_push_back_char(jc, '\n'); - break; - case 'r': - parse_buffer_push_back_char(jc, '\r'); - break; - case 't': - parse_buffer_push_back_char(jc, '\t'); - break; - case '"': - parse_buffer_push_back_char(jc, '"'); - break; - case '\\': - parse_buffer_push_back_char(jc, '\\'); - break; - case '/': - parse_buffer_push_back_char(jc, '/'); - break; - case 'u': - parse_buffer_push_back_char(jc, '\\'); - parse_buffer_push_back_char(jc, 'u'); - break; - default: - return false; - } - - return true; -} - -#define add_char_to_parse_buffer(jc, next_char, next_class) \ - do { \ - if (jc->escaped) { \ - if (!add_escaped_char_to_parse_buffer(jc, next_char)) \ - return false; \ - } else if (!jc->comment) { \ - if ((jc->type != JSON_T_NONE) | !((next_class == C_SPACE) | (next_class == C_WHITE)) /* non-white-space */) { \ - parse_buffer_push_back_char(jc, (char)next_char); \ - } \ - } \ - } while (0) - - -#define assert_type_isnt_string_null_or_bool(jc) \ - assert(jc->type != JSON_T_FALSE); \ - assert(jc->type != JSON_T_TRUE); \ - assert(jc->type != JSON_T_NULL); \ - assert(jc->type != JSON_T_STRING) - - -int -JSON_parser_char(JSON_parser jc, int next_char) -{ -/* - After calling new_JSON_parser, call this function for each character (or - partial character) in your JSON text. It can accept UTF-8, UTF-16, or - UTF-32. It returns true if things are looking ok so far. If it rejects the - text, it returns false. -*/ - int next_class, next_state; - -/* - Determine the character's class. -*/ - if (next_char < 0) { - return false; - } - if (next_char >= 128) { - next_class = C_ETC; - } else { - next_class = ascii_class[next_char]; - if (next_class <= __) { - return false; - } - } - - add_char_to_parse_buffer(jc, next_char, next_class); - -/* - Get the next state from the state transition table. -*/ - next_state = state_transition_table[jc->state][next_class]; - if (next_state >= 0) { -/* - Change the state. -*/ - jc->state = next_state; - } else { -/* - Or perform one of the actions. -*/ - switch (next_state) { -/* Unicode character */ - case UC: - if(!decode_unicode_char(jc)) { - return false; - } - /* check if we need to read a second UTF-16 char */ - if (jc->utf16_high_surrogate) { - jc->state = D1; - } else { - jc->state = ST; - } - break; -/* escaped char */ - case EX: - jc->escaped = 1; - jc->state = ES; - break; -/* integer detected by minus */ - case MX: - jc->type = JSON_T_INTEGER; - jc->state = MI; - break; -/* integer detected by zero */ - case ZX: - jc->type = JSON_T_INTEGER; - jc->state = ZE; - break; -/* integer detected by 1-9 */ - case IX: - jc->type = JSON_T_INTEGER; - jc->state = IT; - break; - -/* floating point number detected by exponent*/ - case DE: - assert_type_isnt_string_null_or_bool(jc); - jc->type = JSON_T_FLOAT; - jc->state = E1; - break; - -/* floating point number detected by fraction */ - case DF: - assert_type_isnt_string_null_or_bool(jc); - if (!jc->handle_floats_manually) { -/* - Some versions of strtod (which underlies sscanf) don't support converting - C-locale formated floating point values. -*/ - assert(jc->parse_buffer[jc->parse_buffer_count-1] == '.'); - jc->parse_buffer[jc->parse_buffer_count-1] = jc->decimal_point; - } - jc->type = JSON_T_FLOAT; - jc->state = FX; - break; -/* string begin " */ - case SB: - parse_buffer_clear(jc); - assert(jc->type == JSON_T_NONE); - jc->type = JSON_T_STRING; - jc->state = ST; - break; - -/* n */ - case NU: - assert(jc->type == JSON_T_NONE); - jc->type = JSON_T_NULL; - jc->state = N1; - break; -/* f */ - case FA: - assert(jc->type == JSON_T_NONE); - jc->type = JSON_T_FALSE; - jc->state = F1; - break; -/* t */ - case TR: - assert(jc->type == JSON_T_NONE); - jc->type = JSON_T_TRUE; - jc->state = T1; - break; - -/* closing comment */ - case CE: - jc->comment = 0; - assert(jc->parse_buffer_count == 0); - assert(jc->type == JSON_T_NONE); - jc->state = jc->before_comment_state; - break; - -/* opening comment */ - case CB: - if (!jc->allow_comments) { - return false; - } - parse_buffer_pop_back_char(jc); - if (!parse_parse_buffer(jc)) { - return false; - } - assert(jc->parse_buffer_count == 0); - assert(jc->type != JSON_T_STRING); - switch (jc->stack[jc->top]) { - case MODE_ARRAY: - case MODE_OBJECT: - switch(jc->state) { - case VA: - case AR: - jc->before_comment_state = jc->state; - break; - default: - jc->before_comment_state = OK; - break; - } - break; - default: - jc->before_comment_state = jc->state; - break; - } - jc->type = JSON_T_NONE; - jc->state = C1; - jc->comment = 1; - break; -/* empty } */ - case -9: - parse_buffer_clear(jc); - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_END, NULL)) { - return false; - } - if (!pop(jc, MODE_KEY)) { - return false; - } - jc->state = OK; - break; - -/* } */ case -8: - parse_buffer_pop_back_char(jc); - if (!parse_parse_buffer(jc)) { - return false; - } - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_END, NULL)) { - return false; - } - if (!pop(jc, MODE_OBJECT)) { - return false; - } - jc->type = JSON_T_NONE; - jc->state = OK; - break; - -/* ] */ case -7: - parse_buffer_pop_back_char(jc); - if (!parse_parse_buffer(jc)) { - return false; - } - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_ARRAY_END, NULL)) { - return false; - } - if (!pop(jc, MODE_ARRAY)) { - return false; - } - - jc->type = JSON_T_NONE; - jc->state = OK; - break; - -/* { */ case -6: - parse_buffer_pop_back_char(jc); - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_BEGIN, NULL)) { - return false; - } - if (!push(jc, MODE_KEY)) { - return false; - } - assert(jc->type == JSON_T_NONE); - jc->state = OB; - break; - -/* [ */ case -5: - parse_buffer_pop_back_char(jc); - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_ARRAY_BEGIN, NULL)) { - return false; - } - if (!push(jc, MODE_ARRAY)) { - return false; - } - assert(jc->type == JSON_T_NONE); - jc->state = AR; - break; - -/* string end " */ case -4: - parse_buffer_pop_back_char(jc); - switch (jc->stack[jc->top]) { - case MODE_KEY: - assert(jc->type == JSON_T_STRING); - jc->type = JSON_T_NONE; - jc->state = CO; - - if (jc->callback) { - JSON_value value; - value.vu.str.value = jc->parse_buffer; - value.vu.str.length = jc->parse_buffer_count; - if (!(*jc->callback)(jc->ctx, JSON_T_KEY, &value)) { - return false; - } - } - parse_buffer_clear(jc); - break; - case MODE_ARRAY: - case MODE_OBJECT: - assert(jc->type == JSON_T_STRING); - if (!parse_parse_buffer(jc)) { - return false; - } - jc->type = JSON_T_NONE; - jc->state = OK; - break; - default: - return false; - } - break; - -/* , */ case -3: - parse_buffer_pop_back_char(jc); - if (!parse_parse_buffer(jc)) { - return false; - } - switch (jc->stack[jc->top]) { - case MODE_OBJECT: -/* - A comma causes a flip from object mode to key mode. -*/ - if (!pop(jc, MODE_OBJECT) || !push(jc, MODE_KEY)) { - return false; - } - assert(jc->type != JSON_T_STRING); - jc->type = JSON_T_NONE; - jc->state = KE; - break; - case MODE_ARRAY: - assert(jc->type != JSON_T_STRING); - jc->type = JSON_T_NONE; - jc->state = VA; - break; - default: - return false; - } - break; - -/* : */ case -2: -/* - A colon causes a flip from key mode to object mode. -*/ - parse_buffer_pop_back_char(jc); - if (!pop(jc, MODE_KEY) || !push(jc, MODE_OBJECT)) { - return false; - } - assert(jc->type == JSON_T_NONE); - jc->state = VA; - break; -/* - Bad action. -*/ - default: - return false; - } - } - return true; -} - - -int -JSON_parser_done(JSON_parser jc) -{ - const int result = jc->state == OK && pop(jc, MODE_DONE); - - return result; -} - - -int JSON_parser_is_legal_white_space_string(const char* s) -{ - int c, char_class; - - if (s == NULL) { - return false; - } - - for (; *s; ++s) { - c = *s; - - if (c < 0 || c >= 128) { - return false; - } - - char_class = ascii_class[c]; - - if (char_class != C_SPACE && char_class != C_WHITE) { - return false; - } - } - - return true; -} - - - -void init_JSON_config(JSON_config* config) -{ - if (config) { - memset(config, 0, sizeof(*config)); - - config->depth = JSON_PARSER_STACK_SIZE - 1; - } -} diff --git a/decoder/JSON_parser.h b/decoder/JSON_parser.h deleted file mode 100644 index de980072..00000000 --- a/decoder/JSON_parser.h +++ /dev/null @@ -1,152 +0,0 @@ -#ifndef JSON_PARSER_H -#define JSON_PARSER_H - -/* JSON_parser.h */ - - -#include <stddef.h> - -/* Windows DLL stuff */ -#ifdef _WIN32 -# ifdef JSON_PARSER_DLL_EXPORTS -# define JSON_PARSER_DLL_API __declspec(dllexport) -# else -# define JSON_PARSER_DLL_API __declspec(dllimport) -# endif -#else -# define JSON_PARSER_DLL_API -#endif - -/* Determine the integer type use to parse non-floating point numbers */ -#if __STDC_VERSION__ >= 199901L || HAVE_LONG_LONG == 1 -typedef long long JSON_int_t; -#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%lld" -#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%lld" -#else -typedef long JSON_int_t; -#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%ld" -#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%ld" -#endif - - -#ifdef __cplusplus -extern "C" { -#endif - -typedef enum -{ - JSON_T_NONE = 0, - JSON_T_ARRAY_BEGIN, // 1 - JSON_T_ARRAY_END, // 2 - JSON_T_OBJECT_BEGIN, // 3 - JSON_T_OBJECT_END, // 4 - JSON_T_INTEGER, // 5 - JSON_T_FLOAT, // 6 - JSON_T_NULL, // 7 - JSON_T_TRUE, // 8 - JSON_T_FALSE, // 9 - JSON_T_STRING, // 10 - JSON_T_KEY, // 11 - JSON_T_MAX // 12 -} JSON_type; - -typedef struct JSON_value_struct { - union { - JSON_int_t integer_value; - - double float_value; - - struct { - const char* value; - size_t length; - } str; - } vu; -} JSON_value; - -typedef struct JSON_parser_struct* JSON_parser; - -/*! \brief JSON parser callback - - \param ctx The pointer passed to new_JSON_parser. - \param type An element of JSON_type but not JSON_T_NONE. - \param value A representation of the parsed value. This parameter is NULL for - JSON_T_ARRAY_BEGIN, JSON_T_ARRAY_END, JSON_T_OBJECT_BEGIN, JSON_T_OBJECT_END, - JSON_T_NULL, JSON_T_TRUE, and SON_T_FALSE. String values are always returned - as zero-terminated C strings. - - \return Non-zero if parsing should continue, else zero. -*/ -typedef int (*JSON_parser_callback)(void* ctx, int type, const struct JSON_value_struct* value); - - -/*! \brief The structure used to configure a JSON parser object - - \param depth If negative, the parser can parse arbitrary levels of JSON, otherwise - the depth is the limit - \param Pointer to a callback. This parameter may be NULL. In this case the input is merely checked for validity. - \param Callback context. This parameter may be NULL. - \param depth. Specifies the levels of nested JSON to allow. Negative numbers yield unlimited nesting. - \param allowComments. To allow C style comments in JSON, set to non-zero. - \param handleFloatsManually. To decode floating point numbers manually set this parameter to non-zero. - - \return The parser object. -*/ -typedef struct { - JSON_parser_callback callback; - void* callback_ctx; - int depth; - int allow_comments; - int handle_floats_manually; -} JSON_config; - - -/*! \brief Initializes the JSON parser configuration structure to default values. - - The default configuration is - - 127 levels of nested JSON (depends on JSON_PARSER_STACK_SIZE, see json_parser.c) - - no parsing, just checking for JSON syntax - - no comments - - \param config. Used to configure the parser. -*/ -JSON_PARSER_DLL_API void init_JSON_config(JSON_config* config); - -/*! \brief Create a JSON parser object - - \param config. Used to configure the parser. Set to NULL to use the default configuration. - See init_JSON_config - - \return The parser object. -*/ -JSON_PARSER_DLL_API extern JSON_parser new_JSON_parser(JSON_config* config); - -/*! \brief Destroy a previously created JSON parser object. */ -JSON_PARSER_DLL_API extern void delete_JSON_parser(JSON_parser jc); - -/*! \brief Parse a character. - - \return Non-zero, if all characters passed to this function are part of are valid JSON. -*/ -JSON_PARSER_DLL_API extern int JSON_parser_char(JSON_parser jc, int next_char); - -/*! \brief Finalize parsing. - - Call this method once after all input characters have been consumed. - - \return Non-zero, if all parsed characters are valid JSON, zero otherwise. -*/ -JSON_PARSER_DLL_API extern int JSON_parser_done(JSON_parser jc); - -/*! \brief Determine if a given string is valid JSON white space - - \return Non-zero if the string is valid, zero otherwise. -*/ -JSON_PARSER_DLL_API extern int JSON_parser_is_legal_white_space_string(const char* s); - - -#ifdef __cplusplus -} -#endif - - -#endif /* JSON_PARSER_H */ diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 727e5af5..dbec532e 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -33,10 +33,10 @@ noinst_LIBRARIES = libcdec.a EXTRA_DIST = test_data rule_lexer.ll libcdec_a_SOURCES = \ - JSON_parser.h \ aligner.h \ apply_models.h \ bottom_up_parser.h \ + bottom_up_parser-rs.h \ csplit.h \ decoder.h \ earley_composer.h \ @@ -45,6 +45,7 @@ libcdec_a_SOURCES = \ ff_basic.h \ ff_bleu.h \ ff_charset.h \ + ff_conll.h \ ff_const_reorder_common.h \ ff_const_reorder.h \ ff_context.h \ @@ -52,7 +53,7 @@ libcdec_a_SOURCES = \ ff_external.h \ ff_factory.h \ ff_klm.h \ - ff_lexical.h \ + ff_lexical.h \ ff_lm.h \ ff_ngrams.h \ ff_parse_match.h \ @@ -83,7 +84,6 @@ libcdec_a_SOURCES = \ hg_union.h \ incremental.h \ inside_outside.h \ - json_parse.h \ kbest.h \ lattice.h \ lexalign.h \ @@ -103,6 +103,7 @@ libcdec_a_SOURCES = \ aligner.cc \ apply_models.cc \ bottom_up_parser.cc \ + bottom_up_parser-rs.cc \ cdec.cc \ cdec_ff.cc \ csplit.cc \ @@ -113,6 +114,7 @@ libcdec_a_SOURCES = \ ff_basic.cc \ ff_bleu.cc \ ff_charset.cc \ + ff_conll.cc \ ff_context.cc \ ff_const_reorder.cc \ ff_csplit.cc \ @@ -146,7 +148,6 @@ libcdec_a_SOURCES = \ hg_sampler.cc \ hg_union.cc \ incremental.cc \ - json_parse.cc \ lattice.cc \ lexalign.cc \ lextrans.cc \ @@ -162,5 +163,4 @@ libcdec_a_SOURCES = \ tagger.cc \ translator.cc \ trule.cc \ - viterbi.cc \ - JSON_parser.c + viterbi.cc diff --git a/decoder/aligner.cc b/decoder/aligner.cc index 232e022a..fd648370 100644 --- a/decoder/aligner.cc +++ b/decoder/aligner.cc @@ -198,13 +198,13 @@ void AlignerTools::WriteAlignment(const Lattice& src_lattice, } const Hypergraph* g = &in_g; HypergraphP new_hg; - if (!src_lattice.IsSentence() || - !trg_lattice.IsSentence()) { + if (!IsSentence(src_lattice) || + !IsSentence(trg_lattice)) { if (map_instead_of_viterbi) { cerr << " Lattice alignment: using Viterbi instead of MAP alignment\n"; } map_instead_of_viterbi = false; - fix_up_src_spans = !src_lattice.IsSentence(); + fix_up_src_spans = !IsSentence(src_lattice); } KBest::KBestDerivations<vector<Hypergraph::Edge const*>, ViterbiPathTraversal> kbest(in_g, k_best); diff --git a/decoder/aligner.h b/decoder/aligner.h index a34795c9..d68ceefc 100644 --- a/decoder/aligner.h +++ b/decoder/aligner.h @@ -1,4 +1,4 @@ -#ifndef _ALIGNER_H_ +#ifndef ALIGNER_H #include <string> #include <iostream> diff --git a/decoder/apply_models.h b/decoder/apply_models.h index 19a4c7be..f03c973a 100644 --- a/decoder/apply_models.h +++ b/decoder/apply_models.h @@ -1,5 +1,5 @@ -#ifndef _APPLY_MODELS_H_ -#define _APPLY_MODELS_H_ +#ifndef APPLY_MODELS_H_ +#define APPLY_MODELS_H_ #include <iostream> diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc new file mode 100644 index 00000000..fbde7e24 --- /dev/null +++ b/decoder/bottom_up_parser-rs.cc @@ -0,0 +1,341 @@ +#include "bottom_up_parser-rs.h" + +#include <iostream> +#include <map> + +#include "node_state_hash.h" +#include "nt_span.h" +#include "hg.h" +#include "array2d.h" +#include "tdict.h" +#include "verbose.h" + +using namespace std; + +static WordID kEPS = 0; + +struct RSActiveItem; +class RSChart { + public: + RSChart(const string& goal, + const vector<GrammarPtr>& grammars, + const Lattice& input, + Hypergraph* forest); + ~RSChart(); + + void AddToChart(const RSActiveItem& x, int i, int j); + void ConsumeTerminal(const RSActiveItem& x, int i, int j, int k); + void ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k); + bool Parse(); + inline bool GoalFound() const { return goal_idx_ >= 0; } + inline int GetGoalIndex() const { return goal_idx_; } + + private: + void ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const float lattice_cost); + + // returns true if a new node was added to the chart + // false otherwise + bool ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const float lattice_cost); + + void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx); + void TopoSortUnaries(); + + const vector<GrammarPtr>& grammars_; + const Lattice& input_; + Hypergraph* forest_; + Array2D<vector<int>> chart_; // chart_(i,j) is the list of nodes (represented + // by their index in forest_->nodes_) derived spanning i,j + typedef map<int, int> Cat2NodeMap; + Array2D<Cat2NodeMap> nodemap_; + const WordID goal_cat_; // category that is being searched for at [0,n] + TRulePtr goal_rule_; + int goal_idx_; // index of goal node, if found + const int lc_fid_; + vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars + + static WordID kGOAL; // [Goal] +}; + +WordID RSChart::kGOAL = 0; + +// "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span" +struct RSActiveItem { + explicit RSActiveItem(const GrammarIter* g, int i) : + gptr_(g), ant_nodes_(), lattice_cost(0.0), i_(i) {} + void ExtendTerminal(int symbol, float src_cost) { + lattice_cost += src_cost; + if (symbol != kEPS) + gptr_ = gptr_->Extend(symbol); + } + void ExtendNonTerminal(const Hypergraph* hg, int node_index) { + gptr_ = gptr_->Extend(hg->nodes_[node_index].cat_); + ant_nodes_.push_back(node_index); + } + // returns false if the extension has failed + explicit operator bool() const { + return gptr_; + } + const GrammarIter* gptr_; + Hypergraph::TailNodeVector ant_nodes_; + float lattice_cost; // TODO: use SparseVector<double> to encode input features + short i_; +}; + +// some notes on the implementation +// "X" in Rico's Algorithm 2 roughly looks like it is just a pointer into a grammar +// trie, but it is actually a full "dotted item" since it needs to contain the information +// to build the hypergraph (i.e., it must remember the antecedent nodes and where they are, +// also any information about the path costs). + +RSChart::RSChart(const string& goal, + const vector<GrammarPtr>& grammars, + const Lattice& input, + Hypergraph* forest) : + grammars_(grammars), + input_(input), + forest_(forest), + chart_(input.size()+1, input.size()+1), + nodemap_(input.size()+1, input.size()+1), + goal_cat_(TD::Convert(goal) * -1), + goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), + goal_idx_(-1), + lc_fid_(FD::Convert("LatticeCost")), + unaries_() { + for (unsigned i = 0; i < grammars_.size(); ++i) { + const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules(); + for (unsigned j = 0; j < u.size(); ++j) + unaries_.push_back(u[j]); + } + TopoSortUnaries(); + if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; + if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl; +} + +static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) { + if (mark[node] == 1) { + cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n"; + return false; // cycle detected + } else if (mark[node] == 2) { + return true; // already been + } + mark[node] = 1; + const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node); + if (nit != g.end()) { + const vector<TRulePtr>& edges = nit->second; + vector<bool> okay(edges.size(), true); + for (unsigned i = 0; i < edges.size(); ++i) { + okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark); + if (!okay[i]) { + cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl; + } + } + for (unsigned i = 0; i < edges.size(); ++i) { + if (okay[i]) u.push_back(edges[i]); + //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl; + } + } + mark[node] = 2; + return true; +} + +void RSChart::TopoSortUnaries() { + vector<TRulePtr> u(unaries_.size()); u.clear(); + map<int, vector<TRulePtr> > g; + map<int, int> mark; + //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl; + mark[goal_cat_] = 2; + for (unsigned i = 0; i < unaries_.size(); ++i) { + //cerr << "Adding: " << unaries_[i]->AsString() << endl; + g[unaries_[i]->f()[0]].push_back(unaries_[i]); + } + //m[unaries_[i]->lhs_].push_back(unaries_[i]); + for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) { + //cerr << "PROC: " << TD::Convert(-it->first) << endl; + if (mark[it->first] > 0) { + //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n"; + } else { + TopoSortVisit(it->first, u, g, mark); + } + } + unaries_.clear(); + for (int i = u.size() - 1; i >= 0; --i) + unaries_.push_back(u[i]); +} + +bool RSChart::ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const float lattice_cost) { + Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); + //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; + new_edge->prev_i_ = r->prev_i; + new_edge->prev_j_ = r->prev_j; + new_edge->i_ = i; + new_edge->j_ = j; + new_edge->feature_values_ = r->GetFeatureValues(); + if (lattice_cost && lc_fid_) + new_edge->feature_values_.set_value(lc_fid_, lattice_cost); + Cat2NodeMap& c2n = nodemap_(i,j); + const bool is_goal = (r->GetLHS() == kGOAL); + const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); + Hypergraph::Node* node = NULL; + bool added_node = false; + if (ni == c2n.end()) { + //cerr << "(" << i << "," << j << ") => " << TD::Convert(-r->GetLHS()) << endl; + added_node = true; + node = forest_->AddNode(r->GetLHS()); + c2n[r->GetLHS()] = node->id_; + if (is_goal) { + assert(goal_idx_ == -1); + goal_idx_ = node->id_; + } else { + chart_(i,j).push_back(node->id_); + } + } else { + node = &forest_->nodes_[ni->second]; + } + forest_->ConnectEdgeToHeadNode(new_edge, node); + return added_node; +} + +void RSChart::ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const float lattice_cost) { + const int n = rules->GetNumRules(); + //cerr << i << " " << j << ": NUM RULES: " << n << endl; + for (int k = 0; k < n; ++k) { + //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl; + TRulePtr rule = rules->GetIthRule(k); + // apply rule, and if we create a new node, apply any necessary + // unary rules + if (ApplyRule(i, j, rule, tail, lattice_cost)) { + unsigned nodeidx = nodemap_(i,j)[rule->lhs_]; + ApplyUnaryRules(i, j, rule->lhs_, nodeidx); + } + } +} + +void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx) { + for (unsigned ri = 0; ri < unaries_.size(); ++ri) { + //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; + if (unaries_[ri]->f()[0] == cat) { + //cerr << " --MATCH\n"; + WordID new_lhs = unaries_[ri]->GetLHS(); + const Hypergraph::TailNodeVector ant(1, nodeidx); + if (ApplyRule(i, j, unaries_[ri], ant, 0)) { + //cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl; + unsigned nodeidx = nodemap_(i,j)[new_lhs]; + ApplyUnaryRules(i, j, new_lhs, nodeidx); + } + } + } +} + +void RSChart::AddToChart(const RSActiveItem& x, int i, int j) { + // deal with completed rules + const RuleBin* rb = x.gptr_->GetRules(); + if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_cost); + + //cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n"; + // continue looking for extensions of the rule to the right + for (unsigned k = j+1; k <= input_.size(); ++k) { + ConsumeTerminal(x, i, j, k); + ConsumeNonTerminal(x, i, j, k); + } +} + +void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) { + //cerr << "ConsumeT(" << i << "," << j << "," << k << "):\n"; + + const unsigned check_edge_len = k - j; + // long-term TODO preindex this search so i->len->words is constant time rather than fan out + for (auto& in_edge : input_[j]) { + if (in_edge.dist2next == check_edge_len) { + //cerr << " Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl; + RSActiveItem copy = x; + copy.ExtendTerminal(in_edge.label, in_edge.cost); + if (copy) AddToChart(copy, i, k); + } + } +} + +void RSChart::ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k) { + //cerr << "ConsumeNT(" << i << "," << j << "," << k << "):\n"; + for (auto& nodeidx : chart_(j,k)) { + //cerr << " Found completed NT in (" << j << "," << k << ") of type " << TD::Convert(-forest_->nodes_[nodeidx].cat_) << endl; + RSActiveItem copy = x; + copy.ExtendNonTerminal(forest_, nodeidx); + if (copy) AddToChart(copy, i, k); + } +} + +bool RSChart::Parse() { + size_t in_size_2 = input_.size() * input_.size(); + forest_->nodes_.reserve(in_size_2 * 2); + size_t res = min(static_cast<size_t>(2000000), static_cast<size_t>(in_size_2 * 1000)); + forest_->edges_.reserve(res); + goal_idx_ = -1; + const int N = input_.size(); + for (int i = N - 1; i >= 0; --i) { + for (int j = i + 1; j <= N; ++j) { + for (unsigned gi = 0; gi < grammars_.size(); ++gi) { + RSActiveItem item(grammars_[gi]->GetRoot(), i); + ConsumeTerminal(item, i, i, j); + } + for (unsigned gi = 0; gi < grammars_.size(); ++gi) { + RSActiveItem item(grammars_[gi]->GetRoot(), i); + ConsumeNonTerminal(item, i, i, j); + } + } + } + + // look for goal + const vector<int>& dh = chart_(0, input_.size()); + for (unsigned di = 0; di < dh.size(); ++di) { + const Hypergraph::Node& node = forest_->nodes_[dh[di]]; + if (node.cat_ == goal_cat_) { + Hypergraph::TailNodeVector ant(1, node.id_); + ApplyRule(0, input_.size(), goal_rule_, ant, 0); + } + } + if (!SILENT) cerr << endl; + + if (GoalFound()) + forest_->PruneUnreachable(forest_->nodes_.size() - 1); + return GoalFound(); +} + +RSChart::~RSChart() {} + +RSExhaustiveBottomUpParser::RSExhaustiveBottomUpParser( + const string& goal_sym, + const vector<GrammarPtr>& grammars) : + goal_sym_(goal_sym), + grammars_(grammars) {} + +bool RSExhaustiveBottomUpParser::Parse(const Lattice& input, + Hypergraph* forest) const { + kEPS = TD::Convert("*EPS*"); + RSChart chart(goal_sym_, grammars_, input, forest); + const bool result = chart.Parse(); + + if (result) { + for (auto& node : forest->nodes_) { + Span prev; + const Span s = forest->NodeSpan(node.id_, &prev); + node.node_hash = cdec::HashNode(node.cat_, s.l, s.r, prev.l, prev.r); + } + } + return result; +} diff --git a/decoder/bottom_up_parser-rs.h b/decoder/bottom_up_parser-rs.h new file mode 100644 index 00000000..2e271e99 --- /dev/null +++ b/decoder/bottom_up_parser-rs.h @@ -0,0 +1,29 @@ +#ifndef RSBOTTOM_UP_PARSER_H_ +#define RSBOTTOM_UP_PARSER_H_ + +#include <vector> +#include <string> + +#include "lattice.h" +#include "grammar.h" + +class Hypergraph; + +// implementation of Sennrich (2014) parser +// http://aclweb.org/anthology/W/W14/W14-4011.pdf +class RSExhaustiveBottomUpParser { + public: + RSExhaustiveBottomUpParser(const std::string& goal_sym, + const std::vector<GrammarPtr>& grammars); + + // returns true if goal reached spanning the full input + // forest contains the full (i.e., unpruned) parse forest + bool Parse(const Lattice& input, + Hypergraph* forest) const; + + private: + const std::string goal_sym_; + const std::vector<GrammarPtr> grammars_; +}; + +#endif diff --git a/decoder/bottom_up_parser.h b/decoder/bottom_up_parser.h index 546bfb54..628bb96d 100644 --- a/decoder/bottom_up_parser.h +++ b/decoder/bottom_up_parser.h @@ -1,5 +1,5 @@ -#ifndef _BOTTOM_UP_PARSER_H_ -#define _BOTTOM_UP_PARSER_H_ +#ifndef BOTTOM_UP_PARSER_H_ +#define BOTTOM_UP_PARSER_H_ #include <vector> #include <string> diff --git a/decoder/csplit.cc b/decoder/csplit.cc index 4a723822..7ee4092e 100644 --- a/decoder/csplit.cc +++ b/decoder/csplit.cc @@ -151,6 +151,7 @@ bool CompoundSplit::TranslateImpl(const string& input, smeta->SetSourceLength(in.size()); // TODO do utf8 or somethign for (int i = 0; i < in.size(); ++i) smeta->src_lattice_.push_back(vector<LatticeArc>(1, LatticeArc(TD::Convert(in[i]), 0.0, 1))); + smeta->ComputeInputLatticeType(); pimpl_->BuildTrellis(in, forest); forest->Reweight(weights); return true; diff --git a/decoder/csplit.h b/decoder/csplit.h index 82ed23fc..83d457b8 100644 --- a/decoder/csplit.h +++ b/decoder/csplit.h @@ -1,5 +1,5 @@ -#ifndef _CSPLIT_H_ -#define _CSPLIT_H_ +#ifndef CSPLIT_H_ +#define CSPLIT_H_ #include "translator.h" #include "lattice.h" diff --git a/decoder/decoder.cc b/decoder/decoder.cc index c384c33f..1e6c3194 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -17,6 +17,7 @@ namespace std { using std::tr1::unordered_map; } #include "fdict.h" #include "timing_stats.h" #include "verbose.h" +#include "b64featvector.h" #include "translator.h" #include "phrasebased_translator.h" @@ -86,7 +87,7 @@ struct ELengthWeightFunction { } }; inline void ShowBanner() { - cerr << "cdec (c) 2009--2014 by Chris Dyer\n"; + cerr << "cdec (c) 2009--2014 by Chris Dyer" << endl; } inline string str(char const* name,po::variables_map const& conf) { @@ -195,7 +196,7 @@ struct DecoderImpl { } forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1); if (!forestname.empty()) forestname=" "+forestname; - if (!SILENT) { + if (!SILENT) { forest_stats(forest," Pruned "+forestname+" forest",false,false); cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl; } @@ -261,7 +262,7 @@ struct DecoderImpl { assert(ref); LatticeTools::ConvertTextOrPLF(sref, ref); } - } + } // used to construct the suffix string to get the name of arguments for multiple passes // e.g., the "2" in --weights2 @@ -284,7 +285,7 @@ struct DecoderImpl { boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng; int sample_max_trans; bool aligner_mode; - bool graphviz; + bool graphviz; bool joshua_viz; bool encode_b64; bool kbest; @@ -301,6 +302,7 @@ struct DecoderImpl { bool feature_expectations; // TODO Observer bool output_training_vector; // TODO Observer bool remove_intersected_rule_annotations; + bool mr_mira_compat; // Mr.MIRA compatibility mode. boost::scoped_ptr<IncrementalBase> incremental; @@ -404,6 +406,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") ("extract_rules", po::value<string>(), "Extract the rules used in translation (not de-duped!) to a file in this directory") ("show_derivations", po::value<string>(), "Directory to print the derivation structures to") + ("show_derivations_mask", po::value<int>()->default_value(Hypergraph::SPAN|Hypergraph::RULE), "Bit-mask for what to print in derivation structures") ("graphviz","Show (constrained) translation forest in GraphViz format") ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart") ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart") @@ -414,7 +417,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") ("forest_output,O",po::value<string>(),"Directory to write forests to") - ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)"); + ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)") + ("mr_mira_compat", "Mr.MIRA compatibility mode (applies weight delta if available; outputs number of lines before k-best)"); // ob.AddOptions(&opts); po::options_description clo("Command line options"); @@ -665,7 +669,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream unique_kbest = conf.count("unique_k_best"); get_oracle_forest = conf.count("get_oracle_forest"); oracle.show_derivation=conf.count("show_derivations"); + oracle.show_derivation_mask=conf["show_derivations_mask"].as<int>(); remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations"); + mr_mira_compat = conf.count("mr_mira_compat"); combine_size = conf["combine_size"].as<int>(); if (combine_size < 1) combine_size = 1; @@ -699,6 +705,24 @@ void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string); } +static inline void ApplyWeightDelta(const string &delta_b64, vector<weight_t> *weights) { + SparseVector<weight_t> delta; + DecodeFeatureVector(delta_b64, &delta); + if (delta.empty()) return; + // Apply updates + for (SparseVector<weight_t>::iterator dit = delta.begin(); + dit != delta.end(); ++dit) { + int feat_id = dit->first; + union { weight_t weight; unsigned long long repr; } feat_delta; + feat_delta.weight = dit->second; + if (!SILENT) + cerr << "[decoder weight update] " << FD::Convert(feat_id) << " " << feat_delta.weight + << " = " << hex << feat_delta.repr << endl; + if (weights->size() <= feat_id) weights->resize(feat_id + 1); + (*weights)[feat_id] += feat_delta.weight; + } +} + bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { string buf = input; NgramCache::Clear(); // clear ngram cache for remote LM (if used) @@ -709,6 +733,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (sgml.find("id") != sgml.end()) sent_id = atoi(sgml["id"].c_str()); + // Add delta from input to weights before decoding + if (mr_mira_compat) + ApplyWeightDelta(sgml["delta"], init_weights.get()); + if (!SILENT) { cerr << "\nINPUT: "; if (buf.size() < 100) @@ -928,14 +956,14 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { Hypergraph new_hg; { ReadFile rf(writer.fname_); - bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); + bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg); if (!succeeded) abort(); } HG::Union(forest, &new_hg); - bool succeeded = writer.Write(new_hg, false); + bool succeeded = writer.Write(new_hg); if (!succeeded) abort(); } else { - bool succeeded = writer.Write(forest, false); + bool succeeded = writer.Write(forest); if (!succeeded) abort(); } } @@ -947,7 +975,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (kbest && !has_ref) { //TODO: does this work properly? const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; - oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); + oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname); } else if (csplit_output_plf) { cout << HypergraphIO::AsPLF(forest, false) << endl; } else { @@ -1021,14 +1049,14 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { Hypergraph new_hg; { ReadFile rf(writer.fname_); - bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); + bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg); if (!succeeded) abort(); } HG::Union(forest, &new_hg); - bool succeeded = writer.Write(new_hg, false); + bool succeeded = writer.Write(new_hg); if (!succeeded) abort(); } else { - bool succeeded = writer.Write(forest, false); + bool succeeded = writer.Write(forest); if (!succeeded) abort(); } } @@ -1078,7 +1106,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (conf.count("graphviz")) forest.PrintGraphviz(); if (kbest) { const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; - oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); + oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest, mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname); } if (conf.count("show_conditional_prob")) { const prob_t ref_z = Inside<prob_t, EdgeProb>(forest); @@ -1098,4 +1126,3 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { o->NotifyDecodingComplete(smeta); return true; } - diff --git a/decoder/decoder.h b/decoder/decoder.h index 8039a42b..a545206b 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -1,5 +1,5 @@ -#ifndef _DECODER_H_ -#define _DECODER_H_ +#ifndef DECODER_H_ +#define DECODER_H_ #include <iostream> #include <string> diff --git a/decoder/earley_composer.h b/decoder/earley_composer.h index 9f786bf6..31602f67 100644 --- a/decoder/earley_composer.h +++ b/decoder/earley_composer.h @@ -1,5 +1,5 @@ -#ifndef _EARLEY_COMPOSER_H_ -#define _EARLEY_COMPOSER_H_ +#ifndef EARLEY_COMPOSER_H_ +#define EARLEY_COMPOSER_H_ #include <iostream> diff --git a/decoder/factored_lexicon_helper.h b/decoder/factored_lexicon_helper.h index 7fedc517..460bdebb 100644 --- a/decoder/factored_lexicon_helper.h +++ b/decoder/factored_lexicon_helper.h @@ -1,5 +1,5 @@ -#ifndef _FACTORED_LEXICON_HELPER_ -#define _FACTORED_LEXICON_HELPER_ +#ifndef FACTORED_LEXICON_HELPER_ +#define FACTORED_LEXICON_HELPER_ #include <cassert> #include <vector> diff --git a/decoder/ff.h b/decoder/ff.h index afa3dbca..d6487d97 100644 --- a/decoder/ff.h +++ b/decoder/ff.h @@ -1,5 +1,5 @@ -#ifndef _FF_H_ -#define _FF_H_ +#ifndef FF_H_ +#define FF_H_ #include <string> #include <vector> @@ -27,6 +27,12 @@ class FeatureFunction { // search. When non-zero, the last N bytes in the state should be ignored when // splitting a hypernode by the state. This allows the feature function to // store some side data and later retrieve it via the state bytes. + // + // In general, this should not be necessary and it should always be possible + // to replace this with a more appropriate design of state (if you find + // yourself having to ignore some part of the state, you are most likely + // storing redundant information in the state). Be sure that you + // understand how this affects ApplyModelSet() before using it. int IgnoredStateSize() const { return ignored_state_size_; } // override this. not virtual because we want to expose this to factory template for help before creating a FF @@ -82,6 +88,7 @@ class FeatureFunction { state_size_ = state_size; } + // See document of IgnoredStateSize() above. void SetIgnoredStateSize(size_t ignored_state_size) { ignored_state_size_ = ignored_state_size; } diff --git a/decoder/ff_basic.cc b/decoder/ff_basic.cc index f9404d24..f960418a 100644 --- a/decoder/ff_basic.cc +++ b/decoder/ff_basic.cc @@ -49,9 +49,7 @@ void SourceWordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, features->set_value(fid_, edge.rule_->FWords() * value_); } - -ArityPenalty::ArityPenalty(const std::string& param) : - value_(-1.0 / log(10)) { +ArityPenalty::ArityPenalty(const std::string& param) { string fname = "Arity_"; unsigned MAX=DEFAULT_MAX_ARITY; using namespace boost; @@ -61,7 +59,8 @@ ArityPenalty::ArityPenalty(const std::string& param) : WordID fid=FD::Convert(fname+lexical_cast<string>(i)); fids_.push_back(fid); } - while (!fids_.empty() && fids_.back()==0) fids_.pop_back(); // pretty up features vector in case FD was frozen. doesn't change anything + // pretty up features vector in case FD was frozen. doesn't change anything + while (!fids_.empty() && fids_.back()==0) fids_.pop_back(); } void ArityPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, @@ -75,6 +74,6 @@ void ArityPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, (void) state; (void) estimated_features; unsigned a=edge.Arity(); - features->set_value(a<fids_.size()?fids_[a]:0, value_); + if (a < fids_.size()) features->set_value(fids_[a], 1.0); } diff --git a/decoder/ff_basic.h b/decoder/ff_basic.h index 901c0110..c63daf0f 100644 --- a/decoder/ff_basic.h +++ b/decoder/ff_basic.h @@ -1,5 +1,5 @@ -#ifndef _FF_BASIC_H_ -#define _FF_BASIC_H_ +#ifndef FF_BASIC_H_ +#define FF_BASIC_H_ #include "ff.h" @@ -41,7 +41,7 @@ class SourceWordPenalty : public FeatureFunction { const double value_; }; -#define DEFAULT_MAX_ARITY 9 +#define DEFAULT_MAX_ARITY 50 #define DEFAULT_MAX_ARITY_STRINGIZE(x) #x #define DEFAULT_MAX_ARITY_STRINGIZE_EVAL(x) DEFAULT_MAX_ARITY_STRINGIZE(x) #define DEFAULT_MAX_ARITY_STR DEFAULT_MAX_ARITY_STRINGIZE_EVAL(DEFAULT_MAX_ARITY) @@ -62,7 +62,6 @@ class ArityPenalty : public FeatureFunction { void* context) const; private: std::vector<WordID> fids_; - const double value_; }; #endif diff --git a/decoder/ff_bleu.h b/decoder/ff_bleu.h index 344dc788..8ca2c095 100644 --- a/decoder/ff_bleu.h +++ b/decoder/ff_bleu.h @@ -1,5 +1,5 @@ -#ifndef _BLEU_FF_H_ -#define _BLEU_FF_H_ +#ifndef BLEU_FF_H_ +#define BLEU_FF_H_ #include <vector> #include <string> diff --git a/decoder/ff_charset.h b/decoder/ff_charset.h index 267ef65d..e22ece2b 100644 --- a/decoder/ff_charset.h +++ b/decoder/ff_charset.h @@ -1,5 +1,5 @@ -#ifndef _FFCHARSET_H_ -#define _FFCHARSET_H_ +#ifndef FFCHARSET_H_ +#define FFCHARSET_H_ #include <string> #include <map> diff --git a/decoder/ff_conll.cc b/decoder/ff_conll.cc new file mode 100644 index 00000000..8ded44b7 --- /dev/null +++ b/decoder/ff_conll.cc @@ -0,0 +1,250 @@ +#include "ff_conll.h" + +#include <stdlib.h> +#include <sstream> +#include <cassert> +#include <cmath> +#include <boost/lexical_cast.hpp> + +#include "hg.h" +#include "filelib.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "lattice.h" +#include "fdict.h" +#include "verbose.h" +#include "tdict.h" + +CoNLLFeatures::CoNLLFeatures(const string& param) { + // cerr << "initializing CoNLLFeatures with parameters: " << param; + kSOS = TD::Convert("<s>"); + kEOS = TD::Convert("</s>"); + macro_regex = sregex::compile("%([xy])\\[(-[1-9][0-9]*|0|[1-9][1-9]*)]"); + ParseArgs(param); +} + +string CoNLLFeatures::Escape(const string& x) const { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; +} + +// replace %x[relative_location] or %y[relative_location] with actual_token +// within feature_instance +void CoNLLFeatures::ReplaceMacroWithString( + string& feature_instance, bool token_vs_label, int relative_location, + const string& actual_token) const { + + stringstream macro; + if (token_vs_label) { + macro << "%x["; + } else { + macro << "%y["; + } + macro << relative_location << "]"; + int macro_index = feature_instance.find(macro.str()); + if (macro_index == string::npos) { + cerr << "Can't find macro " << macro.str() << " in feature template " + << feature_instance; + abort(); + } + feature_instance.replace(macro_index, macro.str().size(), actual_token); +} + +void CoNLLFeatures::ReplaceTokenMacroWithString( + string& feature_instance, int relative_location, + const string& actual_token) const { + + ReplaceMacroWithString(feature_instance, true, relative_location, + actual_token); +} + +void CoNLLFeatures::ReplaceLabelMacroWithString( + string& feature_instance, int relative_location, + const string& actual_token) const { + + ReplaceMacroWithString(feature_instance, false, relative_location, + actual_token); +} + +void CoNLLFeatures::Error(const string& error_message) const { + cerr << "Error: " << error_message << "\n\n" + + << "CoNLLFeatures Usage: \n" + << " feature_function=CoNLLFeatures -t <TEMPLATE>\n\n" + + << "Example <TEMPLATE>: U1:%x[-1]_%x[0]|%y[0]\n\n" + + << "%x[k] and %y[k] are macros to be instantiated with an input\n" + << "token (for x) or a label (for y). k specifies the relative\n" + << "location of the input token or label with respect to the current\n" + << "position. For x, k is an integer value. For y, k must be 0 (to\n" + << "be extended).\n\n"; + + abort(); +} + +void CoNLLFeatures::ParseArgs(const string& in) { + which_feat = 0; + vector<string> const& argv = SplitOnWhitespace(in); + for (vector<string>::const_iterator i = argv.begin(); i != argv.end(); ++i) { + string const& s = *i; + if (s[0] == '-') { + if (s.size() > 2) { + stringstream msg; + msg << s << " is an invalid option for CoNLLFeatures."; + Error(msg.str()); + } + + switch (s[1]) { + + case 'w': { + if (++i == argv.end()) { + Error("Missing parameter to -w"); + } + which_feat = boost::lexical_cast<unsigned>(*i); + break; + } + // feature template + case 't': { + if (++i == argv.end()) { + Error("Can't find template."); + } + feature_template = *i; + string::const_iterator start = feature_template.begin(); + string::const_iterator end = feature_template.end(); + smatch macro_match; + + // parse the template + while (regex_search(start, end, macro_match, macro_regex)) { + // get the relative location + string relative_location_str(macro_match[2].first, + macro_match[2].second); + int relative_location = atoi(relative_location_str.c_str()); + // add it to the list of relative locations for token or label + // (i.e. x or y) + bool valid_location = true; + if (*macro_match[1].first == 'x') { + // add it to token locations + token_relative_locations.push_back(relative_location); + } else { + if (relative_location != 0) { valid_location = false; } + // add it to label locations + label_relative_locations.push_back(relative_location); + } + if (!valid_location) { + stringstream msg; + msg << "Relative location " << relative_location + << " in feature template " << feature_template + << " is invalid."; + Error(msg.str()); + } + start = macro_match[0].second; + } + break; + } + + // TODO: arguments to specify kSOS and kEOS + + default: { + stringstream msg; + msg << "Invalid option on CoNLLFeatures: " << s; + Error(msg.str()); + break; + } + } // end of switch + } // end of if (token starts with hyphen) + } // end of for loop (over arguments) + + // the -t (i.e. template) option is mandatory in this feature function + if (label_relative_locations.size() == 0 || + token_relative_locations.size() == 0) { + stringstream msg; + msg << "Feature template must specify at least one" + << "token macro (e.g. x[-1]) and one label macro (e.g. y[0])."; + Error(msg.str()); + } +} + +void CoNLLFeatures::PrepareForInput(const SentenceMetadata& smeta) { + const Lattice& sl = smeta.GetSourceLattice(); + current_input.resize(sl.size()); + for (unsigned i = 0; i < sl.size(); ++i) { + if (sl[i].size() != 1) { + stringstream msg; + msg << "CoNLLFeatures don't support lattice inputs!\nid=" + << smeta.GetSentenceId() << endl; + Error(msg.str()); + } + current_input[i] = sl[i][0].label; + } + vector<WordID> wids; + string fn = "feat"; + fn += boost::lexical_cast<string>(which_feat); + string feats = smeta.GetSGMLValue(fn); + if (feats.size() == 0) { + Error("Can't find " + fn + " in <seg>\n"); + } + TD::ConvertSentence(feats, &wids); + assert(current_input.size() == wids.size()); + current_input = wids; +} + +void CoNLLFeatures::TraversalFeaturesImpl( + const SentenceMetadata& smeta, const Hypergraph::Edge& edge, + const vector<const void*>& ant_contexts, SparseVector<double>* features, + SparseVector<double>* estimated_features, void* context) const { + + const TRule& rule = *edge.rule_; + // arity = 0, no nonterminals + // size = 1, predicted label is a single token + if (rule.Arity() != 0 || + rule.e_.size() != 1) { + return; + } + + // replace label macros with actual label strings + // NOTE: currently, this feature function doesn't allow any label + // macros except %y[0]. but you can look at as much of the source as you want + const WordID y0 = rule.e_[0]; + string y0_str = TD::Convert(y0); + + // start of the span in the input being labeled + const int from_src_index = edge.i_; + // end of the span in the input + const int to_src_index = edge.j_; + + // in the case of tagging the size of the spans being labeled will + // always be 1, but in other formalisms, you can have bigger spans + if (to_src_index - from_src_index != 1) { + cerr << "CoNLLFeatures doesn't support input spans of length != 1"; + abort(); + } + + string feature_instance = feature_template; + // replace token macros with actual token strings + for (unsigned i = 0; i < token_relative_locations.size(); ++i) { + int loc = token_relative_locations[i]; + WordID x = loc < 0? kSOS: kEOS; + if(from_src_index + loc >= 0 && + from_src_index + loc < current_input.size()) { + x = current_input[from_src_index + loc]; + } + string x_str = TD::Convert(x); + ReplaceTokenMacroWithString(feature_instance, loc, x_str); + } + + ReplaceLabelMacroWithString(feature_instance, 0, y0_str); + + // pick a real value for this feature + double fval = 1.0; + + // add it to the feature vector + // FD::Convert converts the feature string to a feature int + // Escape makes sure the feature string doesn't have any bad + // symbols that could confuse a parser somewhere + features->add_value(FD::Convert(Escape(feature_instance)), fval); +} diff --git a/decoder/ff_conll.h b/decoder/ff_conll.h new file mode 100644 index 00000000..b37356d8 --- /dev/null +++ b/decoder/ff_conll.h @@ -0,0 +1,45 @@ +#ifndef FF_CONLL_H_ +#define FF_CONLL_H_ + +#include <vector> +#include <boost/xpressive/xpressive.hpp> +#include "ff.h" + +using namespace boost::xpressive; +using namespace std; + +class CoNLLFeatures : public FeatureFunction { + public: + CoNLLFeatures(const string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const HG::Edge& edge, + const vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + virtual void ParseArgs(const string& in); + virtual string Escape(const string& x) const; + virtual void ReplaceMacroWithString(string& feature_instance, + bool token_vs_label, + int relative_location, + const string& actual_token) const; + virtual void ReplaceTokenMacroWithString(string& feature_instance, + int relative_location, + const string& actual_token) const; + virtual void ReplaceLabelMacroWithString(string& feature_instance, + int relative_location, + const string& actual_token) const; + virtual void Error(const string&) const; + + private: + vector<int> token_relative_locations, label_relative_locations; + string feature_template; + vector<WordID> current_input; + WordID kSOS, kEOS; + sregex macro_regex; + unsigned which_feat; +}; + +#endif diff --git a/decoder/ff_const_reorder.cc b/decoder/ff_const_reorder.cc index 95546793..f1a6f7cb 100644 --- a/decoder/ff_const_reorder.cc +++ b/decoder/ff_const_reorder.cc @@ -1071,6 +1071,9 @@ ConstReorderFeature::~ConstReorderFeature() { // TODO void ConstReorderFeature::PrepareForInput(const SentenceMetadata& smeta) { string parse_file = smeta.GetSGMLValue("parse"); + if (parse_file.empty()) { + parse_file = smeta.GetSGMLValue("src_tree"); + } string srl_file = smeta.GetSGMLValue("srl"); assert(!(parse_file == "" && srl_file == "")); diff --git a/decoder/ff_const_reorder_common.h b/decoder/ff_const_reorder_common.h index 7c111de3..755fd948 100644 --- a/decoder/ff_const_reorder_common.h +++ b/decoder/ff_const_reorder_common.h @@ -1081,7 +1081,7 @@ typedef std::unordered_map<std::string, int>::iterator Iterator; struct Tsuruoka_Maxent { Tsuruoka_Maxent(const char* pszModelFName) { if (pszModelFName != NULL) { - m_pModel = new ME_Model(); + m_pModel = new maxent::ME_Model(); m_pModel->load_from_file(pszModelFName); } else m_pModel = NULL; @@ -1091,102 +1091,9 @@ struct Tsuruoka_Maxent { if (m_pModel != NULL) delete m_pModel; } - void fnTrain(const char* pszInstanceFName, const char* pszAlgorithm, - const char* pszModelFName, int /*iNumIteration*/) { - assert(strcmp(pszAlgorithm, "l1") == 0 || strcmp(pszAlgorithm, "l2") == 0 || - strcmp(pszAlgorithm, "sgd") == 0 || - strcmp(pszAlgorithm, "SGD") == 0); - FILE* fpIn = fopen(pszInstanceFName, "r"); - - ME_Model* pModel = new ME_Model(); - - char* pszLine = new char[100001]; - int iNumInstances = 0; - int iLen; - while (!feof(fpIn)) { - pszLine[0] = '\0'; - fgets(pszLine, 20000, fpIn); - if (strlen(pszLine) == 0) { - continue; - } - - iLen = strlen(pszLine); - while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) { - pszLine[iLen - 1] = '\0'; - iLen--; - } - - iNumInstances++; - - ME_Sample* pmes = new ME_Sample(); - - char* p = strrchr(pszLine, ' '); - assert(p != NULL); - p[0] = '\0'; - p++; - std::vector<std::string> vecContext; - SplitOnWhitespace(std::string(pszLine), &vecContext); - - pmes->label = std::string(p); - for (size_t i = 0; i < vecContext.size(); i++) - pmes->add_feature(vecContext[i]); - pModel->add_training_sample((*pmes)); - if (iNumInstances % 100000 == 0) - fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances); - delete pmes; - } - fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances); - fclose(fpIn); - - if (strcmp(pszAlgorithm, "l1") == 0) - pModel->use_l1_regularizer(1.0); - else if (strcmp(pszAlgorithm, "l2") == 0) - pModel->use_l2_regularizer(1.0); - else - pModel->use_SGD(); - - pModel->train(); - pModel->save_to_file(pszModelFName); - - delete pModel; - fprintf(stdout, "......Finished Training\n"); - fprintf(stdout, "......Model saved as %s\n", pszModelFName); - delete[] pszLine; - } - - double fnEval(const char* pszContext, const char* pszOutcome) const { - std::vector<std::string> vecContext; - ME_Sample* pmes = new ME_Sample(); - SplitOnWhitespace(std::string(pszContext), &vecContext); - - for (size_t i = 0; i < vecContext.size(); i++) - pmes->add_feature(vecContext[i]); - std::vector<double> vecProb = m_pModel->classify(*pmes); - delete pmes; - int iLableID = m_pModel->get_class_id(pszOutcome); - return vecProb[iLableID]; - } - void fnEval(const char* pszContext, - std::vector<std::pair<std::string, double> >& vecOutput) const { - std::vector<std::string> vecContext; - ME_Sample* pmes = new ME_Sample(); - SplitOnWhitespace(std::string(pszContext), &vecContext); - - vecOutput.clear(); - - for (size_t i = 0; i < vecContext.size(); i++) - pmes->add_feature(vecContext[i]); - std::vector<double> vecProb = m_pModel->classify(*pmes); - - for (size_t i = 0; i < vecProb.size(); i++) { - std::string label = m_pModel->get_class_label(i); - vecOutput.push_back(make_pair(label, vecProb[i])); - } - delete pmes; - } void fnEval(const char* pszContext, std::vector<double>& vecOutput) const { std::vector<std::string> vecContext; - ME_Sample* pmes = new ME_Sample(); + maxent::ME_Sample* pmes = new maxent::ME_Sample(); SplitOnWhitespace(std::string(pszContext), &vecContext); vecOutput.clear(); @@ -1206,7 +1113,7 @@ struct Tsuruoka_Maxent { } private: - ME_Model* m_pModel; + maxent::ME_Model* m_pModel; }; // an argument item or a predicate item (the verb itself) diff --git a/decoder/ff_context.h b/decoder/ff_context.h index 19198ec3..ed1aea2b 100644 --- a/decoder/ff_context.h +++ b/decoder/ff_context.h @@ -1,6 +1,5 @@ - -#ifndef _FF_CONTEXT_H_ -#define _FF_CONTEXT_H_ +#ifndef FF_CONTEXT_H_ +#define FF_CONTEXT_H_ #include <vector> #include <boost/xpressive/xpressive.hpp> diff --git a/decoder/ff_csplit.h b/decoder/ff_csplit.h index 79bf2886..227f2a14 100644 --- a/decoder/ff_csplit.h +++ b/decoder/ff_csplit.h @@ -1,5 +1,5 @@ -#ifndef _FF_CSPLIT_H_ -#define _FF_CSPLIT_H_ +#ifndef FF_CSPLIT_H_ +#define FF_CSPLIT_H_ #include <boost/shared_ptr.hpp> diff --git a/decoder/ff_external.h b/decoder/ff_external.h index 3e2bee51..fd12a37c 100644 --- a/decoder/ff_external.h +++ b/decoder/ff_external.h @@ -1,5 +1,5 @@ -#ifndef _FFEXTERNAL_H_ -#define _FFEXTERNAL_H_ +#ifndef FFEXTERNAL_H_ +#define FFEXTERNAL_H_ #include "ff.h" diff --git a/decoder/ff_factory.h b/decoder/ff_factory.h index 1aa8e55f..ba9be9ac 100644 --- a/decoder/ff_factory.h +++ b/decoder/ff_factory.h @@ -1,5 +1,5 @@ -#ifndef _FF_FACTORY_H_ -#define _FF_FACTORY_H_ +#ifndef FF_FACTORY_H_ +#define FF_FACTORY_H_ //TODO: use http://www.boost.org/doc/libs/1_43_0/libs/functional/factory/doc/html/index.html ? diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index db4032f7..c8350623 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -1,5 +1,5 @@ -#ifndef _KLM_FF_H_ -#define _KLM_FF_H_ +#ifndef KLM_FF_H_ +#define KLM_FF_H_ #include <vector> #include <string> diff --git a/decoder/ff_lm.h b/decoder/ff_lm.h index 85e79704..83a2e186 100644 --- a/decoder/ff_lm.h +++ b/decoder/ff_lm.h @@ -1,5 +1,5 @@ -#ifndef _LM_FF_H_ -#define _LM_FF_H_ +#ifndef LM_FF_H_ +#define LM_FF_H_ #include <vector> #include <string> diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h index 4965d235..5dea9a7d 100644 --- a/decoder/ff_ngrams.h +++ b/decoder/ff_ngrams.h @@ -1,5 +1,5 @@ -#ifndef _NGRAMS_FF_H_ -#define _NGRAMS_FF_H_ +#ifndef NGRAMS_FF_H_ +#define NGRAMS_FF_H_ #include <vector> #include <map> diff --git a/decoder/ff_parse_match.h b/decoder/ff_parse_match.h index 7820b418..188c406a 100644 --- a/decoder/ff_parse_match.h +++ b/decoder/ff_parse_match.h @@ -1,5 +1,5 @@ -#ifndef _FF_PARSE_MATCH_H_ -#define _FF_PARSE_MATCH_H_ +#ifndef FF_PARSE_MATCH_H_ +#define FF_PARSE_MATCH_H_ #include "ff.h" #include "hg.h" diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h index f210dc65..5c4cf45e 100644 --- a/decoder/ff_rules.h +++ b/decoder/ff_rules.h @@ -1,5 +1,5 @@ -#ifndef _FF_RULES_H_ -#define _FF_RULES_H_ +#ifndef FF_RULES_H_ +#define FF_RULES_H_ #include <vector> #include <map> diff --git a/decoder/ff_ruleshape.h b/decoder/ff_ruleshape.h index 488cfd84..66914f5d 100644 --- a/decoder/ff_ruleshape.h +++ b/decoder/ff_ruleshape.h @@ -1,5 +1,5 @@ -#ifndef _FF_RULESHAPE_H_ -#define _FF_RULESHAPE_H_ +#ifndef FF_RULESHAPE_H_ +#define FF_RULESHAPE_H_ #include <vector> #include <map> diff --git a/decoder/ff_soft_syntax.h b/decoder/ff_soft_syntax.h index e71825d5..da51df7f 100644 --- a/decoder/ff_soft_syntax.h +++ b/decoder/ff_soft_syntax.h @@ -1,5 +1,5 @@ -#ifndef _FF_SOFT_SYNTAX_H_ -#define _FF_SOFT_SYNTAX_H_ +#ifndef FF_SOFT_SYNTAX_H_ +#define FF_SOFT_SYNTAX_H_ #include "ff.h" #include "hg.h" diff --git a/decoder/ff_soft_syntax_mindist.h b/decoder/ff_soft_syntax_mindist.h index bf938b38..205eff4b 100644 --- a/decoder/ff_soft_syntax_mindist.h +++ b/decoder/ff_soft_syntax_mindist.h @@ -1,5 +1,5 @@ -#ifndef _FF_SOFT_SYNTAX_MINDIST_H_ -#define _FF_SOFT_SYNTAX_MINDIST_H_ +#ifndef FF_SOFT_SYNTAX_MINDIST_H_ +#define FF_SOFT_SYNTAX_MINDIST_H_ #include "ff.h" #include "hg.h" diff --git a/decoder/ff_source_path.h b/decoder/ff_source_path.h index 03126412..fc309264 100644 --- a/decoder/ff_source_path.h +++ b/decoder/ff_source_path.h @@ -1,5 +1,5 @@ -#ifndef _FF_SOURCE_PATH_H_ -#define _FF_SOURCE_PATH_H_ +#ifndef FF_SOURCE_PATH_H_ +#define FF_SOURCE_PATH_H_ #include <vector> #include <map> diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h index bdd638c1..6316e881 100644 --- a/decoder/ff_source_syntax.h +++ b/decoder/ff_source_syntax.h @@ -1,5 +1,5 @@ -#ifndef _FF_SOURCE_SYNTAX_H_ -#define _FF_SOURCE_SYNTAX_H_ +#ifndef FF_SOURCE_SYNTAX_H_ +#define FF_SOURCE_SYNTAX_H_ #include "ff.h" #include "hg.h" diff --git a/decoder/ff_source_syntax2.h b/decoder/ff_source_syntax2.h index f606c2bf..bbfa9eb6 100644 --- a/decoder/ff_source_syntax2.h +++ b/decoder/ff_source_syntax2.h @@ -1,5 +1,5 @@ -#ifndef _FF_SOURCE_SYNTAX2_H_ -#define _FF_SOURCE_SYNTAX2_H_ +#ifndef FF_SOURCE_SYNTAX2_H_ +#define FF_SOURCE_SYNTAX2_H_ #include "ff.h" #include "hg.h" diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h index d2f5e84c..e2475491 100644 --- a/decoder/ff_spans.h +++ b/decoder/ff_spans.h @@ -1,5 +1,5 @@ -#ifndef _FF_SPANS_H_ -#define _FF_SPANS_H_ +#ifndef FF_SPANS_H_ +#define FF_SPANS_H_ #include <vector> #include <map> diff --git a/decoder/ff_tagger.h b/decoder/ff_tagger.h index 46418b0c..0cb8c648 100644 --- a/decoder/ff_tagger.h +++ b/decoder/ff_tagger.h @@ -1,5 +1,5 @@ -#ifndef _FF_TAGGER_H_ -#define _FF_TAGGER_H_ +#ifndef FF_TAGGER_H_ +#define FF_TAGGER_H_ #include <map> #include <boost/scoped_ptr.hpp> diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h index 0161f603..ec454621 100644 --- a/decoder/ff_wordalign.h +++ b/decoder/ff_wordalign.h @@ -1,5 +1,5 @@ -#ifndef _FF_WORD_ALIGN_H_ -#define _FF_WORD_ALIGN_H_ +#ifndef FF_WORD_ALIGN_H_ +#define FF_WORD_ALIGN_H_ #include "ff.h" #include "array2d.h" diff --git a/decoder/ff_wordset.h b/decoder/ff_wordset.h index e78cd2fb..94f5ff8a 100644 --- a/decoder/ff_wordset.h +++ b/decoder/ff_wordset.h @@ -1,5 +1,5 @@ -#ifndef _FF_WORDSET_H_ -#define _FF_WORDSET_H_ +#ifndef FF_WORDSET_H_ +#define FF_WORDSET_H_ #include "ff.h" #include "tdict.h" diff --git a/decoder/ffset.h b/decoder/ffset.h index a69a75fa..84f9fdb9 100644 --- a/decoder/ffset.h +++ b/decoder/ffset.h @@ -1,5 +1,5 @@ -#ifndef _FFSET_H_ -#define _FFSET_H_ +#ifndef FFSET_H_ +#define FFSET_H_ #include <utility> #include <vector> diff --git a/decoder/forest_writer.cc b/decoder/forest_writer.cc index 6e4cccb3..cc9094d7 100644 --- a/decoder/forest_writer.cc +++ b/decoder/forest_writer.cc @@ -11,13 +11,13 @@ using namespace std; ForestWriter::ForestWriter(const std::string& path, int num) : - fname_(path + '/' + boost::lexical_cast<string>(num) + ".json.gz"), used_(false) {} + fname_(path + '/' + boost::lexical_cast<string>(num) + ".bin.gz"), used_(false) {} -bool ForestWriter::Write(const Hypergraph& forest, bool minimal_rules) { +bool ForestWriter::Write(const Hypergraph& forest) { assert(!used_); used_ = true; cerr << " Writing forest to " << fname_ << endl; WriteFile wf(fname_); - return HypergraphIO::WriteToJSON(forest, minimal_rules, wf.stream()); + return HypergraphIO::WriteToBinary(forest, wf.stream()); } diff --git a/decoder/forest_writer.h b/decoder/forest_writer.h index 819a8940..54e83470 100644 --- a/decoder/forest_writer.h +++ b/decoder/forest_writer.h @@ -1,5 +1,5 @@ -#ifndef _FOREST_WRITER_H_ -#define _FOREST_WRITER_H_ +#ifndef FOREST_WRITER_H_ +#define FOREST_WRITER_H_ #include <string> @@ -7,7 +7,7 @@ class Hypergraph; struct ForestWriter { ForestWriter(const std::string& path, int num); - bool Write(const Hypergraph& forest, bool minimal_rules); + bool Write(const Hypergraph& forest); const std::string fname_; bool used_; diff --git a/decoder/freqdict.h b/decoder/freqdict.h index 4e03fadd..07d797e2 100644 --- a/decoder/freqdict.h +++ b/decoder/freqdict.h @@ -1,5 +1,5 @@ -#ifndef _FREQDICT_H_ -#define _FREQDICT_H_ +#ifndef FREQDICT_H_ +#define FREQDICT_H_ #include <iostream> #include <map> diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc index 4253b652..fe28f4c6 100644 --- a/decoder/fst_translator.cc +++ b/decoder/fst_translator.cc @@ -27,11 +27,15 @@ struct FSTTranslatorImpl { const vector<double>& weights, Hypergraph* forest) { bool composed = false; - if (input.find("{\"rules\"") == 0) { + if (input.find("::forest::") == 0) { istringstream is(input); + string header, fname; + is >> header >> fname; + ReadFile rf(fname); + if (!rf) { cerr << "Failed to open " << fname << endl; abort(); } Hypergraph src_cfg_hg; - if (!HypergraphIO::ReadFromJSON(&is, &src_cfg_hg)) { - cerr << "Failed to read HG from JSON.\n"; + if (!HypergraphIO::ReadFromBinary(rf.stream(), &src_cfg_hg)) { + cerr << "Failed to read HG.\n"; abort(); } if (add_pass_through_rules) { @@ -95,6 +99,7 @@ bool FSTTranslator::TranslateImpl(const string& input, const vector<double>& weights, Hypergraph* minus_lm_forest) { smeta->SetSourceLength(0); // don't know how to compute this + smeta->input_type_ = cdec::kFOREST; return pimpl_->Translate(input, weights, minus_lm_forest); } diff --git a/decoder/hg.h b/decoder/hg.h index 4ed27d87..c756012e 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -1,5 +1,5 @@ -#ifndef _HG_H_ -#define _HG_H_ +#ifndef HG_H_ +#define HG_H_ // define USE_INFO_EDGE 1 if you want lots of debug info shown with --show_derivations - otherwise it adds quite a bit of overhead if ffs have their logging enabled (e.g. ff_from_fsa) #ifndef USE_INFO_EDGE @@ -18,6 +18,7 @@ #include <string> #include <vector> #include <boost/shared_ptr.hpp> +#include <boost/serialization/vector.hpp> #include "feature_vector.h" #include "small_vector.h" @@ -69,6 +70,18 @@ namespace HG { short int j_; short int prev_i_; short int prev_j_; + template<class Archive> + void serialize(Archive & ar, const unsigned int /*version*/) { + ar & head_node_; + ar & tail_nodes_; + ar & rule_; + ar & feature_values_; + ar & i_; + ar & j_; + ar & prev_i_; + ar & prev_j_; + ar & id_; + } void show(std::ostream &o,unsigned mask=SPAN|RULE) const { o<<'{'; if (mask&CATEGORY) @@ -149,6 +162,24 @@ namespace HG { WordID NT() const { return -cat_; } EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_ EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_ + template<class Archive> + void save(Archive & ar, const unsigned int /*version*/) const { + ar & node_hash; + ar & id_; + ar & TD::Convert(-cat_); + ar & in_edges_; + ar & out_edges_; + } + template<class Archive> + void load(Archive & ar, const unsigned int /*version*/) { + ar & node_hash; + ar & id_; + std::string cat; ar & cat; + cat_ = -TD::Convert(cat); + ar & in_edges_; + ar & out_edges_; + } + BOOST_SERIALIZATION_SPLIT_MEMBER() void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting node_hash = o.node_hash; cat_=o.cat_; @@ -492,6 +523,27 @@ public: void set_ids(); // resync edge,node .id_ void check_ids() const; // assert that .id_ have been kept in sync + template<class Archive> + void save(Archive & ar, const unsigned int /*version*/) const { + unsigned ns = nodes_.size(); ar & ns; + unsigned es = edges_.size(); ar & es; + for (auto& n : nodes_) ar & n; + for (auto& e : edges_) ar & e; + int x; + x = edges_topo_; ar & x; + x = is_linear_chain_; ar & x; + } + template<class Archive> + void load(Archive & ar, const unsigned int /*version*/) { + unsigned ns; ar & ns; nodes_.resize(ns); + unsigned es; ar & es; edges_.resize(es); + for (auto& n : nodes_) ar & n; + for (auto& e : edges_) ar & e; + int x; + ar & x; edges_topo_ = x; + ar & x; is_linear_chain_ = x; + } + BOOST_SERIALIZATION_SPLIT_MEMBER() private: Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {} }; diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc index 02f5a401..b9381d02 100644 --- a/decoder/hg_intersect.cc +++ b/decoder/hg_intersect.cc @@ -88,7 +88,7 @@ namespace HG { bool Intersect(const Lattice& target, Hypergraph* hg) { // there are a number of faster algorithms available for restricted // classes of hypergraph and/or target. - if (hg->IsLinearChain() && target.IsSentence()) + if (hg->IsLinearChain() && IsSentence(target)) return FastLinearIntersect(target, hg); vector<bool> rem(hg->edges_.size(), false); diff --git a/decoder/hg_intersect.h b/decoder/hg_intersect.h index 29a5ea2a..19c1c177 100644 --- a/decoder/hg_intersect.h +++ b/decoder/hg_intersect.h @@ -1,5 +1,5 @@ -#ifndef _HG_INTERSECT_H_ -#define _HG_INTERSECT_H_ +#ifndef HG_INTERSECT_H_ +#define HG_INTERSECT_H_ #include "lattice.h" diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index eb0be3d4..626b2954 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -6,362 +6,27 @@ #include <sstream> #include <iostream> +#include <boost/archive/binary_iarchive.hpp> +#include <boost/archive/binary_oarchive.hpp> +#include <boost/serialization/shared_ptr.hpp> + #include "fast_lexical_cast.hpp" #include "tdict.h" -#include "json_parse.h" #include "hg.h" using namespace std; -struct HGReader : public JSONParser { - HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; } - - void CreateNode(const string& cat, const string& shash, const vector<int>& in_edges) { - WordID c = TD::Convert("X") * -1; - if (!cat.empty()) c = TD::Convert(cat) * -1; - Hypergraph::Node* node = hg.AddNode(c); - char* dend; - if (shash.size()) - node->node_hash = strtoull(shash.c_str(), &dend, 16); - else - node->node_hash = 0; - for (int i = 0; i < in_edges.size(); ++i) { - if (in_edges[i] >= hg.edges_.size()) { - cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i] - << ", but hg only has " << hg.edges_.size() << " edges!\n"; - abort(); - } - hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node); - } - } - void CreateEdge(const TRulePtr& rule, SparseVector<double>* feats, const SmallVectorUnsigned& tail) { - Hypergraph::Edge* edge = hg.AddEdge(rule, tail); - feats->swap(edge->feature_values_); - edge->i_ = spans[0]; - edge->j_ = spans[1]; - edge->prev_i_ = spans[2]; - edge->prev_j_ = spans[3]; - } - - bool HandleJSONEvent(int type, const JSON_value* value) { - switch(state) { - case -1: - assert(type == JSON_T_OBJECT_BEGIN); - state = 0; - break; - case 0: - if (type == JSON_T_OBJECT_END) { - //cerr << "HG created\n"; // TODO, signal some kind of callback - } else if (type == JSON_T_KEY) { - string val = value->vu.str.value; - if (val == "features") { assert(fdict.empty()); state = 1; } - else if (val == "is_sorted") { state = 3; } - else if (val == "rules") { assert(rules.empty()); state = 4; } - else if (val == "node") { state = 8; } - else if (val == "edges") { state = 13; } - else { cerr << "Unexpected key: " << val << endl; return false; } - } - break; - - // features - case 1: - if(type == JSON_T_NULL) { state = 0; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 2; - break; - case 2: - if(type == JSON_T_ARRAY_END) { state = 0; break; } - assert(type == JSON_T_STRING); - fdict.push_back(FD::Convert(value->vu.str.value)); - assert(fdict.back() > 0); - break; - - // is_sorted - case 3: - assert(type == JSON_T_TRUE || type == JSON_T_FALSE); - is_sorted = (type == JSON_T_TRUE); - if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; } - state = 0; - break; - - // rules - case 4: - if(type == JSON_T_NULL) { state = 0; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 5; - break; - case 5: - if(type == JSON_T_ARRAY_END) { state = 0; break; } - assert(type == JSON_T_INTEGER); - state = 6; - rule_id = value->vu.integer_value; - break; - case 6: - assert(type == JSON_T_STRING); - rules[rule_id] = TRulePtr(new TRule(value->vu.str.value)); - state = 5; - break; - - // Nodes - case 8: - assert(type == JSON_T_OBJECT_BEGIN); - ++nodes; - in_edges.clear(); - cat.clear(); - shash.clear(); - state = 9; break; - case 9: - if (type == JSON_T_OBJECT_END) { - //cerr << "Creating NODE\n"; - CreateNode(cat, shash, in_edges); - state = 0; break; - } - assert(type == JSON_T_KEY); - cur_key = value->vu.str.value; - if (cur_key == "cat") { assert(cat.empty()); state = 10; break; } - if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; } - if (cur_key == "node_hash") { assert(shash.empty()); state = 24; break; } - cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n"; - return false; - case 10: - assert(type == JSON_T_STRING || type == JSON_T_NULL); - cat = value->vu.str.value; - state = 9; break; - case 11: - if (type == JSON_T_NULL) { state = 9; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 12; break; - case 12: - if (type == JSON_T_ARRAY_END) { state = 9; break; } - assert(type == JSON_T_INTEGER); - //cerr << "in_edges: " << value->vu.integer_value << endl; - in_edges.push_back(value->vu.integer_value); - break; - - // "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12}, - // { "tail": null, "feats" : [0,0.87,1,0.02], "spans":[1,2,3,4], "rule": 17}, - // { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}] - case 13: - assert(type == JSON_T_ARRAY_BEGIN); - state = 14; - break; - case 14: - if (type == JSON_T_ARRAY_END) { state = 0; break; } - assert(type == JSON_T_OBJECT_BEGIN); - //cerr << "New edge\n"; - ++edges; - cur_rule.reset(); feats.clear(); tail.clear(); - state = 15; break; - case 15: - if (type == JSON_T_OBJECT_END) { - CreateEdge(cur_rule, &feats, tail); - state = 14; break; - } - assert(type == JSON_T_KEY); - cur_key = value->vu.str.value; - //cerr << "edge key " << cur_key << endl; - if (cur_key == "rule") { assert(!cur_rule); state = 16; break; } - if (cur_key == "spans") { assert(!cur_rule); state = 22; break; } - if (cur_key == "feats") { assert(feats.empty()); state = 17; break; } - if (cur_key == "tail") { assert(tail.empty()); state = 20; break; } - cerr << "Unexpected key " << cur_key << " in edge specification\n"; - return false; - case 16: // edge.rule - if (type == JSON_T_INTEGER) { - int rule_id = value->vu.integer_value; - if (rules.find(rule_id) == rules.end()) { - // rules list must come before the edge definitions! - cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n"; - return false; - } - cur_rule = rules[rule_id]; - } else if (type == JSON_T_STRING) { - cur_rule.reset(new TRule(value->vu.str.value)); - } else { - cerr << "Rule must be either a rule id or a rule string" << endl; - return false; - } - // cerr << "Edge: rule=" << cur_rule->AsString() << endl; - state = 15; - break; - case 17: // edge.feats - if (type == JSON_T_NULL) { state = 15; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 18; break; - case 18: - if (type == JSON_T_ARRAY_END) { state = 15; break; } - if (type != JSON_T_INTEGER && type != JSON_T_STRING) { - cerr << "Unexpected feature id type\n"; return false; - } - if (type == JSON_T_INTEGER) { - fid = value->vu.integer_value; - assert(fid < fdict.size()); - fid = fdict[fid]; - } else if (JSON_T_STRING) { - fid = FD::Convert(value->vu.str.value); - } else { abort(); } - state = 19; - break; - case 19: - { - assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT); - double val = (type == JSON_T_INTEGER ? static_cast<double>(value->vu.integer_value) : - strtod(value->vu.str.value, NULL)); - feats.set_value(fid, val); - state = 18; - break; - } - case 20: // edge.tail - if (type == JSON_T_NULL) { state = 15; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 21; break; - case 21: - if (type == JSON_T_ARRAY_END) { state = 15; break; } - assert(type == JSON_T_INTEGER); - tail.push_back(value->vu.integer_value); - break; - case 22: // edge.spans - assert(type == JSON_T_ARRAY_BEGIN); - state = 23; - spans[0] = spans[1] = spans[2] = spans[3] = -1; - spanc = 0; - break; - case 23: - if (type == JSON_T_ARRAY_END) { state = 15; break; } - assert(type == JSON_T_INTEGER); - assert(spanc < 4); - spans[spanc] = value->vu.integer_value; - ++spanc; - break; - case 24: // read node hash - assert(type == JSON_T_STRING); - shash = value->vu.str.value; - state = 9; - break; - } - return true; - } - string rp; - string cat; - SmallVectorUnsigned tail; - vector<int> in_edges; - string shash; - TRulePtr cur_rule; - map<int, TRulePtr> rules; - vector<int> fdict; - SparseVector<double> feats; - int state; - int fid; - int nodes; - int edges; - int spans[4]; - int spanc; - string cur_key; - Hypergraph& hg; - int rule_id; - bool nodes_needed; - bool edges_needed; - bool is_sorted; -}; - -bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) { +bool HypergraphIO::ReadFromBinary(istream* in, Hypergraph* hg) { + boost::archive::binary_iarchive oa(*in); hg->clear(); - HGReader reader(hg); - return reader.Parse(in); -} - -static void WriteRule(const TRule& r, ostream* out) { - if (!r.lhs_) { (*out) << "[X] ||| "; } - JSONParser::WriteEscapedString(r.AsString(), out); + oa >> *hg; + return true; } -bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) { - if (hg.empty()) { *out << "{}\n"; return true; } - map<const TRule*, int> rid; - ostream& o = *out; - rid[NULL] = 0; - o << '{'; - if (!remove_rules) { - o << "\"rules\":["; - for (int i = 0; i < hg.edges_.size(); ++i) { - const TRule* r = hg.edges_[i].rule_.get(); - int &id = rid[r]; - if (!id) { - id=rid.size() - 1; - if (id > 1) o << ','; - o << id << ','; - WriteRule(*r, &o); - }; - } - o << "],"; - } - const bool use_fdict = FD::NumFeats() < 1000; - if (use_fdict) { - o << "\"features\":["; - for (int i = 1; i < FD::NumFeats(); ++i) { - o << (i==1 ? "":","); - JSONParser::WriteEscapedString(FD::Convert(i), &o); - } - o << "],"; - } - vector<int> edgemap(hg.edges_.size(), -1); // edges may be in non-topo order - int edge_count = 0; - for (int i = 0; i < hg.nodes_.size(); ++i) { - const Hypergraph::Node& node = hg.nodes_[i]; - if (i > 0) { o << ","; } - o << "\"edges\":["; - for (int j = 0; j < node.in_edges_.size(); ++j) { - const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; - edgemap[edge.id_] = edge_count; - ++edge_count; - o << (j == 0 ? "" : ",") << "{"; - - o << "\"tail\":["; - for (int k = 0; k < edge.tail_nodes_.size(); ++k) { - o << (k > 0 ? "," : "") << edge.tail_nodes_[k]; - } - o << "],"; - - o << "\"spans\":[" << edge.i_ << "," << edge.j_ << "," << edge.prev_i_ << "," << edge.prev_j_ << "],"; - - o << "\"feats\":["; - bool first = true; - for (SparseVector<double>::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) { - if (!it->second) continue; // don't write features that have a zero value - if (!it->first) continue; // if the feature set was frozen this might happen - if (!first) o << ','; - if (use_fdict) - o << (it->first - 1); - else { - JSONParser::WriteEscapedString(FD::Convert(it->first), &o); - } - o << ',' << it->second; - first = false; - } - o << "]"; - if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; } - o << "}"; - } - o << "],"; - - o << "\"node\":{\"in_edges\":["; - for (int j = 0; j < node.in_edges_.size(); ++j) { - int mapped_edge = edgemap[node.in_edges_[j]]; - assert(mapped_edge >= 0); - o << (j == 0 ? "" : ",") << mapped_edge; - } - o << "]"; - if (node.cat_ < 0) { - o << ",\"cat\":"; - JSONParser::WriteEscapedString(TD::Convert(node.cat_ * -1), &o); - } - char buf[48]; - sprintf(buf, "%016lX", node.node_hash); - o << ",\"node_hash\":\"" << buf << "\""; - o << "}"; - } - o << "}\n"; +bool HypergraphIO::WriteToBinary(const Hypergraph& hg, ostream* out) { + boost::archive::binary_oarchive oa(*out); + oa << hg; return true; } diff --git a/decoder/hg_io.h b/decoder/hg_io.h index 58af8132..93a9e280 100644 --- a/decoder/hg_io.h +++ b/decoder/hg_io.h @@ -1,5 +1,5 @@ -#ifndef _HG_IO_H_ -#define _HG_IO_H_ +#ifndef HG_IO_H_ +#define HG_IO_H_ #include <iostream> #include <string> @@ -9,19 +9,11 @@ class Hypergraph; struct HypergraphIO { - // the format is basically a list of nodes and edges in topological order - // any edge you read, you must have already read its tail nodes - // any node you read, you must have already read its incoming edges - // this may make writing a bit more challenging if your forest is not - // topologically sorted (but that probably doesn't happen very often), - // but it makes reading much more memory efficient. - // see test_data/small.json.gz for an email encoding - static bool ReadFromJSON(std::istream* in, Hypergraph* out); + static bool ReadFromBinary(std::istream* in, Hypergraph* out); + static bool WriteToBinary(const Hypergraph& hg, std::ostream* out); // if remove_rules is used, the hypergraph is serialized without rule information // (so it only contains structure and feature information) - static bool WriteToJSON(const Hypergraph& hg, bool remove_rules, std::ostream* out); - static void WriteAsCFG(const Hypergraph& hg); // Write only the target size information in bottom-up order. diff --git a/decoder/hg_remove_eps.h b/decoder/hg_remove_eps.h index 82f06039..f67fe6e2 100644 --- a/decoder/hg_remove_eps.h +++ b/decoder/hg_remove_eps.h @@ -1,5 +1,5 @@ -#ifndef _HG_REMOVE_EPS_H_ -#define _HG_REMOVE_EPS_H_ +#ifndef HG_REMOVE_EPS_H_ +#define HG_REMOVE_EPS_H_ #include "wordid.h" class Hypergraph; diff --git a/decoder/hg_sampler.h b/decoder/hg_sampler.h index 6ac39a20..4267b5ec 100644 --- a/decoder/hg_sampler.h +++ b/decoder/hg_sampler.h @@ -1,6 +1,5 @@ -#ifndef _HG_SAMPLER_H_ -#define _HG_SAMPLER_H_ - +#ifndef HG_SAMPLER_H_ +#define HG_SAMPLER_H_ #include <vector> #include <string> diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 5cb8626a..366b269d 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -1,10 +1,14 @@ #define BOOST_TEST_MODULE hg_test #include <boost/test/unit_test.hpp> #include <boost/test/floating_point_comparison.hpp> +#include <boost/archive/text_oarchive.hpp> +#include <boost/archive/text_iarchive.hpp> +#include <boost/serialization/shared_ptr.hpp> +#include <boost/serialization/vector.hpp> +#include <sstream> #include <iostream> #include "tdict.h" -#include "json_parse.h" #include "hg_intersect.h" #include "hg_union.h" #include "viterbi.h" @@ -394,16 +398,6 @@ BOOST_AUTO_TEST_CASE(Small) { BOOST_CHECK_CLOSE(2.1431036, log(c2), 1e-4); } -BOOST_AUTO_TEST_CASE(JSONTest) { - std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); - ostringstream os; - JSONParser::WriteEscapedString("\"I don't know\", she said.", &os); - BOOST_CHECK_EQUAL("\"\\\"I don't know\\\", she said.\"", os.str()); - ostringstream os2; - JSONParser::WriteEscapedString("yes", &os2); - BOOST_CHECK_EQUAL("\"yes\"", os2.str()); -} - BOOST_AUTO_TEST_CASE(TestGenericKBest) { std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; @@ -427,19 +421,29 @@ BOOST_AUTO_TEST_CASE(TestGenericKBest) { } } -BOOST_AUTO_TEST_CASE(TestReadWriteHG) { +BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) { std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); - Hypergraph hg,hg2; - CreateHG(path, &hg); - hg.edges_.front().j_ = 23; - hg.edges_.back().prev_i_ = 99; - ostringstream os; - HypergraphIO::WriteToJSON(hg, false, &os); - istringstream is(os.str()); - HypergraphIO::ReadFromJSON(&is, &hg2); - BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths()); - BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23); - BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99); + Hypergraph hg; + Hypergraph hg2; + std::string out; + { + CreateHG(path, &hg); + hg.edges_.front().j_ = 23; + hg.edges_.back().prev_i_ = 99; + ostringstream os; + boost::archive::text_oarchive oa(os); + oa << hg; + out = os.str(); + } + { + cerr << out << endl; + istringstream is(out); + boost::archive::text_iarchive ia(is); + ia >> hg2; + BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths()); + BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23); + BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99); + } } BOOST_AUTO_TEST_SUITE_END() diff --git a/decoder/hg_test.h b/decoder/hg_test.h index b7bab3c2..575b9c54 100644 --- a/decoder/hg_test.h +++ b/decoder/hg_test.h @@ -12,12 +12,8 @@ namespace { typedef char const* Name; -Name urdu_json="urdu.json.gz"; -Name urdu_wts="Arity_0 1.70741473606976 Arity_1 1.12426238048012 Arity_2 1.14986187839554 Glue -0.04589037041388 LanguageModel 1.09051 PassThrough -3.66226367902928 PhraseModel_0 -1.94633451863252 PhraseModel_1 -0.1475347695476 PhraseModel_2 -1.614818994946 WordPenalty -3.0 WordPenaltyFsa -0.56028442964748 ShorterThanPrev -10 LongerThanPrev -10"; -Name small_json="small.json.gz"; +Name small_json="small.bin.gz"; Name small_wts="Model_0 -2 Model_1 -.5 Model_2 -1.1 Model_3 -1 Model_4 -1 Model_5 .5 Model_6 .2 Model_7 -.3"; -Name perro_json="perro.json.gz"; -Name perro_wts="SameFirstLetter 1 LongerThanPrev 1 ShorterThanPrev 1 GlueTop 0.0 Glue -1.0 EgivenF -0.5 FgivenE -0.5 LexEgivenF -0.5 LexFgivenE -0.5 LM 1"; } @@ -32,7 +28,7 @@ struct HGSetup { static void JsonFile(Hypergraph *hg,std::string f) { ReadFile rf(f); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } static void JsonTestFile(Hypergraph *hg,std::string path,std::string n) { JsonFile(hg,path + "/"+n); @@ -48,35 +44,35 @@ void AddNullEdge(Hypergraph* hg) { } void HGSetup::CreateTinyLatticeHG(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.tiny_lattice"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.tiny_lattice.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); AddNullEdge(hg); } void HGSetup::CreateLatticeHG(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.lattice"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.lattice.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); AddNullEdge(hg); } void HGSetup::CreateHG_tiny(const std::string& path, Hypergraph* hg) { - ReadFile rf(path + "/hg_test.tiny"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.tiny.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } void HGSetup::CreateHG_int(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.hg_int"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.hg_int.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } void HGSetup::CreateHG(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.hg"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.hg.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } void HGSetup::CreateHGBalanced(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.hg_balanced"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.hg_balanced.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } #endif diff --git a/decoder/hg_union.h b/decoder/hg_union.h index 34624246..bb7e2d09 100644 --- a/decoder/hg_union.h +++ b/decoder/hg_union.h @@ -1,5 +1,5 @@ -#ifndef _HG_UNION_H_ -#define _HG_UNION_H_ +#ifndef HG_UNION_H_ +#define HG_UNION_H_ class Hypergraph; namespace HG { diff --git a/decoder/incremental.h b/decoder/incremental.h index f791a626..46b4817b 100644 --- a/decoder/incremental.h +++ b/decoder/incremental.h @@ -1,5 +1,5 @@ -#ifndef _INCREMENTAL_H_ -#define _INCREMENTAL_H_ +#ifndef INCREMENTAL_H_ +#define INCREMENTAL_H_ #include "weights.h" #include <vector> diff --git a/decoder/inside_outside.h b/decoder/inside_outside.h index c0377fe8..d5bda63c 100644 --- a/decoder/inside_outside.h +++ b/decoder/inside_outside.h @@ -1,5 +1,5 @@ -#ifndef _INSIDE_OUTSIDE_H_ -#define _INSIDE_OUTSIDE_H_ +#ifndef INSIDE_OUTSIDE_H_ +#define INSIDE_OUTSIDE_H_ #include <vector> #include <algorithm> diff --git a/decoder/json_parse.cc b/decoder/json_parse.cc deleted file mode 100644 index f6fdfea8..00000000 --- a/decoder/json_parse.cc +++ /dev/null @@ -1,50 +0,0 @@ -#include "json_parse.h" - -#include <string> -#include <iostream> - -using namespace std; - -static const char *json_hex_chars = "0123456789abcdef"; - -void JSONParser::WriteEscapedString(const string& in, ostream* out) { - int pos = 0; - int start_offset = 0; - unsigned char c = 0; - (*out) << '"'; - while(pos < in.size()) { - c = in[pos]; - switch(c) { - case '\b': - case '\n': - case '\r': - case '\t': - case '"': - case '\\': - case '/': - if(pos - start_offset > 0) - (*out) << in.substr(start_offset, pos - start_offset); - if(c == '\b') (*out) << "\\b"; - else if(c == '\n') (*out) << "\\n"; - else if(c == '\r') (*out) << "\\r"; - else if(c == '\t') (*out) << "\\t"; - else if(c == '"') (*out) << "\\\""; - else if(c == '\\') (*out) << "\\\\"; - else if(c == '/') (*out) << "\\/"; - start_offset = ++pos; - break; - default: - if(c < ' ') { - cerr << "Warning, bad character (" << static_cast<int>(c) << ") in string\n"; - if(pos - start_offset > 0) - (*out) << in.substr(start_offset, pos - start_offset); - (*out) << "\\u00" << json_hex_chars[c >> 4] << json_hex_chars[c & 0xf]; - start_offset = ++pos; - } else pos++; - } - } - if(pos - start_offset > 0) - (*out) << in.substr(start_offset, pos - start_offset); - (*out) << '"'; -} - diff --git a/decoder/json_parse.h b/decoder/json_parse.h deleted file mode 100644 index c3cba954..00000000 --- a/decoder/json_parse.h +++ /dev/null @@ -1,58 +0,0 @@ -#ifndef _JSON_WRAPPER_H_ -#define _JSON_WRAPPER_H_ - -#include <iostream> -#include <cassert> -#include "JSON_parser.h" - -class JSONParser { - public: - JSONParser() { - init_JSON_config(&config); - hack.mf = &JSONParser::Callback; - config.depth = 10; - config.callback_ctx = reinterpret_cast<void*>(this); - config.callback = hack.cb; - config.allow_comments = 1; - config.handle_floats_manually = 1; - jc = new_JSON_parser(&config); - } - virtual ~JSONParser() { - delete_JSON_parser(jc); - } - bool Parse(std::istream* in) { - int count = 0; - int lc = 1; - for (; in ; ++count) { - int next_char = in->get(); - if (!in->good()) break; - if (lc == '\n') { ++lc; } - if (!JSON_parser_char(jc, next_char)) { - std::cerr << "JSON_parser_char: syntax error, line " << lc << " (byte " << count << ")" << std::endl; - return false; - } - } - if (!JSON_parser_done(jc)) { - std::cerr << "JSON_parser_done: syntax error\n"; - return false; - } - return true; - } - static void WriteEscapedString(const std::string& in, std::ostream* out); - protected: - virtual bool HandleJSONEvent(int type, const JSON_value* value) = 0; - private: - int Callback(int type, const JSON_value* value) { - if (HandleJSONEvent(type, value)) return 1; - return 0; - } - JSON_parser_struct* jc; - JSON_config config; - typedef int (JSONParser::* MF)(int type, const struct JSON_value_struct* value); - union CBHack { - JSON_parser_callback cb; - MF mf; - } hack; -}; - -#endif diff --git a/decoder/kbest.h b/decoder/kbest.h index c7194c7e..d6b3eb94 100644 --- a/decoder/kbest.h +++ b/decoder/kbest.h @@ -1,5 +1,5 @@ -#ifndef _HG_KBEST_H_ -#define _HG_KBEST_H_ +#ifndef HG_KBEST_H_ +#define HG_KBEST_H_ #include <vector> #include <utility> diff --git a/decoder/lattice.cc b/decoder/lattice.cc index 89da3cd0..1f97048d 100644 --- a/decoder/lattice.cc +++ b/decoder/lattice.cc @@ -50,7 +50,6 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) { l.resize(ids.size()); for (int i = 0; i < l.size(); ++i) l[i].push_back(LatticeArc(ids[i], 0.0, 1)); - l.is_sentence_ = true; } void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) { diff --git a/decoder/lattice.h b/decoder/lattice.h index ad4ca50d..1258d3f5 100644 --- a/decoder/lattice.h +++ b/decoder/lattice.h @@ -1,5 +1,5 @@ -#ifndef __LATTICE_H_ -#define __LATTICE_H_ +#ifndef LATTICE_H_ +#define LATTICE_H_ #include <string> #include <vector> @@ -25,22 +25,24 @@ class Lattice : public std::vector<std::vector<LatticeArc> > { friend void LatticeTools::ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl); friend void LatticeTools::ConvertTextToLattice(const std::string& text, Lattice* pl); public: - Lattice() : is_sentence_(false) {} + Lattice() {} explicit Lattice(size_t t, const std::vector<LatticeArc>& v = std::vector<LatticeArc>()) : - std::vector<std::vector<LatticeArc> >(t, v), - is_sentence_(false) {} + std::vector<std::vector<LatticeArc>>(t, v) {} int Distance(int from, int to) const { if (dist_.empty()) return (to - from); return dist_(from, to); } - // TODO this should actually be computed based on the contents - // of the lattice - bool IsSentence() const { return is_sentence_; } private: void ComputeDistances(); Array2D<int> dist_; - bool is_sentence_; }; +inline bool IsSentence(const Lattice& in) { + bool res = true; + for (auto& alt : in) + if (alt.size() > 1) { res = false; break; } + return res; +} + #endif diff --git a/decoder/lexalign.cc b/decoder/lexalign.cc index 11f20de7..dd529311 100644 --- a/decoder/lexalign.cc +++ b/decoder/lexalign.cc @@ -114,10 +114,9 @@ bool LexicalAlign::TranslateImpl(const string& input, Hypergraph* forest) { Lattice& lattice = smeta->src_lattice_; LatticeTools::ConvertTextOrPLF(input, &lattice); - if (!lattice.IsSentence()) { - // lexical models make independence assumptions - // that don't work with lattices or conf nets - cerr << "LexicalTrans: cannot deal with lattice source input!\n"; + smeta->ComputeInputLatticeType(); + if (smeta->GetInputType() != cdec::kSEQUENCE) { + cerr << "LexicalTrans: cannot deal with non-sequence input!"; abort(); } smeta->SetSourceLength(lattice.size()); diff --git a/decoder/lexalign.h b/decoder/lexalign.h index 7ba4fe64..6415f4f9 100644 --- a/decoder/lexalign.h +++ b/decoder/lexalign.h @@ -1,5 +1,5 @@ -#ifndef _LEXALIGN_H_ -#define _LEXALIGN_H_ +#ifndef LEXALIGN_H_ +#define LEXALIGN_H_ #include "translator.h" #include "lattice.h" diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index 74a18c3f..d13a891a 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -271,10 +271,9 @@ bool LexicalTrans::TranslateImpl(const string& input, Hypergraph* forest) { Lattice& lattice = smeta->src_lattice_; LatticeTools::ConvertTextOrPLF(input, &lattice); - if (!lattice.IsSentence()) { - // lexical models make independence assumptions - // that don't work with lattices or conf nets - cerr << "LexicalTrans: cannot deal with lattice source input!\n"; + smeta->ComputeInputLatticeType(); + if (smeta->GetInputType() != cdec::kSEQUENCE) { + cerr << "LexicalTrans: cannot deal with non-sequence inputs\n"; abort(); } smeta->SetSourceLength(lattice.size()); diff --git a/decoder/lextrans.h b/decoder/lextrans.h index 2d51e7c0..a23a4e0d 100644 --- a/decoder/lextrans.h +++ b/decoder/lextrans.h @@ -1,5 +1,5 @@ -#ifndef _LEXTrans_H_ -#define _LEXTrans_H_ +#ifndef LEXTrans_H_ +#define LEXTrans_H_ #include "translator.h" #include "lattice.h" diff --git a/decoder/node_state_hash.h b/decoder/node_state_hash.h index 9fc01a09..f380fcb1 100644 --- a/decoder/node_state_hash.h +++ b/decoder/node_state_hash.h @@ -1,5 +1,5 @@ -#ifndef _NODE_STATE_HASH_ -#define _NODE_STATE_HASH_ +#ifndef NODE_STATE_HASH_ +#define NODE_STATE_HASH_ #include <cassert> #include <cstring> diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index d2c4715c..cd587833 100644 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -21,6 +21,7 @@ #include "kbest.h" #include "timing_stats.h" #include "sentences.h" +#include "b64featvector.h" //TODO: put function impls into .cc //TODO: move Translation into its own .h and use in cdec @@ -252,19 +253,31 @@ struct OracleBleu { } bool show_derivation; + int show_derivation_mask; + template <class Filter> - void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) { + void kbest(int sent_id, Hypergraph const& forest, int k, bool mr_mira_compat, + int src_len, std::ostream& kbest_out = std::cout, + std::ostream& deriv_out = std::cerr) { using namespace std; using namespace boost; typedef KBest::KBestDerivations<Sentence, ESentenceTraversal,Filter> K; K kbest(forest,k); //add length (f side) src length of this sentence to the psuedo-doc src length count float curr_src_length = doc_src_length + tmp_src_length; - for (int i = 0; i < k; ++i) { + if (mr_mira_compat) kbest_out << k << "\n"; + int i = 0; + for (; i < k; ++i) { typename K::Derivation *d = kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; - kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " - << d->feature_values << " ||| " << log(d->score); + kbest_out << sent_id << " ||| "; + if (mr_mira_compat) kbest_out << src_len << " ||| "; + kbest_out << TD::GetString(d->yield) << " ||| "; + if (mr_mira_compat) + kbest_out << EncodeFeatureVector(d->feature_values); + else + kbest_out << d->feature_values; + kbest_out << " ||| " << log(d->score); if (!refs.empty()) { ScoreP sentscore = GetScore(d->yield,sent_id); sentscore->PlusEquals(*doc_score,float(1)); @@ -275,14 +288,21 @@ struct OracleBleu { if (show_derivation) { deriv_out<<"\nsent_id="<<sent_id<<"."<<i<<" ||| "; //where i is candidate #/k deriv_out<<log(d->score)<<"\n"; - deriv_out<<kbest.derivation_tree(*d,true); + deriv_out<<kbest.derivation_tree(*d,true, show_derivation_mask); deriv_out<<"\n"<<flush; } } + if (mr_mira_compat) { + for (; i < k; ++i) kbest_out << "\n"; + kbest_out << flush; + } } // TODO decoder output should probably be moved to another file - how about oracle_bleu.h - void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_, std::string const &deriv_out_filename_) { + void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, + const bool unique, const bool mr_mira_compat, + const int src_len, std::string const& kbest_out_filename_, + std::string const& deriv_out_filename_) { WriteFile ko(kbest_out_filename_); std::cerr << "Output kbest to " << kbest_out_filename_ <<std::endl; @@ -295,9 +315,11 @@ struct OracleBleu { WriteFile oderiv(sderiv.str()); if (!unique) - kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),oderiv.get()); + kbest<KBest::NoFilter<std::vector<WordID> > >( + sent_id, forest, k, mr_mira_compat, src_len, ko.get(), oderiv.get()); else { - kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),oderiv.get()); + kbest<KBest::FilterUnique>(sent_id, forest, k, mr_mira_compat, src_len, + ko.get(), oderiv.get()); } } @@ -305,7 +327,8 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo { std::ostringstream kbest_string_stream; kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id; - DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), "-"); + DumpKBest(sent_id, forest, k, unique, false, -1, kbest_string_stream.str(), + "-"); } }; diff --git a/decoder/phrasebased_translator.cc b/decoder/phrasebased_translator.cc index 8048248e..8415353a 100644 --- a/decoder/phrasebased_translator.cc +++ b/decoder/phrasebased_translator.cc @@ -114,6 +114,7 @@ struct PhraseBasedTranslatorImpl { Lattice lattice; LatticeTools::ConvertTextOrPLF(input, &lattice); smeta->SetSourceLength(lattice.size()); + smeta->ComputeInputLatticeType(); size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion); minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100); if (add_pass_through_rules) { diff --git a/decoder/phrasebased_translator.h b/decoder/phrasebased_translator.h index e5e3f8a2..10790d0d 100644 --- a/decoder/phrasebased_translator.h +++ b/decoder/phrasebased_translator.h @@ -1,5 +1,5 @@ -#ifndef _PHRASEBASED_TRANSLATOR_H_ -#define _PHRASEBASED_TRANSLATOR_H_ +#ifndef PHRASEBASED_TRANSLATOR_H_ +#define PHRASEBASED_TRANSLATOR_H_ #include "translator.h" diff --git a/decoder/phrasetable_fst.h b/decoder/phrasetable_fst.h index 477de1f7..966bb14d 100644 --- a/decoder/phrasetable_fst.h +++ b/decoder/phrasetable_fst.h @@ -1,5 +1,5 @@ -#ifndef _PHRASETABLE_FST_H_ -#define _PHRASETABLE_FST_H_ +#ifndef PHRASETABLE_FST_H_ +#define PHRASETABLE_FST_H_ #include <vector> #include <string> diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc index 10192f7a..2c5fa9c4 100644 --- a/decoder/rescore_translator.cc +++ b/decoder/rescore_translator.cc @@ -3,6 +3,7 @@ #include <sstream> #include <boost/shared_ptr.hpp> +#include "filelib.h" #include "sentence_metadata.h" #include "hg.h" #include "hg_io.h" @@ -20,16 +21,18 @@ struct RescoreTranslatorImpl { bool Translate(const string& input, const vector<double>& weights, Hypergraph* forest) { - if (input == "{}") return false; - if (input.find("{\"rules\"") == 0) { - istringstream is(input); - Hypergraph src_cfg_hg; - if (!HypergraphIO::ReadFromJSON(&is, forest)) { - cerr << "Parse error while reading HG from JSON.\n"; - abort(); - } - } else { - cerr << "Can only read HG input from JSON: use training/grammar_convert\n"; + istringstream is(input); + string header, fname; + is >> header >> fname; + if (header != "::forest::") { + cerr << "RescoreTranslator: expected input lines of form ::forest:: filename.gz\n"; + abort(); + } + ReadFile rf(fname); + if (!rf) { cerr << "Can't read " << fname << endl; abort(); } + Hypergraph src_cfg_hg; + if (!HypergraphIO::ReadFromBinary(rf.stream(), forest)) { + cerr << "Parse error while reading HG.\n"; abort(); } Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); @@ -53,6 +56,7 @@ bool RescoreTranslator::TranslateImpl(const string& input, const vector<double>& weights, Hypergraph* minus_lm_forest) { smeta->SetSourceLength(0); // don't know how to compute this + smeta->input_type_ = cdec::kFOREST; return pimpl_->Translate(input, weights, minus_lm_forest); } diff --git a/decoder/rule_lexer.h b/decoder/rule_lexer.h index e15c056d..5267f9ca 100644 --- a/decoder/rule_lexer.h +++ b/decoder/rule_lexer.h @@ -1,5 +1,5 @@ -#ifndef _RULE_LEXER_H_ -#define _RULE_LEXER_H_ +#ifndef RULE_LEXER_H_ +#define RULE_LEXER_H_ #include <iostream> #include <string> diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll index d4a8d86b..8b48ab7b 100644 --- a/decoder/rule_lexer.ll +++ b/decoder/rule_lexer.ll @@ -356,6 +356,7 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const void RuleLexer::ReadRule(const std::string& srule, RuleCallback func, bool mono, void* extra) { init_default_feature_names(); + scfglex_fname = srule; lex_mono_rules = mono; lex_line = 1; rule_callback_extra = extra; diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 83b65c28..538f82ec 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -195,6 +195,7 @@ struct SCFGTranslatorImpl { Lattice& lattice = smeta->src_lattice_; LatticeTools::ConvertTextOrPLF(input, &lattice); smeta->SetSourceLength(lattice.size()); + smeta->ComputeInputLatticeType(); if (add_pass_through_rules){ if (!SILENT) cerr << "Adding pass through grammar" << endl; PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_, num_pt_features); diff --git a/decoder/sentence_metadata.h b/decoder/sentence_metadata.h index f2a779f4..e13c2ca5 100644 --- a/decoder/sentence_metadata.h +++ b/decoder/sentence_metadata.h @@ -1,14 +1,20 @@ -#ifndef _SENTENCE_METADATA_H_ -#define _SENTENCE_METADATA_H_ +#ifndef SENTENCE_METADATA_H_ +#define SENTENCE_METADATA_H_ #include <string> #include <map> #include <cassert> #include "lattice.h" +#include "tree_fragment.h" struct DocScorer; // deprecated, will be removed struct Score; // deprecated, will be removed +namespace cdec { +enum InputType { kSEQUENCE, kTREE, kLATTICE, kFOREST, kUNKNOWN }; +class TreeFragment; +} + class SentenceMetadata { public: friend class DecoderImpl; @@ -17,7 +23,17 @@ class SentenceMetadata { src_len_(-1), has_reference_(ref.size() > 0), trg_len_(ref.size()), - ref_(has_reference_ ? &ref : NULL) {} + ref_(has_reference_ ? &ref : NULL), + input_type_(cdec::kUNKNOWN) {} + + // helper function for lattice inputs + void ComputeInputLatticeType() { + input_type_ = cdec::kSEQUENCE; + for (auto& alt : src_lattice_) { + if (alt.size() > 1) { input_type_ = cdec::kLATTICE; break; } + } + } + cdec::InputType GetInputType() const { return input_type_; } int GetSentenceId() const { return sent_id_; } @@ -25,6 +41,8 @@ class SentenceMetadata { // it has parsed the source void SetSourceLength(int sl) { src_len_ = sl; } + const cdec::TreeFragment& GetSourceTree() const { return src_tree_; } + // this should be called if a separate model needs to // specify how long the target sentence should be void SetTargetLength(int tl) { @@ -64,12 +82,15 @@ class SentenceMetadata { const Score* app_score; public: Lattice src_lattice_; // this will only be set if inputs are finite state! + cdec::TreeFragment src_tree_; // this will be set only if inputs are trees private: // you need to be very careful when depending on these values // they will only be set during training / alignment contexts const bool has_reference_; int trg_len_; const Lattice* const ref_; + public: + cdec::InputType input_type_; }; #endif diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 30fb055f..500d2061 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -100,6 +100,8 @@ bool Tagger::TranslateImpl(const string& input, Lattice& lattice = smeta->src_lattice_; LatticeTools::ConvertTextToLattice(input, &lattice); smeta->SetSourceLength(lattice.size()); + smeta->ComputeInputLatticeType(); + assert(smeta->GetInputType() == cdec::kSEQUENCE); vector<WordID> sequence(lattice.size()); for (int i = 0; i < lattice.size(); ++i) { assert(lattice[i].size() == 1); diff --git a/decoder/tagger.h b/decoder/tagger.h index 9ac820d9..51659d5b 100644 --- a/decoder/tagger.h +++ b/decoder/tagger.h @@ -1,5 +1,5 @@ -#ifndef _TAGGER_H_ -#define _TAGGER_H_ +#ifndef TAGGER_H_ +#define TAGGER_H_ #include "translator.h" diff --git a/decoder/test_data/hg_test.hg.bin.gz b/decoder/test_data/hg_test.hg.bin.gz Binary files differnew file mode 100644 index 00000000..c07fbe8c --- /dev/null +++ b/decoder/test_data/hg_test.hg.bin.gz diff --git a/decoder/test_data/hg_test.hg_balanced.bin.gz b/decoder/test_data/hg_test.hg_balanced.bin.gz Binary files differnew file mode 100644 index 00000000..896d3d60 --- /dev/null +++ b/decoder/test_data/hg_test.hg_balanced.bin.gz diff --git a/decoder/test_data/hg_test.hg_int.bin.gz b/decoder/test_data/hg_test.hg_int.bin.gz Binary files differnew file mode 100644 index 00000000..e0bd6187 --- /dev/null +++ b/decoder/test_data/hg_test.hg_int.bin.gz diff --git a/decoder/test_data/hg_test.lattice.bin.gz b/decoder/test_data/hg_test.lattice.bin.gz Binary files differnew file mode 100644 index 00000000..8a8c05f4 --- /dev/null +++ b/decoder/test_data/hg_test.lattice.bin.gz diff --git a/decoder/test_data/hg_test.tiny.bin.gz b/decoder/test_data/hg_test.tiny.bin.gz Binary files differnew file mode 100644 index 00000000..0e68eb40 --- /dev/null +++ b/decoder/test_data/hg_test.tiny.bin.gz diff --git a/decoder/test_data/hg_test.tiny_lattice.bin.gz b/decoder/test_data/hg_test.tiny_lattice.bin.gz Binary files differnew file mode 100644 index 00000000..97e8dc05 --- /dev/null +++ b/decoder/test_data/hg_test.tiny_lattice.bin.gz diff --git a/decoder/test_data/perro.json.gz b/decoder/test_data/perro.json.gz Binary files differdeleted file mode 100644 index 41de5758..00000000 --- a/decoder/test_data/perro.json.gz +++ /dev/null diff --git a/decoder/test_data/small.bin.gz b/decoder/test_data/small.bin.gz Binary files differnew file mode 100644 index 00000000..1c5a1631 --- /dev/null +++ b/decoder/test_data/small.bin.gz diff --git a/decoder/test_data/small.json.gz b/decoder/test_data/small.json.gz Binary files differdeleted file mode 100644 index f6f37293..00000000 --- a/decoder/test_data/small.json.gz +++ /dev/null diff --git a/decoder/test_data/urdu.json.gz b/decoder/test_data/urdu.json.gz Binary files differdeleted file mode 100644 index 84535402..00000000 --- a/decoder/test_data/urdu.json.gz +++ /dev/null diff --git a/decoder/translator.h b/decoder/translator.h index ba218a0b..096cf191 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -1,5 +1,5 @@ -#ifndef _TRANSLATOR_H_ -#define _TRANSLATOR_H_ +#ifndef TRANSLATOR_H_ +#define TRANSLATOR_H_ #include <string> #include <vector> diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index bd3b01d0..cdd83ffc 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -32,12 +32,12 @@ static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bo ++lc; if (line.size() == 0 || line[0] == '#') continue; std::vector<StringPiece> fields = TokenizeMultisep(line, " ||| "); - if (has_multiple_states && fields.size() != 4) { - cerr << "Expected 4 fields in rule file but line " << lc << " is:\n" << line << endl; + if (has_multiple_states && fields.size() < 4) { + cerr << "Expected at least 4 fields in rule file but line " << lc << " is:\n" << line << endl; abort(); } - if (!has_multiple_states && fields.size() != 3) { - cerr << "Expected 3 fields in rule file but line " << lc << " is:\n" << line << endl; + if (!has_multiple_states && fields.size() < 3) { + cerr << "Expected at least 3 fields in rule file but line " << lc << " is:\n" << line << endl; abort(); } @@ -73,6 +73,7 @@ static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bo cerr << "Not implemented...\n"; abort(); // TODO read in states } else { os << " ||| " << fields[1] << " ||| " << fields[2]; + if (fields.size() > 3) os << " ||| " << fields[3]; rule.reset(new TRule(os.str())); } cur->rules.push_back(rule); @@ -266,6 +267,7 @@ struct Tree2StringTranslatorImpl { for (auto sym : rule_src) cur = &cur->next[sym]; TRulePtr rule(new TRule(rhse, rhsf, lhs)); + rule->a_.push_back(AlignmentPoint(0, 0)); rule->ComputeArity(); rule->scores_.set_value(ntfid, 1.0); rule->scores_.set_value(kFID, 1.0); @@ -287,6 +289,8 @@ struct Tree2StringTranslatorImpl { const vector<double>& weights, Hypergraph* minus_lm_forest) { cdec::TreeFragment input_tree(input, false); + smeta->src_tree_ = input_tree; + smeta->input_type_ = cdec::kTREE; if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; hg.ReserveNodes(input_tree.nodes.size()); diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 42f7793a..5f717c5b 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -64,6 +64,13 @@ int TreeFragment::SetupSpansRec(unsigned cur, int left) { return right; } +vector<int> TreeFragment::Terminals() const { + vector<int> terms; + for (auto& x : *this) + if (IsTerminal(x)) terms.push_back(x); + return terms; +} + // cp is the character index in the tree // np keeps track of the nodes (nonterminals) that have been built // symp keeps track of the terminal symbols that have been built diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index 6b4842ee..e19b79fb 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -72,6 +72,8 @@ class TreeFragment { BreadthFirstIterator bfs_begin(unsigned node_idx) const; BreadthFirstIterator bfs_end() const; + std::vector<int> Terminals() const; + private: // cp is the character index in the tree // np keeps track of the nodes (nonterminals) that have been built diff --git a/decoder/trule.h b/decoder/trule.h index 243b0da9..7af46747 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -1,5 +1,5 @@ -#ifndef _RULE_H_ -#define _RULE_H_ +#ifndef TRULE_H_ +#define TRULE_H_ #include <algorithm> #include <vector> @@ -11,6 +11,7 @@ #include "sparse_vector.h" #include "wordid.h" +#include "tdict.h" class TRule; typedef boost::shared_ptr<TRule> TRulePtr; @@ -26,6 +27,7 @@ struct AlignmentPoint { short s_; short t_; }; + inline std::ostream& operator<<(std::ostream& os, const AlignmentPoint& p) { return os << static_cast<int>(p.s_) << '-' << static_cast<int>(p.t_); } @@ -163,6 +165,66 @@ class TRule { // optional, shows internal structure of TSG rules boost::shared_ptr<cdec::TreeFragment> tree_structure; + friend class boost::serialization::access; + template<class Archive> + void save(Archive & ar, const unsigned int /*version*/) const { + ar & TD::Convert(-lhs_); + unsigned f_size = f_.size(); + ar & f_size; + assert(f_size <= (sizeof(size_t) * 8)); + size_t f_nt_mask = 0; + for (int i = f_.size() - 1; i >= 0; --i) { + f_nt_mask <<= 1; + f_nt_mask |= (f_[i] <= 0 ? 1 : 0); + } + ar & f_nt_mask; + for (unsigned i = 0; i < f_.size(); ++i) + ar & TD::Convert(f_[i] < 0 ? -f_[i] : f_[i]); + unsigned e_size = e_.size(); + ar & e_size; + size_t e_nt_mask = 0; + assert(e_size <= (sizeof(size_t) * 8)); + for (int i = e_.size() - 1; i >= 0; --i) { + e_nt_mask <<= 1; + e_nt_mask |= (e_[i] <= 0 ? 1 : 0); + } + ar & e_nt_mask; + for (unsigned i = 0; i < e_.size(); ++i) + if (e_[i] <= 0) ar & e_[i]; else ar & TD::Convert(e_[i]); + ar & arity_; + ar & scores_; + } + template<class Archive> + void load(Archive & ar, const unsigned int /*version*/) { + std::string lhs; ar & lhs; lhs_ = -TD::Convert(lhs); + unsigned f_size; ar & f_size; + f_.resize(f_size); + size_t f_nt_mask; ar & f_nt_mask; + std::string sym; + for (unsigned i = 0; i < f_size; ++i) { + bool mask = (f_nt_mask & 1); + ar & sym; + f_[i] = TD::Convert(sym) * (mask ? -1 : 1); + f_nt_mask >>= 1; + } + unsigned e_size; ar & e_size; + e_.resize(e_size); + size_t e_nt_mask; ar & e_nt_mask; + for (unsigned i = 0; i < e_size; ++i) { + bool mask = (e_nt_mask & 1); + if (mask) { + ar & e_[i]; + } else { + ar & sym; + e_[i] = TD::Convert(sym); + } + e_nt_mask >>= 1; + } + ar & arity_; + ar & scores_; + } + + BOOST_SERIALIZATION_SPLIT_MEMBER() private: TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {} }; diff --git a/decoder/trule_test.cc b/decoder/trule_test.cc index 0cb7e2e8..d75c2016 100644 --- a/decoder/trule_test.cc +++ b/decoder/trule_test.cc @@ -4,6 +4,10 @@ #include <boost/test/unit_test.hpp> #include <boost/test/floating_point_comparison.hpp> #include <iostream> +#include <boost/archive/text_oarchive.hpp> +#include <boost/archive/text_iarchive.hpp> +#include <boost/serialization/shared_ptr.hpp> +#include <sstream> #include "tdict.h" using namespace std; @@ -53,3 +57,35 @@ BOOST_AUTO_TEST_CASE(TestRuleR) { BOOST_CHECK_EQUAL(t6.e_[3], 0); } +BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) { + string str; + string t7str; + { + TRule t7; + t7.ReadFromString("[X] ||| den [X,1] sah [X,2] . ||| [2] saw the [1] . ||| Feature1=0.12321 Foo=0.23232 Bar=0.121"); + cerr << t7.AsString() << endl; + ostringstream os; + TRulePtr tp1(new TRule("[X] ||| a b c ||| x z y ||| A=1 B=2")); + TRulePtr tp2 = tp1; + boost::archive::text_oarchive oa(os); + oa << t7; + oa << tp1; + oa << tp2; + str = os.str(); + t7str = t7.AsString(); + } + { + istringstream is(str); + boost::archive::text_iarchive ia(is); + TRule t8; + ia >> t8; + TRulePtr tp3, tp4; + ia >> tp3; + ia >> tp4; + cerr << t8.AsString() << endl; + BOOST_CHECK_EQUAL(t7str, t8.AsString()); + cerr << tp3->AsString() << endl; + cerr << tp4->AsString() << endl; + } +} + diff --git a/decoder/viterbi.h b/decoder/viterbi.h index a8a0ea7f..20ee73cc 100644 --- a/decoder/viterbi.h +++ b/decoder/viterbi.h @@ -1,5 +1,5 @@ -#ifndef _VITERBI_H_ -#define _VITERBI_H_ +#ifndef VITERBI_H_ +#define VITERBI_H_ #include <vector> #include "prob.h" diff --git a/extractor/README.md b/extractor/README.md index 642fbd1d..b83ff900 100644 --- a/extractor/README.md +++ b/extractor/README.md @@ -1,10 +1,10 @@ -C++ implementation of the online grammar extractor originally developed by [Adam Lopez](http://www.cs.jhu.edu/~alopez/). +A simple and fast C++ implementation of a SCFG grammar extractor using suffix arrays. The implementation is described in this [paper](https://ufal.mff.cuni.cz/pbml/102/art-baltescu-blunsom.pdf). The original cython extractor is described in [Adam Lopez](http://www.cs.jhu.edu/~alopez/)'s PhD [thesis](http://www.cs.jhu.edu/~alopez/papers/adam.lopez.dissertation.pdf). The grammar extraction takes place in two steps: (a) precomputing a number of data structures and (b) actually extracting the grammars. All the flags below have the same meaning as in the cython implementation. To compile the data structures you need to run: - cdec/extractor/compile -a <alignment> -b <parallel_corpus> -c <compile_config_file> -o <compile_directory> + cdec/extractor/sacompile -a <alignment> -b <parallel_corpus> -c <compile_config_file> -o <compile_directory> To extract the grammars you need to run: diff --git a/extractor/sacompile.cc b/extractor/sacompile.cc index 3ee668ce..d80ab64d 100644 --- a/extractor/sacompile.cc +++ b/extractor/sacompile.cc @@ -114,6 +114,7 @@ int main(int argc, char** argv) { stop_write = Clock::now(); write_duration += GetDuration(start_write, stop_write); + stop_time = Clock::now(); cerr << "Constructing suffix array took " << GetDuration(start_time, stop_time) << " seconds" << endl; diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index 350571a6..134beb2f 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -10,17 +10,19 @@ * Currently only used for next pointers. */ -#ifndef LM_BHIKSHA__ -#define LM_BHIKSHA__ - -#include <stdint.h> -#include <assert.h> +#ifndef LM_BHIKSHA_H +#define LM_BHIKSHA_H #include "lm/model_type.hh" #include "lm/trie.hh" #include "util/bit_packing.hh" #include "util/sorted_uniform.hh" +#include <algorithm> + +#include <stdint.h> +#include <assert.h> + namespace lm { namespace ngram { struct Config; @@ -73,15 +75,24 @@ class ArrayBhiksha { ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config); void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const { - const uint64_t *begin_it = util::BinaryBelow(util::IdentityAccessor<uint64_t>(), offset_begin_, offset_end_, index); + // Some assertions are commented out because they are expensive. + // assert(*offset_begin_ == 0); + // std::upper_bound returns the first element that is greater. Want the + // last element that is <= to the index. + const uint64_t *begin_it = std::upper_bound(offset_begin_, offset_end_, index) - 1; + // Since *offset_begin_ == 0, the position should be in range. + // assert(begin_it >= offset_begin_); const uint64_t *end_it; - for (end_it = begin_it; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {} + for (end_it = begin_it + 1; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {} + // assert(end_it == std::upper_bound(offset_begin_, offset_end_, index + 1)); --end_it; + // assert(end_it >= begin_it); out.begin = ((begin_it - offset_begin_) << next_inline_.bits) | util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask); out.end = ((end_it - offset_begin_) << next_inline_.bits) | util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask); - //assert(out.end >= out.begin); + // If this fails, consider rebuilding your model using KenLM after 1e333d786b748555e8f368d2bbba29a016c98052 + assert(out.end >= out.begin); } void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) { @@ -109,4 +120,4 @@ class ArrayBhiksha { } // namespace ngram } // namespace lm -#endif // LM_BHIKSHA__ +#endif // LM_BHIKSHA_H diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 9c744b13..48117404 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -149,7 +149,7 @@ void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int s void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const { assert(header_size_ != kInvalidSize); - util::PReadOrThrow(file_.get(), to, amount, offset_excluding_header + header_size_); + util::ErsatzPRead(file_.get(), to, amount, offset_excluding_header + header_size_); } void *BinaryFormat::LoadBinary(std::size_t size) { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index f33f88d7..136d6b1a 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -1,5 +1,5 @@ -#ifndef LM_BINARY_FORMAT__ -#define LM_BINARY_FORMAT__ +#ifndef LM_BINARY_FORMAT_H +#define LM_BINARY_FORMAT_H #include "lm/config.hh" #include "lm/model_type.hh" @@ -103,4 +103,4 @@ bool IsBinaryFormat(int fd); } // namespace ngram } // namespace lm -#endif // LM_BINARY_FORMAT__ +#endif // LM_BINARY_FORMAT_H diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 4da81209..94a71ad2 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -1,5 +1,5 @@ -#ifndef LM_BLANK__ -#define LM_BLANK__ +#ifndef LM_BLANK_H +#define LM_BLANK_H #include <limits> @@ -40,4 +40,4 @@ inline bool HasExtension(const float &backoff) { } // namespace ngram } // namespace lm -#endif // LM_BLANK__ +#endif // LM_BLANK_H diff --git a/klm/lm/builder/Makefile.am b/klm/lm/builder/Makefile.am index 38259c51..bb15ff04 100644 --- a/klm/lm/builder/Makefile.am +++ b/klm/lm/builder/Makefile.am @@ -1,4 +1,8 @@ -bin_PROGRAMS = lmplz +bin_PROGRAMS = lmplz dump_counts + +dump_counts_SOURCES = \ + print.cc \ + dump_counts_main.cc lmplz_SOURCES = \ lmplz_main.cc \ @@ -7,6 +11,7 @@ lmplz_SOURCES = \ corpus_count.cc \ corpus_count.hh \ discount.hh \ + hash_gamma.hh \ header_info.hh \ initial_probabilities.cc \ initial_probabilities.hh \ @@ -22,6 +27,7 @@ lmplz_SOURCES = \ print.hh \ sort.hh +dump_counts_LDADD = ../libklm.a ../../util/double-conversion/libklm_util_double.a ../../util/stream/libklm_util_stream.a ../../util/libklm_util.a $(BOOST_THREAD_LIBS) lmplz_LDADD = ../libklm.a ../../util/double-conversion/libklm_util_double.a ../../util/stream/libklm_util_stream.a ../../util/libklm_util.a $(BOOST_THREAD_LIBS) AM_CPPFLAGS = -W -Wall -I$(top_srcdir)/klm diff --git a/klm/lm/builder/adjust_counts.cc b/klm/lm/builder/adjust_counts.cc index a6f48011..803c557d 100644 --- a/klm/lm/builder/adjust_counts.cc +++ b/klm/lm/builder/adjust_counts.cc @@ -1,8 +1,9 @@ #include "lm/builder/adjust_counts.hh" -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "util/stream/timer.hh" #include <algorithm> +#include <iostream> namespace lm { namespace builder { @@ -10,56 +11,78 @@ BadDiscountException::BadDiscountException() throw() {} BadDiscountException::~BadDiscountException() throw() {} namespace { -// Return last word in full that is different. +// Return last word in full that is different. const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) { const WordIndex *cur_word = full.end() - 1; const WordIndex *pre_word = lower_last.end() - 1; - // Find last difference. + // Find last difference. for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {} return cur_word; } class StatCollector { public: - StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<Discount> &discounts) - : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) { + StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<uint64_t> &counts_pruned, std::vector<Discount> &discounts) + : orders_(order), full_(orders_.back()), counts_(counts), counts_pruned_(counts_pruned), discounts_(discounts) { memset(&orders_[0], 0, sizeof(OrderStat) * order); } ~StatCollector() {} - void CalculateDiscounts() { + void CalculateDiscounts(const DiscountConfig &config) { counts_.resize(orders_.size()); - discounts_.resize(orders_.size()); + counts_pruned_.resize(orders_.size()); for (std::size_t i = 0; i < orders_.size(); ++i) { const OrderStat &s = orders_[i]; counts_[i] = s.count; + counts_pruned_[i] = s.count_pruned; + } - for (unsigned j = 1; j < 4; ++j) { - // TODO: Specialize error message for j == 3, meaning 3+ - UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for " - << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any " - << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?"); - } - - // See equation (26) in Chen and Goodman. - discounts_[i].amount[0] = 0.0; - float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]); - for (unsigned j = 1; j < 4; ++j) { - discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]); - UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]); + discounts_ = config.overwrite; + discounts_.resize(orders_.size()); + for (std::size_t i = config.overwrite.size(); i < orders_.size(); ++i) { + const OrderStat &s = orders_[i]; + try { + for (unsigned j = 1; j < 4; ++j) { + // TODO: Specialize error message for j == 3, meaning 3+ + UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for " + << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any " + << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?"); + } + + // See equation (26) in Chen and Goodman. + discounts_[i].amount[0] = 0.0; + float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]); + for (unsigned j = 1; j < 4; ++j) { + discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]); + UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]); + } + } catch (const BadDiscountException &e) { + switch (config.bad_action) { + case THROW_UP: + throw; + case COMPLAIN: + std::cerr << e.what() << " Substituting fallback discounts D1=" << config.fallback.amount[1] << " D2=" << config.fallback.amount[2] << " D3+=" << config.fallback.amount[3] << std::endl; + case SILENT: + break; + } + discounts_[i] = config.fallback; } } } - void Add(std::size_t order_minus_1, uint64_t count) { + void Add(std::size_t order_minus_1, uint64_t count, bool pruned = false) { OrderStat &stat = orders_[order_minus_1]; ++stat.count; + if (!pruned) + ++stat.count_pruned; if (count < 5) ++stat.n[count]; } - void AddFull(uint64_t count) { + void AddFull(uint64_t count, bool pruned = false) { ++full_.count; + if (!pruned) + ++full_.count_pruned; if (count < 5) ++full_.n[count]; } @@ -68,24 +91,27 @@ class StatCollector { // n_1 in equation 26 of Chen and Goodman etc uint64_t n[5]; uint64_t count; + uint64_t count_pruned; }; std::vector<OrderStat> orders_; OrderStat &full_; std::vector<uint64_t> &counts_; + std::vector<uint64_t> &counts_pruned_; std::vector<Discount> &discounts_; }; -// Reads all entries in order like NGramStream does. +// Reads all entries in order like NGramStream does. // But deletes any entries that have <s> in the 1st (not 0th) position on the // way out by putting other entries in their place. This disrupts the sort -// order but we don't care because the data is going to be sorted again. +// order but we don't care because the data is going to be sorted again. class CollapseStream { public: - CollapseStream(const util::stream::ChainPosition &position) : + CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold) : current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), - block_(position) { + prune_threshold_(prune_threshold), + block_(position) { StartBlock(); } @@ -96,10 +122,18 @@ class CollapseStream { CollapseStream &operator++() { assert(block_); + if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) { memcpy(current_.Base(), copy_from_, current_.TotalSize()); UpdateCopyFrom(); + + // Mark highest order n-grams for later pruning + if(current_.Count() <= prune_threshold_) { + current_.Mark(); + } + } + current_.NextInMemory(); uint8_t *block_base = static_cast<uint8_t*>(block_->Get()); if (current_.Base() == block_base + block_->ValidSize()) { @@ -107,6 +141,12 @@ class CollapseStream { ++block_; StartBlock(); } + + // Mark highest order n-grams for later pruning + if(current_.Count() <= prune_threshold_) { + current_.Mark(); + } + return *this; } @@ -119,9 +159,15 @@ class CollapseStream { current_.ReBase(block_->Get()); copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize(); UpdateCopyFrom(); + + // Mark highest order n-grams for later pruning + if(current_.Count() <= prune_threshold_) { + current_.Mark(); + } + } - // Find last without bos. + // Find last without bos. void UpdateCopyFrom() { for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) { if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break; @@ -132,83 +178,107 @@ class CollapseStream { // Goes backwards in the block uint8_t *copy_from_; - + uint64_t prune_threshold_; util::stream::Link block_; }; } // namespace -void AdjustCounts::Run(const ChainPositions &positions) { +void AdjustCounts::Run(const util::stream::ChainPositions &positions) { UTIL_TIMER("(%w s) Adjusted counts\n"); const std::size_t order = positions.size(); - StatCollector stats(order, counts_, discounts_); + StatCollector stats(order, counts_, counts_pruned_, discounts_); if (order == 1) { + // Only unigrams. Just collect stats. for (NGramStream full(positions[0]); full; ++full) stats.AddFull(full->Count()); - stats.CalculateDiscounts(); + + stats.CalculateDiscounts(discount_config_); return; } NGramStreams streams; streams.Init(positions, positions.size() - 1); - CollapseStream full(positions[positions.size() - 1]); + + CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back()); - // Initialization: <unk> has count 0 and so does <s>. + // Initialization: <unk> has count 0 and so does <s>. NGramStream *lower_valid = streams.begin(); streams[0]->Count() = 0; *streams[0]->begin() = kUNK; stats.Add(0, 0); (++streams[0])->Count() = 0; *streams[0]->begin() = kBOS; - // not in stats because it will get put in later. + // not in stats because it will get put in later. + std::vector<uint64_t> lower_counts(positions.size(), 0); + // iterate over full (the stream of the highest order ngrams) - for (; full; ++full) { + for (; full; ++full) { const WordIndex *different = FindDifference(*full, **lower_valid); std::size_t same = full->end() - 1 - different; - // Increment the adjusted count. + // Increment the adjusted count. if (same) ++streams[same - 1]->Count(); - // Output all the valid ones that changed. + // Output all the valid ones that changed. for (; lower_valid >= &streams[same]; --lower_valid) { - stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count()); + + // mjd: review this! + uint64_t order = (*lower_valid)->Order(); + uint64_t realCount = lower_counts[order - 1]; + if(order > 1 && prune_thresholds_[order - 1] && realCount <= prune_thresholds_[order - 1]) + (*lower_valid)->Mark(); + + stats.Add(lower_valid - streams.begin(), (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked()); ++*lower_valid; } + + // Count the true occurrences of lower-order n-grams + for (std::size_t i = 0; i < lower_counts.size(); ++i) { + if (i >= same) { + lower_counts[i] = 0; + } + lower_counts[i] += full->UnmarkedCount(); + } // This is here because bos is also const WordIndex *, so copy gets - // consistent argument types. + // consistent argument types. const WordIndex *full_end = full->end(); - // Initialize and mark as valid up to bos. + // Initialize and mark as valid up to bos. const WordIndex *bos; for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) { ++lower_valid; std::copy(bos, full_end, (*lower_valid)->begin()); (*lower_valid)->Count() = 1; } - // Now bos indicates where <s> is or is the 0th word of full. + // Now bos indicates where <s> is or is the 0th word of full. if (bos != full->begin()) { - // There is an <s> beyond the 0th word. + // There is an <s> beyond the 0th word. NGramStream &to = *++lower_valid; std::copy(bos, full_end, to->begin()); - to->Count() = full->Count(); + + // mjd: what is this doing? + to->Count() = full->UnmarkedCount(); } else { - stats.AddFull(full->Count()); + stats.AddFull(full->UnmarkedCount(), full->IsMarked()); } assert(lower_valid >= &streams[0]); } // Output everything valid. for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) { - stats.Add(s - streams.begin(), (*s)->Count()); + if((*s)->Count() <= prune_thresholds_[(*s)->Order() - 1]) + (*s)->Mark(); + stats.Add(s - streams.begin(), (*s)->UnmarkedCount(), (*s)->IsMarked()); ++*s; } - // Poison everyone! Except the N-grams which were already poisoned by the input. + // Poison everyone! Except the N-grams which were already poisoned by the input. for (NGramStream *s = streams.begin(); s != streams.end(); ++s) s->Poison(); - stats.CalculateDiscounts(); + stats.CalculateDiscounts(discount_config_); // NOTE: See special early-return case for unigrams near the top of this function } diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh index f38ff79d..a5435c28 100644 --- a/klm/lm/builder/adjust_counts.hh +++ b/klm/lm/builder/adjust_counts.hh @@ -1,24 +1,35 @@ -#ifndef LM_BUILDER_ADJUST_COUNTS__ -#define LM_BUILDER_ADJUST_COUNTS__ +#ifndef LM_BUILDER_ADJUST_COUNTS_H +#define LM_BUILDER_ADJUST_COUNTS_H #include "lm/builder/discount.hh" +#include "lm/lm_exception.hh" #include "util/exception.hh" #include <vector> #include <stdint.h> +namespace util { namespace stream { class ChainPositions; } } + namespace lm { namespace builder { -class ChainPositions; - class BadDiscountException : public util::Exception { public: BadDiscountException() throw(); ~BadDiscountException() throw(); }; +struct DiscountConfig { + // Overrides discounts for orders [1,discount_override.size()]. + std::vector<Discount> overwrite; + // If discounting fails for an order, copy them from here. + Discount fallback; + // What to do when discounts are out of range or would trigger divison by + // zero. It it does something other than THROW_UP, use fallback_discount. + WarningAction bad_action; +}; + /* Compute adjusted counts. * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts. * Output: [1,N]-grams with adjusted counts. @@ -27,18 +38,32 @@ class BadDiscountException : public util::Exception { */ class AdjustCounts { public: - AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts) - : counts_(counts), discounts_(discounts) {} + // counts: output + // counts_pruned: output + // discounts: mostly output. If the input already has entries, they will be kept. + // prune_thresholds: input. n-grams with normal (not adjusted) count below this will be pruned. + AdjustCounts( + const std::vector<uint64_t> &prune_thresholds, + std::vector<uint64_t> &counts, + std::vector<uint64_t> &counts_pruned, + const DiscountConfig &discount_config, + std::vector<Discount> &discounts) + : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned), discount_config_(discount_config), discounts_(discounts) + {} - void Run(const ChainPositions &positions); + void Run(const util::stream::ChainPositions &positions); private: + const std::vector<uint64_t> &prune_thresholds_; std::vector<uint64_t> &counts_; + std::vector<uint64_t> &counts_pruned_; + + DiscountConfig discount_config_; std::vector<Discount> &discounts_; }; } // namespace builder } // namespace lm -#endif // LM_BUILDER_ADJUST_COUNTS__ +#endif // LM_BUILDER_ADJUST_COUNTS_H diff --git a/klm/lm/builder/adjust_counts_test.cc b/klm/lm/builder/adjust_counts_test.cc index 68b5f33e..073c5dfe 100644 --- a/klm/lm/builder/adjust_counts_test.cc +++ b/klm/lm/builder/adjust_counts_test.cc @@ -1,6 +1,6 @@ #include "lm/builder/adjust_counts.hh" -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "util/scoped.hh" #include <boost/thread/thread.hpp> @@ -61,19 +61,24 @@ BOOST_AUTO_TEST_CASE(Simple) { util::stream::ChainConfig config; config.total_memory = 100; config.block_count = 1; - Chains chains(4); + util::stream::Chains chains(4); for (unsigned i = 0; i < 4; ++i) { config.entry_size = NGram::TotalSize(i + 1); chains.push_back(config); } chains[3] >> WriteInput(); - ChainPositions for_adjust(chains); + util::stream::ChainPositions for_adjust(chains); for (unsigned i = 0; i < 4; ++i) { chains[i] >> boost::ref(outputs[i]); } chains >> util::stream::kRecycle; - BOOST_CHECK_THROW(AdjustCounts(counts, discount).Run(for_adjust), BadDiscountException); + std::vector<uint64_t> counts_pruned(4); + std::vector<uint64_t> prune_thresholds(4); + DiscountConfig discount_config; + discount_config.fallback = Discount(); + discount_config.bad_action = THROW_UP; + BOOST_CHECK_THROW(AdjustCounts(prune_thresholds, counts, counts_pruned, discount_config, discount).Run(for_adjust), BadDiscountException); } BOOST_REQUIRE_EQUAL(4UL, counts.size()); BOOST_CHECK_EQUAL(4UL, counts[0]); diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc index ccc06efc..590e79fa 100644 --- a/klm/lm/builder/corpus_count.cc +++ b/klm/lm/builder/corpus_count.cc @@ -2,6 +2,7 @@ #include "lm/builder/ngram.hh" #include "lm/lm_exception.hh" +#include "lm/vocab.hh" #include "lm/word_index.hh" #include "util/fake_ofstream.hh" #include "util/file.hh" @@ -37,60 +38,6 @@ struct VocabEntry { }; #pragma pack(pop) -const float kProbingMultiplier = 1.5; - -class VocabHandout { - public: - static std::size_t MemUsage(WordIndex initial_guess) { - if (initial_guess < 2) initial_guess = 2; - return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier)); - } - - explicit VocabHandout(int fd, WordIndex initial_guess) : - table_backing_(util::CallocOrThrow(MemUsage(initial_guess))), - table_(table_backing_.get(), MemUsage(initial_guess)), - double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)), - word_list_(fd) { - Lookup("<unk>"); // Force 0 - Lookup("<s>"); // Force 1 - Lookup("</s>"); // Force 2 - } - - WordIndex Lookup(const StringPiece &word) { - VocabEntry entry; - entry.key = util::MurmurHashNative(word.data(), word.size()); - entry.value = table_.SizeNoSerialization(); - - Table::MutableIterator it; - if (table_.FindOrInsert(entry, it)) - return it->value; - word_list_ << word << '\0'; - UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh."); - if (Size() >= double_cutoff_) { - table_backing_.call_realloc(table_.DoubleTo()); - table_.Double(table_backing_.get()); - double_cutoff_ *= 2; - } - return entry.value; - } - - WordIndex Size() const { - return table_.SizeNoSerialization(); - } - - private: - // TODO: factor out a resizable probing hash table. - // TODO: use mremap on linux to get all zeros on resizes. - util::scoped_malloc table_backing_; - - typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table; - Table table_; - - std::size_t double_cutoff_; - - util::FakeOFStream word_list_; -}; - class DedupeHash : public std::unary_function<const WordIndex *, bool> { public: explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} @@ -127,6 +74,10 @@ struct DedupeEntry { } }; + +// TODO: don't have this here, should be with probing hash table defaults? +const float kProbingMultiplier = 1.5; + typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe; class Writer { @@ -220,37 +171,50 @@ float CorpusCount::DedupeMultiplier(std::size_t order) { } std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { - return VocabHandout::MemUsage(vocab_estimate); + return ngram::GrowableVocab<ngram::WriteUniqueWords>::MemUsage(vocab_estimate); } -CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) +CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol) : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), - dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { + dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)), + disallowed_symbol_action_(disallowed_symbol) { } -void CorpusCount::Run(const util::stream::ChainPosition &position) { - UTIL_TIMER("(%w s) Counted n-grams\n"); +namespace { + void ComplainDisallowed(StringPiece word, WarningAction &action) { + switch (action) { + case SILENT: + return; + case COMPLAIN: + std::cerr << "Warning: " << word << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl; + action = SILENT; + return; + case THROW_UP: + UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace."); + } + } +} // namespace - VocabHandout vocab(vocab_write_, type_count_); +void CorpusCount::Run(const util::stream::ChainPosition &position) { + ngram::GrowableVocab<ngram::WriteUniqueWords> vocab(type_count_, vocab_write_); token_count_ = 0; type_count_ = 0; - const WordIndex end_sentence = vocab.Lookup("</s>"); + const WordIndex end_sentence = vocab.FindOrInsert("</s>"); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; bool delimiters[256]; - memset(delimiters, 0, sizeof(delimiters)); - const char kDelimiterSet[] = "\0\t\n\r "; - for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) { - delimiters[static_cast<unsigned char>(*i)] = true; - } + util::BoolCharacter::Build("\0\t\n\r ", delimiters); try { while(true) { StringPiece line(from_.ReadLine()); writer.StartSentence(); for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) { - WordIndex word = vocab.Lookup(*w); - UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future."); + WordIndex word = vocab.FindOrInsert(*w); + if (word <= 2) { + ComplainDisallowed(*w, disallowed_symbol_action_); + continue; + } writer.Append(word); ++count; } diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh index aa0ed8ed..da4ff9fc 100644 --- a/klm/lm/builder/corpus_count.hh +++ b/klm/lm/builder/corpus_count.hh @@ -1,6 +1,7 @@ -#ifndef LM_BUILDER_CORPUS_COUNT__ -#define LM_BUILDER_CORPUS_COUNT__ +#ifndef LM_BUILDER_CORPUS_COUNT_H +#define LM_BUILDER_CORPUS_COUNT_H +#include "lm/lm_exception.hh" #include "lm/word_index.hh" #include "util/scoped.hh" @@ -28,7 +29,7 @@ class CorpusCount { // token_count: out. // type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value. - CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block); + CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol); void Run(const util::stream::ChainPosition &position); @@ -40,8 +41,10 @@ class CorpusCount { std::size_t dedupe_mem_size_; util::scoped_malloc dedupe_mem_; + + WarningAction disallowed_symbol_action_; }; } // namespace builder } // namespace lm -#endif // LM_BUILDER_CORPUS_COUNT__ +#endif // LM_BUILDER_CORPUS_COUNT_H diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc index 6d325ef5..26cb6346 100644 --- a/klm/lm/builder/corpus_count_test.cc +++ b/klm/lm/builder/corpus_count_test.cc @@ -45,7 +45,7 @@ BOOST_AUTO_TEST_CASE(Short) { NGramStream stream; uint64_t token_count; WordIndex type_count = 10; - CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize()); + CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize(), SILENT); chain >> boost::ref(counter) >> stream >> util::stream::kRecycle; const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"}; diff --git a/klm/lm/builder/discount.hh b/klm/lm/builder/discount.hh index 4d0aa4fd..e2f40846 100644 --- a/klm/lm/builder/discount.hh +++ b/klm/lm/builder/discount.hh @@ -1,5 +1,5 @@ -#ifndef BUILDER_DISCOUNT__ -#define BUILDER_DISCOUNT__ +#ifndef LM_BUILDER_DISCOUNT_H +#define LM_BUILDER_DISCOUNT_H #include <algorithm> @@ -23,4 +23,4 @@ struct Discount { } // namespace builder } // namespace lm -#endif // BUILDER_DISCOUNT__ +#endif // LM_BUILDER_DISCOUNT_H diff --git a/klm/lm/builder/dump_counts_main.cc b/klm/lm/builder/dump_counts_main.cc new file mode 100644 index 00000000..fa001679 --- /dev/null +++ b/klm/lm/builder/dump_counts_main.cc @@ -0,0 +1,36 @@ +#include "lm/builder/print.hh" +#include "lm/word_index.hh" +#include "util/file.hh" +#include "util/read_compressed.hh" + +#include <boost/lexical_cast.hpp> + +#include <iostream> +#include <vector> + +int main(int argc, char *argv[]) { + if (argc != 4) { + std::cerr << "Usage: " << argv[0] << " counts vocabulary order\n" + "The counts file contains records with 4-byte vocabulary ids followed by 8-byte\n" + "counts. Each record has order many vocabulary ids.\n" + "The vocabulary file contains the words delimited by NULL in order of id.\n" + "The vocabulary file may not be compressed because it is mmapped but the counts\n" + "file can be compressed.\n"; + return 1; + } + util::ReadCompressed counts(util::OpenReadOrThrow(argv[1])); + util::scoped_fd vocab_file(util::OpenReadOrThrow(argv[2])); + lm::builder::VocabReconstitute vocab(vocab_file.get()); + unsigned int order = boost::lexical_cast<unsigned int>(argv[3]); + std::vector<char> record(sizeof(uint32_t) * order + sizeof(uint64_t)); + while (std::size_t got = counts.ReadOrEOF(&*record.begin(), record.size())) { + UTIL_THROW_IF(got != record.size(), util::Exception, "Read " << got << " bytes at the end of file, which is not a complete record of length " << record.size()); + const lm::WordIndex *words = reinterpret_cast<const lm::WordIndex*>(&*record.begin()); + for (const lm::WordIndex *i = words; i != words + order; ++i) { + UTIL_THROW_IF(*i >= vocab.Size(), util::Exception, "Vocab ID " << *i << " is larger than the vocab file's maximum of " << vocab.Size() << ". Are you sure you have the right order and vocab file for these counts?"); + std::cout << vocab.Lookup(*i) << ' '; + } + // TODO don't use std::cout because it is slow. Add fast uint64_t printing support to FakeOFStream. + std::cout << *reinterpret_cast<const uint64_t*>(words + order) << '\n'; + } +} diff --git a/klm/lm/builder/hash_gamma.hh b/klm/lm/builder/hash_gamma.hh new file mode 100644 index 00000000..4bef47e8 --- /dev/null +++ b/klm/lm/builder/hash_gamma.hh @@ -0,0 +1,19 @@ +#ifndef LM_BUILDER_HASH_GAMMA__ +#define LM_BUILDER_HASH_GAMMA__ + +#include <stdint.h> + +namespace lm { namespace builder { + +#pragma pack(push) +#pragma pack(4) + +struct HashGamma { + uint64_t hash_value; + float gamma; +}; + +#pragma pack(pop) + +}} // namespaces +#endif // LM_BUILDER_HASH_GAMMA__ diff --git a/klm/lm/builder/header_info.hh b/klm/lm/builder/header_info.hh index ccca1456..16f3f609 100644 --- a/klm/lm/builder/header_info.hh +++ b/klm/lm/builder/header_info.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_HEADER_INFO__ -#define LM_BUILDER_HEADER_INFO__ +#ifndef LM_BUILDER_HEADER_INFO_H +#define LM_BUILDER_HEADER_INFO_H #include <string> #include <stdint.h> diff --git a/klm/lm/builder/initial_probabilities.cc b/klm/lm/builder/initial_probabilities.cc index 58b42a20..5d19a897 100644 --- a/klm/lm/builder/initial_probabilities.cc +++ b/klm/lm/builder/initial_probabilities.cc @@ -3,6 +3,8 @@ #include "lm/builder/discount.hh" #include "lm/builder/ngram_stream.hh" #include "lm/builder/sort.hh" +#include "lm/builder/hash_gamma.hh" +#include "util/murmur_hash.hh" #include "util/file.hh" #include "util/stream/chain.hh" #include "util/stream/io.hh" @@ -14,55 +16,182 @@ namespace lm { namespace builder { namespace { struct BufferEntry { - // Gamma from page 20 of Chen and Goodman. + // Gamma from page 20 of Chen and Goodman. float gamma; - // \sum_w a(c w) for all w. + // \sum_w a(c w) for all w. float denominator; }; -// Extract an array of gamma from an array of BufferEntry. +struct HashBufferEntry : public BufferEntry { + // Hash value of ngram. Used to join contexts with backoffs. + uint64_t hash_value; +}; + +// Reads all entries in order like NGramStream does. +// But deletes any entries that have CutoffCount below or equal to pruning +// threshold. +class PruneNGramStream { + public: + PruneNGramStream(const util::stream::ChainPosition &position) : + current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), + dest_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), + currentCount_(0), + block_(position) + { + StartBlock(); + } + + NGram &operator*() { return current_; } + NGram *operator->() { return ¤t_; } + + operator bool() const { + return block_; + } + + PruneNGramStream &operator++() { + assert(block_); + + if (current_.Order() > 1) { + if(currentCount_ > 0) { + if(dest_.Base() < current_.Base()) { + memcpy(dest_.Base(), current_.Base(), current_.TotalSize()); + } + dest_.NextInMemory(); + } + } else { + dest_.NextInMemory(); + } + + current_.NextInMemory(); + + uint8_t *block_base = static_cast<uint8_t*>(block_->Get()); + if (current_.Base() == block_base + block_->ValidSize()) { + block_->SetValidSize(dest_.Base() - block_base); + ++block_; + StartBlock(); + if (block_) { + currentCount_ = current_.CutoffCount(); + } + } else { + currentCount_ = current_.CutoffCount(); + } + + return *this; + } + + private: + void StartBlock() { + for (; ; ++block_) { + if (!block_) return; + if (block_->ValidSize()) break; + } + current_.ReBase(block_->Get()); + currentCount_ = current_.CutoffCount(); + + dest_.ReBase(block_->Get()); + } + + NGram current_; // input iterator + NGram dest_; // output iterator + + uint64_t currentCount_; + + util::stream::Link block_; +}; + +// Extract an array of HashedGamma from an array of BufferEntry. class OnlyGamma { public: + OnlyGamma(bool pruning) : pruning_(pruning) {} + void Run(const util::stream::ChainPosition &position) { for (util::stream::Link block_it(position); block_it; ++block_it) { - float *out = static_cast<float*>(block_it->Get()); - const float *in = out; - const float *end = static_cast<const float*>(block_it->ValidEnd()); - for (out += 1, in += 2; in < end; out += 1, in += 2) { - *out = *in; + if(pruning_) { + const HashBufferEntry *in = static_cast<const HashBufferEntry*>(block_it->Get()); + const HashBufferEntry *end = static_cast<const HashBufferEntry*>(block_it->ValidEnd()); + + // Just make it point to the beginning of the stream so it can be overwritten + // With HashGamma values. Do not attempt to interpret the values until set below. + HashGamma *out = static_cast<HashGamma*>(block_it->Get()); + for (; in < end; out += 1, in += 1) { + // buffering, otherwise might overwrite values too early + float gamma_buf = in->gamma; + uint64_t hash_buf = in->hash_value; + + out->gamma = gamma_buf; + out->hash_value = hash_buf; + } + block_it->SetValidSize((block_it->ValidSize() * sizeof(HashGamma)) / sizeof(HashBufferEntry)); + } + else { + float *out = static_cast<float*>(block_it->Get()); + const float *in = out; + const float *end = static_cast<const float*>(block_it->ValidEnd()); + for (out += 1, in += 2; in < end; out += 1, in += 2) { + *out = *in; + } + block_it->SetValidSize(block_it->ValidSize() / 2); } - block_it->SetValidSize(block_it->ValidSize() / 2); } } + + private: + bool pruning_; }; class AddRight { public: - AddRight(const Discount &discount, const util::stream::ChainPosition &input) - : discount_(discount), input_(input) {} + AddRight(const Discount &discount, const util::stream::ChainPosition &input, bool pruning) + : discount_(discount), input_(input), pruning_(pruning) {} void Run(const util::stream::ChainPosition &output) { NGramStream in(input_); util::stream::Stream out(output); std::vector<WordIndex> previous(in->Order() - 1); + // Silly windows requires this workaround to just get an invalid pointer when empty. + void *const previous_raw = previous.empty() ? NULL : static_cast<void*>(&previous[0]); const std::size_t size = sizeof(WordIndex) * previous.size(); + for(; in; ++out) { - memcpy(&previous[0], in->begin(), size); + memcpy(previous_raw, in->begin(), size); uint64_t denominator = 0; + uint64_t normalizer = 0; + uint64_t counts[4]; memset(counts, 0, sizeof(counts)); do { - denominator += in->Count(); - ++counts[std::min(in->Count(), static_cast<uint64_t>(3))]; - } while (++in && !memcmp(&previous[0], in->begin(), size)); + denominator += in->UnmarkedCount(); + + // Collect unused probability mass from pruning. + // Becomes 0 for unpruned ngrams. + normalizer += in->UnmarkedCount() - in->CutoffCount(); + + // Chen&Goodman do not mention counting based on cutoffs, but + // backoff becomes larger than 1 otherwise, so probably needs + // to count cutoffs. Counts normally without pruning. + if(in->CutoffCount() > 0) + ++counts[std::min(in->CutoffCount(), static_cast<uint64_t>(3))]; + + } while (++in && !memcmp(previous_raw, in->begin(), size)); + BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get()); entry.denominator = static_cast<float>(denominator); entry.gamma = 0.0; for (unsigned i = 1; i <= 3; ++i) { entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]); } + + // Makes model sum to 1 with pruning (I hope). + entry.gamma += normalizer; + entry.gamma /= entry.denominator; + + if(pruning_) { + // If pruning is enabled the stream actually contains HashBufferEntry, see InitialProbabilities(...), + // so add a hash value that identifies the current ngram. + static_cast<HashBufferEntry*>(&entry)->hash_value = util::MurmurHashNative(previous_raw, size); + } } out.Poison(); } @@ -70,6 +199,7 @@ class AddRight { private: const Discount &discount_; const util::stream::ChainPosition input_; + bool pruning_; }; class MergeRight { @@ -82,7 +212,7 @@ class MergeRight { void Run(const util::stream::ChainPosition &primary) { util::stream::Stream summed(from_adder_); - NGramStream grams(primary); + PruneNGramStream grams(primary); // Without interpolation, the interpolation weight goes to <unk>. if (grams->Order() == 1 && !interpolate_unigrams_) { @@ -97,15 +227,16 @@ class MergeRight { ++summed; return; } - + std::vector<WordIndex> previous(grams->Order() - 1); const std::size_t size = sizeof(WordIndex) * previous.size(); for (; grams; ++summed) { memcpy(&previous[0], grams->begin(), size); const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get()); + do { Payload &pay = grams->Value(); - pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator; + pay.uninterp.prob = discount_.Apply(grams->UnmarkedCount()) / sums.denominator; pay.uninterp.gamma = sums.gamma; } while (++grams && !memcmp(&previous[0], grams->begin(), size)); } @@ -119,17 +250,29 @@ class MergeRight { } // namespace -void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) { - util::stream::ChainConfig gamma_config = config.adder_out; - gamma_config.entry_size = sizeof(BufferEntry); +void InitialProbabilities( + const InitialProbabilitiesConfig &config, + const std::vector<Discount> &discounts, + util::stream::Chains &primary, + util::stream::Chains &second_in, + util::stream::Chains &gamma_out, + const std::vector<uint64_t> &prune_thresholds) { for (size_t i = 0; i < primary.size(); ++i) { + util::stream::ChainConfig gamma_config = config.adder_out; + if(prune_thresholds[i] > 0) + gamma_config.entry_size = sizeof(HashBufferEntry); + else + gamma_config.entry_size = sizeof(BufferEntry); + util::stream::ChainPosition second(second_in[i].Add()); second_in[i] >> util::stream::kRecycle; gamma_out.push_back(gamma_config); - gamma_out[i] >> AddRight(discounts[i], second); + gamma_out[i] >> AddRight(discounts[i], second, prune_thresholds[i] > 0); + primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]); - // Don't bother with the OnlyGamma thread for something to discard. - if (i) gamma_out[i] >> OnlyGamma(); + + // Don't bother with the OnlyGamma thread for something to discard. + if (i) gamma_out[i] >> OnlyGamma(prune_thresholds[i] > 0); } } diff --git a/klm/lm/builder/initial_probabilities.hh b/klm/lm/builder/initial_probabilities.hh index 626388eb..c1010e08 100644 --- a/klm/lm/builder/initial_probabilities.hh +++ b/klm/lm/builder/initial_probabilities.hh @@ -1,14 +1,15 @@ -#ifndef LM_BUILDER_INITIAL_PROBABILITIES__ -#define LM_BUILDER_INITIAL_PROBABILITIES__ +#ifndef LM_BUILDER_INITIAL_PROBABILITIES_H +#define LM_BUILDER_INITIAL_PROBABILITIES_H #include "lm/builder/discount.hh" #include "util/stream/config.hh" #include <vector> +namespace util { namespace stream { class Chains; } } + namespace lm { namespace builder { -class Chains; struct InitialProbabilitiesConfig { // These should be small buffers to keep the adder from getting too far ahead @@ -26,9 +27,15 @@ struct InitialProbabilitiesConfig { * The values are bare floats and should be buffered for interpolation to * use. */ -void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out); +void InitialProbabilities( + const InitialProbabilitiesConfig &config, + const std::vector<Discount> &discounts, + util::stream::Chains &primary, + util::stream::Chains &second_in, + util::stream::Chains &gamma_out, + const std::vector<uint64_t> &prune_thresholds); } // namespace builder } // namespace lm -#endif // LM_BUILDER_INITIAL_PROBABILITIES__ +#endif // LM_BUILDER_INITIAL_PROBABILITIES_H diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc index 50026806..a7947a42 100644 --- a/klm/lm/builder/interpolate.cc +++ b/klm/lm/builder/interpolate.cc @@ -1,18 +1,74 @@ #include "lm/builder/interpolate.hh" +#include "lm/builder/hash_gamma.hh" #include "lm/builder/joint_order.hh" -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "lm/builder/sort.hh" #include "lm/lm_exception.hh" +#include "util/fixed_array.hh" +#include "util/murmur_hash.hh" #include <assert.h> +#include <math.h> namespace lm { namespace builder { namespace { -class Callback { +/* Calculate q, the collapsed probability and backoff, as defined in + * @inproceedings{Heafield-rest, + * author = {Kenneth Heafield and Philipp Koehn and Alon Lavie}, + * title = {Language Model Rest Costs and Space-Efficient Storage}, + * year = {2012}, + * month = {July}, + * booktitle = {Proceedings of the Joint Conference on Empirical Methods in Natural Language Processing and Computational Natural Language Learning}, + * address = {Jeju Island, Korea}, + * pages = {1169--1178}, + * url = {http://kheafield.com/professional/edinburgh/rest\_paper.pdf}, + * } + * This is particularly convenient to calculate during interpolation because + * the needed backoff terms are already accessed at the same time. + */ +class OutputQ { public: - Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) { + explicit OutputQ(std::size_t order) : q_delta_(order) {} + + void Gram(unsigned order_minus_1, float full_backoff, ProbBackoff &out) { + float &q_del = q_delta_[order_minus_1]; + if (order_minus_1) { + // Divide by context's backoff (which comes in as out.backoff) + q_del = q_delta_[order_minus_1 - 1] / out.backoff * full_backoff; + } else { + q_del = full_backoff; + } + out.prob = log10f(out.prob * q_del); + // TODO: stop wastefully outputting this! + out.backoff = 0.0; + } + + private: + // Product of backoffs in the numerator divided by backoffs in the + // denominator. Does not include + std::vector<float> q_delta_; +}; + +/* Default: output probability and backoff */ +class OutputProbBackoff { + public: + explicit OutputProbBackoff(std::size_t /*order*/) {} + + void Gram(unsigned /*order_minus_1*/, float full_backoff, ProbBackoff &out) const { + // Correcting for numerical precision issues. Take that IRST. + out.prob = std::min(0.0f, log10f(out.prob)); + out.backoff = log10f(full_backoff); + } +}; + +template <class Output> class Callback { + public: + Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds) + : backoffs_(backoffs.size()), probs_(backoffs.size() + 2), + prune_thresholds_(prune_thresholds), + output_(backoffs.size() + 1 /* order */) { probs_[0] = uniform_prob; for (std::size_t i = 0; i < backoffs.size(); ++i) { backoffs_.push_back(backoffs[i]); @@ -21,6 +77,10 @@ class Callback { ~Callback() { for (std::size_t i = 0; i < backoffs_.size(); ++i) { + if(prune_thresholds_[i + 1] > 0) + while(backoffs_[i]) + ++backoffs_[i]; + if (backoffs_[i]) { std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl; abort(); @@ -32,34 +92,66 @@ class Callback { Payload &pay = gram.Value(); pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; probs_[order_minus_1 + 1] = pay.complete.prob; - pay.complete.prob = log10(pay.complete.prob); - // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling. + + float out_backoff; if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { - pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get())); - ++backoffs_[order_minus_1]; + if(prune_thresholds_[order_minus_1 + 1] > 0) { + //Compute hash value for current context + uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex)); + + const HashGamma *hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get()); + while(current_hash != hashed_backoff->hash_value && ++backoffs_[order_minus_1]) + hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get()); + + if(current_hash == hashed_backoff->hash_value) { + out_backoff = hashed_backoff->gamma; + ++backoffs_[order_minus_1]; + } else { + // Has been pruned away so it is not a context anymore + out_backoff = 1.0; + } + } else { + out_backoff = *static_cast<const float*>(backoffs_[order_minus_1].Get()); + ++backoffs_[order_minus_1]; + } } else { - // Not a context. - pay.complete.backoff = 0.0; + // Not a context. + out_backoff = 1.0; } + + output_.Gram(order_minus_1, out_backoff, pay.complete); } void Exit(unsigned, const NGram &) const {} private: - FixedArray<util::stream::Stream> backoffs_; + util::FixedArray<util::stream::Stream> backoffs_; std::vector<float> probs_; + const std::vector<uint64_t>& prune_thresholds_; + + Output output_; }; } // namespace -Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs) - : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {} +Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds, bool output_q) + : uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>. + backoffs_(backoffs), + prune_thresholds_(prune_thresholds), + output_q_(output_q) {} // perform order-wise interpolation -void Interpolate::Run(const ChainPositions &positions) { +void Interpolate::Run(const util::stream::ChainPositions &positions) { assert(positions.size() == backoffs_.size() + 1); - Callback callback(uniform_prob_, backoffs_); - JointOrder<Callback, SuffixOrder>(positions, callback); + if (output_q_) { + typedef Callback<OutputQ> C; + C callback(uniform_prob_, backoffs_, prune_thresholds_); + JointOrder<C, SuffixOrder>(positions, callback); + } else { + typedef Callback<OutputProbBackoff> C; + C callback(uniform_prob_, backoffs_, prune_thresholds_); + JointOrder<C, SuffixOrder>(positions, callback); + } } }} // namespaces diff --git a/klm/lm/builder/interpolate.hh b/klm/lm/builder/interpolate.hh index 9268d404..0acece92 100644 --- a/klm/lm/builder/interpolate.hh +++ b/klm/lm/builder/interpolate.hh @@ -1,9 +1,11 @@ -#ifndef LM_BUILDER_INTERPOLATE__ -#define LM_BUILDER_INTERPOLATE__ +#ifndef LM_BUILDER_INTERPOLATE_H +#define LM_BUILDER_INTERPOLATE_H -#include <stdint.h> +#include "util/stream/multi_stream.hh" + +#include <vector> -#include "lm/builder/multi_stream.hh" +#include <stdint.h> namespace lm { namespace builder { @@ -14,14 +16,18 @@ namespace lm { namespace builder { */ class Interpolate { public: - explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs); + // Normally vocab_size is the unigram count-1 (since p(<s>) = 0) but might + // be larger when the user specifies a consistent vocabulary size. + explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool output_q_); - void Run(const ChainPositions &positions); + void Run(const util::stream::ChainPositions &positions); private: float uniform_prob_; - ChainPositions backoffs_; + util::stream::ChainPositions backoffs_; + const std::vector<uint64_t> prune_thresholds_; + bool output_q_; }; }} // namespaces -#endif // LM_BUILDER_INTERPOLATE__ +#endif // LM_BUILDER_INTERPOLATE_H diff --git a/klm/lm/builder/joint_order.hh b/klm/lm/builder/joint_order.hh index b5620144..7235d4f7 100644 --- a/klm/lm/builder/joint_order.hh +++ b/klm/lm/builder/joint_order.hh @@ -1,14 +1,14 @@ -#ifndef LM_BUILDER_JOINT_ORDER__ -#define LM_BUILDER_JOINT_ORDER__ +#ifndef LM_BUILDER_JOINT_ORDER_H +#define LM_BUILDER_JOINT_ORDER_H -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "lm/lm_exception.hh" #include <string.h> namespace lm { namespace builder { -template <class Callback, class Compare> void JointOrder(const ChainPositions &positions, Callback &callback) { +template <class Callback, class Compare> void JointOrder(const util::stream::ChainPositions &positions, Callback &callback) { // Allow matching to reference streams[-1]. NGramStreams streams_with_dummy; streams_with_dummy.InitWithDummy(positions); @@ -40,4 +40,4 @@ template <class Callback, class Compare> void JointOrder(const ChainPositions &p }} // namespaces -#endif // LM_BUILDER_JOINT_ORDER__ +#endif // LM_BUILDER_JOINT_ORDER_H diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc index 2563deed..265dd216 100644 --- a/klm/lm/builder/lmplz_main.cc +++ b/klm/lm/builder/lmplz_main.cc @@ -1,4 +1,5 @@ #include "lm/builder/pipeline.hh" +#include "lm/lm_exception.hh" #include "util/file.hh" #include "util/file_piece.hh" #include "util/usage.hh" @@ -7,6 +8,7 @@ #include <boost/program_options.hpp> #include <boost/version.hpp> +#include <vector> namespace { class SizeNotify { @@ -25,6 +27,57 @@ boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, co return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value); } +// Parse and validate pruning thresholds then return vector of threshold counts +// for each n-grams order. +std::vector<uint64_t> ParsePruning(const std::vector<std::string> ¶m, std::size_t order) { + // convert to vector of integers + std::vector<uint64_t> prune_thresholds; + prune_thresholds.reserve(order); + for (std::vector<std::string>::const_iterator it(param.begin()); it != param.end(); ++it) { + try { + prune_thresholds.push_back(boost::lexical_cast<uint64_t>(*it)); + } catch(const boost::bad_lexical_cast &) { + UTIL_THROW(util::Exception, "Bad pruning threshold " << *it); + } + } + + // Fill with zeros by default. + if (prune_thresholds.empty()) { + prune_thresholds.resize(order, 0); + return prune_thresholds; + } + + // validate pruning threshold if specified + // throw if each n-gram order has not threshold specified + UTIL_THROW_IF(prune_thresholds.size() > order, util::Exception, "You specified pruning thresholds for orders 1 through " << prune_thresholds.size() << " but the model only has order " << order); + // threshold for unigram can only be 0 (no pruning) + UTIL_THROW_IF(prune_thresholds[0] != 0, util::Exception, "Unigram pruning is not implemented, so the first pruning threshold must be 0."); + + // check if threshold are not in decreasing order + uint64_t lower_threshold = 0; + for (std::vector<uint64_t>::iterator it = prune_thresholds.begin(); it != prune_thresholds.end(); ++it) { + UTIL_THROW_IF(lower_threshold > *it, util::Exception, "Pruning thresholds should be in non-decreasing order. Otherwise substrings would be removed, which is bad for query-time data structures."); + lower_threshold = *it; + } + + // Pad to all orders using the last value. + prune_thresholds.resize(order, prune_thresholds.back()); + return prune_thresholds; +} + +lm::builder::Discount ParseDiscountFallback(const std::vector<std::string> ¶m) { + lm::builder::Discount ret; + UTIL_THROW_IF(param.size() > 3, util::Exception, "Specify at most three fallback discounts: 1, 2, and 3+"); + UTIL_THROW_IF(param.empty(), util::Exception, "Fallback discounting enabled, but no discount specified"); + ret.amount[0] = 0.0; + for (unsigned i = 0; i < 3; ++i) { + float discount = boost::lexical_cast<float>(param[i < param.size() ? i : (param.size() - 1)]); + UTIL_THROW_IF(discount < 0.0 || discount > static_cast<float>(i+1), util::Exception, "The discount for count " << (i+1) << " was parsed as " << discount << " which is not in the range [0, " << (i+1) << "]."); + ret.amount[i + 1] = discount; + } + return ret; +} + } // namespace int main(int argc, char *argv[]) { @@ -34,25 +87,36 @@ int main(int argc, char *argv[]) { lm::builder::PipelineConfig pipeline; std::string text, arpa; + std::vector<std::string> pruning; + std::vector<std::string> discount_fallback; + std::vector<std::string> discount_fallback_default; + discount_fallback_default.push_back("0.5"); + discount_fallback_default.push_back("1"); + discount_fallback_default.push_back("1.5"); options.add_options() - ("help", po::bool_switch(), "Show this help message") + ("help,h", po::bool_switch(), "Show this help message") ("order,o", po::value<std::size_t>(&pipeline.order) #if BOOST_VERSION >= 104200 ->required() #endif , "Order of the model") - ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)") + ("interpolate_unigrams", po::value<bool>(&pipeline.initial_probs.interpolate_unigrams)->default_value(true)->implicit_value(true), "Interpolate the unigrams (default) as opposed to giving lots of mass to <unk> like SRI. If you want SRI's behavior with a large <unk> and the old lmplz default, use --interpolate_unigrams 0.") + ("skip_symbols", po::bool_switch(), "Treat <s>, </s>, and <unk> as whitespace instead of throwing an exception") ("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix") ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow") ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)") - ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table") ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)") - ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") + ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table") + ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write a file containing the unique vocabulary strings delimited by null bytes") + ("vocab_pad", po::value<uint64_t>(&pipeline.vocab_size_for_unk)->default_value(0), "If the vocabulary is smaller than this value, pad with <unk> to reach this size. Requires --interpolate_unigrams") ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.") ("text", po::value<std::string>(&text), "Read text from a file instead of stdin") - ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout"); + ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout") + ("collapse_values", po::bool_switch(&pipeline.output_q), "Collapse probability and backoff into a single value, q that yields the same sentence-level probabilities. See http://kheafield.com/professional/edinburgh/rest_paper.pdf for more details, including a proof.") + ("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Unigram pruning is not implemented, so the first value must be zero. Default is to not prune, which is equivalent to --prune 0.") + ("discount_fallback", po::value<std::vector<std::string> >(&discount_fallback)->multitoken()->implicit_value(discount_fallback_default, "0.5 1 1.5"), "The closed-form estimate for Kneser-Ney discounts does not work without singletons or doubletons. It can also fail if these values are out of range. This option falls back to user-specified discounts when the closed-form estimate fails. Note that this option is generally a bad idea: you should deduplicate your corpus instead. However, class-based models need custom discounts because they lack singleton unigrams. Provide up to three discounts (for adjusted counts 1, 2, and 3+), which will be applied to all orders where the closed-form estimates fail."); po::variables_map vm; po::store(po::parse_command_line(argc, argv, options), vm); @@ -95,6 +159,29 @@ int main(int argc, char *argv[]) { } #endif + if (pipeline.vocab_size_for_unk && !pipeline.initial_probs.interpolate_unigrams) { + std::cerr << "--vocab_pad requires --interpolate_unigrams be on" << std::endl; + return 1; + } + + if (vm["skip_symbols"].as<bool>()) { + pipeline.disallowed_symbol_action = lm::COMPLAIN; + } else { + pipeline.disallowed_symbol_action = lm::THROW_UP; + } + + if (vm.count("discount_fallback")) { + pipeline.discount.fallback = ParseDiscountFallback(discount_fallback); + pipeline.discount.bad_action = lm::COMPLAIN; + } else { + // Unused, just here to prevent the compiler from complaining about uninitialized. + pipeline.discount.fallback = lm::builder::Discount(); + pipeline.discount.bad_action = lm::THROW_UP; + } + + // parse pruning thresholds. These depend on order, so it is not done as a notifier. + pipeline.prune_thresholds = ParsePruning(pruning, pipeline.order); + util::NormalizeTempPrefix(pipeline.sort.temp_prefix); lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh index f5681516..0472bcb1 100644 --- a/klm/lm/builder/ngram.hh +++ b/klm/lm/builder/ngram.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_NGRAM__ -#define LM_BUILDER_NGRAM__ +#ifndef LM_BUILDER_NGRAM_H +#define LM_BUILDER_NGRAM_H #include "lm/weights.hh" #include "lm/word_index.hh" @@ -26,7 +26,7 @@ union Payload { class NGram { public: - NGram(void *begin, std::size_t order) + NGram(void *begin, std::size_t order) : begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {} const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); } @@ -38,12 +38,12 @@ class NGram { end_ = begin_ + difference; } - // Would do operator++ but that can get confusing for a stream. + // Would do operator++ but that can get confusing for a stream. void NextInMemory() { ReBase(&Value() + 1); } - // Lower-case in deference to STL. + // Lower-case in deference to STL. const WordIndex *begin() const { return begin_; } WordIndex *begin() { return begin_; } const WordIndex *end() const { return end_; } @@ -61,7 +61,7 @@ class NGram { return order * sizeof(WordIndex) + sizeof(Payload); } std::size_t TotalSize() const { - // Compiler should optimize this. + // Compiler should optimize this. return TotalSize(Order()); } static std::size_t OrderFromSize(std::size_t size) { @@ -69,6 +69,31 @@ class NGram { assert(size == TotalSize(ret)); return ret; } + + // manipulate msb to signal that ngram can be pruned + /*mjd**********************************************************************/ + + bool IsMarked() const { + return Value().count >> (sizeof(Value().count) * 8 - 1); + } + + void Mark() { + Value().count |= (1ul << (sizeof(Value().count) * 8 - 1)); + } + + void Unmark() { + Value().count &= ~(1ul << (sizeof(Value().count) * 8 - 1)); + } + + uint64_t UnmarkedCount() const { + return Value().count & ~(1ul << (sizeof(Value().count) * 8 - 1)); + } + + uint64_t CutoffCount() const { + return IsMarked() ? 0 : UnmarkedCount(); + } + + /*mjd**********************************************************************/ private: WordIndex *begin_, *end_; @@ -81,4 +106,4 @@ const WordIndex kEOS = 2; } // namespace builder } // namespace lm -#endif // LM_BUILDER_NGRAM__ +#endif // LM_BUILDER_NGRAM_H diff --git a/klm/lm/builder/ngram_stream.hh b/klm/lm/builder/ngram_stream.hh index 3c994664..ab42734c 100644 --- a/klm/lm/builder/ngram_stream.hh +++ b/klm/lm/builder/ngram_stream.hh @@ -1,8 +1,9 @@ -#ifndef LM_BUILDER_NGRAM_STREAM__ -#define LM_BUILDER_NGRAM_STREAM__ +#ifndef LM_BUILDER_NGRAM_STREAM_H +#define LM_BUILDER_NGRAM_STREAM_H #include "lm/builder/ngram.hh" #include "util/stream/chain.hh" +#include "util/stream/multi_stream.hh" #include "util/stream/stream.hh" #include <cstddef> @@ -51,5 +52,7 @@ inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream & return chain; } +typedef util::stream::GenericStreams<NGramStream> NGramStreams; + }} // namespaces -#endif // LM_BUILDER_NGRAM_STREAM__ +#endif // LM_BUILDER_NGRAM_STREAM_H diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc index 44a2313c..21064ab3 100644 --- a/klm/lm/builder/pipeline.cc +++ b/klm/lm/builder/pipeline.cc @@ -2,6 +2,7 @@ #include "lm/builder/adjust_counts.hh" #include "lm/builder/corpus_count.hh" +#include "lm/builder/hash_gamma.hh" #include "lm/builder/initial_probabilities.hh" #include "lm/builder/interpolate.hh" #include "lm/builder/print.hh" @@ -20,10 +21,13 @@ namespace lm { namespace builder { namespace { -void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts) { +void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts) { std::cerr << "Statistics:\n"; for (size_t i = 0; i < counts.size(); ++i) { - std::cerr << (i + 1) << ' ' << counts[i]; + std::cerr << (i + 1) << ' ' << counts_pruned[i]; + if(counts[i] != counts_pruned[i]) + std::cerr << "/" << counts[i]; + for (size_t d = 1; d <= 3; ++d) std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d]; std::cerr << '\n'; @@ -39,7 +43,7 @@ class Master { const PipelineConfig &Config() const { return config_; } - Chains &MutableChains() { return chains_; } + util::stream::Chains &MutableChains() { return chains_; } template <class T> Master &operator>>(const T &worker) { chains_ >> worker; @@ -64,7 +68,7 @@ class Master { } // For initial probabilities, but this is generic. - void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, Chains &second, util::stream::ChainConfig second_config) { + void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, util::stream::Chains &second, util::stream::ChainConfig second_config) { // Do merge first before allocating chain memory. for (std::size_t i = 1; i < config_.order; ++i) { sorts[i - 1].Merge(0); @@ -198,9 +202,9 @@ class Master { PipelineConfig config_; - Chains chains_; + util::stream::Chains chains_; // Often only unigrams, but sometimes all orders. - FixedArray<util::stream::FileBuffer> files_; + util::FixedArray<util::stream::FileBuffer> files_; }; void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) { @@ -221,7 +225,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m WordIndex type_count = config.vocab_estimate; util::FilePiece text(text_file, NULL, &std::cerr); text_file_name = text.FileName(); - CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize()); + CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action); chain >> boost::ref(counter); util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner()); @@ -231,21 +235,22 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m master.InitForAdjust(sorter, type_count); } -void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) { +void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, + util::FixedArray<util::stream::FileBuffer> &gammas, const std::vector<uint64_t> &prune_thresholds) { const PipelineConfig &config = master.Config(); - Chains second(config.order); + util::stream::Chains second(config.order); { Sorts<ContextOrder> sorts; master.SetupSorts(sorts); - PrintStatistics(counts, discounts); - lm::ngram::ShowSizes(counts); + PrintStatistics(counts, counts_pruned, discounts); + lm::ngram::ShowSizes(counts_pruned); std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl; - master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in); + master.SortAndReadTwice(counts_pruned, sorts, second, config.initial_probs.adder_in); } - Chains gamma_chains(config.order); - InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains); + util::stream::Chains gamma_chains(config.order); + InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds); // Don't care about gamma for 0. gamma_chains[0] >> util::stream::kRecycle; gammas.Init(config.order - 1); @@ -257,19 +262,25 @@ void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector master.SetupSorts(primary); } -void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) { +void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas) { std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl; const PipelineConfig &config = master.Config(); master.MaximumLazyInput(counts, primary); - Chains gamma_chains(config.order - 1); - util::stream::ChainConfig read_backoffs(config.read_backoffs); - read_backoffs.entry_size = sizeof(float); + util::stream::Chains gamma_chains(config.order - 1); for (std::size_t i = 0; i < config.order - 1; ++i) { + util::stream::ChainConfig read_backoffs(config.read_backoffs); + + // Add 1 because here we are skipping unigrams + if(config.prune_thresholds[i + 1] > 0) + read_backoffs.entry_size = sizeof(HashGamma); + else + read_backoffs.entry_size = sizeof(float); + gamma_chains.push_back(read_backoffs); gamma_chains.back() >> gammas[i].Source(); } - master >> Interpolate(counts[0], ChainPositions(gamma_chains)); + master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.output_q); gamma_chains >> util::stream::kRecycle; master.BufferFinal(counts); } @@ -291,32 +302,40 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) { "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size."); UTIL_TIMER("(%w s) Total wall time elapsed\n"); - Master master(config); - - util::scoped_fd vocab_file(config.vocab_file.empty() ? - util::MakeTemp(config.TempPrefix()) : - util::CreateOrThrow(config.vocab_file.c_str())); - uint64_t token_count; - std::string text_file_name; - CountText(text_file, vocab_file.get(), master, token_count, text_file_name); - std::vector<uint64_t> counts; - std::vector<Discount> discounts; - master >> AdjustCounts(counts, discounts); + Master master(config); + // master's destructor will wait for chains. But they might be deadlocked if + // this thread dies because e.g. it ran out of memory. + try { + util::scoped_fd vocab_file(config.vocab_file.empty() ? + util::MakeTemp(config.TempPrefix()) : + util::CreateOrThrow(config.vocab_file.c_str())); + uint64_t token_count; + std::string text_file_name; + CountText(text_file, vocab_file.get(), master, token_count, text_file_name); + + std::vector<uint64_t> counts; + std::vector<uint64_t> counts_pruned; + std::vector<Discount> discounts; + master >> AdjustCounts(config.prune_thresholds, counts, counts_pruned, config.discount, discounts); + + { + util::FixedArray<util::stream::FileBuffer> gammas; + Sorts<SuffixOrder> primary; + InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds); + InterpolateProbabilities(counts_pruned, master, primary, gammas); + } - { - FixedArray<util::stream::FileBuffer> gammas; - Sorts<SuffixOrder> primary; - InitialProbabilities(counts, discounts, master, primary, gammas); - InterpolateProbabilities(counts, master, primary, gammas); + std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; + VocabReconstitute vocab(vocab_file.get()); + UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?"); + HeaderInfo header_info(text_file_name, token_count); + master >> PrintARPA(vocab, counts_pruned, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; + master.MutableChains().Wait(true); + } catch (const util::Exception &e) { + std::cerr << e.what() << std::endl; + abort(); } - - std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; - VocabReconstitute vocab(vocab_file.get()); - UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?"); - HeaderInfo header_info(text_file_name, token_count); - master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; - master.MutableChains().Wait(true); } }} // namespaces diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh index 845e5481..09e1a4d5 100644 --- a/klm/lm/builder/pipeline.hh +++ b/klm/lm/builder/pipeline.hh @@ -1,8 +1,10 @@ -#ifndef LM_BUILDER_PIPELINE__ -#define LM_BUILDER_PIPELINE__ +#ifndef LM_BUILDER_PIPELINE_H +#define LM_BUILDER_PIPELINE_H +#include "lm/builder/adjust_counts.hh" #include "lm/builder/initial_probabilities.hh" #include "lm/builder/header_info.hh" +#include "lm/lm_exception.hh" #include "lm/word_index.hh" #include "util/stream/config.hh" #include "util/file_piece.hh" @@ -18,6 +20,8 @@ struct PipelineConfig { util::stream::SortConfig sort; InitialProbabilitiesConfig initial_probs; util::stream::ChainConfig read_backoffs; + + // Include a header in the ARPA with some statistics? bool verbose_header; // Estimated vocabulary size. Used for sizing CorpusCount memory and @@ -30,6 +34,34 @@ struct PipelineConfig { // Number of blocks to use. This will be overridden to 1 if everything fits. std::size_t block_count; + // n-gram count thresholds for pruning. 0 values means no pruning for + // corresponding n-gram order + std::vector<uint64_t> prune_thresholds; //mjd + + // What to do with discount failures. + DiscountConfig discount; + + // Compute collapsed q values instead of probability and backoff + bool output_q; + + /* Computing the perplexity of LMs with different vocabularies is hard. For + * example, the lowest perplexity is attained by a unigram model that + * predicts p(<unk>) = 1 and has no other vocabulary. Also, linearly + * interpolated models will sum to more than 1 because <unk> is duplicated + * (SRI just pretends p(<unk>) = 0 for these purposes, which makes it sum to + * 1 but comes with its own problems). This option will make the vocabulary + * a particular size by replicating <unk> multiple times for purposes of + * computing vocabulary size. It has no effect if the actual vocabulary is + * larger. This parameter serves the same purpose as IRSTLM's "dub". + */ + uint64_t vocab_size_for_unk; + + /* What to do the first time <s>, </s>, or <unk> appears in the input. If + * this is anything but THROW_UP, then the symbol will always be treated as + * whitespace. + */ + WarningAction disallowed_symbol_action; + const std::string &TempPrefix() const { return sort.temp_prefix; } std::size_t TotalMemory() const { return sort.total_memory; } }; @@ -38,4 +70,4 @@ struct PipelineConfig { void Pipeline(PipelineConfig config, int text_file, int out_arpa); }} // namespaces -#endif // LM_BUILDER_PIPELINE__ +#endif // LM_BUILDER_PIPELINE_H diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc index 84bd81ca..aee6e134 100644 --- a/klm/lm/builder/print.cc +++ b/klm/lm/builder/print.cc @@ -42,22 +42,22 @@ PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> util::WriteOrThrow(out_fd, as_string.data(), as_string.size()); } -void PrintARPA::Run(const ChainPositions &positions) { +void PrintARPA::Run(const util::stream::ChainPositions &positions) { util::scoped_fd closer(out_fd_); UTIL_TIMER("(%w s) Wrote ARPA file\n"); util::FakeOFStream out(out_fd_); for (unsigned order = 1; order <= positions.size(); ++order) { out << "\\" << order << "-grams:" << '\n'; for (NGramStream stream(positions[order - 1]); stream; ++stream) { - // Correcting for numerical precision issues. Take that IRST. - out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin()); + // Correcting for numerical precision issues. Take that IRST. + out << stream->Value().complete.prob << '\t' << vocab_.Lookup(*stream->begin()); for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) { out << ' ' << vocab_.Lookup(*i); } - float backoff = stream->Value().complete.backoff; - if (backoff != 0.0) - out << '\t' << backoff; + if (order != positions.size()) + out << '\t' << stream->Value().complete.backoff; out << '\n'; + } out << '\n'; } diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh index adbbb94a..9856cea8 100644 --- a/klm/lm/builder/print.hh +++ b/klm/lm/builder/print.hh @@ -1,8 +1,8 @@ -#ifndef LM_BUILDER_PRINT__ -#define LM_BUILDER_PRINT__ +#ifndef LM_BUILDER_PRINT_H +#define LM_BUILDER_PRINT_H #include "lm/builder/ngram.hh" -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "lm/builder/header_info.hh" #include "util/file.hh" #include "util/mmap.hh" @@ -59,7 +59,7 @@ template <class V> class Print { public: explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {} - void Run(const ChainPositions &chains) { + void Run(const util::stream::ChainPositions &chains) { NGramStreams streams(chains); for (NGramStream *s = streams.begin(); s != streams.end(); ++s) { DumpStream(*s); @@ -92,7 +92,7 @@ class PrintARPA { // Takes ownership of out_fd upon Run(). explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd); - void Run(const ChainPositions &positions); + void Run(const util::stream::ChainPositions &positions); private: const VocabReconstitute &vocab_; @@ -100,4 +100,4 @@ class PrintARPA { }; }} // namespaces -#endif // LM_BUILDER_PRINT__ +#endif // LM_BUILDER_PRINT_H diff --git a/klm/lm/builder/sort.hh b/klm/lm/builder/sort.hh index 9989389b..712bb8e3 100644 --- a/klm/lm/builder/sort.hh +++ b/klm/lm/builder/sort.hh @@ -1,7 +1,7 @@ -#ifndef LM_BUILDER_SORT__ -#define LM_BUILDER_SORT__ +#ifndef LM_BUILDER_SORT_H +#define LM_BUILDER_SORT_H -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "lm/builder/ngram.hh" #include "lm/word_index.hh" #include "util/stream/sort.hh" @@ -14,24 +14,71 @@ namespace lm { namespace builder { +/** + * Abstract parent class for defining custom n-gram comparators. + */ template <class Child> class Comparator : public std::binary_function<const void *, const void *, bool> { public: + + /** + * Constructs a comparator capable of comparing two n-grams. + * + * @param order Number of words in each n-gram + */ explicit Comparator(std::size_t order) : order_(order) {} + /** + * Applies the comparator using the Compare method that must be defined in any class that inherits from this class. + * + * @param lhs A pointer to the n-gram on the left-hand side of the comparison + * @param rhs A pointer to the n-gram on the right-hand side of the comparison + * + * @see ContextOrder::Compare + * @see PrefixOrder::Compare + * @see SuffixOrder::Compare + */ inline bool operator()(const void *lhs, const void *rhs) const { return static_cast<const Child*>(this)->Compare(static_cast<const WordIndex*>(lhs), static_cast<const WordIndex*>(rhs)); } + /** Gets the n-gram order defined for this comparator. */ std::size_t Order() const { return order_; } protected: std::size_t order_; }; +/** + * N-gram comparator that compares n-grams according to their reverse (suffix) order. + * + * This comparator compares n-grams lexicographically, one word at a time, + * beginning with the last word of each n-gram and ending with the first word of each n-gram. + * + * Some examples of n-gram comparisons as defined by this comparator: + * - a b c == a b c + * - a b c < a b d + * - a b c > a d b + * - a b c > a b b + * - a b c > x a c + * - a b c < x y z + */ class SuffixOrder : public Comparator<SuffixOrder> { public: + + /** + * Constructs a comparator capable of comparing two n-grams. + * + * @param order Number of words in each n-gram + */ explicit SuffixOrder(std::size_t order) : Comparator<SuffixOrder>(order) {} + /** + * Compares two n-grams lexicographically, one word at a time, + * beginning with the last word of each n-gram and ending with the first word of each n-gram. + * + * @param lhs A pointer to the n-gram on the left-hand side of the comparison + * @param rhs A pointer to the n-gram on the right-hand side of the comparison + */ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { for (std::size_t i = order_ - 1; i != 0; --i) { if (lhs[i] != rhs[i]) @@ -43,10 +90,40 @@ class SuffixOrder : public Comparator<SuffixOrder> { static const unsigned kMatchOffset = 1; }; + +/** + * N-gram comparator that compares n-grams according to the reverse (suffix) order of the n-gram context. + * + * This comparator compares n-grams lexicographically, one word at a time, + * beginning with the penultimate word of each n-gram and ending with the first word of each n-gram; + * finally, this comparator compares the last word of each n-gram. + * + * Some examples of n-gram comparisons as defined by this comparator: + * - a b c == a b c + * - a b c < a b d + * - a b c < a d b + * - a b c > a b b + * - a b c > x a c + * - a b c < x y z + */ class ContextOrder : public Comparator<ContextOrder> { public: + + /** + * Constructs a comparator capable of comparing two n-grams. + * + * @param order Number of words in each n-gram + */ explicit ContextOrder(std::size_t order) : Comparator<ContextOrder>(order) {} + /** + * Compares two n-grams lexicographically, one word at a time, + * beginning with the penultimate word of each n-gram and ending with the first word of each n-gram; + * finally, this comparator compares the last word of each n-gram. + * + * @param lhs A pointer to the n-gram on the left-hand side of the comparison + * @param rhs A pointer to the n-gram on the right-hand side of the comparison + */ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { for (int i = order_ - 2; i >= 0; --i) { if (lhs[i] != rhs[i]) @@ -56,10 +133,37 @@ class ContextOrder : public Comparator<ContextOrder> { } }; +/** + * N-gram comparator that compares n-grams according to their natural (prefix) order. + * + * This comparator compares n-grams lexicographically, one word at a time, + * beginning with the first word of each n-gram and ending with the last word of each n-gram. + * + * Some examples of n-gram comparisons as defined by this comparator: + * - a b c == a b c + * - a b c < a b d + * - a b c < a d b + * - a b c > a b b + * - a b c < x a c + * - a b c < x y z + */ class PrefixOrder : public Comparator<PrefixOrder> { public: + + /** + * Constructs a comparator capable of comparing two n-grams. + * + * @param order Number of words in each n-gram + */ explicit PrefixOrder(std::size_t order) : Comparator<PrefixOrder>(order) {} + /** + * Compares two n-grams lexicographically, one word at a time, + * beginning with the first word of each n-gram and ending with the last word of each n-gram. + * + * @param lhs A pointer to the n-gram on the left-hand side of the comparison + * @param rhs A pointer to the n-gram on the right-hand side of the comparison + */ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { for (std::size_t i = 0; i < order_; ++i) { if (lhs[i] != rhs[i]) @@ -84,15 +188,52 @@ struct AddCombiner { }; // The combiner is only used on a single chain, so I didn't bother to allow -// that template. -template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Compare> > { +// that template. +/** + * Represents an @ref util::FixedArray "array" capable of storing @ref util::stream::Sort "Sort" objects. + * + * In the anticipated use case, an instance of this class will maintain one @ref util::stream::Sort "Sort" object + * for each n-gram order (ranging from 1 up to the maximum n-gram order being processed). + * Use in this manner would enable the n-grams each n-gram order to be sorted, in parallel. + * + * @tparam Compare An @ref Comparator "ngram comparator" to use during sorting. + */ +template <class Compare> class Sorts : public util::FixedArray<util::stream::Sort<Compare> > { private: typedef util::stream::Sort<Compare> S; - typedef FixedArray<S> P; + typedef util::FixedArray<S> P; public: + + /** + * Constructs, but does not initialize. + * + * @ref util::FixedArray::Init() "Init" must be called before use. + * + * @see util::FixedArray::Init() + */ + Sorts() {} + + /** + * Constructs an @ref util::FixedArray "array" capable of storing a fixed number of @ref util::stream::Sort "Sort" objects. + * + * @param number The maximum number of @ref util::stream::Sort "sorters" that can be held by this @ref util::FixedArray "array" + * @see util::FixedArray::FixedArray() + */ + explicit Sorts(std::size_t number) : util::FixedArray<util::stream::Sort<Compare> >(number) {} + + /** + * Constructs a new @ref util::stream::Sort "Sort" object which is stored in this @ref util::FixedArray "array". + * + * The new @ref util::stream::Sort "Sort" object is constructed using the provided @ref util::stream::SortConfig "SortConfig" and @ref Comparator "ngram comparator"; + * once constructed, a new worker @ref util::stream::Thread "thread" (owned by the @ref util::stream::Chain "chain") will sort the n-gram data stored + * in the @ref util::stream::Block "blocks" of the provided @ref util::stream::Chain "chain". + * + * @see util::stream::Sort::Sort() + * @see util::stream::Chain::operator>>() + */ void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) { - new (P::end()) S(chain, config, compare); + new (P::end()) S(chain, config, compare); // use "placement new" syntax to initalize S in an already-allocated memory location P::Constructed(); } }; @@ -100,4 +241,4 @@ template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Comp } // namespace builder } // namespace lm -#endif // LM_BUILDER_SORT__ +#endif // LM_BUILDER_SORT_H diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 0de7b7c6..dab28123 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -1,5 +1,5 @@ -#ifndef LM_CONFIG__ -#define LM_CONFIG__ +#ifndef LM_CONFIG_H +#define LM_CONFIG_H #include "lm/lm_exception.hh" #include "util/mmap.hh" @@ -120,4 +120,4 @@ struct Config { } /* namespace ngram */ } /* namespace lm */ -#endif // LM_CONFIG__ +#endif // LM_CONFIG_H diff --git a/klm/lm/enumerate_vocab.hh b/klm/lm/enumerate_vocab.hh index 27263621..f5ce7898 100644 --- a/klm/lm/enumerate_vocab.hh +++ b/klm/lm/enumerate_vocab.hh @@ -1,5 +1,5 @@ -#ifndef LM_ENUMERATE_VOCAB__ -#define LM_ENUMERATE_VOCAB__ +#ifndef LM_ENUMERATE_VOCAB_H +#define LM_ENUMERATE_VOCAB_H #include "lm/word_index.hh" #include "util/string_piece.hh" @@ -24,5 +24,5 @@ class EnumerateVocab { } // namespace lm -#endif // LM_ENUMERATE_VOCAB__ +#endif // LM_ENUMERATE_VOCAB_H diff --git a/klm/lm/facade.hh b/klm/lm/facade.hh index de1551f1..8e12b62e 100644 --- a/klm/lm/facade.hh +++ b/klm/lm/facade.hh @@ -1,5 +1,5 @@ -#ifndef LM_FACADE__ -#define LM_FACADE__ +#ifndef LM_FACADE_H +#define LM_FACADE_H #include "lm/virtual_interface.hh" #include "util/string_piece.hh" @@ -70,4 +70,4 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ } // mamespace base } // namespace lm -#endif // LM_FACADE__ +#endif // LM_FACADE_H diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh index 602b5b31..99c97b11 100644 --- a/klm/lm/filter/arpa_io.hh +++ b/klm/lm/filter/arpa_io.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_ARPA_IO__ -#define LM_FILTER_ARPA_IO__ +#ifndef LM_FILTER_ARPA_IO_H +#define LM_FILTER_ARPA_IO_H /* Input and output for ARPA format language model files. */ #include "lm/read_arpa.hh" @@ -111,4 +111,4 @@ template <class Output> void ReadARPA(util::FilePiece &in_lm, Output &out) { } // namespace lm -#endif // LM_FILTER_ARPA_IO__ +#endif // LM_FILTER_ARPA_IO_H diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh index d992026f..de894baf 100644 --- a/klm/lm/filter/count_io.hh +++ b/klm/lm/filter/count_io.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_COUNT_IO__ -#define LM_FILTER_COUNT_IO__ +#ifndef LM_FILTER_COUNT_IO_H +#define LM_FILTER_COUNT_IO_H #include <fstream> #include <iostream> @@ -86,4 +86,4 @@ template <class Output> void ReadCount(util::FilePiece &in_file, Output &out) { } // namespace lm -#endif // LM_FILTER_COUNT_IO__ +#endif // LM_FILTER_COUNT_IO_H diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh index 7d8c28db..5a2e2db3 100644 --- a/klm/lm/filter/format.hh +++ b/klm/lm/filter/format.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_FORMAT_H__ -#define LM_FILTER_FORMAT_H__ +#ifndef LM_FILTER_FORMAT_H +#define LM_FILTER_FORMAT_H #include "lm/filter/arpa_io.hh" #include "lm/filter/count_io.hh" @@ -247,4 +247,4 @@ class MultipleOutputBuffer { } // namespace lm -#endif // LM_FILTER_FORMAT_H__ +#endif // LM_FILTER_FORMAT_H diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh index e8e85835..e5898c9a 100644 --- a/klm/lm/filter/phrase.hh +++ b/klm/lm/filter/phrase.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_PHRASE_H__ -#define LM_FILTER_PHRASE_H__ +#ifndef LM_FILTER_PHRASE_H +#define LM_FILTER_PHRASE_H #include "util/murmur_hash.hh" #include "util/string_piece.hh" @@ -165,4 +165,4 @@ class Multiple : public detail::ConditionCommon { } // namespace phrase } // namespace lm -#endif // LM_FILTER_PHRASE_H__ +#endif // LM_FILTER_PHRASE_H diff --git a/klm/lm/filter/thread.hh b/klm/lm/filter/thread.hh index e785b263..6a6523f9 100644 --- a/klm/lm/filter/thread.hh +++ b/klm/lm/filter/thread.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_THREAD_H__ -#define LM_FILTER_THREAD_H__ +#ifndef LM_FILTER_THREAD_H +#define LM_FILTER_THREAD_H #include "util/thread_pool.hh" @@ -164,4 +164,4 @@ template <class Filter, class OutputBuffer, class RealOutput> class Controller : } // namespace lm -#endif // LM_FILTER_THREAD_H__ +#endif // LM_FILTER_THREAD_H diff --git a/klm/lm/filter/vocab.hh b/klm/lm/filter/vocab.hh index 7f0fadaa..2ee6e1f8 100644 --- a/klm/lm/filter/vocab.hh +++ b/klm/lm/filter/vocab.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_VOCAB_H__ -#define LM_FILTER_VOCAB_H__ +#ifndef LM_FILTER_VOCAB_H +#define LM_FILTER_VOCAB_H // Vocabulary-based filters for language models. @@ -130,4 +130,4 @@ class Multiple { } // namespace vocab } // namespace lm -#endif // LM_FILTER_VOCAB_H__ +#endif // LM_FILTER_VOCAB_H diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh index eb657501..822c5c27 100644 --- a/klm/lm/filter/wrapper.hh +++ b/klm/lm/filter/wrapper.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_WRAPPER_H__ -#define LM_FILTER_WRAPPER_H__ +#ifndef LM_FILTER_WRAPPER_H +#define LM_FILTER_WRAPPER_H #include "util/string_piece.hh" @@ -53,4 +53,4 @@ template <class FilterT> class ContextFilter { } // namespace lm -#endif // LM_FILTER_WRAPPER_H__ +#endif // LM_FILTER_WRAPPER_H diff --git a/klm/lm/interpolate/arpa_to_stream.cc b/klm/lm/interpolate/arpa_to_stream.cc new file mode 100644 index 00000000..f2696f39 --- /dev/null +++ b/klm/lm/interpolate/arpa_to_stream.cc @@ -0,0 +1,47 @@ +#include "lm/interpolate/arpa_to_stream.hh" + +// TODO: should this move out of builder? +#include "lm/builder/ngram_stream.hh" +#include "lm/read_arpa.hh" +#include "lm/vocab.hh" + +namespace lm { namespace interpolate { + +ARPAToStream::ARPAToStream(int fd, ngram::GrowableVocab<ngram::WriteUniqueWords> &vocab) + : in_(fd), vocab_(vocab) { + + // Read the ARPA file header. + // + // After the following call, counts_ will be correctly initialized, + // and in_ will be positioned for reading the body of the ARPA file. + ReadARPACounts(in_, counts_); + +} + +void ARPAToStream::Run(const util::stream::ChainPositions &positions) { + // Make one stream for each order. + builder::NGramStreams streams(positions); + PositiveProbWarn warn; + + // Unigrams are handled specially because they're being inserted into the vocab. + ReadNGramHeader(in_, 1); + for (uint64_t i = 0; i < counts_[0]; ++i, ++streams[0]) { + streams[0]->begin()[0] = vocab_.FindOrInsert(Read1Gram(in_, streams[0]->Value().complete, warn)); + } + // Finish off the unigram stream. + streams[0].Poison(); + + // TODO: don't waste backoff field for highest order. + for (unsigned char n = 2; n <= counts_.size(); ++n) { + ReadNGramHeader(in_, n); + builder::NGramStream &stream = streams[n - 1]; + const uint64_t end = counts_[n - 1]; + for (std::size_t i = 0; i < end; ++i, ++stream) { + ReadNGram(in_, n, vocab_, stream->begin(), stream->Value().complete, warn); + } + // Finish the stream for n-grams.. + stream.Poison(); + } +} + +}} // namespaces diff --git a/klm/lm/interpolate/arpa_to_stream.hh b/klm/lm/interpolate/arpa_to_stream.hh new file mode 100644 index 00000000..4613998d --- /dev/null +++ b/klm/lm/interpolate/arpa_to_stream.hh @@ -0,0 +1,38 @@ +#include "lm/read_arpa.hh" +#include "util/file_piece.hh" + +#include <vector> + +#include <stdint.h> + +namespace util { namespace stream { class ChainPositions; } } + +namespace lm { + +namespace ngram { +template <class T> class GrowableVocab; +class WriteUniqueWords; +} // namespace ngram + +namespace interpolate { + +class ARPAToStream { + public: + // Takes ownership of fd. + explicit ARPAToStream(int fd, ngram::GrowableVocab<ngram::WriteUniqueWords> &vocab); + + std::size_t Order() const { return counts_.size(); } + + const std::vector<uint64_t> &Counts() const { return counts_; } + + void Run(const util::stream::ChainPositions &positions); + + private: + util::FilePiece in_; + + std::vector<uint64_t> counts_; + + ngram::GrowableVocab<ngram::WriteUniqueWords> &vocab_; +}; + +}} // namespaces diff --git a/klm/lm/interpolate/example_sort_main.cc b/klm/lm/interpolate/example_sort_main.cc new file mode 100644 index 00000000..4282255e --- /dev/null +++ b/klm/lm/interpolate/example_sort_main.cc @@ -0,0 +1,144 @@ +#include "lm/interpolate/arpa_to_stream.hh" + +#include "lm/builder/print.hh" +#include "lm/builder/sort.hh" +#include "lm/vocab.hh" +#include "util/file.hh" +#include "util/unistd.hh" + + +int main() { + + // TODO: Make these all command-line parameters + const std::size_t ONE_GB = 1 << 30; + const std::size_t SIXTY_FOUR_MB = 1 << 26; + const std::size_t NUMBER_OF_BLOCKS = 2; + + // Vocab strings will be written to this file, forgotten, and reconstituted + // later. This saves memory. + util::scoped_fd vocab_file(util::MakeTemp("/tmp/")); + std::vector<uint64_t> counts; + util::stream::Chains chains; + { + // Use consistent vocab ids across models. + lm::ngram::GrowableVocab<lm::ngram::WriteUniqueWords> vocab(10, vocab_file.get()); + lm::interpolate::ARPAToStream reader(STDIN_FILENO, vocab); + counts = reader.Counts(); + + // Configure a chain for each order. TODO: extract chain balance heuristics from lm/builder/pipeline.cc + chains.Init(reader.Order()); + + for (std::size_t i = 0; i < reader.Order(); ++i) { + + // The following call to chains.push_back() invokes the Chain constructor + // and appends the newly created Chain object to the chains array + chains.push_back(util::stream::ChainConfig(lm::builder::NGram::TotalSize(i + 1), NUMBER_OF_BLOCKS, ONE_GB)); + + } + + // The following call to the >> method of chains + // constructs a ChainPosition for each chain in chains using Chain::Add(); + // that function begins with a call to Chain::Start() + // that allocates memory for the chain. + // + // After the following call to the >> method of chains, + // a new thread will be running + // and will be executing the reader.Run() method + // to read through the body of the ARPA file from standard input. + // + // For each n-gram line in the ARPA file, + // the thread executing reader.Run() + // will write the probability, the n-gram, and the backoff + // to the appropriate location in the appropriate chain + // (for details, see the ReadNGram() method in read_arpa.hh). + // + // Normally >> copies then runs so inline >> works. But here we want a ref. + chains >> boost::ref(reader); + + + util::stream::SortConfig sort_config; + sort_config.temp_prefix = "/tmp/"; + sort_config.buffer_size = SIXTY_FOUR_MB; + sort_config.total_memory = ONE_GB; + + // Parallel sorts across orders (though somewhat limited because ARPA files are not being read in parallel across orders) + lm::builder::Sorts<lm::builder::SuffixOrder> sorts(reader.Order()); + for (std::size_t i = 0; i < reader.Order(); ++i) { + + // The following call to sorts.push_back() invokes the Sort constructor + // and appends the newly constructed Sort object to the sorts array. + // + // After the construction of the Sort object, + // two new threads will be running (each owned by the chains[i] object). + // + // The first new thread will execute BlockSorter.Run() to sort the n-gram entries of order (i+1) + // that were previously read into chains[i] by the ARPA input reader thread. + // + // The second new thread will execute WriteAndRecycle.Run() + // to write each sorted block of data to disk as a temporary file. + sorts.push_back(chains[i], sort_config, lm::builder::SuffixOrder(i + 1)); + + } + + // Output to the same chains. + for (std::size_t i = 0; i < reader.Order(); ++i) { + + // The following call to Chain::Wait() + // joins the threads owned by chains[i]. + // + // As such the following call won't return + // until all threads owned by chains[i] have completed. + // + // The following call also resets chain[i] + // so that it can be reused + // (including free'ing the memory previously used by the chain) + chains[i].Wait(); + + + // In an ideal world (without memory restrictions) + // we could merge all of the previously sorted blocks + // by reading them all completely into memory + // and then running merge sort over them. + // + // In the real world, we have memory restrictions; + // depending on how many blocks we have, + // and how much memory we can use to read from each block (sort_config.buffer_size) + // it may be the case that we have insufficient memory + // to read sort_config.buffer_size of data from each block from disk. + // + // If this occurs, then it will be necessary to perform one or more rounds of merge sort on disk; + // doing so will reduce the number of blocks that we will eventually need to read from + // when performing the final round of merge sort in memory. + // + // So, the following call determines whether it is necessary + // to perform one or more rounds of merge sort on disk; + // if such on-disk merge sorting is required, such sorting is performed. + // + // Finally, the following method launches a thread that calls OwningMergingReader.Run() + // to perform the final round of merge sort in memory. + // + // Merge sort could have be invoked directly + // so that merge sort memory doesn't coexist with Chain memory. + sorts[i].Output(chains[i]); + } + + // sorts can go out of scope even though it's still writing to the chains. + // note that vocab going out of scope flushes to vocab_file. + } + + + // Get the vocabulary mapping used for this ARPA file + lm::builder::VocabReconstitute reconstitute(vocab_file.get()); + + // After the following call to the << method of chains, + // a new thread will be running + // and will be executing the Run() method of PrintARPA + // to print the final sorted ARPA file to standard output. + chains >> lm::builder::PrintARPA(reconstitute, counts, NULL, STDOUT_FILENO); + + // Joins all threads that chains owns, + // and does a for loop over each chain object in chains, + // calling chain.Wait() on each such chain object + chains.Wait(true); + +} diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 85c1ea37..36d61369 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -35,8 +35,8 @@ * phrase, even if hypotheses are generated left-to-right. */ -#ifndef LM_LEFT__ -#define LM_LEFT__ +#ifndef LM_LEFT_H +#define LM_LEFT_H #include "lm/max_order.hh" #include "lm/state.hh" @@ -213,4 +213,4 @@ template <class M> class RuleScore { } // namespace ngram } // namespace lm -#endif // LM_LEFT__ +#endif // LM_LEFT_H diff --git a/klm/lm/lm_exception.hh b/klm/lm/lm_exception.hh index f607ced1..8bb61081 100644 --- a/klm/lm/lm_exception.hh +++ b/klm/lm/lm_exception.hh @@ -1,5 +1,5 @@ -#ifndef LM_LM_EXCEPTION__ -#define LM_LM_EXCEPTION__ +#ifndef LM_LM_EXCEPTION_H +#define LM_LM_EXCEPTION_H // Named to avoid conflict with util/exception.hh. diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index 3eb97ccd..f7344cde 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -1,9 +1,13 @@ -/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER, THEN CHANGE THE BUILD SYSTEM. +#ifndef LM_MAX_ORDER_H +#define LM_MAX_ORDER_H +/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER_H, THEN CHANGE THE BUILD SYSTEM. * If not, this is the default maximum order. * Having this limit means that State can be * (kMaxOrder - 1) * sizeof(float) bytes instead of * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead */ #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER_H, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif + +#endif // LM_MAX_ORDER_H diff --git a/klm/lm/model.hh b/klm/lm/model.hh index e75da93b..6925a56d 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -1,5 +1,5 @@ -#ifndef LM_MODEL__ -#define LM_MODEL__ +#ifndef LM_MODEL_H +#define LM_MODEL_H #include "lm/bhiksha.hh" #include "lm/binary_format.hh" @@ -153,4 +153,4 @@ base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), } // namespace ngram } // namespace lm -#endif // LM_MODEL__ +#endif // LM_MODEL_H diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 7005b05e..0f54724b 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -176,7 +176,7 @@ template <class M> void MinimalState(const M &model) { AppendTest("to", 1, -1.687872, false); AppendTest("look", 2, -0.2922095, true); BOOST_CHECK_EQUAL(2, state.length); - AppendTest("good", 3, -7, true); + AppendTest("a", 3, -7, true); } template <class M> void ExtendLeftTest(const M &model) { diff --git a/klm/lm/model_type.hh b/klm/lm/model_type.hh index 8b35c793..fbe1117a 100644 --- a/klm/lm/model_type.hh +++ b/klm/lm/model_type.hh @@ -1,5 +1,5 @@ -#ifndef LM_MODEL_TYPE__ -#define LM_MODEL_TYPE__ +#ifndef LM_MODEL_TYPE_H +#define LM_MODEL_TYPE_H namespace lm { namespace ngram { @@ -20,4 +20,4 @@ const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE - TRIE); } // namespace ngram } // namespace lm -#endif // LM_MODEL_TYPE__ +#endif // LM_MODEL_TYPE_H diff --git a/klm/lm/neural/wordvecs.cc b/klm/lm/neural/wordvecs.cc new file mode 100644 index 00000000..09bb4260 --- /dev/null +++ b/klm/lm/neural/wordvecs.cc @@ -0,0 +1,23 @@ +#include "lm/neural/wordvecs.hh" + +#include "util/file_piece.hh" + +namespace lm { namespace neural { + +WordVecs::WordVecs(util::FilePiece &f) { + const unsigned long lines = f.ReadULong(); + const std::size_t vocab_mem = ngram::ProbingVocabulary::Size(lines, 1.5); + vocab_backing_.reset(util::CallocOrThrow(vocab_mem)); + vocab_.SetupMemory(vocab_backing_.get(), vocab_mem); + const unsigned long width = f.ReadULong(); + vecs_.resize(width, lines); + for (unsigned long i = 0; i < lines; ++i) { + WordIndex column = vocab_.Insert(f.ReadDelimited()); + for (unsigned int row = 0; row < width; ++row) { + vecs_(row,column) = f.ReadFloat(); + } + } + vocab_.FinishedLoading(); +} + +}} // namespaces diff --git a/klm/lm/neural/wordvecs.hh b/klm/lm/neural/wordvecs.hh new file mode 100644 index 00000000..921a2b22 --- /dev/null +++ b/klm/lm/neural/wordvecs.hh @@ -0,0 +1,38 @@ +#ifndef LM_NEURAL_WORDVECS_H +#define LM_NEURAL_WORDVECS_H + +#include "util/scoped.hh" +#include "lm/vocab.hh" + +#include <Eigen/Dense> + +namespace util { class FilePiece; } + +namespace lm { +namespace neural { + +class WordVecs { + public: + // Columns of the matrix are word vectors. The column index is the word. + typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor> Storage; + + /* The file should begin with a line stating the number of word vectors and + * the length of the vectors. Then it's followed by lines containing a + * word followed by floating-point values. + */ + explicit WordVecs(util::FilePiece &in); + + const Storage &Vectors() const { return vecs_; } + + WordIndex Index(StringPiece str) const { return vocab_.Index(str); } + + private: + util::scoped_malloc vocab_backing_; + ngram::ProbingVocabulary vocab_; + + Storage vecs_; +}; + +}} // namespaces + +#endif // LM_NEURAL_WORDVECS_H diff --git a/klm/lm/ngram_query.hh b/klm/lm/ngram_query.hh index ec2590f4..5f330c5c 100644 --- a/klm/lm/ngram_query.hh +++ b/klm/lm/ngram_query.hh @@ -1,8 +1,9 @@ -#ifndef LM_NGRAM_QUERY__ -#define LM_NGRAM_QUERY__ +#ifndef LM_NGRAM_QUERY_H +#define LM_NGRAM_QUERY_H #include "lm/enumerate_vocab.hh" #include "lm/model.hh" +#include "util/file_piece.hh" #include "util/usage.hh" #include <cstdlib> @@ -16,64 +17,94 @@ namespace lm { namespace ngram { -template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { +struct BasicPrint { + void Word(StringPiece, WordIndex, const FullScoreReturn &) const {} + void Line(uint64_t oov, float total) const { + std::cout << "Total: " << total << " OOV: " << oov << '\n'; + } + void Summary(double, double, uint64_t, uint64_t) {} + +}; + +struct FullPrint : public BasicPrint { + void Word(StringPiece surface, WordIndex vocab, const FullScoreReturn &ret) const { + std::cout << surface << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; + } + + void Summary(double ppl_including_oov, double ppl_excluding_oov, uint64_t corpus_oov, uint64_t corpus_tokens) { + std::cout << + "Perplexity including OOVs:\t" << ppl_including_oov << "\n" + "Perplexity excluding OOVs:\t" << ppl_excluding_oov << "\n" + "OOVs:\t" << corpus_oov << "\n" + "Tokens:\t" << corpus_tokens << '\n' + ; + } +}; + +template <class Model, class Printer> void Query(const Model &model, bool sentence_context) { + Printer printer; typename Model::State state, out; lm::FullScoreReturn ret; - std::string word; + StringPiece word; + + util::FilePiece in(0); double corpus_total = 0.0; + double corpus_total_oov_only = 0.0; uint64_t corpus_oov = 0; uint64_t corpus_tokens = 0; - while (in_stream) { + while (true) { state = sentence_context ? model.BeginSentenceState() : model.NullContextState(); float total = 0.0; - bool got = false; uint64_t oov = 0; - while (in_stream >> word) { - got = true; + + while (in.ReadWordSameLine(word)) { lm::WordIndex vocab = model.GetVocabulary().Index(word); - if (vocab == 0) ++oov; ret = model.FullScore(state, vocab, out); + if (vocab == model.GetVocabulary().NotFound()) { + ++oov; + corpus_total_oov_only += ret.prob; + } total += ret.prob; - out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; + printer.Word(word, vocab, ret); ++corpus_tokens; state = out; - char c; - while (true) { - c = in_stream.get(); - if (!in_stream) break; - if (c == '\n') break; - if (!isspace(c)) { - in_stream.unget(); - break; - } - } - if (c == '\n') break; } - if (!got && !in_stream) break; + // If people don't have a newline after their last query, this won't add a </s>. + // Sue me. + try { + UTIL_THROW_IF('\n' != in.get(), util::Exception, "FilePiece is confused."); + } catch (const util::EndOfFileException &e) { break; } if (sentence_context) { ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); total += ret.prob; ++corpus_tokens; - out_stream << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; + printer.Word("</s>", model.GetVocabulary().EndSentence(), ret); } - out_stream << "Total: " << total << " OOV: " << oov << '\n'; + printer.Line(oov, total); corpus_total += total; corpus_oov += oov; } - out_stream << "Perplexity " << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << std::endl; + printer.Summary( + pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))), // PPL including OOVs + pow(10.0, -((corpus_total - corpus_total_oov_only) / static_cast<double>(corpus_tokens - corpus_oov))), // PPL excluding OOVs + corpus_oov, + corpus_tokens); } -template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { - Config config; - M model(file, config); - Query(model, sentence_context, in_stream, out_stream); +template <class Model> void Query(const char *file, const Config &config, bool sentence_context, bool show_words) { + Model model(file, config); + if (show_words) { + Query<Model, FullPrint>(model, sentence_context); + } else { + Query<Model, BasicPrint>(model, sentence_context); + } } } // namespace ngram } // namespace lm -#endif // LM_NGRAM_QUERY__ +#endif // LM_NGRAM_QUERY_H diff --git a/klm/lm/partial.hh b/klm/lm/partial.hh index 1dede359..d8adc696 100644 --- a/klm/lm/partial.hh +++ b/klm/lm/partial.hh @@ -1,5 +1,5 @@ -#ifndef LM_PARTIAL__ -#define LM_PARTIAL__ +#ifndef LM_PARTIAL_H +#define LM_PARTIAL_H #include "lm/return.hh" #include "lm/state.hh" @@ -164,4 +164,4 @@ template <class Model> float Subsume(const Model &model, Left &first_left, const } // namespace ngram } // namespace lm -#endif // LM_PARTIAL__ +#endif // LM_PARTIAL_H diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 9d3a2f43..84a30872 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -1,5 +1,5 @@ -#ifndef LM_QUANTIZE_H__ -#define LM_QUANTIZE_H__ +#ifndef LM_QUANTIZE_H +#define LM_QUANTIZE_H #include "lm/blank.hh" #include "lm/config.hh" @@ -230,4 +230,4 @@ class SeparatelyQuantize { } // namespace ngram } // namespace lm -#endif // LM_QUANTIZE_H__ +#endif // LM_QUANTIZE_H diff --git a/klm/lm/query_main.cc b/klm/lm/query_main.cc index bd4fde62..3013ff21 100644 --- a/klm/lm/query_main.cc +++ b/klm/lm/query_main.cc @@ -1,4 +1,5 @@ #include "lm/ngram_query.hh" +#include "util/getopt.hh" #ifdef WITH_NPLM #include "lm/wrappers/nplm.hh" @@ -7,47 +8,76 @@ #include <stdlib.h> void Usage(const char *name) { - std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl; - std::cerr << "Usage: " << name << " [-n] lm_file" << std::endl; - std::cerr << "Input is wrapped in <s> and </s> unless -n is passed." << std::endl; + std::cerr << + "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << ".\n" + "Usage: " << name << " [-n] [-s] lm_file\n" + "-n: Do not wrap the input in <s> and </s>.\n" + "-s: Sentence totals only.\n" + "-l lazy|populate|read|parallel: Load lazily, with populate, or malloc+read\n" + "The default loading method is populate on Linux and read on others.\n"; exit(1); } int main(int argc, char *argv[]) { + if (argc == 1 || (argc == 2 && !strcmp(argv[1], "--help"))) + Usage(argv[0]); + + lm::ngram::Config config; bool sentence_context = true; - const char *file = NULL; - for (char **arg = argv + 1; arg != argv + argc; ++arg) { - if (!strcmp(*arg, "-n")) { - sentence_context = false; - } else if (!strcmp(*arg, "-h") || !strcmp(*arg, "--help") || file) { - Usage(argv[0]); - } else { - file = *arg; + bool show_words = true; + + int opt; + while ((opt = getopt(argc, argv, "hnsl:")) != -1) { + switch (opt) { + case 'n': + sentence_context = false; + break; + case 's': + show_words = false; + break; + case 'l': + if (!strcmp(optarg, "lazy")) { + config.load_method = util::LAZY; + } else if (!strcmp(optarg, "populate")) { + config.load_method = util::POPULATE_OR_READ; + } else if (!strcmp(optarg, "read")) { + config.load_method = util::READ; + } else if (!strcmp(optarg, "parallel")) { + config.load_method = util::PARALLEL_READ; + } else { + Usage(argv[0]); + } + break; + case 'h': + default: + Usage(argv[0]); } } - if (!file) Usage(argv[0]); + if (optind + 1 != argc) + Usage(argv[0]); + const char *file = argv[optind]; try { using namespace lm::ngram; ModelType model_type; if (RecognizeBinary(file, model_type)) { switch(model_type) { case PROBING: - Query<lm::ngram::ProbingModel>(file, sentence_context, std::cin, std::cout); + Query<lm::ngram::ProbingModel>(file, config, sentence_context, show_words); break; case REST_PROBING: - Query<lm::ngram::RestProbingModel>(file, sentence_context, std::cin, std::cout); + Query<lm::ngram::RestProbingModel>(file, config, sentence_context, show_words); break; case TRIE: - Query<TrieModel>(file, sentence_context, std::cin, std::cout); + Query<TrieModel>(file, config, sentence_context, show_words); break; case QUANT_TRIE: - Query<QuantTrieModel>(file, sentence_context, std::cin, std::cout); + Query<QuantTrieModel>(file, config, sentence_context, show_words); break; case ARRAY_TRIE: - Query<ArrayTrieModel>(file, sentence_context, std::cin, std::cout); + Query<ArrayTrieModel>(file, config, sentence_context, show_words); break; case QUANT_ARRAY_TRIE: - Query<QuantArrayTrieModel>(file, sentence_context, std::cin, std::cout); + Query<QuantArrayTrieModel>(file, config, sentence_context, show_words); break; default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; @@ -56,12 +86,15 @@ int main(int argc, char *argv[]) { #ifdef WITH_NPLM } else if (lm::np::Model::Recognize(file)) { lm::np::Model model(file); - Query(model, sentence_context, std::cin, std::cout); + if (show_words) { + Query<lm::np::Model, lm::ngram::FullPrint>(model, sentence_context); + } else { + Query<lm::np::Model, lm::ngram::BasicPrint>(model, sentence_context); + } #endif } else { - Query<ProbingModel>(file, sentence_context, std::cin, std::cout); + Query<ProbingModel>(file, config, sentence_context, show_words); } - std::cerr << "Total time including destruction:\n"; util::PrintUsage(std::cerr); } catch (const std::exception &e) { std::cerr << e.what() << std::endl; diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh index 234d130c..64eeef30 100644 --- a/klm/lm/read_arpa.hh +++ b/klm/lm/read_arpa.hh @@ -1,5 +1,5 @@ -#ifndef LM_READ_ARPA__ -#define LM_READ_ARPA__ +#ifndef LM_READ_ARPA_H +#define LM_READ_ARPA_H #include "lm/lm_exception.hh" #include "lm/word_index.hh" @@ -28,7 +28,7 @@ void ReadEnd(util::FilePiece &in); extern const bool kARPASpaces[256]; -// Positive log probability warning. +// Positive log probability warning. class PositiveProbWarn { public: PositiveProbWarn() : action_(THROW_UP) {} @@ -48,17 +48,17 @@ template <class Voc, class Weights> void Read1Gram(util::FilePiece &f, Voc &voca warn.Warn(prob); prob = 0.0; } - if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability"); - Weights &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))]; - value.prob = prob; - ReadBackoff(f, value); + UTIL_THROW_IF(f.get() != '\t', FormatLoadException, "Expected tab after probability"); + WordIndex word = vocab.Insert(f.ReadDelimited(kARPASpaces)); + Weights &w = unigrams[word]; + w.prob = prob; + ReadBackoff(f, w); } catch(util::Exception &e) { e << " in the 1-gram at byte " << f.Offset(); throw; } } -// Return true if a positive log probability came out. template <class Voc, class Weights> void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) { ReadNGramHeader(f, 1); for (std::size_t i = 0; i < count; ++i) { @@ -67,16 +67,21 @@ template <class Voc, class Weights> void Read1Grams(util::FilePiece &f, std::siz vocab.FinishedLoading(unigrams); } -// Return true if a positive log probability came out. -template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, WordIndex *const reverse_indices, Weights &weights, PositiveProbWarn &warn) { +// Read ngram, write vocab ids to indices_out. +template <class Voc, class Weights, class Iterator> void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, Iterator indices_out, Weights &weights, PositiveProbWarn &warn) { try { weights.prob = f.ReadFloat(); if (weights.prob > 0.0) { warn.Warn(weights.prob); weights.prob = 0.0; } - for (WordIndex *vocab_out = reverse_indices + n - 1; vocab_out >= reverse_indices; --vocab_out) { - *vocab_out = vocab.Index(f.ReadDelimited(kARPASpaces)); + for (unsigned char i = 0; i < n; ++i, ++indices_out) { + StringPiece word(f.ReadDelimited(kARPASpaces)); + WordIndex index = vocab.Index(word); + *indices_out = index; + // Check for words mapped to <unk> that are not the string <unk>. + UTIL_THROW_IF(index == 0 /* mapped to <unk> */ && (word != StringPiece("<unk>", 5)) && (word != StringPiece("<UNK>", 5)), + FormatLoadException, "Word " << word << " was not seen in the unigrams (which are supposed to list the entire vocabulary) but appears"); } ReadBackoff(f, weights); } catch(util::Exception &e) { @@ -87,4 +92,4 @@ template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const uns } // namespace lm -#endif // LM_READ_ARPA__ +#endif // LM_READ_ARPA_H diff --git a/klm/lm/return.hh b/klm/lm/return.hh index 622320ce..982ffd66 100644 --- a/klm/lm/return.hh +++ b/klm/lm/return.hh @@ -1,5 +1,5 @@ -#ifndef LM_RETURN__ -#define LM_RETURN__ +#ifndef LM_RETURN_H +#define LM_RETURN_H #include <stdint.h> @@ -39,4 +39,4 @@ struct FullScoreReturn { }; } // namespace lm -#endif // LM_RETURN__ +#endif // LM_RETURN_H diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 354a56b4..7e63e006 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -178,7 +178,7 @@ template <class Build, class Activate, class Store> void ReadNGrams( typename Store::Entry entry; std::vector<typename Value::Weights *> between; for (size_t i = 0; i < count; ++i) { - ReadNGram(f, n, vocab, &*vocab_ids.begin(), entry.value, warn); + ReadNGram(f, n, vocab, vocab_ids.rbegin(), entry.value, warn); build.SetRest(&*vocab_ids.begin(), n, entry.value); keys[0] = detail::CombineWordHash(static_cast<uint64_t>(vocab_ids.front()), vocab_ids[1]); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 8193262b..9dc84454 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -1,5 +1,5 @@ -#ifndef LM_SEARCH_HASHED__ -#define LM_SEARCH_HASHED__ +#ifndef LM_SEARCH_HASHED_H +#define LM_SEARCH_HASHED_H #include "lm/model_type.hh" #include "lm/config.hh" @@ -189,4 +189,4 @@ template <class Value> class HashedSearch { } // namespace ngram } // namespace lm -#endif // LM_SEARCH_HASHED__ +#endif // LM_SEARCH_HASHED_H diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 4a88194e..7fc70f4e 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -561,6 +561,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup } // Crazy backwards thing so we initialize using pointers to ones that have already been initialized for (unsigned char i = counts.size() - 1; i >= 2; --i) { + // use "placement new" syntax to initalize Middle in an already-allocated memory location new (middle_begin_ + i - 2) Middle( middle_starts[i-2], quant_.MiddleBits(config), diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 299262a5..d8838d2b 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -1,5 +1,5 @@ -#ifndef LM_SEARCH_TRIE__ -#define LM_SEARCH_TRIE__ +#ifndef LM_SEARCH_TRIE_H +#define LM_SEARCH_TRIE_H #include "lm/config.hh" #include "lm/model_type.hh" @@ -127,4 +127,4 @@ template <class Quant, class Bhiksha> class TrieSearch { } // namespace ngram } // namespace lm -#endif // LM_SEARCH_TRIE__ +#endif // LM_SEARCH_TRIE_H diff --git a/klm/lm/sizes.hh b/klm/lm/sizes.hh index 85abade7..eb7e99de 100644 --- a/klm/lm/sizes.hh +++ b/klm/lm/sizes.hh @@ -1,5 +1,5 @@ -#ifndef LM_SIZES__ -#define LM_SIZES__ +#ifndef LM_SIZES_H +#define LM_SIZES_H #include <vector> @@ -14,4 +14,4 @@ void ShowSizes(const std::vector<uint64_t> &counts); void ShowSizes(const char *file, const lm::ngram::Config &config); }} // namespaces -#endif // LM_SIZES__ +#endif // LM_SIZES_H diff --git a/klm/lm/state.hh b/klm/lm/state.hh index 543df37c..f6c51d6f 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -1,5 +1,5 @@ -#ifndef LM_STATE__ -#define LM_STATE__ +#ifndef LM_STATE_H +#define LM_STATE_H #include "lm/max_order.hh" #include "lm/word_index.hh" @@ -122,4 +122,4 @@ inline uint64_t hash_value(const ChartState &state) { } // namespace ngram } // namespace lm -#endif // LM_STATE__ +#endif // LM_STATE_H diff --git a/klm/lm/test.arpa b/klm/lm/test.arpa index ef214eae..c4d2e6df 100644 --- a/klm/lm/test.arpa +++ b/klm/lm/test.arpa @@ -105,7 +105,7 @@ ngram 5=4 -0.04835128 looking on a -0.4771212 -3 also would consider -7 -6 <unk> however <unk> -12 --7 to look good +-7 to look a \4-grams: -0.009249173 looking on a little -0.4771212 diff --git a/klm/lm/test_nounk.arpa b/klm/lm/test_nounk.arpa index 060733d9..e38fc854 100644 --- a/klm/lm/test_nounk.arpa +++ b/klm/lm/test_nounk.arpa @@ -101,7 +101,7 @@ ngram 5=4 -0.1892331 little more loin -0.04835128 looking on a -0.4771212 -3 also would consider -7 --7 to look good +-7 to look a \4-grams: -0.009249173 looking on a little -0.4771212 diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index d9895f89..5f8e7ce7 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -99,8 +99,11 @@ template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Find(WordInd } template <class Bhiksha> void BitPackedMiddle<Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) { - uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits(); - bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end); + // Write at insert_index. . . + uint64_t last_next_write = insert_index_ * total_bits_ + + // at the offset where the next pointers are stored. + (total_bits_ - bhiksha_.InlineBits()); + bhiksha_.WriteNext(base_, last_next_write, insert_index_, next_end); bhiksha_.FinishedLoading(config); } diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index d858ab5e..cd39298b 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -1,5 +1,5 @@ -#ifndef LM_TRIE__ -#define LM_TRIE__ +#ifndef LM_TRIE_H +#define LM_TRIE_H #include "lm/weights.hh" #include "lm/word_index.hh" @@ -143,4 +143,4 @@ class BitPackedLongest : public BitPacked { } // namespace ngram } // namespace lm -#endif // LM_TRIE__ +#endif // LM_TRIE_H diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 126d43ab..c3f46874 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -16,6 +16,7 @@ #include <cstdio> #include <cstdlib> #include <deque> +#include <iterator> #include <limits> #include <vector> @@ -106,14 +107,20 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_pre } struct ThrowCombine { - void operator()(std::size_t /*entry_size*/, const void * /*first*/, const void * /*second*/, FILE * /*out*/) const { - UTIL_THROW(FormatLoadException, "Duplicate n-gram detected."); + void operator()(std::size_t entry_size, unsigned char order, const void *first, const void *second, FILE * /*out*/) const { + const WordIndex *base = reinterpret_cast<const WordIndex*>(first); + FormatLoadException e; + e << "Duplicate n-gram detected with vocab ids"; + for (const WordIndex *i = base; i != base + order; ++i) { + e << ' ' << *i; + } + throw e; } }; // Useful for context files that just contain records with no value. struct FirstCombine { - void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const { + void operator()(std::size_t entry_size, unsigned char /*order*/, const void *first, const void * /*second*/, FILE *out) const { util::WriteOrThrow(out, first, entry_size); } }; @@ -133,7 +140,7 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f util::WriteOrThrow(out_file.get(), second.Data(), entry_size); ++second; } else { - combine(entry_size, first.Data(), second.Data(), out_file.get()); + combine(entry_size, order, first.Data(), second.Data(), out_file.get()); ++first; ++second; } } @@ -248,11 +255,13 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; if (order == counts.size()) { for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn); + std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order); + ReadNGram(f, order, vocab, it, *reinterpret_cast<Prob*>(out + words_size), warn); } } else { for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn); + std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order); + ReadNGram(f, order, vocab, it, *reinterpret_cast<ProbBackoff*>(out + words_size), warn); } } // Sort full records by full n-gram. diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh index 1afd9562..e5406d9b 100644 --- a/klm/lm/trie_sort.hh +++ b/klm/lm/trie_sort.hh @@ -1,7 +1,7 @@ // Step of trie builder: create sorted files. -#ifndef LM_TRIE_SORT__ -#define LM_TRIE_SORT__ +#ifndef LM_TRIE_SORT_H +#define LM_TRIE_SORT_H #include "lm/max_order.hh" #include "lm/word_index.hh" @@ -111,4 +111,4 @@ class SortedFiles { } // namespace ngram } // namespace lm -#endif // LM_TRIE_SORT__ +#endif // LM_TRIE_SORT_H diff --git a/klm/lm/value.hh b/klm/lm/value.hh index ba716713..36e87084 100644 --- a/klm/lm/value.hh +++ b/klm/lm/value.hh @@ -1,5 +1,5 @@ -#ifndef LM_VALUE__ -#define LM_VALUE__ +#ifndef LM_VALUE_H +#define LM_VALUE_H #include "lm/model_type.hh" #include "lm/value_build.hh" @@ -154,4 +154,4 @@ struct RestValue { } // namespace ngram } // namespace lm -#endif // LM_VALUE__ +#endif // LM_VALUE_H diff --git a/klm/lm/value_build.hh b/klm/lm/value_build.hh index 461e6a5c..6fd26ef8 100644 --- a/klm/lm/value_build.hh +++ b/klm/lm/value_build.hh @@ -1,5 +1,5 @@ -#ifndef LM_VALUE_BUILD__ -#define LM_VALUE_BUILD__ +#ifndef LM_VALUE_BUILD_H +#define LM_VALUE_BUILD_H #include "lm/weights.hh" #include "lm/word_index.hh" @@ -94,4 +94,4 @@ template <class Model> class LowerRestBuild { } // namespace ngram } // namespace lm -#endif // LM_VALUE_BUILD__ +#endif // LM_VALUE_BUILD_H diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 7a3e2379..2a2690e1 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -1,5 +1,5 @@ -#ifndef LM_VIRTUAL_INTERFACE__ -#define LM_VIRTUAL_INTERFACE__ +#ifndef LM_VIRTUAL_INTERFACE_H +#define LM_VIRTUAL_INTERFACE_H #include "lm/return.hh" #include "lm/word_index.hh" @@ -157,4 +157,4 @@ class Model { } // mamespace base } // namespace lm -#endif // LM_VIRTUAL_INTERFACE__ +#endif // LM_VIRTUAL_INTERFACE_H diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 7f0878f4..2285d279 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -170,11 +170,15 @@ struct ProbingVocabularyHeader { ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} +uint64_t ProbingVocabulary::Size(uint64_t entries, float probing_multiplier) { + return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, probing_multiplier); +} + uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) { - return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); + return Size(entries, config.probing_multiplier); } -void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { +void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated) { header_ = static_cast<detail::ProbingVocabularyHeader*>(start); lookup_ = Lookup(static_cast<uint8_t*>(start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)), allocated); bound_ = 1; @@ -201,12 +205,12 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) { return 0; } else { if (enumerate_) enumerate_->Add(bound_, str); - lookup_.Insert(ProbingVocabuaryEntry::Make(hashed, bound_)); + lookup_.Insert(ProbingVocabularyEntry::Make(hashed, bound_)); return bound_++; } } -void ProbingVocabulary::InternalFinishedLoading() { +void ProbingVocabulary::FinishedLoading() { lookup_.FinishedInserting(); header_->bound = bound_; header_->version = kProbingVocabularyVersion; diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 074b74d8..d6ae07b8 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -1,9 +1,11 @@ -#ifndef LM_VOCAB__ -#define LM_VOCAB__ +#ifndef LM_VOCAB_H +#define LM_VOCAB_H #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" +#include "util/fake_ofstream.hh" +#include "util/murmur_hash.hh" #include "util/pool.hh" #include "util/probing_hash_table.hh" #include "util/sorted_uniform.hh" @@ -104,17 +106,16 @@ class SortedVocabulary : public base::Vocabulary { #pragma pack(push) #pragma pack(4) -struct ProbingVocabuaryEntry { +struct ProbingVocabularyEntry { uint64_t key; WordIndex value; typedef uint64_t Key; - uint64_t GetKey() const { - return key; - } + uint64_t GetKey() const { return key; } + void SetKey(uint64_t to) { key = to; } - static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) { - ProbingVocabuaryEntry ret; + static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) { + ProbingVocabularyEntry ret; ret.key = key; ret.value = value; return ret; @@ -132,13 +133,18 @@ class ProbingVocabulary : public base::Vocabulary { return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; } + static uint64_t Size(uint64_t entries, float probing_multiplier); + // This just unwraps Config to get the probing_multiplier. static uint64_t Size(uint64_t entries, const Config &config); // Vocab words are [0, Bound()). WordIndex Bound() const { return bound_; } // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. - void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + void SetupMemory(void *start, std::size_t allocated); + void SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { + SetupMemory(start, allocated); + } void Relocate(void *new_start); @@ -147,8 +153,9 @@ class ProbingVocabulary : public base::Vocabulary { WordIndex Insert(const StringPiece &str); template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) { - InternalFinishedLoading(); + FinishedLoading(); } + void FinishedLoading(); std::size_t UnkCountChangePadding() const { return 0; } @@ -157,9 +164,7 @@ class ProbingVocabulary : public base::Vocabulary { void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); private: - void InternalFinishedLoading(); - - typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup; + typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup; Lookup lookup_; @@ -181,7 +186,64 @@ template <class Vocab> void CheckSpecials(const Config &config, const Vocab &voc if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>"); } +class WriteUniqueWords { + public: + explicit WriteUniqueWords(int fd) : word_list_(fd) {} + + void operator()(const StringPiece &word) { + word_list_ << word << '\0'; + } + + private: + util::FakeOFStream word_list_; +}; + +class NoOpUniqueWords { + public: + NoOpUniqueWords() {} + void operator()(const StringPiece &word) {} +}; + +template <class NewWordAction = NoOpUniqueWords> class GrowableVocab { + public: + static std::size_t MemUsage(WordIndex content) { + return Lookup::MemUsage(content > 2 ? content : 2); + } + + // Does not take ownership of write_wordi + template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction()) + : lookup_(initial_size), new_word_(new_word_construct) { + FindOrInsert("<unk>"); // Force 0 + FindOrInsert("<s>"); // Force 1 + FindOrInsert("</s>"); // Force 2 + } + + WordIndex Index(const StringPiece &str) const { + Lookup::ConstIterator i; + return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; + } + + WordIndex FindOrInsert(const StringPiece &word) { + ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size()); + Lookup::MutableIterator it; + if (!lookup_.FindOrInsert(entry, it)) { + new_word_(word); + UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh"); + } + return it->value; + } + + WordIndex Size() const { return lookup_.Size(); } + + private: + typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup; + + Lookup lookup_; + + NewWordAction new_word_; +}; + } // namespace ngram } // namespace lm -#endif // LM_VOCAB__ +#endif // LM_VOCAB_H diff --git a/klm/lm/weights.hh b/klm/lm/weights.hh index bd5d8034..da1963d8 100644 --- a/klm/lm/weights.hh +++ b/klm/lm/weights.hh @@ -1,5 +1,5 @@ -#ifndef LM_WEIGHTS__ -#define LM_WEIGHTS__ +#ifndef LM_WEIGHTS_H +#define LM_WEIGHTS_H // Weights for n-grams. Probability and possibly a backoff. @@ -19,4 +19,4 @@ struct RestWeights { }; } // namespace lm -#endif // LM_WEIGHTS__ +#endif // LM_WEIGHTS_H diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh index e09557a7..a5a0fda8 100644 --- a/klm/lm/word_index.hh +++ b/klm/lm/word_index.hh @@ -1,6 +1,6 @@ // Separate header because this is used often. -#ifndef LM_WORD_INDEX__ -#define LM_WORD_INDEX__ +#ifndef LM_WORD_INDEX_H +#define LM_WORD_INDEX_H #include <limits.h> diff --git a/klm/lm/wrappers/README b/klm/lm/wrappers/README new file mode 100644 index 00000000..56c34c23 --- /dev/null +++ b/klm/lm/wrappers/README @@ -0,0 +1,3 @@ +This directory is for wrappers around other people's LMs, presenting an interface similar to KenLM's. You will need to have their LM installed. + +NPLM is a work in progress. diff --git a/klm/lm/wrappers/nplm.cc b/klm/lm/wrappers/nplm.cc new file mode 100644 index 00000000..70622bd2 --- /dev/null +++ b/klm/lm/wrappers/nplm.cc @@ -0,0 +1,90 @@ +#include "lm/wrappers/nplm.hh" +#include "util/exception.hh" +#include "util/file.hh" + +#include <algorithm> + +#include <string.h> + +#include "neuralLM.h" + +namespace lm { +namespace np { + +Vocabulary::Vocabulary(const nplm::vocabulary &vocab) + : base::Vocabulary(vocab.lookup_word("<s>"), vocab.lookup_word("</s>"), vocab.lookup_word("<unk>")), + vocab_(vocab), null_word_(vocab.lookup_word("<null>")) {} + +Vocabulary::~Vocabulary() {} + +WordIndex Vocabulary::Index(const std::string &str) const { + return vocab_.lookup_word(str); +} + +bool Model::Recognize(const std::string &name) { + try { + util::scoped_fd file(util::OpenReadOrThrow(name.c_str())); + char magic_check[16]; + util::ReadOrThrow(file.get(), magic_check, sizeof(magic_check)); + const char nnlm_magic[] = "\\config\nversion "; + return !memcmp(magic_check, nnlm_magic, 16); + } catch (const util::Exception &) { + return false; + } +} + +Model::Model(const std::string &file, std::size_t cache) + : base_instance_(new nplm::neuralLM(file)), vocab_(base_instance_->get_vocabulary()), cache_size_(cache) { + UTIL_THROW_IF(base_instance_->get_order() > NPLM_MAX_ORDER, util::Exception, "This NPLM has order " << (unsigned int)base_instance_->get_order() << " but the KenLM wrapper was compiled with " << NPLM_MAX_ORDER << ". Change the defintion of NPLM_MAX_ORDER and recompile."); + // log10 compatible with backoff models. + base_instance_->set_log_base(10.0); + State begin_sentence, null_context; + std::fill(begin_sentence.words, begin_sentence.words + NPLM_MAX_ORDER - 1, base_instance_->lookup_word("<s>")); + null_word_ = base_instance_->lookup_word("<null>"); + std::fill(null_context.words, null_context.words + NPLM_MAX_ORDER - 1, null_word_); + + Init(begin_sentence, null_context, vocab_, base_instance_->get_order()); +} + +Model::~Model() {} + +FullScoreReturn Model::FullScore(const State &from, const WordIndex new_word, State &out_state) const { + nplm::neuralLM *lm = backend_.get(); + if (!lm) { + lm = new nplm::neuralLM(*base_instance_); + backend_.reset(lm); + lm->set_cache(cache_size_); + } + // State is in natural word order. + FullScoreReturn ret; + for (int i = 0; i < lm->get_order() - 1; ++i) { + lm->staging_ngram()(i) = from.words[i]; + } + lm->staging_ngram()(lm->get_order() - 1) = new_word; + ret.prob = lm->lookup_from_staging(); + // Always say full order. + ret.ngram_length = lm->get_order(); + // Shift everything down by one. + memcpy(out_state.words, from.words + 1, sizeof(WordIndex) * (lm->get_order() - 2)); + out_state.words[lm->get_order() - 2] = new_word; + // Fill in trailing words with zeros so state comparison works. + memset(out_state.words + lm->get_order() - 1, 0, sizeof(WordIndex) * (NPLM_MAX_ORDER - lm->get_order())); + return ret; +} + +// TODO: optimize with direct call? +FullScoreReturn Model::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const { + // State is in natural word order. The API here specifies reverse order. + std::size_t state_length = std::min<std::size_t>(Order() - 1, context_rend - context_rbegin); + State state; + // Pad with null words. + for (lm::WordIndex *i = state.words; i < state.words + Order() - 1 - state_length; ++i) { + *i = null_word_; + } + // Put new words at the end. + std::reverse_copy(context_rbegin, context_rbegin + state_length, state.words + Order() - 1 - state_length); + return FullScore(state, new_word, out_state); +} + +} // namespace np +} // namespace lm diff --git a/klm/lm/wrappers/nplm.hh b/klm/lm/wrappers/nplm.hh new file mode 100644 index 00000000..b7dd4a21 --- /dev/null +++ b/klm/lm/wrappers/nplm.hh @@ -0,0 +1,83 @@ +#ifndef LM_WRAPPERS_NPLM_H +#define LM_WRAPPERS_NPLM_H + +#include "lm/facade.hh" +#include "lm/max_order.hh" +#include "util/string_piece.hh" + +#include <boost/thread/tss.hpp> +#include <boost/scoped_ptr.hpp> + +/* Wrapper to NPLM "by Ashish Vaswani, with contributions from David Chiang + * and Victoria Fossum." + * http://nlg.isi.edu/software/nplm/ + */ + +namespace nplm { +class vocabulary; +class neuralLM; +} // namespace nplm + +namespace lm { +namespace np { + +class Vocabulary : public base::Vocabulary { + public: + Vocabulary(const nplm::vocabulary &vocab); + + ~Vocabulary(); + + WordIndex Index(const std::string &str) const; + + // TODO: lobby them to support StringPiece + WordIndex Index(const StringPiece &str) const { + return Index(std::string(str.data(), str.size())); + } + + lm::WordIndex NullWord() const { return null_word_; } + + private: + const nplm::vocabulary &vocab_; + + const lm::WordIndex null_word_; +}; + +// Sorry for imposing my limitations on your code. +#define NPLM_MAX_ORDER 7 + +struct State { + WordIndex words[NPLM_MAX_ORDER - 1]; +}; + +class Model : public lm::base::ModelFacade<Model, State, Vocabulary> { + private: + typedef lm::base::ModelFacade<Model, State, Vocabulary> P; + + public: + // Does this look like an NPLM? + static bool Recognize(const std::string &file); + + explicit Model(const std::string &file, std::size_t cache_size = 1 << 20); + + ~Model(); + + FullScoreReturn FullScore(const State &from, const WordIndex new_word, State &out_state) const; + + FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; + + private: + boost::scoped_ptr<nplm::neuralLM> base_instance_; + + mutable boost::thread_specific_ptr<nplm::neuralLM> backend_; + + Vocabulary vocab_; + + lm::WordIndex null_word_; + + const std::size_t cache_size_; +}; + +} // namespace np +} // namespace lm + +#endif // LM_WRAPPERS_NPLM_H diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 5e650af7..5db6e340 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -1,21 +1,21 @@ -#noinst_PROGRAMS = \ +noinst_PROGRAMS = cat_compressed + +cat_compressed_SOURCES = cat_compressed_main.cc +cat_compressed_LDADD = libklm_util.a + +#TESTS = \ # file_piece_test \ # joint_sort_test \ # key_value_packing_test \ # probing_hash_table_test \ # sorted_uniform_test - -#TESTS = \ # file_piece_test \ # joint_sort_test \ # key_value_packing_test \ # probing_hash_table_test \ # sorted_uniform_test -#file_piece_test_SOURCES = file_piece_test.cc -#file_piece_test_LDADD = libklm_util.a - noinst_LIBRARIES = libklm_util.a libklm_util_a_SOURCES = \ @@ -30,6 +30,8 @@ libklm_util_a_SOURCES = \ file.hh \ file_piece.cc \ file_piece.hh \ + fixed_array.hh \ + getopt.c \ getopt.hh \ have.hh \ joint_sort.hh \ @@ -38,6 +40,8 @@ libklm_util_a_SOURCES = \ multi_intersection.hh \ murmur_hash.cc \ murmur_hash.hh \ + parallel_read.cc \ + parallel_read.hh \ pcqueue.hh \ pool.cc \ pool.hh \ @@ -54,7 +58,9 @@ libklm_util_a_SOURCES = \ string_piece_hash.hh \ thread_pool.hh \ tokenize_piece.hh \ + unistd.hh \ usage.cc \ usage.hh AM_CPPFLAGS = -W -Wall -I$(top_srcdir)/klm -I$(top_srcdir)/klm/util/double-conversion + diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index dcbd814c..1e34d9ab 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_BIT_PACKING__ -#define UTIL_BIT_PACKING__ +#ifndef UTIL_BIT_PACKING_H +#define UTIL_BIT_PACKING_H /* Bit-level packing routines * @@ -183,4 +183,4 @@ struct BitAddress { } // namespace util -#endif // UTIL_BIT_PACKING__ +#endif // UTIL_BIT_PACKING_H diff --git a/klm/util/cat_compressed_main.cc b/klm/util/cat_compressed_main.cc new file mode 100644 index 00000000..2b4d7292 --- /dev/null +++ b/klm/util/cat_compressed_main.cc @@ -0,0 +1,47 @@ +// Like cat but interprets compressed files. +#include "util/file.hh" +#include "util/read_compressed.hh" + +#include <string.h> +#include <iostream> + +namespace { +const std::size_t kBufSize = 16384; +void Copy(util::ReadCompressed &from, int to) { + util::scoped_malloc buffer(util::MallocOrThrow(kBufSize)); + while (std::size_t amount = from.Read(buffer.get(), kBufSize)) { + util::WriteOrThrow(to, buffer.get(), amount); + } +} +} // namespace + +int main(int argc, char *argv[]) { + // Lane Schwartz likes -h and --help + for (int i = 1; i < argc; ++i) { + char *arg = argv[i]; + if (!strcmp(arg, "--")) break; + if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) { + std::cerr << + "A cat implementation that interprets compressed files.\n" + "Usage: " << argv[0] << " [file1] [file2] ...\n" + "If no file is provided, then stdin is read.\n"; + return 1; + } + } + + try { + if (argc == 1) { + util::ReadCompressed in(0); + Copy(in, 1); + } else { + for (int i = 1; i < argc; ++i) { + util::ReadCompressed in(util::OpenReadOrThrow(argv[i])); + Copy(in, 1); + } + } + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 2; + } + return 0; +} diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index b94399a8..535dbde2 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_ERSATZ_PROGRESS__ -#define UTIL_ERSATZ_PROGRESS__ +#ifndef UTIL_ERSATZ_PROGRESS_H +#define UTIL_ERSATZ_PROGRESS_H #include <iostream> #include <string> @@ -55,4 +55,4 @@ class ErsatzProgress { } // namespace util -#endif // UTIL_ERSATZ_PROGRESS__ +#endif // UTIL_ERSATZ_PROGRESS_H diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 0298272b..4e50a6f3 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_EXCEPTION__ -#define UTIL_EXCEPTION__ +#ifndef UTIL_EXCEPTION_H +#define UTIL_EXCEPTION_H #include <exception> #include <limits> @@ -83,6 +83,9 @@ template <class Except, class Data> typename Except::template ExceptionTag<Excep #define UTIL_THROW(Exception, Modify) \ UTIL_THROW_BACKEND(NULL, Exception, , Modify); +#define UTIL_THROW2(Modify) \ + UTIL_THROW_BACKEND(NULL, util::Exception, , Modify); + #if __GNUC__ >= 3 #define UTIL_UNLIKELY(x) __builtin_expect (!!(x), 0) #else @@ -143,4 +146,4 @@ inline std::size_t CheckOverflow(uint64_t value) { } // namespace util -#endif // UTIL_EXCEPTION__ +#endif // UTIL_EXCEPTION_H diff --git a/klm/util/fake_ofstream.hh b/klm/util/fake_ofstream.hh index bcdebe45..eefb1edc 100644 --- a/klm/util/fake_ofstream.hh +++ b/klm/util/fake_ofstream.hh @@ -2,6 +2,9 @@ * Does not support many data types. Currently, it's targeted at writing ARPA * files quickly. */ +#ifndef UTIL_FAKE_OFSTREAM_H +#define UTIL_FAKE_OFSTREAM_H + #include "util/double-conversion/double-conversion.h" #include "util/double-conversion/utils.h" #include "util/file.hh" @@ -17,7 +20,8 @@ class FakeOFStream { static const std::size_t kOutBuf = 1048576; // Does not take ownership of out. - explicit FakeOFStream(int out) + // Allows default constructor, but must call SetFD. + explicit FakeOFStream(int out = -1) : buf_(util::MallocOrThrow(kOutBuf)), builder_(static_cast<char*>(buf_.get()), kOutBuf), // Mostly the default but with inf instead. And no flags. @@ -28,6 +32,11 @@ class FakeOFStream { if (buf_.get()) Flush(); } + void SetFD(int to) { + if (builder_.position()) Flush(); + fd_ = to; + } + FakeOFStream &operator<<(float value) { // Odd, but this is the largest number found in the comments. EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8); @@ -92,3 +101,5 @@ class FakeOFStream { }; } // namespace + +#endif diff --git a/klm/util/file.cc b/klm/util/file.cc index 51eaf972..aa61cf9a 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -5,28 +5,29 @@ #include "util/exception.hh" +#include <algorithm> #include <cstdlib> #include <cstdio> -#include <sstream> #include <iostream> +#include <limits> +#include <sstream> + #include <assert.h> #include <errno.h> +#include <limits.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #include <stdint.h> -#if defined __MINGW32__ +#if defined(__MINGW32__) #include <windows.h> #include <unistd.h> #warning "The file functions on MinGW have not been tested for file sizes above 2^31 - 1. Please read https://stackoverflow.com/questions/12539488/determine-64-bit-file-size-in-c-on-mingw-32-bit and fix" #elif defined(_WIN32) || defined(_WIN64) #include <windows.h> #include <io.h> -#include <algorithm> -#include <limits.h> -#include <limits> #else #include <unistd.h> #endif @@ -40,9 +41,9 @@ scoped_fd::~scoped_fd() { } } -scoped_FILE::~scoped_FILE() { - if (file_ && std::fclose(file_)) { - std::cerr << "Could not close file " << std::endl; +void scoped_FILE_closer::Close(std::FILE *file) { + if (file && std::fclose(file)) { + std::cerr << "Could not close file " << file << std::endl; std::abort(); } } @@ -111,7 +112,7 @@ uint64_t SizeOrThrow(int fd) { void ResizeOrThrow(int fd, uint64_t to) { #if defined __MINGW32__ - // Does this handle 64-bit? + // Does this handle 64-bit? int ret = ftruncate #elif defined(_WIN32) || defined(_WIN64) errno_t ret = _chsize_s @@ -128,8 +129,10 @@ namespace { std::size_t GuardLarge(std::size_t size) { // The following operating systems have broken read/write/pread/pwrite that // only supports up to 2^31. + // OS X man pages claim to support 64-bit, but Kareem M. Darwish had problems + // building with larger files, so APPLE is also here. #if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID) || defined(__MINGW32__) - return std::min(static_cast<std::size_t>(static_cast<unsigned>(-1)), size); + return size < INT_MAX ? size : INT_MAX; #else return size; #endif @@ -172,13 +175,44 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) { return amount; } -void PReadOrThrow(int fd, void *to_void, std::size_t size, uint64_t off) { - uint8_t *to = static_cast<uint8_t*>(to_void); +void WriteOrThrow(int fd, const void *data_void, std::size_t size) { + const uint8_t *data = static_cast<const uint8_t*>(data_void); + while (size) { #if defined(_WIN32) || defined(_WIN64) - UTIL_THROW(Exception, "This pread implementation for windows is broken. Please send me a patch that does not change the file pointer. Atomically. Or send me an implementation of pwrite that is allowed to change the file pointer but can be called concurrently with pread."); - const std::size_t kMaxDWORD = static_cast<std::size_t>(4294967295UL); + int ret; +#else + ssize_t ret; #endif - for (;size ;) { + errno = 0; + do { + ret = +#if defined(_WIN32) || defined(_WIN64) + _write +#else + write +#endif + (fd, data, GuardLarge(size)); + } while (ret == -1 && errno == EINTR); + UTIL_THROW_IF_ARG(ret < 1, FDException, (fd), "while writing " << size << " bytes"); + data += ret; + size -= ret; + } +} + +void WriteOrThrow(FILE *to, const void *data, std::size_t size) { + if (!size) return; + UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), ErrnoException, "Short write; requested size " << size); +} + +#if defined(_WIN32) || defined(_WIN64) +namespace { +const std::size_t kMaxDWORD = static_cast<std::size_t>(4294967295UL); +} // namespace +#endif + +void ErsatzPRead(int fd, void *to_void, std::size_t size, uint64_t off) { + uint8_t *to = static_cast<uint8_t*>(to_void); + while (size) { #if defined(_WIN32) || defined(_WIN64) /* BROKEN: changes file pointer. Even if you save it and change it back, it won't be safe to use concurrently with write() or read() which lmplz does. */ // size_t might be 64-bit. DWORD is always 32. @@ -192,16 +226,15 @@ void PReadOrThrow(int fd, void *to_void, std::size_t size, uint64_t off) { #else ssize_t ret; errno = 0; - do { - ret = + ret = #ifdef OS_ANDROID - pread64 + pread64 #else - pread + pread #endif - (fd, to, GuardLarge(size), off); - } while (ret == -1 && errno == EINTR); + (fd, to, GuardLarge(size), off); if (ret <= 0) { + if (ret == -1 && errno == EINTR) continue; UTIL_THROW_IF(ret == 0, EndOfFileException, " for reading " << size << " bytes at " << off << " from " << NameFromFD(fd)); UTIL_THROW_ARG(FDException, (fd), "while reading " << size << " bytes at offset " << off); } @@ -212,34 +245,41 @@ void PReadOrThrow(int fd, void *to_void, std::size_t size, uint64_t off) { } } -void WriteOrThrow(int fd, const void *data_void, std::size_t size) { - const uint8_t *data = static_cast<const uint8_t*>(data_void); - while (size) { +void ErsatzPWrite(int fd, const void *from_void, std::size_t size, uint64_t off) { + const uint8_t *from = static_cast<const uint8_t*>(from_void); + while(size) { #if defined(_WIN32) || defined(_WIN64) - int ret; + /* Changes file pointer. Even if you save it and change it back, it won't be safe to use concurrently with write() or read() */ + // size_t might be 64-bit. DWORD is always 32. + DWORD writing = static_cast<DWORD>(std::min<std::size_t>(kMaxDWORD, size)); + DWORD ret; + OVERLAPPED overlapped; + memset(&overlapped, 0, sizeof(OVERLAPPED)); + overlapped.Offset = static_cast<DWORD>(off); + overlapped.OffsetHigh = static_cast<DWORD>(off >> 32); + UTIL_THROW_IF(!WriteFile((HANDLE)_get_osfhandle(fd), from, writing, &ret, &overlapped), Exception, "WriteFile failed for offset " << off); #else ssize_t ret; -#endif errno = 0; - do { - ret = -#if defined(_WIN32) || defined(_WIN64) - _write + ret = +#ifdef OS_ANDROID + pwrite64 #else - write + pwrite +#endif + (fd, from, GuardLarge(size), off); + if (ret <= 0) { + if (ret == -1 && errno == EINTR) continue; + UTIL_THROW_IF(ret == 0, EndOfFileException, " for writing " << size << " bytes at " << off << " from " << NameFromFD(fd)); + UTIL_THROW_ARG(FDException, (fd), "while writing " << size << " bytes at offset " << off); + } #endif - (fd, data, GuardLarge(size)); - } while (ret == -1 && errno == EINTR); - UTIL_THROW_IF_ARG(ret < 1, FDException, (fd), "while writing " << size << " bytes"); - data += ret; size -= ret; + off += ret; + from += ret; } } -void WriteOrThrow(FILE *to, const void *data, std::size_t size) { - if (!size) return; - UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), ErrnoException, "Short write; requested size " << size); -} void FSyncOrThrow(int fd) { // Apparently windows doesn't have fsync? @@ -443,8 +483,8 @@ void NormalizeTempPrefix(std::string &base) { ) base += '/'; } -int MakeTemp(const std::string &base) { - std::string name(base); +int MakeTemp(const StringPiece &base) { + std::string name(base.data(), base.size()); name += "XXXXXX"; name.push_back(0); int ret; @@ -452,7 +492,7 @@ int MakeTemp(const std::string &base) { return ret; } -std::FILE *FMakeTemp(const std::string &base) { +std::FILE *FMakeTemp(const StringPiece &base) { util::scoped_fd file(MakeTemp(base)); return FDOpenOrThrow(file); } @@ -478,14 +518,18 @@ bool TryName(int fd, std::string &out) { if (-1 == lstat(name.c_str(), &sb)) return false; out.resize(sb.st_size + 1); - ssize_t ret = readlink(name.c_str(), &out[0], sb.st_size + 1); - if (-1 == ret) - return false; - if (ret > sb.st_size) { - // Increased in size?! - return false; + // lstat gave us a size, but I've seen it grow, possibly due to symlinks on top of symlinks. + while (true) { + ssize_t ret = readlink(name.c_str(), &out[0], out.size()); + if (-1 == ret) + return false; + if ((size_t)ret < out.size()) { + out.resize(ret); + break; + } + // Exponential growth. + out.resize(out.size() * 2); } - out.resize(ret); // Don't use the non-file names. if (!out.empty() && out[0] != '/') return false; diff --git a/klm/util/file.hh b/klm/util/file.hh index be88431d..7204b6a0 100644 --- a/klm/util/file.hh +++ b/klm/util/file.hh @@ -1,7 +1,9 @@ -#ifndef UTIL_FILE__ -#define UTIL_FILE__ +#ifndef UTIL_FILE_H +#define UTIL_FILE_H #include "util/exception.hh" +#include "util/scoped.hh" +#include "util/string_piece.hh" #include <cstddef> #include <cstdio> @@ -41,29 +43,10 @@ class scoped_fd { scoped_fd &operator=(const scoped_fd &); }; -class scoped_FILE { - public: - explicit scoped_FILE(std::FILE *file = NULL) : file_(file) {} - - ~scoped_FILE(); - - std::FILE *get() { return file_; } - const std::FILE *get() const { return file_; } - - void reset(std::FILE *to = NULL) { - scoped_FILE other(file_); - file_ = to; - } - - std::FILE *release() { - std::FILE *ret = file_; - file_ = NULL; - return ret; - } - - private: - std::FILE *file_; +struct scoped_FILE_closer { + static void Close(std::FILE *file); }; +typedef scoped<std::FILE, scoped_FILE_closer> scoped_FILE; /* Thrown for any operation where the fd is known. */ class FDException : public ErrnoException { @@ -106,12 +89,20 @@ void ResizeOrThrow(int fd, uint64_t to); std::size_t PartialRead(int fd, void *to, std::size_t size); void ReadOrThrow(int fd, void *to, std::size_t size); std::size_t ReadOrEOF(int fd, void *to_void, std::size_t size); -// Positioned: unix only for now. -void PReadOrThrow(int fd, void *to, std::size_t size, uint64_t off); void WriteOrThrow(int fd, const void *data_void, std::size_t size); void WriteOrThrow(FILE *to, const void *data, std::size_t size); +/* These call pread/pwrite in a loop. However, on Windows they call ReadFile/ + * WriteFile which changes the file pointer. So it's safe to call ErsatzPRead + * and ErsatzPWrite concurrently (or any combination thereof). But it changes + * the file pointer on windows, so it's not safe to call concurrently with + * anything that uses the implicit file pointer e.g. the Read/Write functions + * above. + */ +void ErsatzPRead(int fd, void *to, std::size_t size, uint64_t off); +void ErsatzPWrite(int fd, const void *data_void, std::size_t size, uint64_t off); + void FSyncOrThrow(int fd); // Seeking @@ -125,8 +116,8 @@ std::FILE *FDOpenReadOrThrow(scoped_fd &file); // Temporary files // Append a / if base is a directory. void NormalizeTempPrefix(std::string &base); -int MakeTemp(const std::string &prefix); -std::FILE *FMakeTemp(const std::string &prefix); +int MakeTemp(const StringPiece &prefix); +std::FILE *FMakeTemp(const StringPiece &prefix); // dup an fd. int DupOrThrow(int fd); @@ -139,4 +130,4 @@ std::string NameFromFD(int fd); } // namespace util -#endif // UTIL_FILE__ +#endif // UTIL_FILE_H diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index 9c7e00c4..4aaa250e 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -84,6 +84,13 @@ StringPiece FilePiece::ReadLine(char delim) { } } +bool FilePiece::ReadLineOrEOF(StringPiece &to, char delim) { + try { + to = ReadLine(delim); + } catch (const util::EndOfFileException &e) { return false; } + return true; +} + float FilePiece::ReadFloat() { return ReadNumber<float>(); } diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index ed3dc5ad..5495ddcc 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_FILE_PIECE__ -#define UTIL_FILE_PIECE__ +#ifndef UTIL_FILE_PIECE_H +#define UTIL_FILE_PIECE_H #include "util/ersatz_progress.hh" #include "util/exception.hh" @@ -56,10 +56,33 @@ class FilePiece { return Consume(FindDelimiterOrEOF(delim)); } + // Read word until the line or file ends. + bool ReadWordSameLine(StringPiece &to, const bool *delim = kSpaces) { + assert(delim[static_cast<unsigned char>('\n')]); + // Skip non-enter spaces. + for (; ; ++position_) { + if (position_ == position_end_) { + try { + Shift(); + } catch (const util::EndOfFileException &e) { return false; } + // And break out at end of file. + if (position_ == position_end_) return false; + } + if (!delim[static_cast<unsigned char>(*position_)]) break; + if (*position_ == '\n') return false; + } + // We can't be at the end of file because there's at least one character open. + to = Consume(FindDelimiterOrEOF(delim)); + return true; + } + // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. // It is similar to getline in that way. StringPiece ReadLine(char delim = '\n'); + // Doesn't throw EndOfFileException, just returns false. + bool ReadLineOrEOF(StringPiece &to, char delim = '\n'); + float ReadFloat(); double ReadDouble(); long int ReadLong(); @@ -132,4 +155,4 @@ class FilePiece { } // namespace util -#endif // UTIL_FILE_PIECE__ +#endif // UTIL_FILE_PIECE_H diff --git a/klm/util/fixed_array.hh b/klm/util/fixed_array.hh new file mode 100644 index 00000000..416b92f4 --- /dev/null +++ b/klm/util/fixed_array.hh @@ -0,0 +1,153 @@ +#ifndef UTIL_FIXED_ARRAY_H +#define UTIL_FIXED_ARRAY_H + +#include "util/scoped.hh" + +#include <cstddef> + +#include <assert.h> +#include <stdlib.h> + +namespace util { + +/** + * Defines a fixed-size collection. + * + * Ever want an array of things by they don't have a default constructor or are + * non-copyable? FixedArray allows constructing one at a time. + */ +template <class T> class FixedArray { + public: + /** Initialize with a given size bound but do not construct the objects. */ + explicit FixedArray(std::size_t limit) { + Init(limit); + } + + /** + * Constructs an instance, but does not initialize it. + * + * Any objects constructed in this manner must be subsequently @ref FixedArray::Init() "initialized" prior to use. + * + * @see FixedArray::Init() + */ + FixedArray() + : newed_end_(NULL) +#ifndef NDEBUG + , allocated_end_(NULL) +#endif + {} + + /** + * Initialize with a given size bound but do not construct the objects. + * + * This method is responsible for allocating memory. + * Objects stored in this array will be constructed in a location within this allocated memory. + */ + void Init(std::size_t count) { + assert(!block_.get()); + block_.reset(malloc(sizeof(T) * count)); + if (!block_.get()) throw std::bad_alloc(); + newed_end_ = begin(); +#ifndef NDEBUG + allocated_end_ = begin() + count; +#endif + } + + /** + * Constructs a copy of the provided array. + * + * @param from Array whose elements should be copied into this newly-constructed data structure. + */ + FixedArray(const FixedArray &from) { + std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get()); + Init(size); + for (std::size_t i = 0; i < size; ++i) { + push_back(from[i]); + } + } + + /** + * Frees the memory held by this object. + */ + ~FixedArray() { clear(); } + + /** Gets a pointer to the first object currently stored in this data structure. */ + T *begin() { return static_cast<T*>(block_.get()); } + + /** Gets a const pointer to the last object currently stored in this data structure. */ + const T *begin() const { return static_cast<const T*>(block_.get()); } + + /** Gets a pointer to the last object currently stored in this data structure. */ + T *end() { return newed_end_; } + + /** Gets a const pointer to the last object currently stored in this data structure. */ + const T *end() const { return newed_end_; } + + /** Gets a reference to the last object currently stored in this data structure. */ + T &back() { return *(end() - 1); } + + /** Gets a const reference to the last object currently stored in this data structure. */ + const T &back() const { return *(end() - 1); } + + /** Gets the number of objects currently stored in this data structure. */ + std::size_t size() const { return end() - begin(); } + + /** Returns true if there are no objects currently stored in this data structure. */ + bool empty() const { return begin() == end(); } + + /** + * Gets a reference to the object with index i currently stored in this data structure. + * + * @param i Index of the object to reference + */ + T &operator[](std::size_t i) { return begin()[i]; } + + /** + * Gets a const reference to the object with index i currently stored in this data structure. + * + * @param i Index of the object to reference + */ + const T &operator[](std::size_t i) const { return begin()[i]; } + + /** + * Constructs a new object using the provided parameter, + * and stores it in this data structure. + * + * The memory backing the constructed object is managed by this data structure. + */ + template <class C> void push_back(const C &c) { + new (end()) T(c); // use "placement new" syntax to initalize T in an already-allocated memory location + Constructed(); + } + + /** + * Removes all elements from this array. + */ + void clear() { + for (T *i = begin(); i != end(); ++i) + i->~T(); + newed_end_ = begin(); + } + + protected: + // Always call Constructed after successful completion of new. + void Constructed() { + ++newed_end_; +#ifndef NDEBUG + assert(newed_end_ <= allocated_end_); +#endif + } + + private: + util::scoped_malloc block_; + + T *newed_end_; + +#ifndef NDEBUG + T *allocated_end_; +#endif +}; + +} // namespace util + +#endif // UTIL_FIXED_ARRAY_H diff --git a/klm/util/getopt.hh b/klm/util/getopt.hh index 6ad97732..50eab56f 100644 --- a/klm/util/getopt.hh +++ b/klm/util/getopt.hh @@ -11,8 +11,8 @@ Code given out at the 1985 UNIFORUM conference in Dallas. #endif #ifndef __GNUC__ -#ifndef _WINGETOPT_H_ -#define _WINGETOPT_H_ +#ifndef UTIL_GETOPT_H +#define UTIL_GETOPT_H #ifdef __cplusplus extern "C" { @@ -28,6 +28,6 @@ extern int getopt(int argc, char **argv, char *opts); } #endif -#endif /* _GETOPT_H_ */ +#endif /* UTIL_GETOPT_H */ #endif /* __GNUC__ */ diff --git a/klm/util/have.hh b/klm/util/have.hh index 6e18529d..dc3f6330 100644 --- a/klm/util/have.hh +++ b/klm/util/have.hh @@ -1,6 +1,6 @@ /* Optional packages. You might want to integrate this with your build system e.g. config.h from ./configure. */ -#ifndef UTIL_HAVE__ -#define UTIL_HAVE__ +#ifndef UTIL_HAVE_H +#define UTIL_HAVE_H #ifdef HAVE_CONFIG_H #include "config.h" @@ -10,4 +10,4 @@ //#define HAVE_ICU #endif -#endif // UTIL_HAVE__ +#endif // UTIL_HAVE_H diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh index b1ec48e2..de4b554f 100644 --- a/klm/util/joint_sort.hh +++ b/klm/util/joint_sort.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_JOINT_SORT__ -#define UTIL_JOINT_SORT__ +#ifndef UTIL_JOINT_SORT_H +#define UTIL_JOINT_SORT_H /* A terrifying amount of C++ to coax std::sort into soring one range while * also permuting another range the same way. @@ -143,4 +143,4 @@ template <class KeyIter, class ValueIter> void JointSort(const KeyIter &key_begi } // namespace util -#endif // UTIL_JOINT_SORT__ +#endif // UTIL_JOINT_SORT_H diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index cee6a970..a3c8a022 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -6,6 +6,7 @@ #include "util/exception.hh" #include "util/file.hh" +#include "util/parallel_read.hh" #include "util/scoped.hh" #include <iostream> @@ -40,7 +41,7 @@ void SyncOrThrow(void *start, size_t length) { #if defined(_WIN32) || defined(_WIN64) UTIL_THROW_IF(!::FlushViewOfFile(start, length), ErrnoException, "Failed to sync mmap"); #else - UTIL_THROW_IF(msync(start, length, MS_SYNC), ErrnoException, "Failed to sync mmap"); + UTIL_THROW_IF(length && msync(start, length, MS_SYNC), ErrnoException, "Failed to sync mmap"); #endif } @@ -154,6 +155,10 @@ void MapRead(LoadMethod method, int fd, uint64_t offset, std::size_t size, scope SeekOrThrow(fd, offset); ReadOrThrow(fd, out.get(), size); break; + case PARALLEL_READ: + out.reset(MallocOrThrow(size), size, scoped_memory::MALLOC_ALLOCATED); + ParallelRead(fd, out.get(), size, offset); + break; } } @@ -189,4 +194,66 @@ void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) { } } +Rolling::Rolling(const Rolling ©_from, uint64_t increase) { + *this = copy_from; + IncreaseBase(increase); +} + +Rolling &Rolling::operator=(const Rolling ©_from) { + fd_ = copy_from.fd_; + file_begin_ = copy_from.file_begin_; + file_end_ = copy_from.file_end_; + for_write_ = copy_from.for_write_; + block_ = copy_from.block_; + read_bound_ = copy_from.read_bound_; + + current_begin_ = 0; + if (copy_from.IsPassthrough()) { + current_end_ = copy_from.current_end_; + ptr_ = copy_from.ptr_; + } else { + // Force call on next mmap. + current_end_ = 0; + ptr_ = NULL; + } + return *this; +} + +Rolling::Rolling(int fd, bool for_write, std::size_t block, std::size_t read_bound, uint64_t offset, uint64_t amount) { + current_begin_ = 0; + current_end_ = 0; + fd_ = fd; + file_begin_ = offset; + file_end_ = offset + amount; + for_write_ = for_write; + block_ = block; + read_bound_ = read_bound; +} + +void *Rolling::ExtractNonRolling(scoped_memory &out, uint64_t index, std::size_t size) { + out.reset(); + if (IsPassthrough()) return static_cast<uint8_t*>(get()) + index; + uint64_t offset = index + file_begin_; + // Round down to multiple of page size. + uint64_t cruft = offset % static_cast<uint64_t>(SizePage()); + std::size_t map_size = static_cast<std::size_t>(size + cruft); + out.reset(MapOrThrow(map_size, for_write_, kFileFlags, true, fd_, offset - cruft), map_size, scoped_memory::MMAP_ALLOCATED); + return static_cast<uint8_t*>(out.get()) + static_cast<std::size_t>(cruft); +} + +void Rolling::Roll(uint64_t index) { + assert(!IsPassthrough()); + std::size_t amount; + if (file_end_ - (index + file_begin_) > static_cast<uint64_t>(block_)) { + amount = block_; + current_end_ = index + amount - read_bound_; + } else { + amount = file_end_ - (index + file_begin_); + current_end_ = index + amount; + } + ptr_ = static_cast<uint8_t*>(ExtractNonRolling(mem_, index, amount)) - index; + + current_begin_ = index; +} + } // namespace util diff --git a/klm/util/mmap.hh b/klm/util/mmap.hh index b218c4d1..9b1e120f 100644 --- a/klm/util/mmap.hh +++ b/klm/util/mmap.hh @@ -1,8 +1,9 @@ -#ifndef UTIL_MMAP__ -#define UTIL_MMAP__ +#ifndef UTIL_MMAP_H +#define UTIL_MMAP_H // Utilities for mmaped files. #include <cstddef> +#include <limits> #include <stdint.h> #include <sys/types.h> @@ -52,6 +53,9 @@ class scoped_memory { public: typedef enum {MMAP_ALLOCATED, ARRAY_ALLOCATED, MALLOC_ALLOCATED, NONE_ALLOCATED} Alloc; + scoped_memory(void *data, std::size_t size, Alloc source) + : data_(data), size_(size), source_(source) {} + scoped_memory() : data_(NULL), size_(0), source_(NONE_ALLOCATED) {} ~scoped_memory() { reset(); } @@ -72,7 +76,6 @@ class scoped_memory { void call_realloc(std::size_t to); private: - void *data_; std::size_t size_; @@ -90,7 +93,9 @@ typedef enum { // Populate on Linux. malloc and read on non-Linux. POPULATE_OR_READ, // malloc and read. - READ + READ, + // malloc and read in parallel (recommended for Lustre) + PARALLEL_READ, } LoadMethod; extern const int kFileFlags; @@ -109,6 +114,79 @@ void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file); // msync wrapper void SyncOrThrow(void *start, size_t length); +// Forward rolling memory map with no overlap. +class Rolling { + public: + Rolling() {} + + explicit Rolling(void *data) { Init(data); } + + Rolling(const Rolling ©_from, uint64_t increase = 0); + Rolling &operator=(const Rolling ©_from); + + // For an actual rolling mmap. + explicit Rolling(int fd, bool for_write, std::size_t block, std::size_t read_bound, uint64_t offset, uint64_t amount); + + // For a static mapping + void Init(void *data) { + ptr_ = data; + current_end_ = std::numeric_limits<uint64_t>::max(); + current_begin_ = 0; + // Mark as a pass-through. + fd_ = -1; + } + + void IncreaseBase(uint64_t by) { + file_begin_ += by; + ptr_ = static_cast<uint8_t*>(ptr_) + by; + if (!IsPassthrough()) current_end_ = 0; + } + + void DecreaseBase(uint64_t by) { + file_begin_ -= by; + ptr_ = static_cast<uint8_t*>(ptr_) - by; + if (!IsPassthrough()) current_end_ = 0; + } + + void *ExtractNonRolling(scoped_memory &out, uint64_t index, std::size_t size); + + // Returns base pointer + void *get() const { return ptr_; } + + // Returns base pointer. + void *CheckedBase(uint64_t index) { + if (index >= current_end_ || index < current_begin_) { + Roll(index); + } + return ptr_; + } + + // Returns indexed pointer. + void *CheckedIndex(uint64_t index) { + return static_cast<uint8_t*>(CheckedBase(index)) + index; + } + + private: + void Roll(uint64_t index); + + // True if this is just a thin wrapper on a pointer. + bool IsPassthrough() const { return fd_ == -1; } + + void *ptr_; + uint64_t current_begin_; + uint64_t current_end_; + + scoped_memory mem_; + + int fd_; + uint64_t file_begin_; + uint64_t file_end_; + + bool for_write_; + std::size_t block_; + std::size_t read_bound_; +}; + } // namespace util -#endif // UTIL_MMAP__ +#endif // UTIL_MMAP_H diff --git a/klm/util/multi_intersection.hh b/klm/util/multi_intersection.hh index 04678352..2955acc7 100644 --- a/klm/util/multi_intersection.hh +++ b/klm/util/multi_intersection.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_MULTI_INTERSECTION__ -#define UTIL_MULTI_INTERSECTION__ +#ifndef UTIL_MULTI_INTERSECTION_H +#define UTIL_MULTI_INTERSECTION_H #include <boost/optional.hpp> #include <boost/range/iterator_range.hpp> @@ -77,4 +77,4 @@ template <class Iterator, class Output> void AllIntersection(std::vector<boost:: } // namespace util -#endif // UTIL_MULTI_INTERSECTION__ +#endif // UTIL_MULTI_INTERSECTION_H diff --git a/klm/util/murmur_hash.hh b/klm/util/murmur_hash.hh index 4891833e..f17157cd 100644 --- a/klm/util/murmur_hash.hh +++ b/klm/util/murmur_hash.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_MURMUR_HASH__ -#define UTIL_MURMUR_HASH__ +#ifndef UTIL_MURMUR_HASH_H +#define UTIL_MURMUR_HASH_H #include <cstddef> #include <stdint.h> @@ -15,4 +15,4 @@ uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed = 0); } // namespace util -#endif // UTIL_MURMUR_HASH__ +#endif // UTIL_MURMUR_HASH_H diff --git a/klm/util/parallel_read.cc b/klm/util/parallel_read.cc new file mode 100644 index 00000000..6435eb84 --- /dev/null +++ b/klm/util/parallel_read.cc @@ -0,0 +1,69 @@ +#include "util/parallel_read.hh" + +#include "util/file.hh" + +#ifdef WITH_THREADS +#include "util/thread_pool.hh" + +namespace util { +namespace { + +class Reader { + public: + explicit Reader(int fd) : fd_(fd) {} + + struct Request { + void *to; + std::size_t size; + uint64_t offset; + + bool operator==(const Request &other) const { + return (to == other.to) && (size == other.size) && (offset == other.offset); + } + }; + + void operator()(const Request &request) { + util::ErsatzPRead(fd_, request.to, request.size, request.offset); + } + + private: + int fd_; +}; + +} // namespace + +void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset) { + Reader::Request poison; + poison.to = NULL; + poison.size = 0; + poison.offset = 0; + unsigned threads = boost::thread::hardware_concurrency(); + if (!threads) threads = 2; + ThreadPool<Reader> pool(2 /* don't need much of a queue */, threads, fd, poison); + const std::size_t kBatch = 1ULL << 25; // 32 MB + Reader::Request request; + request.to = to; + request.size = kBatch; + request.offset = offset; + for (; amount > kBatch; amount -= kBatch) { + pool.Produce(request); + request.to = reinterpret_cast<uint8_t*>(request.to) + kBatch; + request.offset += kBatch; + } + request.size = amount; + if (request.size) { + pool.Produce(request); + } +} + +} // namespace util + +#else // WITH_THREADS + +namespace util { +void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset) { + util::ErsatzPRead(fd, to, amount, offset); +} +} // namespace util + +#endif diff --git a/klm/util/parallel_read.hh b/klm/util/parallel_read.hh new file mode 100644 index 00000000..1e96e790 --- /dev/null +++ b/klm/util/parallel_read.hh @@ -0,0 +1,16 @@ +#ifndef UTIL_PARALLEL_READ__ +#define UTIL_PARALLEL_READ__ + +/* Read pieces of a file in parallel. This has a very specific use case: + * reading files from Lustre is CPU bound so multiple threads actually + * increases throughput. Speed matters when an LM takes a terabyte. + */ + +#include <cstddef> +#include <stdint.h> + +namespace util { +void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset); +} // namespace util + +#endif // UTIL_PARALLEL_READ__ diff --git a/klm/util/pcqueue.hh b/klm/util/pcqueue.hh index 07e4146f..d2ffee77 100644 --- a/klm/util/pcqueue.hh +++ b/klm/util/pcqueue.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_PCQUEUE__ -#define UTIL_PCQUEUE__ +#ifndef UTIL_PCQUEUE_H +#define UTIL_PCQUEUE_H #include "util/exception.hh" @@ -72,7 +72,8 @@ inline void WaitSemaphore (Semaphore &on) { #endif // __APPLE__ -/* Producer consumer queue safe for multiple producers and multiple consumers. +/** + * Producer consumer queue safe for multiple producers and multiple consumers. * T must be default constructable and have operator=. * The value is copied twice for Consume(T &out) or three times for Consume(), * so larger objects should be passed via pointer. @@ -152,4 +153,4 @@ template <class T> class PCQueue : boost::noncopyable { } // namespace util -#endif // UTIL_PCQUEUE__ +#endif // UTIL_PCQUEUE_H diff --git a/klm/util/pool.hh b/klm/util/pool.hh index 72f8a0c8..89e793d7 100644 --- a/klm/util/pool.hh +++ b/klm/util/pool.hh @@ -1,8 +1,8 @@ // Very simple pool. It can only allocate memory. And all of the memory it // allocates must be freed at the same time. -#ifndef UTIL_POOL__ -#define UTIL_POOL__ +#ifndef UTIL_POOL_H +#define UTIL_POOL_H #include <vector> @@ -42,4 +42,4 @@ class Pool { } // namespace util -#endif // UTIL_POOL__ +#endif // UTIL_POOL_H diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 38524806..ea228dd9 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_PROBING_HASH_TABLE__ -#define UTIL_PROBING_HASH_TABLE__ +#ifndef UTIL_PROBING_HASH_TABLE_H +#define UTIL_PROBING_HASH_TABLE_H #include "util/exception.hh" #include "util/scoped.hh" @@ -258,6 +258,10 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry private: typedef ProbingHashTable<EntryT, HashT, EqualT> Backend; public: + static std::size_t MemUsage(std::size_t size, float multiplier = 1.5) { + return Backend::Size(size, multiplier); + } + typedef EntryT Entry; typedef typename Entry::Key Key; typedef const Entry *ConstIterator; @@ -268,6 +272,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry AutoProbing(std::size_t initial_size = 10, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) : allocated_(Backend::Size(initial_size, 1.5)), mem_(util::MallocOrThrow(allocated_)), backend_(mem_.get(), allocated_, invalid, hash_func, equal_func) { threshold_ = initial_size * 1.2; + Clear(); } // Assumes that the key is unique. Multiple insertions won't cause a failure, just inconsistent lookup. @@ -323,4 +328,4 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry } // namespace util -#endif // UTIL_PROBING_HASH_TABLE__ +#endif // UTIL_PROBING_HASH_TABLE_H diff --git a/klm/util/proxy_iterator.hh b/klm/util/proxy_iterator.hh index a2810a47..8aa697bf 100644 --- a/klm/util/proxy_iterator.hh +++ b/klm/util/proxy_iterator.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_PROXY_ITERATOR__ -#define UTIL_PROXY_ITERATOR__ +#ifndef UTIL_PROXY_ITERATOR_H +#define UTIL_PROXY_ITERATOR_H #include <cstddef> #include <iterator> @@ -98,4 +98,4 @@ template <class Proxy> ProxyIterator<Proxy> operator+(std::ptrdiff_t amount, con } // namespace util -#endif // UTIL_PROXY_ITERATOR__ +#endif // UTIL_PROXY_ITERATOR_H diff --git a/klm/util/read_compressed.cc b/klm/util/read_compressed.cc index b62a6e83..cee98040 100644 --- a/klm/util/read_compressed.cc +++ b/klm/util/read_compressed.cc @@ -49,6 +49,8 @@ class ReadBase { thunk.internal_.reset(with); } + ReadBase *Current(ReadCompressed &thunk) { return thunk.internal_.get(); } + static uint64_t &ReadCount(ReadCompressed &thunk) { return thunk.raw_amount_; } @@ -56,6 +58,8 @@ class ReadBase { namespace { +ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, std::size_t already_size, bool require_compressed); + // Completed file that other classes can thunk to. class Complete : public ReadBase { public: @@ -80,7 +84,7 @@ class Uncompressed : public ReadBase { class UncompressedWithHeader : public ReadBase { public: - UncompressedWithHeader(int fd, void *already_data, std::size_t already_size) : fd_(fd) { + UncompressedWithHeader(int fd, const void *already_data, std::size_t already_size) : fd_(fd) { assert(already_size); buf_.reset(malloc(already_size)); if (!buf_.get()) throw std::bad_alloc(); @@ -91,6 +95,7 @@ class UncompressedWithHeader : public ReadBase { std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { assert(buf_.get()); + assert(remain_ != end_); std::size_t sending = std::min<std::size_t>(amount, end_ - remain_); memcpy(to, remain_, sending); remain_ += sending; @@ -108,23 +113,51 @@ class UncompressedWithHeader : public ReadBase { scoped_fd fd_; }; -#ifdef HAVE_ZLIB -class GZip : public ReadBase { +static const std::size_t kInputBuffer = 16384; + +template <class Compression> class StreamCompressed : public ReadBase { + public: + StreamCompressed(int fd, const void *already_data, std::size_t already_size) + : file_(fd), + in_buffer_(MallocOrThrow(kInputBuffer)), + back_(memcpy(in_buffer_.get(), already_data, already_size), already_size) {} + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + if (amount == 0) return 0; + back_.SetOutput(to, amount); + do { + if (!back_.Stream().avail_in) ReadInput(thunk); + if (!back_.Process()) { + // reached end, at least for the compressed portion. + std::size_t ret = static_cast<const uint8_t *>(static_cast<void*>(back_.Stream().next_out)) - static_cast<const uint8_t*>(to); + ReplaceThis(ReadFactory(file_.release(), ReadCount(thunk), back_.Stream().next_in, back_.Stream().avail_in, true), thunk); + if (ret) return ret; + // We did not read anything this round, so clients might think EOF. Transfer responsibility to the next reader. + return Current(thunk)->Read(to, amount, thunk); + } + } while (back_.Stream().next_out == to); + return static_cast<const uint8_t*>(static_cast<void*>(back_.Stream().next_out)) - static_cast<const uint8_t*>(to); + } + private: - static const std::size_t kInputBuffer = 16384; + void ReadInput(ReadCompressed &thunk) { + assert(!back_.Stream().avail_in); + std::size_t got = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); + back_.SetInput(in_buffer_.get(), got); + ReadCount(thunk) += got; + } + + scoped_fd file_; + scoped_malloc in_buffer_; + + Compression back_; +}; + +#ifdef HAVE_ZLIB +class GZip { public: - GZip(int fd, void *already_data, std::size_t already_size) - : file_(fd), in_buffer_(malloc(kInputBuffer)) { - if (!in_buffer_.get()) throw std::bad_alloc(); - assert(already_size < kInputBuffer); - if (already_size) { - memcpy(in_buffer_.get(), already_data, already_size); - stream_.next_in = static_cast<Bytef *>(in_buffer_.get()); - stream_.avail_in = already_size; - stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size); - } else { - stream_.avail_in = 0; - } + GZip(const void *base, std::size_t amount) { + SetInput(base, amount); stream_.zalloc = Z_NULL; stream_.zfree = Z_NULL; stream_.opaque = Z_NULL; @@ -141,227 +174,154 @@ class GZip : public ReadBase { } } - std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { - if (amount == 0) return 0; + void SetOutput(void *to, std::size_t amount) { stream_.next_out = static_cast<Bytef*>(to); stream_.avail_out = std::min<std::size_t>(std::numeric_limits<uInt>::max(), amount); - do { - if (!stream_.avail_in) ReadInput(thunk); - int result = inflate(&stream_, 0); - switch (result) { - case Z_OK: - break; - case Z_STREAM_END: - { - std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); - ReplaceThis(new Complete(), thunk); - return ret; - } - case Z_ERRNO: - UTIL_THROW(ErrnoException, "zlib error"); - default: - UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result); - } - } while (stream_.next_out == to); - return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); } - private: - void ReadInput(ReadCompressed &thunk) { - assert(!stream_.avail_in); - stream_.next_in = static_cast<Bytef *>(in_buffer_.get()); - stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); - ReadCount(thunk) += stream_.avail_in; + void SetInput(const void *base, std::size_t amount) { + assert(amount < static_cast<std::size_t>(std::numeric_limits<uInt>::max())); + stream_.next_in = const_cast<Bytef*>(static_cast<const Bytef*>(base)); + stream_.avail_in = amount; } - scoped_fd file_; - scoped_malloc in_buffer_; + const z_stream &Stream() const { return stream_; } + + bool Process() { + int result = inflate(&stream_, 0); + switch (result) { + case Z_OK: + return true; + case Z_STREAM_END: + return false; + case Z_ERRNO: + UTIL_THROW(ErrnoException, "zlib error"); + default: + UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result); + } + } + + private: z_stream stream_; }; #endif // HAVE_ZLIB -const uint8_t kBZMagic[3] = {'B', 'Z', 'h'}; - #ifdef HAVE_BZLIB -class BZip : public ReadBase { +class BZip { public: - BZip(int fd, void *already_data, std::size_t already_size) { - scoped_fd hold(fd); - closer_.reset(FDOpenReadOrThrow(hold)); - file_ = NULL; - Open(already_data, already_size); + BZip(const void *base, std::size_t amount) { + memset(&stream_, 0, sizeof(stream_)); + SetInput(base, amount); + HandleError(BZ2_bzDecompressInit(&stream_, 0, 0)); } - BZip(FILE *file, void *already_data, std::size_t already_size) { - closer_.reset(file); - file_ = NULL; - Open(already_data, already_size); + ~BZip() { + try { + HandleError(BZ2_bzDecompressEnd(&stream_)); + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); + } } - ~BZip() { - Close(file_); + bool Process() { + int ret = BZ2_bzDecompress(&stream_); + if (ret == BZ_STREAM_END) return false; + HandleError(ret); + return true; } - std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { - assert(file_); - int bzerror = BZ_OK; - int ret = BZ2_bzRead(&bzerror, file_, to, std::min<std::size_t>(static_cast<std::size_t>(INT_MAX), amount)); - long pos = ftell(closer_.get()); - if (pos != -1) ReadCount(thunk) = pos; - switch (bzerror) { - case BZ_STREAM_END: - /* bzip2 files can be concatenated by e.g. pbzip2. Annoyingly, the - * library doesn't handle this internally. This gets the trailing - * data, grows it up to magic as needed, validates the magic, and - * reopens. - */ - { - bzerror = BZ_OK; - void *trailing_data; - int trailing_size; - BZ2_bzReadGetUnused(&bzerror, file_, &trailing_data, &trailing_size); - UTIL_THROW_IF(bzerror != BZ_OK, BZException, "bzip2 error in BZ2_bzReadGetUnused " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror); - std::string trailing(static_cast<const char*>(trailing_data), trailing_size); - Close(file_); - - if (trailing_size < (int)sizeof(kBZMagic)) { - trailing.resize(sizeof(kBZMagic)); - if (1 != fread(&trailing[trailing_size], sizeof(kBZMagic) - trailing_size, 1, closer_.get())) { - UTIL_THROW_IF(trailing_size, BZException, "File has trailing cruft"); - // Legitimate end of file. - ReplaceThis(new Complete(), thunk); - return ret; - } - } - UTIL_THROW_IF(memcmp(trailing.data(), kBZMagic, sizeof(kBZMagic)), BZException, "Trailing cruft is not another bzip2 stream"); - Open(&trailing[0], trailing.size()); - } - return ret; - case BZ_OK: - return ret; - default: - UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror); - } + void SetOutput(void *base, std::size_t amount) { + stream_.next_out = static_cast<char*>(base); + stream_.avail_out = std::min<std::size_t>(std::numeric_limits<unsigned int>::max(), amount); } + void SetInput(const void *base, std::size_t amount) { + stream_.next_in = const_cast<char*>(static_cast<const char*>(base)); + stream_.avail_in = amount; + } + + const bz_stream &Stream() const { return stream_; } + private: - void Open(void *already_data, std::size_t already_size) { - assert(!file_); - int bzerror = BZ_OK; - file_ = BZ2_bzReadOpen(&bzerror, closer_.get(), 0, 0, already_data, already_size); - switch (bzerror) { + void HandleError(int value) { + switch(value) { case BZ_OK: return; case BZ_CONFIG_ERROR: - UTIL_THROW(BZException, "Looks like bzip2 was miscompiled."); + UTIL_THROW(BZException, "bzip2 seems to be miscompiled."); case BZ_PARAM_ERROR: - UTIL_THROW(BZException, "Parameter error"); - case BZ_IO_ERROR: - UTIL_THROW(BZException, "IO error reading file"); + UTIL_THROW(BZException, "bzip2 Parameter error"); + case BZ_DATA_ERROR: + UTIL_THROW(BZException, "bzip2 detected a corrupt file"); + case BZ_DATA_ERROR_MAGIC: + UTIL_THROW(BZException, "bzip2 detected bad magic bytes. Perhaps this was not a bzip2 file after all?"); case BZ_MEM_ERROR: throw std::bad_alloc(); default: - UTIL_THROW(BZException, "Unknown bzip2 error code " << bzerror); + UTIL_THROW(BZException, "Unknown bzip2 error code " << value); } - assert(file_); } - static void Close(BZFILE *&file) { - if (file == NULL) return; - int bzerror = BZ_OK; - BZ2_bzReadClose(&bzerror, file); - if (bzerror != BZ_OK) { - std::cerr << "bz2 readclose error number " << bzerror << std::endl; - abort(); - } - file = NULL; - } - - scoped_FILE closer_; - BZFILE *file_; + bz_stream stream_; }; #endif // HAVE_BZLIB #ifdef HAVE_XZLIB -class XZip : public ReadBase { - private: - static const std::size_t kInputBuffer = 16384; +class XZip { public: - XZip(int fd, void *already_data, std::size_t already_size) - : file_(fd), in_buffer_(malloc(kInputBuffer)), stream_(), action_(LZMA_RUN) { - if (!in_buffer_.get()) throw std::bad_alloc(); - assert(already_size < kInputBuffer); - if (already_size) { - memcpy(in_buffer_.get(), already_data, already_size); - stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get()); - stream_.avail_in = already_size; - stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size); - } else { - stream_.avail_in = 0; - } - stream_.allocator = NULL; - lzma_ret ret = lzma_stream_decoder(&stream_, UINT64_MAX, LZMA_CONCATENATED); - switch (ret) { - case LZMA_OK: - break; - case LZMA_MEM_ERROR: - UTIL_THROW(ErrnoException, "xz open error"); - default: - UTIL_THROW(XZException, "xz error code " << ret); - } + XZip(const void *base, std::size_t amount) + : stream_(), action_(LZMA_RUN) { + memset(&stream_, 0, sizeof(stream_)); + SetInput(base, amount); + HandleError(lzma_stream_decoder(&stream_, UINT64_MAX, 0)); } ~XZip() { lzma_end(&stream_); } - std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { - if (amount == 0) return 0; - stream_.next_out = static_cast<uint8_t*>(to); + void SetOutput(void *base, std::size_t amount) { + stream_.next_out = static_cast<uint8_t*>(base); stream_.avail_out = amount; - do { - if (!stream_.avail_in) ReadInput(thunk); - lzma_ret status = lzma_code(&stream_, action_); - switch (status) { - case LZMA_OK: - break; - case LZMA_STREAM_END: - UTIL_THROW_IF(action_ != LZMA_FINISH, XZException, "Input not finished yet."); - { - std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); - ReplaceThis(new Complete(), thunk); - return ret; - } - case LZMA_MEM_ERROR: - throw std::bad_alloc(); - case LZMA_FORMAT_ERROR: - UTIL_THROW(XZException, "xzlib says file format not recognized"); - case LZMA_OPTIONS_ERROR: - UTIL_THROW(XZException, "xzlib says unsupported compression options"); - case LZMA_DATA_ERROR: - UTIL_THROW(XZException, "xzlib says this file is corrupt"); - case LZMA_BUF_ERROR: - UTIL_THROW(XZException, "xzlib says unexpected end of input"); - default: - UTIL_THROW(XZException, "unrecognized xzlib error " << status); - } - } while (stream_.next_out == to); - return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); + } + + void SetInput(const void *base, std::size_t amount) { + stream_.next_in = static_cast<const uint8_t*>(base); + stream_.avail_in = amount; + if (!amount) action_ = LZMA_FINISH; + } + + const lzma_stream &Stream() const { return stream_; } + + bool Process() { + lzma_ret status = lzma_code(&stream_, action_); + if (status == LZMA_STREAM_END) return false; + HandleError(status); + return true; } private: - void ReadInput(ReadCompressed &thunk) { - assert(!stream_.avail_in); - stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get()); - stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); - if (!stream_.avail_in) action_ = LZMA_FINISH; - ReadCount(thunk) += stream_.avail_in; + void HandleError(lzma_ret value) { + switch (value) { + case LZMA_OK: + return; + case LZMA_MEM_ERROR: + throw std::bad_alloc(); + case LZMA_FORMAT_ERROR: + UTIL_THROW(XZException, "xzlib says file format not recognized"); + case LZMA_OPTIONS_ERROR: + UTIL_THROW(XZException, "xzlib says unsupported compression options"); + case LZMA_DATA_ERROR: + UTIL_THROW(XZException, "xzlib says this file is corrupt"); + case LZMA_BUF_ERROR: + UTIL_THROW(XZException, "xzlib says unexpected end of input"); + default: + UTIL_THROW(XZException, "unrecognized xzlib error " << value); + } } - scoped_fd file_; - scoped_malloc in_buffer_; lzma_stream stream_; - lzma_action action_; }; #endif // HAVE_XZLIB @@ -384,66 +344,67 @@ class IStreamReader : public ReadBase { }; enum MagicResult { - UNKNOWN, GZIP, BZIP, XZIP + UTIL_UNKNOWN, UTIL_GZIP, UTIL_BZIP, UTIL_XZIP }; -MagicResult DetectMagic(const void *from_void) { +MagicResult DetectMagic(const void *from_void, std::size_t length) { const uint8_t *header = static_cast<const uint8_t*>(from_void); - if (header[0] == 0x1f && header[1] == 0x8b) { - return GZIP; + if (length >= 2 && header[0] == 0x1f && header[1] == 0x8b) { + return UTIL_GZIP; } - if (!memcmp(header, kBZMagic, sizeof(kBZMagic))) { - return BZIP; + const uint8_t kBZMagic[3] = {'B', 'Z', 'h'}; + if (length >= sizeof(kBZMagic) && !memcmp(header, kBZMagic, sizeof(kBZMagic))) { + return UTIL_BZIP; } const uint8_t kXZMagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 }; - if (!memcmp(header, kXZMagic, sizeof(kXZMagic))) { - return XZIP; + if (length >= sizeof(kXZMagic) && !memcmp(header, kXZMagic, sizeof(kXZMagic))) { + return UTIL_XZIP; } - return UNKNOWN; + return UTIL_UNKNOWN; } -ReadBase *ReadFactory(int fd, uint64_t &raw_amount) { +ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, const std::size_t already_size, bool require_compressed) { scoped_fd hold(fd); - unsigned char header[ReadCompressed::kMagicSize]; - raw_amount = ReadOrEOF(fd, header, ReadCompressed::kMagicSize); - if (!raw_amount) - return new Uncompressed(hold.release()); - if (raw_amount != ReadCompressed::kMagicSize) - return new UncompressedWithHeader(hold.release(), header, raw_amount); - switch (DetectMagic(header)) { - case GZIP: + std::string header(reinterpret_cast<const char*>(already_data), already_size); + if (header.size() < ReadCompressed::kMagicSize) { + std::size_t original = header.size(); + header.resize(ReadCompressed::kMagicSize); + std::size_t got = ReadOrEOF(fd, &header[original], ReadCompressed::kMagicSize - original); + raw_amount += got; + header.resize(original + got); + } + if (header.empty()) { + return new Complete(); + } + switch (DetectMagic(&header[0], header.size())) { + case UTIL_GZIP: #ifdef HAVE_ZLIB - return new GZip(hold.release(), header, ReadCompressed::kMagicSize); + return new StreamCompressed<GZip>(hold.release(), header.data(), header.size()); #else UTIL_THROW(CompressedException, "This looks like a gzip file but gzip support was not compiled in."); #endif - case BZIP: + case UTIL_BZIP: #ifdef HAVE_BZLIB - return new BZip(hold.release(), header, ReadCompressed::kMagicSize); + return new StreamCompressed<BZip>(hold.release(), &header[0], header.size()); #else - UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZ), but bzip support was not compiled in."); + UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZh), but bzip support was not compiled in."); #endif - case XZIP: + case UTIL_XZIP: #ifdef HAVE_XZLIB - return new XZip(hold.release(), header, ReadCompressed::kMagicSize); + return new StreamCompressed<XZip>(hold.release(), header.data(), header.size()); #else UTIL_THROW(CompressedException, "This looks like an xz file, but xz support was not compiled in."); #endif - case UNKNOWN: - break; - } - try { - SeekOrThrow(fd, 0); - } catch (const util::ErrnoException &e) { - return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize); + default: + UTIL_THROW_IF(require_compressed, CompressedException, "Uncompressed data detected after a compresssed file. This could be supported but usually indicates an error."); + return new UncompressedWithHeader(hold.release(), header.data(), header.size()); } - return new Uncompressed(hold.release()); } } // namespace bool ReadCompressed::DetectCompressedMagic(const void *from_void) { - return DetectMagic(from_void) != UNKNOWN; + return DetectMagic(from_void, kMagicSize) != UTIL_UNKNOWN; } ReadCompressed::ReadCompressed(int fd) { @@ -459,8 +420,9 @@ ReadCompressed::ReadCompressed() {} ReadCompressed::~ReadCompressed() {} void ReadCompressed::Reset(int fd) { + raw_amount_ = 0; internal_.reset(); - internal_.reset(ReadFactory(fd, raw_amount_)); + internal_.reset(ReadFactory(fd, raw_amount_, NULL, 0, false)); } void ReadCompressed::Reset(std::istream &in) { @@ -472,4 +434,15 @@ std::size_t ReadCompressed::Read(void *to, std::size_t amount) { return internal_->Read(to, amount, *this); } +std::size_t ReadCompressed::ReadOrEOF(void *const to_in, std::size_t amount) { + uint8_t *to = reinterpret_cast<uint8_t*>(to_in); + while (amount) { + std::size_t got = Read(to, amount); + if (!got) break; + to += got; + amount -= got; + } + return to - reinterpret_cast<uint8_t*>(to_in); +} + } // namespace util diff --git a/klm/util/read_compressed.hh b/klm/util/read_compressed.hh index 8b54c9e8..767ee94b 100644 --- a/klm/util/read_compressed.hh +++ b/klm/util/read_compressed.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_READ_COMPRESSED__ -#define UTIL_READ_COMPRESSED__ +#ifndef UTIL_READ_COMPRESSED_H +#define UTIL_READ_COMPRESSED_H #include "util/exception.hh" #include "util/scoped.hh" @@ -62,6 +62,10 @@ class ReadCompressed { std::size_t Read(void *to, std::size_t amount); + // Repeatedly call read to fill a buffer unless EOF is hit. + // Return number of bytes read. + std::size_t ReadOrEOF(void *const to, std::size_t amount); + uint64_t RawAmount() const { return raw_amount_; } private: @@ -78,4 +82,4 @@ class ReadCompressed { } // namespace util -#endif // UTIL_READ_COMPRESSED__ +#endif // UTIL_READ_COMPRESSED_H diff --git a/klm/util/read_compressed_test.cc b/klm/util/read_compressed_test.cc index 50450a02..301e8f4b 100644 --- a/klm/util/read_compressed_test.cc +++ b/klm/util/read_compressed_test.cc @@ -113,6 +113,11 @@ BOOST_AUTO_TEST_CASE(ReadXZ) { } #endif +#ifdef HAVE_ZLIB +BOOST_AUTO_TEST_CASE(AppendGZ) { +} +#endif + BOOST_AUTO_TEST_CASE(IStream) { std::string name(WriteRandom()); std::fstream stream(name.c_str(), std::ios::in); diff --git a/klm/util/scoped.cc b/klm/util/scoped.cc index 6c5b0c2d..de1d9e94 100644 --- a/klm/util/scoped.cc +++ b/klm/util/scoped.cc @@ -32,10 +32,6 @@ void *CallocOrThrow(std::size_t requested) { return InspectAddr(std::calloc(1, requested), requested, "calloc"); } -scoped_malloc::~scoped_malloc() { - std::free(p_); -} - void scoped_malloc::call_realloc(std::size_t requested) { p_ = InspectAddr(std::realloc(p_, requested), requested, "realloc"); } diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index b642d064..60c36c36 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -1,9 +1,10 @@ -#ifndef UTIL_SCOPED__ -#define UTIL_SCOPED__ +#ifndef UTIL_SCOPED_H +#define UTIL_SCOPED_H /* Other scoped objects in the style of scoped_ptr. */ #include "util/exception.hh" #include <cstddef> +#include <cstdlib> namespace util { @@ -16,89 +17,93 @@ class MallocException : public ErrnoException { void *MallocOrThrow(std::size_t requested); void *CallocOrThrow(std::size_t requested); -class scoped_malloc { +/* Unfortunately, defining the operator* for void * makes the compiler complain. + * So scoped is specialized to void. This includes the functionality common to + * both, namely everything except reference. + */ +template <class T, class Closer> class scoped_base { public: - scoped_malloc() : p_(NULL) {} + explicit scoped_base(T *p = NULL) : p_(p) {} - scoped_malloc(void *p) : p_(p) {} + ~scoped_base() { Closer::Close(p_); } - ~scoped_malloc(); - - void reset(void *p = NULL) { - scoped_malloc other(p_); + void reset(T *p = NULL) { + scoped_base other(p_); p_ = p; } - void call_realloc(std::size_t to); - - void *get() { return p_; } - const void *get() const { return p_; } - - private: - void *p_; - - scoped_malloc(const scoped_malloc &); - scoped_malloc &operator=(const scoped_malloc &); -}; - -// Hat tip to boost. -template <class T> class scoped_array { - public: - explicit scoped_array(T *content = NULL) : c_(content) {} - - ~scoped_array() { delete [] c_; } - - T *get() { return c_; } - const T* get() const { return c_; } + T *get() { return p_; } + const T *get() const { return p_; } - T &operator*() { return *c_; } - const T&operator*() const { return *c_; } + T *operator->() { return p_; } + const T *operator->() const { return p_; } - T &operator[](std::size_t idx) { return c_[idx]; } - const T &operator[](std::size_t idx) const { return c_[idx]; } - - void reset(T *to = NULL) { - scoped_array<T> other(c_); - c_ = to; + T *release() { + T *ret = p_; + p_ = NULL; + return ret; } - private: - T *c_; + protected: + T *p_; - scoped_array(const scoped_array &); - void operator=(const scoped_array &); + private: + scoped_base(const scoped_base &); + scoped_base &operator=(const scoped_base &); }; -template <class T> class scoped_ptr { +template <class T, class Closer> class scoped : public scoped_base<T, Closer> { public: - explicit scoped_ptr(T *content = NULL) : c_(content) {} + explicit scoped(T *p = NULL) : scoped_base<T, Closer>(p) {} - ~scoped_ptr() { delete c_; } + T &operator*() { return *scoped_base<T, Closer>::p_; } + const T&operator*() const { return *scoped_base<T, Closer>::p_; } +}; - T *get() { return c_; } - const T* get() const { return c_; } +template <class Closer> class scoped<void, Closer> : public scoped_base<void, Closer> { + public: + explicit scoped(void *p = NULL) : scoped_base<void, Closer>(p) {} +}; - T &operator*() { return *c_; } - const T&operator*() const { return *c_; } +/* Closer for c functions like std::free and cmph cleanup functions */ +template <class T, void (*clean)(T*)> struct scoped_c_forward { + static void Close(T *p) { clean(p); } +}; +// Call a C function to delete stuff +template <class T, void (*clean)(T*)> class scoped_c : public scoped<T, scoped_c_forward<T, clean> > { + public: + explicit scoped_c(T *p = NULL) : scoped<T, scoped_c_forward<T, clean> >(p) {} +}; - T *operator->() { return c_; } - const T*operator->() const { return c_; } +class scoped_malloc : public scoped_c<void, std::free> { + public: + explicit scoped_malloc(void *p = NULL) : scoped_c<void, std::free>(p) {} - T &operator[](std::size_t idx) { return c_[idx]; } - const T &operator[](std::size_t idx) const { return c_[idx]; } + void call_realloc(std::size_t to); +}; - void reset(T *to = NULL) { - scoped_ptr<T> other(c_); - c_ = to; - } +/* scoped_array using delete[] */ +struct scoped_delete_array_forward { + template <class T> static void Close(T *p) { delete [] p; } +}; +// Hat tip to boost. +template <class T> class scoped_array : public scoped<T, scoped_delete_array_forward> { + public: + explicit scoped_array(T *p = NULL) : scoped<T, scoped_delete_array_forward>(p) {} - private: - T *c_; + T &operator[](std::size_t idx) { return scoped<T, scoped_delete_array_forward>::p_[idx]; } + const T &operator[](std::size_t idx) const { return scoped<T, scoped_delete_array_forward>::p_[idx]; } +}; - scoped_ptr(const scoped_ptr &); - void operator=(const scoped_ptr &); +/* scoped_ptr using delete. If only there were a template typedef. */ +struct scoped_delete_forward { + template <class T> static void Close(T *p) { delete p; } +}; +template <class T> class scoped_ptr : public scoped<T, scoped_delete_forward> { + public: + explicit scoped_ptr(T *p = NULL) : scoped<T, scoped_delete_forward>(p) {} }; } // namespace util -#endif // UTIL_SCOPED__ +#endif // UTIL_SCOPED_H diff --git a/klm/util/sized_iterator.hh b/klm/util/sized_iterator.hh index a72657b5..75f6886f 100644 --- a/klm/util/sized_iterator.hh +++ b/klm/util/sized_iterator.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_SIZED_ITERATOR__ -#define UTIL_SIZED_ITERATOR__ +#ifndef UTIL_SIZED_ITERATOR_H +#define UTIL_SIZED_ITERATOR_H #include "util/proxy_iterator.hh" @@ -117,4 +117,4 @@ template <class Delegate, class Proxy = SizedProxy> class SizedCompare : public }; } // namespace util -#endif // UTIL_SIZED_ITERATOR__ +#endif // UTIL_SIZED_ITERATOR_H diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index 7700d9e6..a3f6d021 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_SORTED_UNIFORM__ -#define UTIL_SORTED_UNIFORM__ +#ifndef UTIL_SORTED_UNIFORM_H +#define UTIL_SORTED_UNIFORM_H #include <algorithm> #include <cstddef> @@ -101,27 +101,6 @@ template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(co return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out); } -// May return begin - 1. -template <class Iterator, class Accessor> Iterator BinaryBelow( - const Accessor &accessor, - Iterator begin, - Iterator end, - const typename Accessor::Key key) { - while (end > begin) { - Iterator pivot(begin + (end - begin) / 2); - typename Accessor::Key mid(accessor(pivot)); - if (mid < key) { - begin = pivot + 1; - } else if (mid > key) { - end = pivot; - } else { - for (++pivot; (pivot < end) && accessor(pivot) == mid; ++pivot) {} - return pivot - 1; - } - } - return begin - 1; -} - } // namespace util -#endif // UTIL_SORTED_UNIFORM__ +#endif // UTIL_SORTED_UNIFORM_H diff --git a/klm/util/stream/Makefile.am b/klm/util/stream/Makefile.am index f18cbedb..25817b50 100644 --- a/klm/util/stream/Makefile.am +++ b/klm/util/stream/Makefile.am @@ -11,6 +11,7 @@ libklm_util_stream_a_SOURCES = \ line_input.hh \ multi_progress.cc \ multi_progress.hh \ + multi_stream.hh \ sort.hh \ stream.hh \ timer.hh diff --git a/klm/util/stream/block.hh b/klm/util/stream/block.hh index 11aa991e..aa7e28bb 100644 --- a/klm/util/stream/block.hh +++ b/klm/util/stream/block.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_BLOCK__ -#define UTIL_STREAM_BLOCK__ +#ifndef UTIL_STREAM_BLOCK_H +#define UTIL_STREAM_BLOCK_H #include <cstddef> #include <stdint.h> @@ -7,28 +7,77 @@ namespace util { namespace stream { +/** + * Encapsulates a block of memory. + */ class Block { public: + + /** + * Constructs an empty block. + */ Block() : mem_(NULL), valid_size_(0) {} + /** + * Constructs a block that encapsulates a segment of memory. + * + * @param[in] mem The segment of memory to encapsulate + * @param[in] size The size of the memory segment in bytes + */ Block(void *mem, std::size_t size) : mem_(mem), valid_size_(size) {} + /** + * Set the number of bytes in this block that should be interpreted as valid. + * + * @param[in] to Number of bytes + */ void SetValidSize(std::size_t to) { valid_size_ = to; } - // Read might fill in less than Allocated at EOF. + + /** + * Gets the number of bytes in this block that should be interpreted as valid. + * This is important because read might fill in less than Allocated at EOF. + */ std::size_t ValidSize() const { return valid_size_; } + /** Gets a void pointer to the memory underlying this block. */ void *Get() { return mem_; } + + /** Gets a const void pointer to the memory underlying this block. */ const void *Get() const { return mem_; } + + /** + * Gets a const void pointer to the end of the valid section of memory + * encapsulated by this block. + */ const void *ValidEnd() const { return reinterpret_cast<const uint8_t*>(mem_) + valid_size_; } + /** + * Returns true if this block encapsulates a valid (non-NULL) block of memory. + * + * This method is a user-defined implicit conversion function to boolean; + * among other things, this method enables bare instances of this class + * to be used as the condition of an if statement. + */ operator bool() const { return mem_ != NULL; } + + /** + * Returns true if this block is empty. + * + * In other words, if Get()==NULL, this method will return true. + */ bool operator!() const { return mem_ == NULL; } private: friend class Link; + + /** + * Points this block's memory at NULL. + * + * This class defines poison as a block whose memory pointer is NULL. + */ void SetToPoison() { mem_ = NULL; } @@ -40,4 +89,4 @@ class Block { } // namespace stream } // namespace util -#endif // UTIL_STREAM_BLOCK__ +#endif // UTIL_STREAM_BLOCK_H diff --git a/klm/util/stream/chain.cc b/klm/util/stream/chain.cc index 46708c60..4596af7a 100644 --- a/klm/util/stream/chain.cc +++ b/klm/util/stream/chain.cc @@ -59,6 +59,11 @@ Chain &Chain::operator>>(const WriteAndRecycle &writer) { return *this; } +Chain &Chain::operator>>(const PWriteAndRecycle &writer) { + threads_.push_back(new Thread(Complete(), writer)); + return *this; +} + void Chain::Wait(bool release_memory) { if (queues_.empty()) { assert(threads_.empty()); @@ -126,7 +131,12 @@ Link::~Link() { // abort(); } else { if (!poisoned_) { - // Pass the poison! + // Poison is a block whose memory pointer is NULL. + // + // Because we're in the else block, + // we know that the memory pointer of current_ is NULL. + // + // Pass the current (poison) block! out_->Produce(current_); } } diff --git a/klm/util/stream/chain.hh b/klm/util/stream/chain.hh index 0cc83a85..50865086 100644 --- a/klm/util/stream/chain.hh +++ b/klm/util/stream/chain.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_CHAIN__ -#define UTIL_STREAM_CHAIN__ +#ifndef UTIL_STREAM_CHAIN_H +#define UTIL_STREAM_CHAIN_H #include "util/stream/block.hh" #include "util/stream/config.hh" @@ -24,7 +24,12 @@ class ChainConfigException : public Exception { }; class Chain; -// Specifies position in chain for Link constructor. + +/** + * Encapsulates a @ref PCQueue "producer queue" and a @ref PCQueue "consumer queue" within a @ref Chain "chain". + * + * Specifies position in chain for Link constructor. + */ class ChainPosition { public: const Chain &GetChain() const { return *chain_; } @@ -41,14 +46,32 @@ class ChainPosition { WorkerProgress progress_; }; -// Position is usually ChainPosition but if there are multiple streams involved, this can be ChainPositions. + +/** + * Encapsulates a worker thread processing data at a given position in the chain. + * + * Each instance of this class owns one boost thread in which the worker is Run(). + */ class Thread { public: + + /** + * Constructs a new Thread in which the provided Worker is Run(). + * + * Position is usually ChainPosition but if there are multiple streams involved, this can be ChainPositions. + * + * After a call to this constructor, the provided worker will be running within a boost thread owned by the newly constructed Thread object. + */ template <class Position, class Worker> Thread(const Position &position, const Worker &worker) : thread_(boost::ref(*this), position, worker) {} ~Thread(); + /** + * Launches the provided worker in this object's boost thread. + * + * This method is called automatically by this class's @ref Thread() "constructor". + */ template <class Position, class Worker> void operator()(const Position &position, Worker &worker) { try { worker.Run(position); @@ -63,14 +86,27 @@ class Thread { boost::thread thread_; }; +/** + * This resets blocks to full valid size. Used to close the loop in Chain by recycling blocks. + */ class Recycler { public: + /** + * Resets the blocks in the chain such that the blocks' respective valid sizes match the chain's block size. + * + * @see Block::SetValidSize() + * @see Chain::BlockSize() + */ void Run(const ChainPosition &position); }; extern const Recycler kRecycle; class WriteAndRecycle; - +class PWriteAndRecycle; + +/** + * Represents a sequence of workers, through which @ref Block "blocks" can pass. + */ class Chain { private: template <class T, void (T::*ptr)(const ChainPosition &) = &T::Run> struct CheckForRun { @@ -78,8 +114,20 @@ class Chain { }; public: + + /** + * Constructs a configured Chain. + * + * @param config Specifies how to configure the Chain. + */ explicit Chain(const ChainConfig &config); + /** + * Destructs a Chain. + * + * This method waits for the chain's threads to complete, + * and frees the memory held by this chain. + */ ~Chain(); void ActivateProgress() { @@ -91,24 +139,49 @@ class Chain { progress_.SetTarget(target); } + /** + * Gets the number of bytes in each record of a Block. + * + * @see ChainConfig::entry_size + */ std::size_t EntrySize() const { return config_.entry_size; } + + /** + * Gets the inital @ref Block::ValidSize "valid size" for @ref Block "blocks" in this chain. + * + * @see Block::ValidSize + */ std::size_t BlockSize() const { return block_size_; } - // Two ways to add to the chain: Add() or operator>>. + /** Two ways to add to the chain: Add() or operator>>. */ ChainPosition Add(); - // This is for adding threaded workers with a Run method. + /** + * Adds a new worker to this chain, + * and runs that worker in a new Thread owned by this chain. + * + * The worker must have a Run method that accepts a position argument. + * + * @see Thread::operator()() + */ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) { assert(!complete_called_); threads_.push_back(new Thread(Add(), worker)); return *this; } - // Avoid copying the worker. + /** + * Adds a new worker to this chain (but avoids copying that worker), + * and runs that worker in a new Thread owned by this chain. + * + * The worker must have a Run method that accepts a position argument. + * + * @see Thread::operator()() + */ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) { assert(!complete_called_); threads_.push_back(new Thread(Add(), worker)); @@ -122,12 +195,21 @@ class Chain { threads_.push_back(new Thread(Complete(), kRecycle)); } + /** + * Adds a Recycler worker to this chain, + * and runs that worker in a new Thread owned by this chain. + */ Chain &operator>>(const Recycler &) { CompleteLoop(); return *this; } + /** + * Adds a WriteAndRecycle worker to this chain, + * and runs that worker in a new Thread owned by this chain. + */ Chain &operator>>(const WriteAndRecycle &writer); + Chain &operator>>(const PWriteAndRecycle &writer); // Chains are reusable. Call Wait to wait for everything to finish and free memory. void Wait(bool release_memory = true); @@ -156,28 +238,87 @@ class Chain { }; // Create the link in the worker thread using the position token. +/** + * Represents a C++ style iterator over @ref Block "blocks". + */ class Link { public: + // Either default construct and Init or just construct all at once. + + /** + * Constructs an @ref Init "initialized" link. + * + * @see Init + */ + explicit Link(const ChainPosition &position); + + /** + * Constructs a link that must subsequently be @ref Init "initialized". + * + * @see Init + */ Link(); + + /** + * Initializes the link with the input @ref PCQueue "consumer queue" and output @ref PCQueue "producer queue" at a given @ref ChainPosition "position" in the @ref Chain "chain". + * + * @see Link() + */ void Init(const ChainPosition &position); - explicit Link(const ChainPosition &position); - + /** + * Destructs the link object. + * + * If necessary, this method will pass a poison block + * to this link's output @ref PCQueue "producer queue". + * + * @see Block::SetToPoison() + */ ~Link(); + /** + * Gets a reference to the @ref Block "block" at this link. + */ Block &operator*() { return current_; } + + /** + * Gets a const reference to the @ref Block "block" at this link. + */ const Block &operator*() const { return current_; } + /** + * Gets a pointer to the @ref Block "block" at this link. + */ Block *operator->() { return ¤t_; } + + /** + * Gets a const pointer to the @ref Block "block" at this link. + */ const Block *operator->() const { return ¤t_; } + /** + * Gets the link at the next @ref ChainPosition "position" in the @ref Chain "chain". + */ Link &operator++(); + /** + * Returns true if the @ref Block "block" at this link encapsulates a valid (non-NULL) block of memory. + * + * This method is a user-defined implicit conversion function to boolean; + * among other things, this method enables bare instances of this class + * to be used as the condition of an if statement. + */ operator bool() const { return current_; } + /** + * @ref Block::SetToPoison() "Poisons" the @ref Block "block" at this link, + * and passes this now-poisoned block to this link's output @ref PCQueue "producer queue". + * + * @see Block::SetToPoison() + */ void Poison(); - + private: Block current_; PCQueue<Block> *in_, *out_; @@ -195,4 +336,4 @@ inline Chain &operator>>(Chain &chain, Link &link) { } // namespace stream } // namespace util -#endif // UTIL_STREAM_CHAIN__ +#endif // UTIL_STREAM_CHAIN_H diff --git a/klm/util/stream/config.hh b/klm/util/stream/config.hh index 1eeb3a8a..6bad36bc 100644 --- a/klm/util/stream/config.hh +++ b/klm/util/stream/config.hh @@ -1,32 +1,63 @@ -#ifndef UTIL_STREAM_CONFIG__ -#define UTIL_STREAM_CONFIG__ +#ifndef UTIL_STREAM_CONFIG_H +#define UTIL_STREAM_CONFIG_H #include <cstddef> #include <string> namespace util { namespace stream { +/** + * Represents how a chain should be configured. + */ struct ChainConfig { + + /** Constructs an configuration with underspecified (or default) parameters. */ ChainConfig() {} + /** + * Constructs a chain configuration object. + * + * @param [in] in_entry_size Number of bytes in each record. + * @param [in] in_block_count Number of blocks in the chain. + * @param [in] in_total_memory Total number of bytes available to the chain. + * This value will be divided amongst the blocks in the chain. + */ ChainConfig(std::size_t in_entry_size, std::size_t in_block_count, std::size_t in_total_memory) : entry_size(in_entry_size), block_count(in_block_count), total_memory(in_total_memory) {} + /** + * Number of bytes in each record. + */ std::size_t entry_size; + + /** + * Number of blocks in the chain. + */ std::size_t block_count; - // Chain's constructor will make this a multiple of entry_size. + + /** + * Total number of bytes available to the chain. + * This value will be divided amongst the blocks in the chain. + * Chain's constructor will make this a multiple of entry_size. + */ std::size_t total_memory; }; + +/** + * Represents how a sorter should be configured. + */ struct SortConfig { + + /** Filename prefix where temporary files should be placed. */ std::string temp_prefix; - // Size of each input/output buffer. + /** Size of each input/output buffer. */ std::size_t buffer_size; - // Total memory to use when running alone. + /** Total memory to use when running alone. */ std::size_t total_memory; }; }} // namespaces -#endif // UTIL_STREAM_CONFIG__ +#endif // UTIL_STREAM_CONFIG_H diff --git a/klm/util/stream/io.cc b/klm/util/stream/io.cc index 0459f706..c64004c0 100644 --- a/klm/util/stream/io.cc +++ b/klm/util/stream/io.cc @@ -36,12 +36,12 @@ void PRead::Run(const ChainPosition &position) { Link link(position); uint64_t offset = 0; for (; offset + block_size64 < size; offset += block_size64, ++link) { - PReadOrThrow(file_, link->Get(), block_size, offset); + ErsatzPRead(file_, link->Get(), block_size, offset); link->SetValidSize(block_size); } // size - offset is <= block_size, so it casts to 32-bit fine. if (size - offset) { - PReadOrThrow(file_, link->Get(), size - offset, offset); + ErsatzPRead(file_, link->Get(), size - offset, offset); link->SetValidSize(size - offset); ++link; } @@ -62,5 +62,15 @@ void WriteAndRecycle::Run(const ChainPosition &position) { } } +void PWriteAndRecycle::Run(const ChainPosition &position) { + const std::size_t block_size = position.GetChain().BlockSize(); + uint64_t offset = 0; + for (Link link(position); link; ++link) { + ErsatzPWrite(file_, link->Get(), link->ValidSize(), offset); + offset += link->ValidSize(); + link->SetValidSize(block_size); + } +} + } // namespace stream } // namespace util diff --git a/klm/util/stream/io.hh b/klm/util/stream/io.hh index 934b6b3f..8dae2cbf 100644 --- a/klm/util/stream/io.hh +++ b/klm/util/stream/io.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_IO__ -#define UTIL_STREAM_IO__ +#ifndef UTIL_STREAM_IO_H +#define UTIL_STREAM_IO_H #include "util/exception.hh" #include "util/file.hh" @@ -41,6 +41,8 @@ class Write { int file_; }; +// It's a common case that stuff is written and then recycled. So rather than +// spawn another thread to Recycle, this combines the two roles. class WriteAndRecycle { public: explicit WriteAndRecycle(int fd) : file_(fd) {} @@ -49,14 +51,23 @@ class WriteAndRecycle { int file_; }; +class PWriteAndRecycle { + public: + explicit PWriteAndRecycle(int fd) : file_(fd) {} + void Run(const ChainPosition &position); + private: + int file_; +}; + + // Reuse the same file over and over again to buffer output. class FileBuffer { public: explicit FileBuffer(int fd) : file_(fd) {} - WriteAndRecycle Sink() const { + PWriteAndRecycle Sink() const { util::SeekOrThrow(file_.get(), 0); - return WriteAndRecycle(file_.get()); + return PWriteAndRecycle(file_.get()); } PRead Source() const { @@ -73,4 +84,4 @@ class FileBuffer { } // namespace stream } // namespace util -#endif // UTIL_STREAM_IO__ +#endif // UTIL_STREAM_IO_H diff --git a/klm/util/stream/line_input.hh b/klm/util/stream/line_input.hh index 86db1dd0..a870a664 100644 --- a/klm/util/stream/line_input.hh +++ b/klm/util/stream/line_input.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_LINE_INPUT__ -#define UTIL_STREAM_LINE_INPUT__ +#ifndef UTIL_STREAM_LINE_INPUT_H +#define UTIL_STREAM_LINE_INPUT_H namespace util {namespace stream { class ChainPosition; @@ -19,4 +19,4 @@ class LineInput { }; }} // namespaces -#endif // UTIL_STREAM_LINE_INPUT__ +#endif // UTIL_STREAM_LINE_INPUT_H diff --git a/klm/util/stream/multi_progress.hh b/klm/util/stream/multi_progress.hh index c4dd45a9..82e698a5 100644 --- a/klm/util/stream/multi_progress.hh +++ b/klm/util/stream/multi_progress.hh @@ -1,6 +1,6 @@ /* Progress bar suitable for chains of workers */ -#ifndef UTIL_MULTI_PROGRESS__ -#define UTIL_MULTI_PROGRESS__ +#ifndef UTIL_STREAM_MULTI_PROGRESS_H +#define UTIL_STREAM_MULTI_PROGRESS_H #include <boost/thread/mutex.hpp> @@ -87,4 +87,4 @@ class WorkerProgress { }} // namespaces -#endif // UTIL_MULTI_PROGRESS__ +#endif // UTIL_STREAM_MULTI_PROGRESS_H diff --git a/klm/util/stream/multi_stream.hh b/klm/util/stream/multi_stream.hh new file mode 100644 index 00000000..0ee7fab6 --- /dev/null +++ b/klm/util/stream/multi_stream.hh @@ -0,0 +1,127 @@ +#ifndef UTIL_STREAM_MULTI_STREAM_H +#define UTIL_STREAM_MULTI_STREAM_H + +#include "util/fixed_array.hh" +#include "util/scoped.hh" +#include "util/stream/chain.hh" +#include "util/stream/stream.hh" + +#include <cstddef> +#include <new> + +#include <assert.h> +#include <stdlib.h> + +namespace util { namespace stream { + +class Chains; + +class ChainPositions : public util::FixedArray<util::stream::ChainPosition> { + public: + ChainPositions() {} + + void Init(Chains &chains); + + explicit ChainPositions(Chains &chains) { + Init(chains); + } +}; + +class Chains : public util::FixedArray<util::stream::Chain> { + private: + template <class T, void (T::*ptr)(const ChainPositions &) = &T::Run> struct CheckForRun { + typedef Chains type; + }; + + public: + // Must call Init. + Chains() {} + + explicit Chains(std::size_t limit) : util::FixedArray<util::stream::Chain>(limit) {} + + template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) { + threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker)); + return *this; + } + + template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) { + threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker)); + return *this; + } + + Chains &operator>>(const util::stream::Recycler &recycler) { + for (util::stream::Chain *i = begin(); i != end(); ++i) + *i >> recycler; + return *this; + } + + void Wait(bool release_memory = true) { + threads_.clear(); + for (util::stream::Chain *i = begin(); i != end(); ++i) { + i->Wait(release_memory); + } + } + + private: + boost::ptr_vector<util::stream::Thread> threads_; + + Chains(const Chains &); + void operator=(const Chains &); +}; + +inline void ChainPositions::Init(Chains &chains) { + util::FixedArray<util::stream::ChainPosition>::Init(chains.size()); + for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) { + // use "placement new" syntax to initalize ChainPosition in an already-allocated memory location + new (end()) util::stream::ChainPosition(i->Add()); Constructed(); + } +} + +inline Chains &operator>>(Chains &chains, ChainPositions &positions) { + positions.Init(chains); + return chains; +} + +template <class T> class GenericStreams : public util::FixedArray<T> { + private: + typedef util::FixedArray<T> P; + public: + GenericStreams() {} + + // This puts a dummy T at the beginning (useful to algorithms that need to reference something at the beginning). + void InitWithDummy(const ChainPositions &positions) { + P::Init(positions.size() + 1); + new (P::end()) T(); // use "placement new" syntax to initalize T in an already-allocated memory location + P::Constructed(); + for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) { + P::push_back(*i); + } + } + + // Limit restricts to positions[0,limit) + void Init(const ChainPositions &positions, std::size_t limit) { + P::Init(limit); + for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) { + P::push_back(*i); + } + } + void Init(const ChainPositions &positions) { + Init(positions, positions.size()); + } + + GenericStreams(const ChainPositions &positions) { + Init(positions); + } +}; + +template <class T> inline Chains &operator>>(Chains &chains, GenericStreams<T> &streams) { + ChainPositions positions; + chains >> positions; + streams.Init(positions); + return chains; +} + +typedef GenericStreams<Stream> Streams; + +}} // namespaces +#endif // UTIL_STREAM_MULTI_STREAM_H diff --git a/klm/util/stream/sort.hh b/klm/util/stream/sort.hh index 16aa6a03..9082cfdd 100644 --- a/klm/util/stream/sort.hh +++ b/klm/util/stream/sort.hh @@ -15,8 +15,8 @@ * sort. Use a hash table for that. */ -#ifndef UTIL_STREAM_SORT__ -#define UTIL_STREAM_SORT__ +#ifndef UTIL_STREAM_SORT_H +#define UTIL_STREAM_SORT_H #include "util/stream/chain.hh" #include "util/stream/config.hh" @@ -182,7 +182,7 @@ template <class Compare> class MergeQueue { amount = remaining_; buffer_end_ = current_ + remaining_; } - PReadOrThrow(fd, current_, amount, offset_); + ErsatzPRead(fd, current_, amount, offset_); offset_ += amount; assert(current_ <= buffer_end_); remaining_ -= amount; @@ -307,10 +307,10 @@ template <class Compare, class Combine> class MergingReader { const uint64_t block_size = position.GetChain().BlockSize(); Link l(position); for (; offset + block_size < end; ++l, offset += block_size) { - PReadOrThrow(in_, l->Get(), block_size, offset); + ErsatzPRead(in_, l->Get(), block_size, offset); l->SetValidSize(block_size); } - PReadOrThrow(in_, l->Get(), end - offset, offset); + ErsatzPRead(in_, l->Get(), end - offset, offset); l->SetValidSize(end - offset); (++l).Poison(); return; @@ -388,8 +388,10 @@ class BadSortConfig : public Exception { ~BadSortConfig() throw() {} }; +/** Sort */ template <class Compare, class Combine = NeverCombine> class Sort { public: + /** Constructs an object capable of sorting */ Sort(Chain &in, const SortConfig &config, const Compare &compare = Compare(), const Combine &combine = Combine()) : config_(config), data_(MakeTemp(config.temp_prefix)), @@ -545,4 +547,4 @@ template <class Compare, class Combine> uint64_t BlockingSort(Chain &chain, cons } // namespace stream } // namespace util -#endif // UTIL_STREAM_SORT__ +#endif // UTIL_STREAM_SORT_H diff --git a/klm/util/stream/stream.hh b/klm/util/stream/stream.hh index 6ff45b82..7ea1c9f7 100644 --- a/klm/util/stream/stream.hh +++ b/klm/util/stream/stream.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_STREAM__ -#define UTIL_STREAM_STREAM__ +#ifndef UTIL_STREAM_STREAM_H +#define UTIL_STREAM_STREAM_H #include "util/stream/chain.hh" @@ -56,6 +56,9 @@ class Stream : boost::noncopyable { end_ = current_ + block_it_->ValidSize(); } + // The following are pointers to raw memory + // current_ is the current record + // end_ is the end of the block (so we know when to move to the next block) uint8_t *current_, *end_; std::size_t entry_size_; @@ -71,4 +74,4 @@ inline Chain &operator>>(Chain &chain, Stream &stream) { } // namespace stream } // namespace util -#endif // UTIL_STREAM_STREAM__ +#endif // UTIL_STREAM_STREAM_H diff --git a/klm/util/stream/timer.hh b/klm/util/stream/timer.hh index 7e1a5885..06488a17 100644 --- a/klm/util/stream/timer.hh +++ b/klm/util/stream/timer.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_TIMER__ -#define UTIL_STREAM_TIMER__ +#ifndef UTIL_STREAM_TIMER_H +#define UTIL_STREAM_TIMER_H // Sorry Jon, this was adding library dependencies in Moses and people complained. @@ -13,4 +13,4 @@ #define UTIL_TIMER(str) //#endif -#endif // UTIL_STREAM_TIMER__ +#endif // UTIL_STREAM_TIMER_H diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc index 973091c4..62694a35 100644 --- a/klm/util/string_piece.cc +++ b/klm/util/string_piece.cc @@ -1,2 +1 @@ -// this has been moved to utils/ in cdec - +// moved to cdec/utils diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh index 696ca084..a49779aa 100644 --- a/klm/util/string_piece.hh +++ b/klm/util/string_piece.hh @@ -1,2 +1 @@ #include "utils/string_piece.hh" - diff --git a/klm/util/string_piece_hash.hh b/klm/util/string_piece_hash.hh index f206b1d8..5c8c525e 100644 --- a/klm/util/string_piece_hash.hh +++ b/klm/util/string_piece_hash.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STRING_PIECE_HASH__ -#define UTIL_STRING_PIECE_HASH__ +#ifndef UTIL_STRING_PIECE_HASH_H +#define UTIL_STRING_PIECE_HASH_H #include "util/string_piece.hh" @@ -40,4 +40,4 @@ template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece #endif } -#endif // UTIL_STRING_PIECE_HASH__ +#endif // UTIL_STRING_PIECE_HASH_H diff --git a/klm/util/thread_pool.hh b/klm/util/thread_pool.hh index 84e257ea..d1a883a0 100644 --- a/klm/util/thread_pool.hh +++ b/klm/util/thread_pool.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_THREAD_POOL__ -#define UTIL_THREAD_POOL__ +#ifndef UTIL_THREAD_POOL_H +#define UTIL_THREAD_POOL_H #include "util/pcqueue.hh" @@ -18,8 +18,8 @@ template <class HandlerT> class Worker : boost::noncopyable { typedef HandlerT Handler; typedef typename Handler::Request Request; - template <class Construct> Worker(PCQueue<Request> &in, Construct &construct, Request &poison) - : in_(in), handler_(construct), thread_(boost::ref(*this)), poison_(poison) {} + template <class Construct> Worker(PCQueue<Request> &in, Construct &construct, const Request &poison) + : in_(in), handler_(construct), poison_(poison), thread_(boost::ref(*this)) {} // Only call from thread. void operator()() { @@ -30,7 +30,7 @@ template <class HandlerT> class Worker : boost::noncopyable { try { (*handler_)(request); } - catch(std::exception &e) { + catch(const std::exception &e) { std::cerr << "Handler threw " << e.what() << std::endl; abort(); } @@ -49,10 +49,10 @@ template <class HandlerT> class Worker : boost::noncopyable { PCQueue<Request> &in_; boost::optional<Handler> handler_; + + const Request poison_; boost::thread thread_; - - Request poison_; }; template <class HandlerT> class ThreadPool : boost::noncopyable { @@ -92,4 +92,4 @@ template <class HandlerT> class ThreadPool : boost::noncopyable { } // namespace util -#endif // UTIL_THREAD_POOL__ +#endif // UTIL_THREAD_POOL_H diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index 24eae8fb..908c8daf 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_TOKENIZE_PIECE__ -#define UTIL_TOKENIZE_PIECE__ +#ifndef UTIL_TOKENIZE_PIECE_H +#define UTIL_TOKENIZE_PIECE_H #include "util/exception.hh" #include "util/string_piece.hh" @@ -7,7 +7,8 @@ #include <boost/iterator/iterator_facade.hpp> #include <algorithm> -#include <iostream> + +#include <string.h> namespace util { @@ -71,6 +72,13 @@ class BoolCharacter { return StringPiece(in.data() + in.size(), 0); } + template <unsigned Length> static void Build(const char (&characters)[Length], bool (&out)[256]) { + memset(out, 0, sizeof(out)); + for (const char *i = characters; i != characters + Length; ++i) { + out[static_cast<unsigned char>(*i)] = true; + } + } + private: const bool *delimiter_; }; @@ -140,4 +148,4 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it } // namespace util -#endif // UTIL_TOKENIZE_PIECE__ +#endif // UTIL_TOKENIZE_PIECE_H diff --git a/klm/util/unistd.hh b/klm/util/unistd.hh new file mode 100644 index 00000000..0379c491 --- /dev/null +++ b/klm/util/unistd.hh @@ -0,0 +1,22 @@ +#ifndef UTIL_UNISTD_H +#define UTIL_UNISTD_H + +#if defined(_WIN32) || defined(_WIN64) + +// Windows doesn't define <unistd.h> +// +// So we define what we need here instead: +// +#define STDIN_FILENO=0 +#define STDOUT_FILENO=1 + + +#else // Huzzah for POSIX! + +#include <unistd.h> + +#endif + + + +#endif // UTIL_UNISTD_H diff --git a/klm/util/usage.cc b/klm/util/usage.cc index e68d7c7c..2a4aa47d 100644 --- a/klm/util/usage.cc +++ b/klm/util/usage.cc @@ -30,6 +30,8 @@ typedef struct DWORDLONG ullAvailVirtual; DWORDLONG ullAvailExtendedVirtual; } lMEMORYSTATUSEX; +// Is this really supposed to be defined like this? +typedef int WINBOOL; typedef WINBOOL (WINAPI *PFN_MS_EX) (lMEMORYSTATUSEX*); #else #include <sys/resource.h> @@ -196,7 +198,7 @@ uint64_t GuessPhysicalMemory() { #if defined(_WIN32) || defined(_WIN64) { /* this works on windows */ PFN_MS_EX pfnex; - HMODULE h = GetModuleHandle ("kernel32.dll"); + HMODULE h = GetModuleHandle (TEXT("kernel32.dll")); if (!h) return 0; diff --git a/klm/util/usage.hh b/klm/util/usage.hh index da53b9e3..e578b0a6 100644 --- a/klm/util/usage.hh +++ b/klm/util/usage.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_USAGE__ -#define UTIL_USAGE__ +#ifndef UTIL_USAGE_H +#define UTIL_USAGE_H #include <cstddef> #include <iosfwd> #include <string> @@ -18,4 +18,4 @@ uint64_t GuessPhysicalMemory(); // Parse a size like unix sort. Sadly, this means the default multiplier is K. uint64_t ParseSize(const std::string &arg); } // namespace util -#endif // UTIL_USAGE__ +#endif // UTIL_USAGE_H diff --git a/m4/ax_pthread.m4 b/m4/ax_pthread.m4 new file mode 100644 index 00000000..d383ad5c --- /dev/null +++ b/m4/ax_pthread.m4 @@ -0,0 +1,332 @@ +# =========================================================================== +# http://www.gnu.org/software/autoconf-archive/ax_pthread.html +# =========================================================================== +# +# SYNOPSIS +# +# AX_PTHREAD([ACTION-IF-FOUND[, ACTION-IF-NOT-FOUND]]) +# +# DESCRIPTION +# +# This macro figures out how to build C programs using POSIX threads. It +# sets the PTHREAD_LIBS output variable to the threads library and linker +# flags, and the PTHREAD_CFLAGS output variable to any special C compiler +# flags that are needed. (The user can also force certain compiler +# flags/libs to be tested by setting these environment variables.) +# +# Also sets PTHREAD_CC to any special C compiler that is needed for +# multi-threaded programs (defaults to the value of CC otherwise). (This +# is necessary on AIX to use the special cc_r compiler alias.) +# +# NOTE: You are assumed to not only compile your program with these flags, +# but also link it with them as well. e.g. you should link with +# $PTHREAD_CC $CFLAGS $PTHREAD_CFLAGS $LDFLAGS ... $PTHREAD_LIBS $LIBS +# +# If you are only building threads programs, you may wish to use these +# variables in your default LIBS, CFLAGS, and CC: +# +# LIBS="$PTHREAD_LIBS $LIBS" +# CFLAGS="$CFLAGS $PTHREAD_CFLAGS" +# CC="$PTHREAD_CC" +# +# In addition, if the PTHREAD_CREATE_JOINABLE thread-attribute constant +# has a nonstandard name, defines PTHREAD_CREATE_JOINABLE to that name +# (e.g. PTHREAD_CREATE_UNDETACHED on AIX). +# +# Also HAVE_PTHREAD_PRIO_INHERIT is defined if pthread is found and the +# PTHREAD_PRIO_INHERIT symbol is defined when compiling with +# PTHREAD_CFLAGS. +# +# ACTION-IF-FOUND is a list of shell commands to run if a threads library +# is found, and ACTION-IF-NOT-FOUND is a list of commands to run it if it +# is not found. If ACTION-IF-FOUND is not specified, the default action +# will define HAVE_PTHREAD. +# +# Please let the authors know if this macro fails on any platform, or if +# you have any other suggestions or comments. This macro was based on work +# by SGJ on autoconf scripts for FFTW (http://www.fftw.org/) (with help +# from M. Frigo), as well as ac_pthread and hb_pthread macros posted by +# Alejandro Forero Cuervo to the autoconf macro repository. We are also +# grateful for the helpful feedback of numerous users. +# +# Updated for Autoconf 2.68 by Daniel Richard G. +# +# LICENSE +# +# Copyright (c) 2008 Steven G. Johnson <stevenj@alum.mit.edu> +# Copyright (c) 2011 Daniel Richard G. <skunk@iSKUNK.ORG> +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General +# Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program. If not, see <http://www.gnu.org/licenses/>. +# +# As a special exception, the respective Autoconf Macro's copyright owner +# gives unlimited permission to copy, distribute and modify the configure +# scripts that are the output of Autoconf when processing the Macro. You +# need not follow the terms of the GNU General Public License when using +# or distributing such scripts, even though portions of the text of the +# Macro appear in them. The GNU General Public License (GPL) does govern +# all other use of the material that constitutes the Autoconf Macro. +# +# This special exception to the GPL applies to versions of the Autoconf +# Macro released by the Autoconf Archive. When you make and distribute a +# modified version of the Autoconf Macro, you may extend this special +# exception to the GPL to apply to your modified version as well. + +#serial 21 + +AU_ALIAS([ACX_PTHREAD], [AX_PTHREAD]) +AC_DEFUN([AX_PTHREAD], [ +AC_REQUIRE([AC_CANONICAL_HOST]) +AC_LANG_PUSH([C]) +ax_pthread_ok=no + +# We used to check for pthread.h first, but this fails if pthread.h +# requires special compiler flags (e.g. on True64 or Sequent). +# It gets checked for in the link test anyway. + +# First of all, check if the user has set any of the PTHREAD_LIBS, +# etcetera environment variables, and if threads linking works using +# them: +if test x"$PTHREAD_LIBS$PTHREAD_CFLAGS" != x; then + save_CFLAGS="$CFLAGS" + CFLAGS="$CFLAGS $PTHREAD_CFLAGS" + save_LIBS="$LIBS" + LIBS="$PTHREAD_LIBS $LIBS" + AC_MSG_CHECKING([for pthread_join in LIBS=$PTHREAD_LIBS with CFLAGS=$PTHREAD_CFLAGS]) + AC_TRY_LINK_FUNC([pthread_join], [ax_pthread_ok=yes]) + AC_MSG_RESULT([$ax_pthread_ok]) + if test x"$ax_pthread_ok" = xno; then + PTHREAD_LIBS="" + PTHREAD_CFLAGS="" + fi + LIBS="$save_LIBS" + CFLAGS="$save_CFLAGS" +fi + +# We must check for the threads library under a number of different +# names; the ordering is very important because some systems +# (e.g. DEC) have both -lpthread and -lpthreads, where one of the +# libraries is broken (non-POSIX). + +# Create a list of thread flags to try. Items starting with a "-" are +# C compiler flags, and other items are library names, except for "none" +# which indicates that we try without any flags at all, and "pthread-config" +# which is a program returning the flags for the Pth emulation library. + +ax_pthread_flags="pthreads none -Kthread -kthread lthread -pthread -pthreads -mthreads pthread --thread-safe -mt pthread-config" + +# The ordering *is* (sometimes) important. Some notes on the +# individual items follow: + +# pthreads: AIX (must check this before -lpthread) +# none: in case threads are in libc; should be tried before -Kthread and +# other compiler flags to prevent continual compiler warnings +# -Kthread: Sequent (threads in libc, but -Kthread needed for pthread.h) +# -kthread: FreeBSD kernel threads (preferred to -pthread since SMP-able) +# lthread: LinuxThreads port on FreeBSD (also preferred to -pthread) +# -pthread: Linux/gcc (kernel threads), BSD/gcc (userland threads) +# -pthreads: Solaris/gcc +# -mthreads: Mingw32/gcc, Lynx/gcc +# -mt: Sun Workshop C (may only link SunOS threads [-lthread], but it +# doesn't hurt to check since this sometimes defines pthreads too; +# also defines -D_REENTRANT) +# ... -mt is also the pthreads flag for HP/aCC +# pthread: Linux, etcetera +# --thread-safe: KAI C++ +# pthread-config: use pthread-config program (for GNU Pth library) + +case ${host_os} in + solaris*) + + # On Solaris (at least, for some versions), libc contains stubbed + # (non-functional) versions of the pthreads routines, so link-based + # tests will erroneously succeed. (We need to link with -pthreads/-mt/ + # -lpthread.) (The stubs are missing pthread_cleanup_push, or rather + # a function called by this macro, so we could check for that, but + # who knows whether they'll stub that too in a future libc.) So, + # we'll just look for -pthreads and -lpthread first: + + ax_pthread_flags="-pthreads pthread -mt -pthread $ax_pthread_flags" + ;; + + darwin*) + ax_pthread_flags="-pthread $ax_pthread_flags" + ;; +esac + +# Clang doesn't consider unrecognized options an error unless we specify +# -Werror. We throw in some extra Clang-specific options to ensure that +# this doesn't happen for GCC, which also accepts -Werror. + +AC_MSG_CHECKING([if compiler needs -Werror to reject unknown flags]) +save_CFLAGS="$CFLAGS" +ax_pthread_extra_flags="-Werror" +CFLAGS="$CFLAGS $ax_pthread_extra_flags -Wunknown-warning-option -Wsizeof-array-argument" +AC_COMPILE_IFELSE([AC_LANG_PROGRAM([int foo(void);],[foo()])], + [AC_MSG_RESULT([yes])], + [ax_pthread_extra_flags= + AC_MSG_RESULT([no])]) +CFLAGS="$save_CFLAGS" + +if test x"$ax_pthread_ok" = xno; then +for flag in $ax_pthread_flags; do + + case $flag in + none) + AC_MSG_CHECKING([whether pthreads work without any flags]) + ;; + + -*) + AC_MSG_CHECKING([whether pthreads work with $flag]) + PTHREAD_CFLAGS="$flag" + ;; + + pthread-config) + AC_CHECK_PROG([ax_pthread_config], [pthread-config], [yes], [no]) + if test x"$ax_pthread_config" = xno; then continue; fi + PTHREAD_CFLAGS="`pthread-config --cflags`" + PTHREAD_LIBS="`pthread-config --ldflags` `pthread-config --libs`" + ;; + + *) + AC_MSG_CHECKING([for the pthreads library -l$flag]) + PTHREAD_LIBS="-l$flag" + ;; + esac + + save_LIBS="$LIBS" + save_CFLAGS="$CFLAGS" + LIBS="$PTHREAD_LIBS $LIBS" + CFLAGS="$CFLAGS $PTHREAD_CFLAGS $ax_pthread_extra_flags" + + # Check for various functions. We must include pthread.h, + # since some functions may be macros. (On the Sequent, we + # need a special flag -Kthread to make this header compile.) + # We check for pthread_join because it is in -lpthread on IRIX + # while pthread_create is in libc. We check for pthread_attr_init + # due to DEC craziness with -lpthreads. We check for + # pthread_cleanup_push because it is one of the few pthread + # functions on Solaris that doesn't have a non-functional libc stub. + # We try pthread_create on general principles. + AC_LINK_IFELSE([AC_LANG_PROGRAM([#include <pthread.h> + static void routine(void *a) { a = 0; } + static void *start_routine(void *a) { return a; }], + [pthread_t th; pthread_attr_t attr; + pthread_create(&th, 0, start_routine, 0); + pthread_join(th, 0); + pthread_attr_init(&attr); + pthread_cleanup_push(routine, 0); + pthread_cleanup_pop(0) /* ; */])], + [ax_pthread_ok=yes], + []) + + LIBS="$save_LIBS" + CFLAGS="$save_CFLAGS" + + AC_MSG_RESULT([$ax_pthread_ok]) + if test "x$ax_pthread_ok" = xyes; then + break; + fi + + PTHREAD_LIBS="" + PTHREAD_CFLAGS="" +done +fi + +# Various other checks: +if test "x$ax_pthread_ok" = xyes; then + save_LIBS="$LIBS" + LIBS="$PTHREAD_LIBS $LIBS" + save_CFLAGS="$CFLAGS" + CFLAGS="$CFLAGS $PTHREAD_CFLAGS" + + # Detect AIX lossage: JOINABLE attribute is called UNDETACHED. + AC_MSG_CHECKING([for joinable pthread attribute]) + attr_name=unknown + for attr in PTHREAD_CREATE_JOINABLE PTHREAD_CREATE_UNDETACHED; do + AC_LINK_IFELSE([AC_LANG_PROGRAM([#include <pthread.h>], + [int attr = $attr; return attr /* ; */])], + [attr_name=$attr; break], + []) + done + AC_MSG_RESULT([$attr_name]) + if test "$attr_name" != PTHREAD_CREATE_JOINABLE; then + AC_DEFINE_UNQUOTED([PTHREAD_CREATE_JOINABLE], [$attr_name], + [Define to necessary symbol if this constant + uses a non-standard name on your system.]) + fi + + AC_MSG_CHECKING([if more special flags are required for pthreads]) + flag=no + case ${host_os} in + aix* | freebsd* | darwin*) flag="-D_THREAD_SAFE";; + osf* | hpux*) flag="-D_REENTRANT";; + solaris*) + if test "$GCC" = "yes"; then + flag="-D_REENTRANT" + else + # TODO: What about Clang on Solaris? + flag="-mt -D_REENTRANT" + fi + ;; + esac + AC_MSG_RESULT([$flag]) + if test "x$flag" != xno; then + PTHREAD_CFLAGS="$flag $PTHREAD_CFLAGS" + fi + + AC_CACHE_CHECK([for PTHREAD_PRIO_INHERIT], + [ax_cv_PTHREAD_PRIO_INHERIT], [ + AC_LINK_IFELSE([AC_LANG_PROGRAM([[#include <pthread.h>]], + [[int i = PTHREAD_PRIO_INHERIT;]])], + [ax_cv_PTHREAD_PRIO_INHERIT=yes], + [ax_cv_PTHREAD_PRIO_INHERIT=no]) + ]) + AS_IF([test "x$ax_cv_PTHREAD_PRIO_INHERIT" = "xyes"], + [AC_DEFINE([HAVE_PTHREAD_PRIO_INHERIT], [1], [Have PTHREAD_PRIO_INHERIT.])]) + + LIBS="$save_LIBS" + CFLAGS="$save_CFLAGS" + + # More AIX lossage: compile with *_r variant + if test "x$GCC" != xyes; then + case $host_os in + aix*) + AS_CASE(["x/$CC"], + [x*/c89|x*/c89_128|x*/c99|x*/c99_128|x*/cc|x*/cc128|x*/xlc|x*/xlc_v6|x*/xlc128|x*/xlc128_v6], + [#handle absolute path differently from PATH based program lookup + AS_CASE(["x$CC"], + [x/*], + [AS_IF([AS_EXECUTABLE_P([${CC}_r])],[PTHREAD_CC="${CC}_r"])], + [AC_CHECK_PROGS([PTHREAD_CC],[${CC}_r],[$CC])])]) + ;; + esac + fi +fi + +test -n "$PTHREAD_CC" || PTHREAD_CC="$CC" + +AC_SUBST([PTHREAD_LIBS]) +AC_SUBST([PTHREAD_CFLAGS]) +AC_SUBST([PTHREAD_CC]) + +# Finally, execute ACTION-IF-FOUND/ACTION-IF-NOT-FOUND: +if test x"$ax_pthread_ok" = xyes; then + ifelse([$1],,[AC_DEFINE([HAVE_PTHREAD],[1],[Define if you have POSIX threads libraries and header files.])],[$1]) + : +else + ax_pthread_ok=no + $2 +fi +AC_LANG_POP +])dnl AX_PTHREAD diff --git a/mteval/aer_scorer.h b/mteval/aer_scorer.h index 6d53d359..cd1238f3 100644 --- a/mteval/aer_scorer.h +++ b/mteval/aer_scorer.h @@ -1,5 +1,5 @@ -#ifndef _AER_SCORER_ -#define _AER_SCORER_ +#ifndef AER_SCORER_ +#define AER_SCORER_ #include <boost/shared_ptr.hpp> diff --git a/mteval/comb_scorer.h b/mteval/comb_scorer.h index 346be576..d17d089d 100644 --- a/mteval/comb_scorer.h +++ b/mteval/comb_scorer.h @@ -1,5 +1,5 @@ -#ifndef _COMB_SCORER_ -#define _COMB_SCORER_ +#ifndef COMB_SCORER_H_ +#define COMB_SCORER_H_ #include "scorer.h" diff --git a/mteval/external_scorer.h b/mteval/external_scorer.h index 85535655..9565d5af 100644 --- a/mteval/external_scorer.h +++ b/mteval/external_scorer.h @@ -1,5 +1,5 @@ -#ifndef _EXTERNAL_SCORER_H_ -#define _EXTERNAL_SCORER_H_ +#ifndef EXTERNAL_SCORER_H_ +#define EXTERNAL_SCORER_H_ #include <vector> #include <string> diff --git a/mteval/levenshtein.h b/mteval/levenshtein.h index 13a97047..3ae56cf5 100644 --- a/mteval/levenshtein.h +++ b/mteval/levenshtein.h @@ -1,5 +1,5 @@ -#ifndef _LEVENSHTEIN_H_ -#define _LEVENSHTEIN_H_ +#ifndef LEVENSHTEIN_H_ +#define LEVENSHTEIN_H_ namespace cdec { diff --git a/mteval/ns.h b/mteval/ns.h index 153bf0b8..f6329b65 100644 --- a/mteval/ns.h +++ b/mteval/ns.h @@ -1,5 +1,5 @@ -#ifndef _NS_H_ -#define _NS_H_ +#ifndef NS_H_ +#define NS_H_ #include <string> #include <vector> diff --git a/mteval/ns_cer.h b/mteval/ns_cer.h index cb2b4b4a..d9927f78 100644 --- a/mteval/ns_cer.h +++ b/mteval/ns_cer.h @@ -1,5 +1,5 @@ -#ifndef _NS_CER_H_ -#define _NS_CER_H_ +#ifndef NS_CER_H_ +#define NS_CER_H_ #include "ns.h" diff --git a/mteval/ns_comb.h b/mteval/ns_comb.h index 140e7e6a..22cba169 100644 --- a/mteval/ns_comb.h +++ b/mteval/ns_comb.h @@ -1,5 +1,5 @@ -#ifndef _NS_COMB_H_ -#define _NS_COMB_H_ +#ifndef NS_COMB_H_ +#define NS_COMB_H_ #include "ns.h" diff --git a/mteval/ns_docscorer.h b/mteval/ns_docscorer.h index b3c28fc9..5feae2df 100644 --- a/mteval/ns_docscorer.h +++ b/mteval/ns_docscorer.h @@ -1,5 +1,5 @@ -#ifndef _NS_DOC_SCORER_H_ -#define _NS_DOC_SCORER_H_ +#ifndef NS_DOC_SCORER_H_ +#define NS_DOC_SCORER_H_ #include <vector> #include <string> diff --git a/mteval/ns_ext.h b/mteval/ns_ext.h index 78badb2e..77be14b9 100644 --- a/mteval/ns_ext.h +++ b/mteval/ns_ext.h @@ -1,5 +1,5 @@ -#ifndef _NS_EXTERNAL_SCORER_H_ -#define _NS_EXTERNAL_SCORER_H_ +#ifndef NS_EXTERNAL_SCORER_H_ +#define NS_EXTERNAL_SCORER_H_ #include "ns.h" diff --git a/mteval/ns_ssk.h b/mteval/ns_ssk.h index 0d418770..fdace6eb 100644 --- a/mteval/ns_ssk.h +++ b/mteval/ns_ssk.h @@ -1,5 +1,5 @@ -#ifndef _NS_SSK_H_ -#define _NS_SSK_H_ +#ifndef NS_SSK_H_ +#define NS_SSK_H_ #include "ns.h" diff --git a/mteval/ns_ter.h b/mteval/ns_ter.h index c5c25413..cffd1bd7 100644 --- a/mteval/ns_ter.h +++ b/mteval/ns_ter.h @@ -1,5 +1,5 @@ -#ifndef _NS_TER_H_ -#define _NS_TER_H_ +#ifndef NS_TER_H_ +#define NS_TER_H_ #include "ns.h" diff --git a/mteval/ns_wer.h b/mteval/ns_wer.h index 24c85d83..45da70c5 100644 --- a/mteval/ns_wer.h +++ b/mteval/ns_wer.h @@ -1,5 +1,5 @@ -#ifndef _NS_WER_H_ -#define _NS_WER_H_ +#ifndef NS_WER_H_ +#define NS_WER_H_ #include "ns.h" diff --git a/mteval/ter.h b/mteval/ter.h index 43314791..0758c6b6 100644 --- a/mteval/ter.h +++ b/mteval/ter.h @@ -1,5 +1,5 @@ -#ifndef _TER_H_ -#define _TER_H_ +#ifndef TER_H_ +#define TER_H_ #include "scorer.h" diff --git a/python/cdec/hypergraph.pxd b/python/cdec/hypergraph.pxd index 1e150bbc..9780cf8b 100644 --- a/python/cdec/hypergraph.pxd +++ b/python/cdec/hypergraph.pxd @@ -63,7 +63,8 @@ cdef extern from "decoder/viterbi.h": cdef extern from "decoder/hg_io.h" namespace "HypergraphIO": # Hypergraph JSON I/O bint ReadFromJSON(istream* inp, Hypergraph* out) - bint WriteToJSON(Hypergraph& hg, bint remove_rules, ostream* out) + bint ReadFromBinary(istream* inp, Hypergraph* out) + bint WriteToBinary(Hypergraph& hg, ostream* out) # Hypergraph PLF I/O void ReadFromPLF(string& inp, Hypergraph* out) string AsPLF(Hypergraph& hg, bint include_global_parentheses) diff --git a/python/cdec/sa/compile.py b/python/cdec/sa/compile.py index a5bd0699..78ab729d 100644 --- a/python/cdec/sa/compile.py +++ b/python/cdec/sa/compile.py @@ -119,7 +119,7 @@ def main(): a = cdec.sa.Alignment(from_text=args.alignment) a.write_binary(a_bin) stop_time = monitor_cpu() - logger.info('Compiling alignment took %f seonds', stop_time - start_time) + logger.info('Compiling alignment took %f seconds', stop_time - start_time) start_time = monitor_cpu() logger.info('Compiling bilexical dictionary') diff --git a/tests/system_tests/cfg_rescore/input.txt b/tests/system_tests/cfg_rescore/input.txt index 2999a5fb..99624d85 100644 --- a/tests/system_tests/cfg_rescore/input.txt +++ b/tests/system_tests/cfg_rescore/input.txt @@ -1 +1 @@ -{"rules":[1,"[S] ||| [NP1] [VP] ||| [1] [2] ||| Active=1",2,"[S] ||| [NP2] [VPSV] by [NP1] ||| [1] [2] by [3] ||| Passive=1",3,"[VP] ||| [V] [NP2] ||| [1] [2]",4,"[V] ||| ate ||| ate",5,"[VPSV] ||| was eaten ||| was eaten",6,"[NP1] ||| John ||| John",7,"[NP2] ||| broccoli ||| broccoli",8,"[NP2] ||| the broccoli ||| the broccoli ||| Definite=1",9,"[Goal] ||| [X] ||| [1]"],"features":["PhraseModel_0","PhraseModel_1","PhraseModel_2","PhraseModel_3","PhraseModel_4","PhraseModel_5","PhraseModel_6","PhraseModel_7","PhraseModel_8","PhraseModel_9","PhraseModel_10","PhraseModel_11","PhraseModel_12","PhraseModel_13","PhraseModel_14","PhraseModel_15","PhraseModel_16","PhraseModel_17","PhraseModel_18","PhraseModel_19","PhraseModel_20","PhraseModel_21","PhraseModel_22","PhraseModel_23","PhraseModel_24","PhraseModel_25","PhraseModel_26","PhraseModel_27","PhraseModel_28","PhraseModel_29","PhraseModel_30","PhraseModel_31","PhraseModel_32","PhraseModel_33","PhraseModel_34","PhraseModel_35","PhraseModel_36","PhraseModel_37","PhraseModel_38","PhraseModel_39","PhraseModel_40","PhraseModel_41","PhraseModel_42","PhraseModel_43","PhraseModel_44","PhraseModel_45","PhraseModel_46","PhraseModel_47","PhraseModel_48","PhraseModel_49","PhraseModel_50","PhraseModel_51","PhraseModel_52","PhraseModel_53","PhraseModel_54","PhraseModel_55","PhraseModel_56","PhraseModel_57","PhraseModel_58","PhraseModel_59","PhraseModel_60","PhraseModel_61","PhraseModel_62","PhraseModel_63","PhraseModel_64","PhraseModel_65","PhraseModel_66","PhraseModel_67","PhraseModel_68","PhraseModel_69","PhraseModel_70","PhraseModel_71","PhraseModel_72","PhraseModel_73","PhraseModel_74","PhraseModel_75","PhraseModel_76","PhraseModel_77","PhraseModel_78","PhraseModel_79","PhraseModel_80","PhraseModel_81","PhraseModel_82","PhraseModel_83","PhraseModel_84","PhraseModel_85","PhraseModel_86","PhraseModel_87","PhraseModel_88","PhraseModel_89","PhraseModel_90","PhraseModel_91","PhraseModel_92","PhraseModel_93","PhraseModel_94","PhraseModel_95","PhraseModel_96","PhraseModel_97","PhraseModel_98","PhraseModel_99","Active","Passive","Definite"],"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":6}],"node":{"in_edges":[0],"cat":"NP1","node_hash":"0000000000000006"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":4}],"node":{"in_edges":[1],"cat":"V","node_hash":"0000000000000004"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":7},{"tail":[],"spans":[-1,-1,-1,-1],"feats":[102,1],"rule":8}],"node":{"in_edges":[2,3],"cat":"NP2","node_hash":"0000000000000008"},"edges":[{"tail":[1,2],"spans":[-1,-1,-1,-1],"feats":[],"rule":3}],"node":{"in_edges":[4],"cat":"VP","node_hash":"0000000000000003"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":5}],"node":{"in_edges":[5],"cat":"VPSV","node_hash":"0000000000000005"},"edges":[{"tail":[0,3],"spans":[-1,-1,-1,-1],"feats":[100,1],"rule":1},{"tail":[2,4,0],"spans":[-1,-1,-1,-1],"feats":[101,1],"rule":2}],"node":{"in_edges":[6,7],"cat":"S","node_hash":"0000000000000002"},"edges":[{"tail":[5],"spans":[-1,-1,-1,-1],"feats":[],"rule":9}],"node":{"in_edges":[8],"cat":"Goal","node_hash":"000000000000003D"}} +::forest:: input0.hg.bin.gz diff --git a/tests/system_tests/cfg_rescore/input0.hg.bin.gz b/tests/system_tests/cfg_rescore/input0.hg.bin.gz Binary files differnew file mode 100644 index 00000000..051e1e32 --- /dev/null +++ b/tests/system_tests/cfg_rescore/input0.hg.bin.gz diff --git a/tests/system_tests/conll/README b/tests/system_tests/conll/README new file mode 100644 index 00000000..261e6a05 --- /dev/null +++ b/tests/system_tests/conll/README @@ -0,0 +1,8 @@ +To generate the input file, run: + + ~/cdec/corpus/conll2cdec.pl input.conll > input.txt + +This will create a training corpus (i.e., an input is present as well as +gold standard output is present) in input.txt. + +See cdec.ini for examples of how to include features in the model. diff --git a/tests/system_tests/conll/cdec.ini b/tests/system_tests/conll/cdec.ini new file mode 100644 index 00000000..f214857a --- /dev/null +++ b/tests/system_tests/conll/cdec.ini @@ -0,0 +1,13 @@ +formalism=tagger +tagger_tagset=tagset.txt + +# grab the second feature column from the conll input (-w 2) and +# create a feature of i-1,i-2 conjoined with y_i +feature_function=CoNLLFeatures -w 2 -t xxy:%x[-1]_%x[0]:%y[0] + +# grab the second feature column from the conll input (-w 2) and +# create a feature of i-1,i-2 conjoined with y_i +feature_function=CoNLLFeatures -w 1 -t xy:%x[0]:%y[0] + +intersection_strategy=full + diff --git a/tests/system_tests/conll/gold.statistics b/tests/system_tests/conll/gold.statistics new file mode 100644 index 00000000..17366689 --- /dev/null +++ b/tests/system_tests/conll/gold.statistics @@ -0,0 +1,20 @@ +-lm_nodes 12 +-lm_edges 24 +-lm_paths 729 ++lm_nodes 12 ++lm_edges 24 ++lm_paths 729 ++lm_trans O O O B I O +constr_nodes 12 +constr_edges 12 +constr_paths 1 +-lm_nodes 10 +-lm_edges 20 +-lm_paths 243 ++lm_nodes 10 ++lm_edges 20 ++lm_paths 243 ++lm_trans O B I I O +constr_nodes 10 +constr_edges 10 +constr_paths 1 diff --git a/tests/system_tests/conll/gold.stdout b/tests/system_tests/conll/gold.stdout new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/tests/system_tests/conll/gold.stdout diff --git a/tests/system_tests/conll/input.conll b/tests/system_tests/conll/input.conll new file mode 100644 index 00000000..507510ca --- /dev/null +++ b/tests/system_tests/conll/input.conll @@ -0,0 +1,13 @@ +the the DT O +angry angry JJ O +dog dog NN O +bit bite VBD B +me I PRN I +. . . O + +i i PRN O +ate eat VBD B +some some DT I +pie pie NN I +. . . O + diff --git a/tests/system_tests/conll/input.txt b/tests/system_tests/conll/input.txt new file mode 100644 index 00000000..6a1a0230 --- /dev/null +++ b/tests/system_tests/conll/input.txt @@ -0,0 +1,2 @@ +<seg id="0" feat1="the angry dog bite I ." feat2="DT JJ NN VBD PRN ."> the angry dog bit me . ||| O O O B I O </seg> +<seg id="1" feat1="i eat some pie ." feat2="PRN VBD DT NN ."> i ate some pie . ||| O B I I O </seg> diff --git a/tests/system_tests/conll/tagset.txt b/tests/system_tests/conll/tagset.txt new file mode 100644 index 00000000..bd0e6b60 --- /dev/null +++ b/tests/system_tests/conll/tagset.txt @@ -0,0 +1 @@ +B I O diff --git a/tests/system_tests/conll/weights b/tests/system_tests/conll/weights new file mode 100644 index 00000000..de130cb6 --- /dev/null +++ b/tests/system_tests/conll/weights @@ -0,0 +1,64 @@ +# Objective = 7.63544 (eval count=5) +xxy:<s>_DT:B -0.19295226006843877 +xy:the:B -0.19295226006843877 +xxy:<s>_DT:I -0.19295226006843877 +xy:the:I -0.19295226006843877 +xxy:<s>_DT:O 0.38590452013687793 +xy:the:O 0.38590452013687793 +xxy:DT_JJ:B -0.19295226006843877 +xy:angry:B -0.19295226006843877 +xxy:DT_JJ:I -0.19295226006843877 +xy:angry:I -0.19295226006843877 +xxy:DT_JJ:O 0.38590452013687793 +xy:angry:O 0.38590452013687793 +xxy:JJ_NN:B -0.19295226006843885 +xy:dog:B -0.19295226006843885 +xxy:JJ_NN:I -0.19295226006843885 +xy:dog:I -0.19295226006843885 +xxy:JJ_NN:O 0.38590452013687765 +xy:dog:O 0.38590452013687765 +xxy:NN_VBD:B 0.38590452013687765 +xy:bite:B 0.38590452013687765 +xxy:NN_VBD:I -0.19295226006843885 +xy:bite:I -0.19295226006843885 +xxy:NN_VBD:O -0.19295226006843885 +xy:bite:O -0.19295226006843885 +xxy:VBD_PRN:B -0.19295226006843885 +xy:I:B -0.19295226006843885 +xxy:VBD_PRN:I 0.38590452013687765 +xy:I:I 0.38590452013687765 +xxy:VBD_PRN:O -0.19295226006843885 +xy:I:O -0.19295226006843885 +xxy:PRN_.:B -0.16038191506717553 +xy:.:B -0.32076383013435106 +xxy:PRN_.:I -0.16038191506717553 +xy:.:I -0.32076383013435106 +xxy:PRN_.:O 0.32076383013435134 +xy:.:O 0.64152766026870267 +xxy:<s>_PRN:B -0.19295226006843871 +xy:i:B -0.19295226006843871 +xxy:<s>_PRN:I -0.19295226006843871 +xy:i:I -0.19295226006843871 +xxy:<s>_PRN:O 0.38590452013687804 +xy:i:O 0.38590452013687804 +xxy:PRN_VBD:B 0.38590452013687804 +xy:eat:B 0.38590452013687804 +xxy:PRN_VBD:I -0.19295226006843871 +xy:eat:I -0.19295226006843871 +xxy:PRN_VBD:O -0.19295226006843871 +xy:eat:O -0.19295226006843871 +xxy:VBD_DT:B -0.19295226006843877 +xy:some:B -0.19295226006843877 +xxy:VBD_DT:I 0.38590452013687798 +xy:some:I 0.38590452013687798 +xxy:VBD_DT:O -0.19295226006843877 +xy:some:O -0.19295226006843877 +xxy:DT_NN:B -0.19295226006843877 +xy:pie:B -0.19295226006843877 +xxy:DT_NN:I 0.38590452013687798 +xy:pie:I 0.38590452013687798 +xxy:DT_NN:O -0.19295226006843877 +xy:pie:O -0.19295226006843877 +xxy:NN_.:B -0.16038191506717553 +xxy:NN_.:I -0.16038191506717553 +xxy:NN_.:O 0.32076383013435134 diff --git a/tests/system_tests/ftrans/input.txt b/tests/system_tests/ftrans/input.txt index aa37b2e7..99624d85 100644 --- a/tests/system_tests/ftrans/input.txt +++ b/tests/system_tests/ftrans/input.txt @@ -1 +1 @@ -{"rules":[1,"[B] ||| b ||| b",2,"[C] ||| c ||| c",3,"[A] ||| [B,1] [C,2] ||| [1] [2] ||| Mono=1",4,"[A] ||| [C,1] [B,2] ||| [1] [2] ||| Inv=1",5,"[S] ||| [A,1] ||| [1]"],"features":["Mono","Inv"],"edges":[{"tail":[],"feats":[],"rule":1}],"node":{"in_edges":[0],"cat":"B"},"edges":[{"tail":[],"feats":[],"rule":2}],"node":{"in_edges":[1],"cat":"C"},"edges":[{"tail":[0,1],"feats":[0,1],"rule":3},{"tail":[1,0],"feats":[1,1],"rule":4}],"node":{"in_edges":[2,3],"cat":"A"},"edges":[{"tail":[2],"feats":[],"rule":5}],"node":{"in_edges":[4],"cat":"S"}} +::forest:: input0.hg.bin.gz diff --git a/tests/system_tests/ftrans/input0.hg.bin.gz b/tests/system_tests/ftrans/input0.hg.bin.gz Binary files differnew file mode 100644 index 00000000..210f4a44 --- /dev/null +++ b/tests/system_tests/ftrans/input0.hg.bin.gz diff --git a/training/const_reorder/Makefile.am b/training/const_reorder/Makefile.am index 2e81e588..367ac904 100644 --- a/training/const_reorder/Makefile.am +++ b/training/const_reorder/Makefile.am @@ -1,8 +1,12 @@ +noinst_LIBRARIES = libtrainer.a + +libtrainer_a_SOURCES = trainer.h trainer.cc + bin_PROGRAMS = const_reorder_model_trainer argument_reorder_model_trainer AM_CPPFLAGS = -I$(top_srcdir) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder const_reorder_model_trainer_SOURCES = constituent_reorder_model.cc -const_reorder_model_trainer_LDADD = ../../utils/libutils.a +const_reorder_model_trainer_LDADD = ../../utils/libutils.a libtrainer.a argument_reorder_model_trainer_SOURCES = argument_reorder_model.cc -argument_reorder_model_trainer_LDADD = ../../utils/libutils.a +argument_reorder_model_trainer_LDADD = ../../utils/libutils.a libtrainer.a diff --git a/training/const_reorder/argument_reorder_model.cc b/training/const_reorder/argument_reorder_model.cc index 54402436..87f2ce2f 100644 --- a/training/const_reorder/argument_reorder_model.cc +++ b/training/const_reorder/argument_reorder_model.cc @@ -14,7 +14,7 @@ #include "utils/filelib.h" -#include "decoder/ff_const_reorder_common.h" +#include "trainer.h" using namespace std; using namespace const_reorder; @@ -93,8 +93,8 @@ struct SArgumentReorderTrainer { strcpy(pszNewInstanceFName, pszInstanceFname); } - Tsuruoka_Maxent* pMaxent = new Tsuruoka_Maxent(NULL); - pMaxent->fnTrain(pszNewInstanceFName, "l1", pszModelFname, 300); + Tsuruoka_Maxent_Trainer* pMaxent = new Tsuruoka_Maxent_Trainer; + pMaxent->fnTrain(pszNewInstanceFName, "l1", pszModelFname); delete pMaxent; if (strcmp(pszNewInstanceFName, pszInstanceFname) != 0) { diff --git a/training/const_reorder/constituent_reorder_model.cc b/training/const_reorder/constituent_reorder_model.cc index 6bec3f0b..d3ad0f2b 100644 --- a/training/const_reorder/constituent_reorder_model.cc +++ b/training/const_reorder/constituent_reorder_model.cc @@ -12,7 +12,7 @@ #include "utils/filelib.h" -#include "decoder/ff_const_reorder_common.h" +#include "trainer.h" using namespace std; using namespace const_reorder; @@ -104,8 +104,8 @@ struct SConstReorderTrainer { pZhangleMaxent->fnTrain(pszInstanceFname, "lbfgs", pszModelFname, 100, 2.0); delete pZhangleMaxent;*/ - Tsuruoka_Maxent* pMaxent = new Tsuruoka_Maxent(NULL); - pMaxent->fnTrain(pszNewInstanceFName, "l1", pszModelFname, 300); + Tsuruoka_Maxent_Trainer* pMaxent = new Tsuruoka_Maxent_Trainer; + pMaxent->fnTrain(pszNewInstanceFName, "l1", pszModelFname); delete pMaxent; if (strcmp(pszNewInstanceFName, pszInstanceFname) != 0) { diff --git a/training/const_reorder/trainer.cc b/training/const_reorder/trainer.cc new file mode 100644 index 00000000..89bd7479 --- /dev/null +++ b/training/const_reorder/trainer.cc @@ -0,0 +1,67 @@ +#include "trainer.h" + +Tsuruoka_Maxent_Trainer::Tsuruoka_Maxent_Trainer() + : const_reorder::Tsuruoka_Maxent(NULL) {} + +void Tsuruoka_Maxent_Trainer::fnTrain(const char* pszInstanceFName, + const char* pszAlgorithm, + const char* pszModelFName) { + assert(strcmp(pszAlgorithm, "l1") == 0 || strcmp(pszAlgorithm, "l2") == 0 || + strcmp(pszAlgorithm, "sgd") == 0 || strcmp(pszAlgorithm, "SGD") == 0); + FILE* fpIn = fopen(pszInstanceFName, "r"); + + maxent::ME_Model* pModel = new maxent::ME_Model(); + + char* pszLine = new char[100001]; + int iNumInstances = 0; + int iLen; + while (!feof(fpIn)) { + pszLine[0] = '\0'; + fgets(pszLine, 20000, fpIn); + if (strlen(pszLine) == 0) { + continue; + } + + iLen = strlen(pszLine); + while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) { + pszLine[iLen - 1] = '\0'; + iLen--; + } + + iNumInstances++; + + maxent::ME_Sample* pmes = new maxent::ME_Sample(); + + char* p = strrchr(pszLine, ' '); + assert(p != NULL); + p[0] = '\0'; + p++; + std::vector<std::string> vecContext; + SplitOnWhitespace(std::string(pszLine), &vecContext); + + pmes->label = std::string(p); + for (size_t i = 0; i < vecContext.size(); i++) + pmes->add_feature(vecContext[i]); + pModel->add_training_sample((*pmes)); + if (iNumInstances % 100000 == 0) + fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances); + delete pmes; + } + fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances); + fclose(fpIn); + + if (strcmp(pszAlgorithm, "l1") == 0) + pModel->use_l1_regularizer(1.0); + else if (strcmp(pszAlgorithm, "l2") == 0) + pModel->use_l2_regularizer(1.0); + else + pModel->use_SGD(); + + pModel->train(); + pModel->save_to_file(pszModelFName); + + delete pModel; + fprintf(stdout, "......Finished Training\n"); + fprintf(stdout, "......Model saved as %s\n", pszModelFName); + delete[] pszLine; +} diff --git a/training/const_reorder/trainer.h b/training/const_reorder/trainer.h new file mode 100644 index 00000000..e574a536 --- /dev/null +++ b/training/const_reorder/trainer.h @@ -0,0 +1,12 @@ +#ifndef TRAINING_CONST_REORDER_TRAINER_H_ +#define TRAINING_CONST_REORDER_TRAINER_H_ + +#include "decoder/ff_const_reorder_common.h" + +struct Tsuruoka_Maxent_Trainer : const_reorder::Tsuruoka_Maxent { + Tsuruoka_Maxent_Trainer(); + void fnTrain(const char* pszInstanceFName, const char* pszAlgorithm, + const char* pszModelFName); +}; + +#endif // TRAINING_CONST_REORDER_TRAINER_H_ diff --git a/training/dpmert/lo_test.cc b/training/dpmert/lo_test.cc index b8776169..69e5aa3f 100644 --- a/training/dpmert/lo_test.cc +++ b/training/dpmert/lo_test.cc @@ -56,10 +56,11 @@ BOOST_AUTO_TEST_CASE(TestConvexHull) { } BOOST_AUTO_TEST_CASE(TestConvexHullInside) { - const string json = "{\"rules\":[1,\"[X] ||| a ||| a\",2,\"[X] ||| A [X] ||| A [1]\",3,\"[X] ||| c ||| c\",4,\"[X] ||| C [X] ||| C [1]\",5,\"[X] ||| [X] B [X] ||| [1] B [2]\",6,\"[X] ||| [X] b [X] ||| [1] b [2]\",7,\"[X] ||| X [X] ||| X [1]\",8,\"[X] ||| Z [X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}"; + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - istringstream instr(json); - HypergraphIO::ReadFromJSON(&instr, &hg); + ReadFile rf(path + "/test-ch-inside.bin.gz"); + assert(rf); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); SparseVector<double> wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 1.0); @@ -121,13 +122,13 @@ BOOST_AUTO_TEST_CASE( TestS1) { std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - ReadFile rf(path + "/0.json.gz"); - HypergraphIO::ReadFromJSON(rf.stream(), &hg); + ReadFile rf(path + "/0.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); hg.Reweight(wts); Hypergraph hg2; - ReadFile rf2(path + "/1.json.gz"); - HypergraphIO::ReadFromJSON(rf2.stream(), &hg2); + ReadFile rf2(path + "/1.bin.gz"); + HypergraphIO::ReadFromBinary(rf2.stream(), &hg2); hg2.Reweight(wts); vector<vector<WordID> > refs1(4); @@ -193,10 +194,11 @@ BOOST_AUTO_TEST_CASE( TestS1) { } BOOST_AUTO_TEST_CASE(TestZeroOrigin) { - const string json = "{\"rules\":[1,\"[X7] ||| blA ||| without ||| LHSProb=3.92173 LexE2F=2.90799 LexF2E=1.85003 GenerativeProb=10.5381 RulePenalty=1 XFE=2.77259 XEF=0.441833 LabelledEF=2.63906 LabelledFE=4.96981 LogRuleCount=0.693147\",2,\"[X7] ||| blA ||| except ||| LHSProb=4.92173 LexE2F=3.90799 LexF2E=1.85003 GenerativeProb=11.5381 RulePenalty=1 XFE=2.77259 XEF=1.44183 LabelledEF=2.63906 LabelledFE=4.96981 LogRuleCount=1.69315\",3,\"[S] ||| [X7,1] ||| [1] ||| GlueTop=1\",4,\"[X28] ||| EnwAn ||| title ||| LHSProb=3.96802 LexE2F=2.22462 LexF2E=1.83258 GenerativeProb=10.0863 RulePenalty=1 XFE=0 XEF=1.20397 LabelledEF=1.20397 LabelledFE=-1.98341e-08 LogRuleCount=1.09861\",5,\"[X0] ||| EnwAn ||| funny ||| LHSProb=3.98479 LexE2F=1.79176 LexF2E=3.21888 GenerativeProb=11.1681 RulePenalty=1 XFE=0 XEF=2.30259 LabelledEF=2.30259 LabelledFE=0 LogRuleCount=0 SingletonRule=1\",6,\"[X8] ||| [X7,1] EnwAn ||| entitled [1] ||| LHSProb=3.82533 LexE2F=3.21888 LexF2E=2.52573 GenerativeProb=11.3276 RulePenalty=1 XFE=1.20397 XEF=1.20397 LabelledEF=2.30259 LabelledFE=2.30259 LogRuleCount=0 SingletonRule=1\",7,\"[S] ||| [S,1] [X28,2] ||| [1] [2] ||| Glue=1\",8,\"[S] ||| [S,1] [X0,2] ||| [1] [2] ||| Glue=1\",9,\"[S] ||| [X8,1] ||| [1] ||| GlueTop=1\",10,\"[Goal] ||| [S,1] ||| [1]\"],\"features\":[\"PassThrough\",\"Glue\",\"GlueTop\",\"LanguageModel\",\"WordPenalty\",\"LHSProb\",\"LexE2F\",\"LexF2E\",\"GenerativeProb\",\"RulePenalty\",\"XFE\",\"XEF\",\"LabelledEF\",\"LabelledFE\",\"LogRuleCount\",\"SingletonRule\"],\"edges\":[{\"tail\":[],\"spans\":[0,1,-1,-1],\"feats\":[5,3.92173,6,2.90799,7,1.85003,8,10.5381,9,1,10,2.77259,11,0.441833,12,2.63906,13,4.96981,14,0.693147],\"rule\":1},{\"tail\":[],\"spans\":[0,1,-1,-1],\"feats\":[5,4.92173,6,3.90799,7,1.85003,8,11.5381,9,1,10,2.77259,11,1.44183,12,2.63906,13,4.96981,14,1.69315],\"rule\":2}],\"node\":{\"in_edges\":[0,1],\"cat\":\"X7\"},\"edges\":[{\"tail\":[0],\"spans\":[0,1,-1,-1],\"feats\":[2,1],\"rule\":3}],\"node\":{\"in_edges\":[2],\"cat\":\"S\"},\"edges\":[{\"tail\":[],\"spans\":[1,2,-1,-1],\"feats\":[5,3.96802,6,2.22462,7,1.83258,8,10.0863,9,1,11,1.20397,12,1.20397,13,-1.98341e-08,14,1.09861],\"rule\":4}],\"node\":{\"in_edges\":[3],\"cat\":\"X28\"},\"edges\":[{\"tail\":[],\"spans\":[1,2,-1,-1],\"feats\":[5,3.98479,6,1.79176,7,3.21888,8,11.1681,9,1,11,2.30259,12,2.30259,15,1],\"rule\":5}],\"node\":{\"in_edges\":[4],\"cat\":\"X0\"},\"edges\":[{\"tail\":[0],\"spans\":[0,2,-1,-1],\"feats\":[5,3.82533,6,3.21888,7,2.52573,8,11.3276,9,1,10,1.20397,11,1.20397,12,2.30259,13,2.30259,15,1],\"rule\":6}],\"node\":{\"in_edges\":[5],\"cat\":\"X8\"},\"edges\":[{\"tail\":[1,2],\"spans\":[0,2,-1,-1],\"feats\":[1,1],\"rule\":7},{\"tail\":[1,3],\"spans\":[0,2,-1,-1],\"feats\":[1,1],\"rule\":8},{\"tail\":[4],\"spans\":[0,2,-1,-1],\"feats\":[2,1],\"rule\":9}],\"node\":{\"in_edges\":[6,7,8],\"cat\":\"S\"},\"edges\":[{\"tail\":[5],\"spans\":[0,2,-1,-1],\"feats\":[],\"rule\":10}],\"node\":{\"in_edges\":[9],\"cat\":\"Goal\"}}"; + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); + ReadFile rf(path + "/test-zero-origin.bin.gz"); + assert(rf); Hypergraph hg; - istringstream instr(json); - HypergraphIO::ReadFromJSON(&instr, &hg); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); SparseVector<double> wts; wts.set_value(FD::Convert("PassThrough"), -0.929201533002898); hg.Reweight(wts); diff --git a/training/dpmert/mr_dpmert_generate_mapper_input.cc b/training/dpmert/mr_dpmert_generate_mapper_input.cc index 199cd23a..3fa2f476 100644 --- a/training/dpmert/mr_dpmert_generate_mapper_input.cc +++ b/training/dpmert/mr_dpmert_generate_mapper_input.cc @@ -70,7 +70,7 @@ int main(int argc, char** argv) { unsigned dev_set_size = conf["dev_set_size"].as<unsigned>(); for (unsigned i = 0; i < dev_set_size; ++i) { for (unsigned j = 0; j < directions.size(); ++j) { - cout << forest_repository << '/' << i << ".json.gz " << i << ' '; + cout << forest_repository << '/' << i << ".bin.gz " << i << ' '; print(cout, origin, "=", ";"); cout << ' '; print(cout, directions[j], "=", ";"); diff --git a/training/dpmert/mr_dpmert_map.cc b/training/dpmert/mr_dpmert_map.cc index d1efcf96..2bf3f8fc 100644 --- a/training/dpmert/mr_dpmert_map.cc +++ b/training/dpmert/mr_dpmert_map.cc @@ -83,7 +83,7 @@ int main(int argc, char** argv) { istringstream is(line); int sent_id; string file, s_origin, s_direction; - // path-to-file (JSON) sent_ed starting-point search-direction + // path-to-file sent_ed starting-point search-direction is >> file >> sent_id >> s_origin >> s_direction; SparseVector<double> origin; ReadSparseVectorString(s_origin, &origin); @@ -93,7 +93,7 @@ int main(int argc, char** argv) { if (last_file != file) { last_file = file; ReadFile rf(file); - HypergraphIO::ReadFromJSON(rf.stream(), &hg); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); } const ConvexHullWeightFunction wf(origin, direction); const ConvexHull hull = Inside<ConvexHull, ConvexHullWeightFunction>(hg, NULL, wf); diff --git a/training/dpmert/test_data/0.bin.gz b/training/dpmert/test_data/0.bin.gz Binary files differnew file mode 100644 index 00000000..388298e9 --- /dev/null +++ b/training/dpmert/test_data/0.bin.gz diff --git a/training/dpmert/test_data/0.json.gz b/training/dpmert/test_data/0.json.gz Binary files differdeleted file mode 100644 index 30f8dd77..00000000 --- a/training/dpmert/test_data/0.json.gz +++ /dev/null diff --git a/training/dpmert/test_data/1.bin.gz b/training/dpmert/test_data/1.bin.gz Binary files differnew file mode 100644 index 00000000..44f9e0ff --- /dev/null +++ b/training/dpmert/test_data/1.bin.gz diff --git a/training/dpmert/test_data/1.json.gz b/training/dpmert/test_data/1.json.gz Binary files differdeleted file mode 100644 index c82cc179..00000000 --- a/training/dpmert/test_data/1.json.gz +++ /dev/null diff --git a/training/dpmert/test_data/test-ch-inside.bin.gz b/training/dpmert/test_data/test-ch-inside.bin.gz Binary files differnew file mode 100644 index 00000000..392f08c6 --- /dev/null +++ b/training/dpmert/test_data/test-ch-inside.bin.gz diff --git a/training/dpmert/test_data/test-zero-origin.bin.gz b/training/dpmert/test_data/test-zero-origin.bin.gz Binary files differnew file mode 100644 index 00000000..c641faaf --- /dev/null +++ b/training/dpmert/test_data/test-zero-origin.bin.gz diff --git a/training/minrisk/minrisk_optimize.cc b/training/minrisk/minrisk_optimize.cc index da8b5260..a2938fb0 100644 --- a/training/minrisk/minrisk_optimize.cc +++ b/training/minrisk/minrisk_optimize.cc @@ -178,7 +178,7 @@ int main(int argc, char** argv) { ReadFile rf(file); if (kis.size() % 5 == 0) { cerr << '.'; } if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; } - HypergraphIO::ReadFromJSON(rf.stream(), &hg); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); hg.Reweight(weights); curkbest.AddKBestCandidates(hg, kbest_size, ds[sent_id]); if (kbest_file.size()) diff --git a/training/pro/mr_pro_map.cc b/training/pro/mr_pro_map.cc index da58cd24..b142fd05 100644 --- a/training/pro/mr_pro_map.cc +++ b/training/pro/mr_pro_map.cc @@ -203,7 +203,7 @@ int main(int argc, char** argv) { const string kbest_file = os.str(); if (FileExists(kbest_file)) J_i.ReadFromFile(kbest_file); - HypergraphIO::ReadFromJSON(rf.stream(), &hg); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); hg.Reweight(weights); J_i.AddKBestCandidates(hg, kbest_size, ds[sent_id]); J_i.WriteToFile(kbest_file); diff --git a/training/rampion/rampion_cccp.cc b/training/rampion/rampion_cccp.cc index 1e36dc51..1c45bac5 100644 --- a/training/rampion/rampion_cccp.cc +++ b/training/rampion/rampion_cccp.cc @@ -136,7 +136,7 @@ int main(int argc, char** argv) { ReadFile rf(file); if (kis.size() % 5 == 0) { cerr << '.'; } if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; } - HypergraphIO::ReadFromJSON(rf.stream(), &hg); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); hg.Reweight(weights); curkbest.AddKBestCandidates(hg, kbest_size, ds[sent_id]); if (kbest_file.size()) diff --git a/training/utils/Makefile.am b/training/utils/Makefile.am index 27c6e344..edaaf3d4 100644 --- a/training/utils/Makefile.am +++ b/training/utils/Makefile.am @@ -12,10 +12,12 @@ noinst_PROGRAMS = \ EXTRA_DIST = decode-and-evaluate.pl libcall.pl parallelize.pl sentserver_SOURCES = sentserver.cc -sentserver_LDFLAGS = -pthread +sentserver_LDFLAGS = $(PTHREAD_LIBS) +sentserver_CXXFLAGS = $(PTHREAD_CFLAGS) sentclient_SOURCES = sentclient.cc -sentclient_LDFLAGS = -pthread +sentclient_LDFLAGS = $(PTHREAD_LIBS) +sentclient_CXXFLAGS = $(PTHREAD_CFLAGS) TESTS = lbfgs_test optimize_test diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc index 5c1b4d4a..04f1eb77 100644 --- a/training/utils/grammar_convert.cc +++ b/training/utils/grammar_convert.cc @@ -43,7 +43,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::notify(*conf); if (conf->count("help") || conf->count("input") == 0) { - cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into JSON hypergraph.\n"; + cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into serialized hypergraph.\n"; cerr << dcmdline_options << endl; exit(1); } @@ -254,7 +254,8 @@ void ProcessHypergraph(const vector<double>& w, const po::variables_map& conf, c if (w.size() > 0) { hg->Reweight(w); } if (conf.count("collapse_weights")) CollapseWeights(hg); if (conf["output"].as<string>() == "json") { - HypergraphIO::WriteToJSON(*hg, false, &cout); + cerr << "NOT IMPLEMENTED ... talk to cdyer if you need this functionality\n"; + // HypergraphIO::WriteToBinary(*hg, &cout); if (!ref.empty()) { cerr << "REF: " << ref << endl; } } else { vector<WordID> onebest; @@ -315,11 +316,11 @@ int main(int argc, char **argv) { line = line.substr(0, pos + 2); } istringstream is(line); - if (HypergraphIO::ReadFromJSON(&is, &hg)) { + if (HypergraphIO::ReadFromBinary(&is, &hg)) { ProcessHypergraph(w, conf, ref, &hg); hg.clear(); } else { - cerr << "Error reading grammar from JSON: line " << lc << endl; + cerr << "Error reading grammar line " << lc << endl; exit(1); } } else { diff --git a/utils/Makefile.am b/utils/Makefile.am index fabb4454..dd74ddc0 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -22,6 +22,7 @@ libutils_a_SOURCES = \ alias_sampler.h \ alignment_io.h \ array2d.h \ + b64featvector.h \ b64tools.h \ batched_append.h \ city.h \ @@ -38,11 +39,8 @@ libutils_a_SOURCES = \ have_64_bits.h \ indices_after.h \ kernel_string_subseq.h \ - lbfgs.h \ - lbfgs.cpp \ logval.h \ m.h \ - mathvec.h \ maxent.h \ maxent.cpp \ murmur_hash3.h \ @@ -50,8 +48,6 @@ libutils_a_SOURCES = \ named_enum.h \ null_deleter.h \ null_traits.h \ - owlqn.cpp \ - sgd.cpp \ perfect_hash.h \ prob.h \ sampler.h \ @@ -77,6 +73,7 @@ libutils_a_SOURCES = \ fast_lexical_cast.hpp \ intrusive_refcount.hpp \ alignment_io.cc \ + b64featvector.cc \ b64tools.cc \ corpus_tools.cc \ dict.cc \ diff --git a/utils/alias_sampler.h b/utils/alias_sampler.h index 81541f7a..0f9d3f6d 100644 --- a/utils/alias_sampler.h +++ b/utils/alias_sampler.h @@ -1,5 +1,5 @@ -#ifndef _ALIAS_SAMPLER_H_ -#define _ALIAS_SAMPLER_H_ +#ifndef ALIAS_SAMPLER_H_ +#define ALIAS_SAMPLER_H_ #include <vector> #include <limits> diff --git a/utils/alignment_io.h b/utils/alignment_io.h index 63fb916b..ec70688e 100644 --- a/utils/alignment_io.h +++ b/utils/alignment_io.h @@ -1,5 +1,5 @@ -#ifndef _ALIGNMENT_IO_H_ -#define _ALIGNMENT_IO_H_ +#ifndef ALIGNMENT_IO_H_ +#define ALIGNMENT_IO_H_ #include <string> #include <iostream> diff --git a/utils/b64featvector.cc b/utils/b64featvector.cc new file mode 100644 index 00000000..c7d08b29 --- /dev/null +++ b/utils/b64featvector.cc @@ -0,0 +1,55 @@ +#include "b64featvector.h" + +#include <sstream> +#include <boost/scoped_array.hpp> +#include "b64tools.h" +#include "fdict.h" + +using namespace std; + +static inline void EncodeFeatureWeight(const string &featname, weight_t weight, + ostream *output) { + output->write(featname.data(), featname.size() + 1); + output->write(reinterpret_cast<char *>(&weight), sizeof(weight_t)); +} + +string EncodeFeatureVector(const SparseVector<weight_t> &vec) { + string b64; + { + ostringstream base64_strm; + { + ostringstream strm; + for (SparseVector<weight_t>::const_iterator it = vec.begin(); + it != vec.end(); ++it) + if (it->second != 0) + EncodeFeatureWeight(FD::Convert(it->first), it->second, &strm); + string data(strm.str()); + B64::b64encode(data.data(), data.size(), &base64_strm); + } + b64 = base64_strm.str(); + } + return b64; +} + +void DecodeFeatureVector(const string &data, SparseVector<weight_t> *vec) { + vec->clear(); + if (data.empty()) return; + // Decode data + size_t b64_len = data.size(), len = b64_len / 4 * 3; + boost::scoped_array<char> buf(new char[len]); + bool res = + B64::b64decode(reinterpret_cast<const unsigned char *>(data.data()), + b64_len, buf.get(), len); + assert(res); + // Apply updates + size_t cur = 0; + while (cur < len) { + string feat_name(buf.get() + cur); + if (feat_name.empty()) break; // Encountered trailing \0 + int feat_id = FD::Convert(feat_name); + weight_t feat_delta = + *reinterpret_cast<weight_t *>(buf.get() + cur + feat_name.size() + 1); + (*vec)[feat_id] = feat_delta; + cur += feat_name.size() + 1 + sizeof(weight_t); + } +} diff --git a/utils/b64featvector.h b/utils/b64featvector.h new file mode 100644 index 00000000..6ac04d44 --- /dev/null +++ b/utils/b64featvector.h @@ -0,0 +1,12 @@ +#ifndef _B64FEATVECTOR_H_ +#define _B64FEATVECTOR_H_ + +#include <string> + +#include "sparse_vector.h" +#include "weights.h" + +std::string EncodeFeatureVector(const SparseVector<weight_t> &); +void DecodeFeatureVector(const std::string &, SparseVector<weight_t> *); + +#endif // _B64FEATVECTOR_H_ diff --git a/utils/b64tools.h b/utils/b64tools.h index c821fc8f..130a9102 100644 --- a/utils/b64tools.h +++ b/utils/b64tools.h @@ -1,5 +1,5 @@ -#ifndef _B64_TOOLS_H_ -#define _B64_TOOLS_H_ +#ifndef B64_TOOLS_H_ +#define B64_TOOLS_H_ namespace B64 { bool b64decode(const unsigned char* data, const size_t insize, char* out, const size_t outsize); diff --git a/utils/corpus_tools.h b/utils/corpus_tools.h index f6699d87..3ccaf6ef 100644 --- a/utils/corpus_tools.h +++ b/utils/corpus_tools.h @@ -1,5 +1,5 @@ -#ifndef _CORPUS_TOOLS_H_ -#define _CORPUS_TOOLS_H_ +#ifndef CORPUS_TOOLS_H_ +#define CORPUS_TOOLS_H_ #include <string> #include <set> diff --git a/utils/exp_semiring.h b/utils/exp_semiring.h index 26a22071..164286e3 100644 --- a/utils/exp_semiring.h +++ b/utils/exp_semiring.h @@ -1,5 +1,5 @@ -#ifndef _EXP_SEMIRING_H_ -#define _EXP_SEMIRING_H_ +#ifndef EXP_SEMIRING_H_ +#define EXP_SEMIRING_H_ #include <iostream> #include "star.h" diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h index 6e2a77cd..1e0ab428 100644 --- a/utils/fast_sparse_vector.h +++ b/utils/fast_sparse_vector.h @@ -1,5 +1,5 @@ -#ifndef _FAST_SPARSE_VECTOR_H_ -#define _FAST_SPARSE_VECTOR_H_ +#ifndef FAST_SPARSE_VECTOR_H_ +#define FAST_SPARSE_VECTOR_H_ // FastSparseVector<T> is a integer indexed unordered map that supports very fast // (mathematical) vector operations when the sizes are very small, and reasonably diff --git a/utils/fdict.h b/utils/fdict.h index eb853fb2..94763890 100644 --- a/utils/fdict.h +++ b/utils/fdict.h @@ -1,5 +1,5 @@ -#ifndef _FDICT_H_ -#define _FDICT_H_ +#ifndef FDICT_H_ +#define FDICT_H_ #ifdef HAVE_CONFIG_H #include "config.h" diff --git a/utils/feature_vector.h b/utils/feature_vector.h index a7b61a66..bf77b5ac 100644 --- a/utils/feature_vector.h +++ b/utils/feature_vector.h @@ -1,5 +1,5 @@ -#ifndef _FEATURE_VECTOR_H_ -#define _FEATURE_VECTOR_H_ +#ifndef FEATURE_VECTOR_H_ +#define FEATURE_VECTOR_H_ #include <vector> #include "sparse_vector.h" diff --git a/utils/filelib.h b/utils/filelib.h index 4fa69760..90620d05 100644 --- a/utils/filelib.h +++ b/utils/filelib.h @@ -1,5 +1,5 @@ -#ifndef _FILELIB_H_ -#define _FILELIB_H_ +#ifndef FILELIB_H_ +#define FILELIB_H_ #include <cassert> #include <string> diff --git a/utils/kernel_string_subseq.h b/utils/kernel_string_subseq.h index 516e8b89..00ee7da7 100644 --- a/utils/kernel_string_subseq.h +++ b/utils/kernel_string_subseq.h @@ -1,5 +1,5 @@ -#ifndef _KERNEL_STRING_SUBSEQ_H_ -#define _KERNEL_STRING_SUBSEQ_H_ +#ifndef KERNEL_STRING_SUBSEQ_H_ +#define KERNEL_STRING_SUBSEQ_H_ #include <vector> #include <cmath> diff --git a/utils/lbfgs.cpp b/utils/lbfgs.cpp deleted file mode 100644 index bd26f048..00000000 --- a/utils/lbfgs.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include <vector> -#include <iostream> -#include <cmath> -#include <stdio.h> -#include "mathvec.h" -#include "lbfgs.h" -#include "maxent.h" - -using namespace std; - -const static int M = LBFGS_M; -const static double LINE_SEARCH_ALPHA = 0.1; -const static double LINE_SEARCH_BETA = 0.5; - -// stopping criteria -int LBFGS_MAX_ITER = 300; -const static double MIN_GRAD_NORM = 0.0001; - -double ME_Model::backtracking_line_search(const Vec& x0, const Vec& grad0, - const double f0, const Vec& dx, - Vec& x, Vec& grad1) { - double t = 1.0 / LINE_SEARCH_BETA; - - double f; - do { - t *= LINE_SEARCH_BETA; - x = x0 + t * dx; - f = FunctionGradient(x.STLVec(), grad1.STLVec()); - // cout << "*"; - } while (f > f0 + LINE_SEARCH_ALPHA * t * dot_product(dx, grad0)); - - return f; -} - -// -// Jorge Nocedal, "Updating Quasi-Newton Matrices With Limited Storage", -// Mathematics of Computation, Vol. 35, No. 151, pp. 773-782, 1980. -// -Vec approximate_Hg(const int iter, const Vec& grad, const Vec s[], - const Vec y[], const double z[]) { - int offset, bound; - if (iter <= M) { - offset = 0; - bound = iter; - } else { - offset = iter - M; - bound = M; - } - - Vec q = grad; - double alpha[M], beta[M]; - for (int i = bound - 1; i >= 0; i--) { - const int j = (i + offset) % M; - alpha[i] = z[j] * dot_product(s[j], q); - q += -alpha[i] * y[j]; - } - if (iter > 0) { - const int j = (iter - 1) % M; - const double gamma = ((1.0 / z[j]) / dot_product(y[j], y[j])); - // static double gamma; - // if (gamma == 0) gamma = ((1.0 / z[j]) / dot_product(y[j], y[j])); - q *= gamma; - } - for (int i = 0; i <= bound - 1; i++) { - const int j = (i + offset) % M; - beta[i] = z[j] * dot_product(y[j], q); - q += s[j] * (alpha[i] - beta[i]); - } - - return q; -} - -vector<double> ME_Model::perform_LBFGS(const vector<double>& x0) { - const size_t dim = x0.size(); - Vec x = x0; - - Vec grad(dim), dx(dim); - double f = FunctionGradient(x.STLVec(), grad.STLVec()); - - Vec s[M], y[M]; - double z[M]; // rho - - for (int iter = 0; iter < LBFGS_MAX_ITER; iter++) { - - fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, -f, _train_error); - if (_nheldout > 0) { - const double heldout_logl = heldout_likelihood(); - fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl, - _heldout_error); - } - fprintf(stderr, "\n"); - - if (sqrt(dot_product(grad, grad)) < MIN_GRAD_NORM) break; - - dx = -1 * approximate_Hg(iter, grad, s, y, z); - - Vec x1(dim), grad1(dim); - f = backtracking_line_search(x, grad, f, dx, x1, grad1); - - s[iter % M] = x1 - x; - y[iter % M] = grad1 - grad; - z[iter % M] = 1.0 / dot_product(y[iter % M], s[iter % M]); - x = x1; - grad = grad1; - } - - return x.STLVec(); -} diff --git a/utils/lbfgs.h b/utils/lbfgs.h deleted file mode 100644 index 4d706f7a..00000000 --- a/utils/lbfgs.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _LBFGS_H_ -#define _LBFGS_H_ - -#include <vector> - -// template<class FuncGrad> -// std::vector<double> -// perform_LBFGS(FuncGrad func_grad, const std::vector<double> & x0); - -std::vector<double> perform_LBFGS( - double (*func_grad)(const std::vector<double> &, std::vector<double> &), - const std::vector<double> &x0); - -std::vector<double> perform_OWLQN( - double (*func_grad)(const std::vector<double> &, std::vector<double> &), - const std::vector<double> &x0, const double C); - -const int LBFGS_M = 10; - -#endif @@ -1,5 +1,5 @@ -#ifndef _M_H_ -#define _M_H_ +#ifndef M_H_HEADER_ +#define M_H_HEADER_ #include <cassert> #include <cmath> diff --git a/utils/mathvec.h b/utils/mathvec.h deleted file mode 100644 index f8c60e5d..00000000 --- a/utils/mathvec.h +++ /dev/null @@ -1,87 +0,0 @@ -#ifndef _MATH_VECTOR_H_ -#define _MATH_VECTOR_H_ - -#include <vector> -#include <iostream> -#include <cassert> - -class Vec { - private: - std::vector<double> _v; - - public: - Vec(const size_t n = 0, const double val = 0) { _v.resize(n, val); } - Vec(const std::vector<double>& v) : _v(v) {} - const std::vector<double>& STLVec() const { return _v; } - std::vector<double>& STLVec() { return _v; } - size_t Size() const { return _v.size(); } - double& operator[](int i) { return _v[i]; } - const double& operator[](int i) const { return _v[i]; } - Vec& operator+=(const Vec& b) { - assert(b.Size() == _v.size()); - for (size_t i = 0; i < _v.size(); i++) { - _v[i] += b[i]; - } - return *this; - } - Vec& operator*=(const double c) { - for (size_t i = 0; i < _v.size(); i++) { - _v[i] *= c; - } - return *this; - } - void Project(const Vec& y) { - for (size_t i = 0; i < _v.size(); i++) { - // if (sign(_v[i]) != sign(y[i])) _v[i] = 0; - if (_v[i] * y[i] <= 0) _v[i] = 0; - } - } -}; - -inline double dot_product(const Vec& a, const Vec& b) { - double sum = 0; - for (size_t i = 0; i < a.Size(); i++) { - sum += a[i] * b[i]; - } - return sum; -} - -inline std::ostream& operator<<(std::ostream& s, const Vec& a) { - s << "("; - for (size_t i = 0; i < a.Size(); i++) { - if (i != 0) s << ", "; - s << a[i]; - } - s << ")"; - return s; -} - -inline const Vec operator+(const Vec& a, const Vec& b) { - Vec v(a.Size()); - assert(a.Size() == b.Size()); - for (size_t i = 0; i < a.Size(); i++) { - v[i] = a[i] + b[i]; - } - return v; -} - -inline const Vec operator-(const Vec& a, const Vec& b) { - Vec v(a.Size()); - assert(a.Size() == b.Size()); - for (size_t i = 0; i < a.Size(); i++) { - v[i] = a[i] - b[i]; - } - return v; -} - -inline const Vec operator*(const Vec& a, const double c) { - Vec v(a.Size()); - for (size_t i = 0; i < a.Size(); i++) { - v[i] = a[i] * c; - } - return v; -} - -inline const Vec operator*(const double c, const Vec& a) { return a * c; } - -#endif diff --git a/utils/maxent.cpp b/utils/maxent.cpp index 0f49ee9d..fd772e08 100644 --- a/utils/maxent.cpp +++ b/utils/maxent.cpp @@ -3,12 +3,15 @@ */ #include "maxent.h" + +#include <vector> +#include <iostream> #include <cmath> #include <cstdio> -#include "lbfgs.h" using namespace std; +namespace maxent { double ME_Model::FunctionGradient(const vector<double>& x, vector<double>& grad) { assert((int)_fb.Size() == x.size()); @@ -601,6 +604,428 @@ vector<double> ME_Model::classify(ME_Sample& mes) const { return vp; } +// template<class FuncGrad> +// std::vector<double> +// perform_LBFGS(FuncGrad func_grad, const std::vector<double> & x0); + +std::vector<double> perform_LBFGS( + double (*func_grad)(const std::vector<double> &, std::vector<double> &), + const std::vector<double> &x0); + +std::vector<double> perform_OWLQN( + double (*func_grad)(const std::vector<double> &, std::vector<double> &), + const std::vector<double> &x0, const double C); + +const int LBFGS_M = 10; + +const static int M = LBFGS_M; +const static double LINE_SEARCH_ALPHA = 0.1; +const static double LINE_SEARCH_BETA = 0.5; + +// stopping criteria +int LBFGS_MAX_ITER = 300; +const static double MIN_GRAD_NORM = 0.0001; + +// LBFGS + +double ME_Model::backtracking_line_search(const Vec& x0, const Vec& grad0, + const double f0, const Vec& dx, + Vec& x, Vec& grad1) { + double t = 1.0 / LINE_SEARCH_BETA; + + double f; + do { + t *= LINE_SEARCH_BETA; + x = x0 + t * dx; + f = FunctionGradient(x.STLVec(), grad1.STLVec()); + // cout << "*"; + } while (f > f0 + LINE_SEARCH_ALPHA * t * dot_product(dx, grad0)); + + return f; +} + +// +// Jorge Nocedal, "Updating Quasi-Newton Matrices With Limited Storage", +// Mathematics of Computation, Vol. 35, No. 151, pp. 773-782, 1980. +// +Vec approximate_Hg(const int iter, const Vec& grad, const Vec s[], + const Vec y[], const double z[]) { + int offset, bound; + if (iter <= M) { + offset = 0; + bound = iter; + } else { + offset = iter - M; + bound = M; + } + + Vec q = grad; + double alpha[M], beta[M]; + for (int i = bound - 1; i >= 0; i--) { + const int j = (i + offset) % M; + alpha[i] = z[j] * dot_product(s[j], q); + q += -alpha[i] * y[j]; + } + if (iter > 0) { + const int j = (iter - 1) % M; + const double gamma = ((1.0 / z[j]) / dot_product(y[j], y[j])); + // static double gamma; + // if (gamma == 0) gamma = ((1.0 / z[j]) / dot_product(y[j], y[j])); + q *= gamma; + } + for (int i = 0; i <= bound - 1; i++) { + const int j = (i + offset) % M; + beta[i] = z[j] * dot_product(y[j], q); + q += s[j] * (alpha[i] - beta[i]); + } + + return q; +} + +vector<double> ME_Model::perform_LBFGS(const vector<double>& x0) { + const size_t dim = x0.size(); + Vec x = x0; + + Vec grad(dim), dx(dim); + double f = FunctionGradient(x.STLVec(), grad.STLVec()); + + Vec s[M], y[M]; + double z[M]; // rho + + for (int iter = 0; iter < LBFGS_MAX_ITER; iter++) { + + fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, -f, _train_error); + if (_nheldout > 0) { + const double heldout_logl = heldout_likelihood(); + fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl, + _heldout_error); + } + fprintf(stderr, "\n"); + + if (sqrt(dot_product(grad, grad)) < MIN_GRAD_NORM) break; + + dx = -1 * approximate_Hg(iter, grad, s, y, z); + + Vec x1(dim), grad1(dim); + f = backtracking_line_search(x, grad, f, dx, x1, grad1); + + s[iter % M] = x1 - x; + y[iter % M] = grad1 - grad; + z[iter % M] = 1.0 / dot_product(y[iter % M], s[iter % M]); + x = x1; + grad = grad1; + } + + return x.STLVec(); +} + +// OWLQN + +// stopping criteria +int OWLQN_MAX_ITER = 300; + +Vec approximate_Hg(const int iter, const Vec& grad, const Vec s[], + const Vec y[], const double z[]); + +inline int sign(double x) { + if (x > 0) return 1; + if (x < 0) return -1; + return 0; +}; + +static Vec pseudo_gradient(const Vec& x, const Vec& grad0, const double C) { + Vec grad = grad0; + for (size_t i = 0; i < x.Size(); i++) { + if (x[i] != 0) { + grad[i] += C * sign(x[i]); + continue; + } + const double gm = grad0[i] - C; + if (gm > 0) { + grad[i] = gm; + continue; + } + const double gp = grad0[i] + C; + if (gp < 0) { + grad[i] = gp; + continue; + } + grad[i] = 0; + } + + return grad; +} + +double ME_Model::regularized_func_grad(const double C, const Vec& x, + Vec& grad) { + double f = FunctionGradient(x.STLVec(), grad.STLVec()); + for (size_t i = 0; i < x.Size(); i++) { + f += C * fabs(x[i]); + } + + return f; +} + +double ME_Model::constrained_line_search(double C, const Vec& x0, + const Vec& grad0, const double f0, + const Vec& dx, Vec& x, Vec& grad1) { + // compute the orthant to explore + Vec orthant = x0; + for (size_t i = 0; i < orthant.Size(); i++) { + if (orthant[i] == 0) orthant[i] = -grad0[i]; + } + + double t = 1.0 / LINE_SEARCH_BETA; + + double f; + do { + t *= LINE_SEARCH_BETA; + x = x0 + t * dx; + x.Project(orthant); + // for (size_t i = 0; i < x.Size(); i++) { + // if (x0[i] != 0 && sign(x[i]) != sign(x0[i])) x[i] = 0; + // } + + f = regularized_func_grad(C, x, grad1); + // cout << "*"; + } while (f > f0 + LINE_SEARCH_ALPHA * dot_product(x - x0, grad0)); + + return f; +} + +vector<double> ME_Model::perform_OWLQN(const vector<double>& x0, + const double C) { + const size_t dim = x0.size(); + Vec x = x0; + + Vec grad(dim), dx(dim); + double f = regularized_func_grad(C, x, grad); + + Vec s[M], y[M]; + double z[M]; // rho + + for (int iter = 0; iter < OWLQN_MAX_ITER; iter++) { + Vec pg = pseudo_gradient(x, grad, C); + + fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, -f, _train_error); + if (_nheldout > 0) { + const double heldout_logl = heldout_likelihood(); + fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl, + _heldout_error); + } + fprintf(stderr, "\n"); + + if (sqrt(dot_product(pg, pg)) < MIN_GRAD_NORM) break; + + dx = -1 * approximate_Hg(iter, pg, s, y, z); + if (dot_product(dx, pg) >= 0) dx.Project(-1 * pg); + + Vec x1(dim), grad1(dim); + f = constrained_line_search(C, x, pg, f, dx, x1, grad1); + + s[iter % M] = x1 - x; + y[iter % M] = grad1 - grad; + z[iter % M] = 1.0 / dot_product(y[iter % M], s[iter % M]); + + x = x1; + grad = grad1; + } + + return x.STLVec(); +} + +// SGD + +// const double SGD_ETA0 = 1; +// const double SGD_ITER = 30; +// const double SGD_ALPHA = 0.85; + +//#define FOLOS_NAIVE +//#define FOLOS_LAZY +#define SGD_CP + +inline void apply_l1_penalty(const int i, const double u, vector<double>& _vl, + vector<double>& q) { + double& w = _vl[i]; + const double z = w; + double& qi = q[i]; + if (w > 0) { + w = max(0.0, w - (u + qi)); + } else if (w < 0) { + w = min(0.0, w + (u - qi)); + } + qi += w - z; +} + +static double l1norm(const vector<double>& v) { + double sum = 0; + for (size_t i = 0; i < v.size(); i++) sum += abs(v[i]); + return sum; +} + +inline void update_folos_lazy(const int iter_sample, const int k, + vector<double>& _vl, + const vector<double>& sum_eta, + vector<int>& last_updated) { + const double penalty = sum_eta[iter_sample] - sum_eta[last_updated[k]]; + double& x = _vl[k]; + if (x > 0) + x = max(0.0, x - penalty); + else + x = min(0.0, x + penalty); + last_updated[k] = iter_sample; +} + +int ME_Model::perform_SGD() { + if (_l2reg > 0) { + cerr << "error: L2 regularization is currently not supported in SGD mode." + << endl; + exit(1); + } + + cerr << "performing SGD" << endl; + + const double l1param = _l1reg; + + const int d = _fb.Size(); + + vector<int> ri(_vs.size()); + for (size_t i = 0; i < ri.size(); i++) ri[i] = i; + + vector<double> grad(d); + int iter_sample = 0; + const double eta0 = SGD_ETA0; + + // cerr << "l1param = " << l1param << endl; + cerr << "eta0 = " << eta0 << " alpha = " << SGD_ALPHA << endl; + + double u = 0; + vector<double> q(d, 0); + vector<int> last_updated(d, 0); + vector<double> sum_eta; + sum_eta.push_back(0); + + for (int iter = 0; iter < SGD_ITER; iter++) { + + random_shuffle(ri.begin(), ri.end()); + + double logl = 0; + int ncorrect = 0, ntotal = 0; + for (size_t i = 0; i < _vs.size(); i++, ntotal++, iter_sample++) { + const Sample& s = _vs[ri[i]]; + +#ifdef FOLOS_LAZY + for (vector<int>::const_iterator j = s.positive_features.begin(); + j != s.positive_features.end(); j++) { + for (vector<int>::const_iterator k = _feature2mef[*j].begin(); + k != _feature2mef[*j].end(); k++) { + update_folos_lazy(iter_sample, *k, _vl, sum_eta, last_updated); + } + } +#endif + + vector<double> membp(_num_classes); + const int max_label = conditional_probability(s, membp); + + const double eta = + eta0 * pow(SGD_ALPHA, + (double)iter_sample / _vs.size()); // exponential decay + // const double eta = eta0 / (1.0 + (double)iter_sample / + // _vs.size()); + + // if (iter_sample % _vs.size() == 0) cerr << "eta = " << eta << + // endl; + u += eta * l1param; + + sum_eta.push_back(sum_eta.back() + eta * l1param); + + logl += log(membp[s.label]); + if (max_label == s.label) ncorrect++; + + // binary features + for (vector<int>::const_iterator j = s.positive_features.begin(); + j != s.positive_features.end(); j++) { + for (vector<int>::const_iterator k = _feature2mef[*j].begin(); + k != _feature2mef[*j].end(); k++) { + const double me = membp[_fb.Feature(*k).label()]; + const double ee = (_fb.Feature(*k).label() == s.label ? 1.0 : 0); + const double grad = (me - ee); + _vl[*k] -= eta * grad; +#ifdef SGD_CP + apply_l1_penalty(*k, u, _vl, q); +#endif + } + } + // real-valued features + for (vector<pair<int, double> >::const_iterator j = s.rvfeatures.begin(); + j != s.rvfeatures.end(); j++) { + for (vector<int>::const_iterator k = _feature2mef[j->first].begin(); + k != _feature2mef[j->first].end(); k++) { + const double me = membp[_fb.Feature(*k).label()]; + const double ee = (_fb.Feature(*k).label() == s.label ? 1.0 : 0); + const double grad = (me - ee) * j->second; + _vl[*k] -= eta * grad; +#ifdef SGD_CP + apply_l1_penalty(*k, u, _vl, q); +#endif + } + } + +#ifdef FOLOS_NAIVE + for (size_t j = 0; j < d; j++) { + double& x = _vl[j]; + if (x > 0) + x = max(0.0, x - eta * l1param); + else + x = min(0.0, x + eta * l1param); + } +#endif + } + logl /= _vs.size(); +// fprintf(stderr, "%4d logl = %8.3f acc = %6.4f ", iter, logl, +// (double)ncorrect / ntotal); + +#ifdef FOLOS_LAZY + if (l1param > 0) { + for (size_t j = 0; j < d; j++) + update_folos_lazy(iter_sample, j, _vl, sum_eta, last_updated); + } +#endif + + double f = logl; + if (l1param > 0) { + const double l1 = + l1norm(_vl); // this is not accurate when lazy update is used + // cerr << "f0 = " << update_model_expectation() - l1param * l1 << " + // "; + f -= l1param * l1; + int nonzero = 0; + for (int j = 0; j < d; j++) + if (_vl[j] != 0) nonzero++; + // cerr << " f = " << f << " l1 = " << l1 << " nonzero_features = " + // << nonzero << endl; + } + // fprintf(stderr, "%4d obj = %7.3f acc = %6.4f", iter+1, f, + // (double)ncorrect/ntotal); + // fprintf(stderr, "%4d obj = %f", iter+1, f); + fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, f, + 1 - (double)ncorrect / ntotal); + + if (_nheldout > 0) { + double heldout_logl = heldout_likelihood(); + // fprintf(stderr, " heldout_logl = %f acc = %6.4f\n", + // heldout_logl, 1 - _heldout_error); + fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl, + _heldout_error); + } + fprintf(stderr, "\n"); + } + + return 0; +} + +} // namespace maxent + /* * $Log: maxent.cpp,v $ * Revision 1.1.1.1 2007/05/15 08:30:35 kyoshida diff --git a/utils/maxent.h b/utils/maxent.h index b1efd88e..74d13a6f 100644 --- a/utils/maxent.h +++ b/utils/maxent.h @@ -5,21 +5,95 @@ #ifndef __MAXENT_H_ #define __MAXENT_H_ -#include <string> -#include <vector> -#include <list> -#include <map> #include <algorithm> #include <iostream> +#include <list> +#include <map> #include <string> +#include <unordered_map> +#include <vector> + #include <cassert> -#include "mathvec.h" -#define USE_HASH_MAP // if you encounter errors with hash, try commenting out - // this line. (the program will be a bit slower, though) -#ifdef USE_HASH_MAP -#include <unordered_map> -#endif +namespace maxent { +class Vec { + private: + std::vector<double> _v; + + public: + Vec(const size_t n = 0, const double val = 0) { _v.resize(n, val); } + Vec(const std::vector<double>& v) : _v(v) {} + const std::vector<double>& STLVec() const { return _v; } + std::vector<double>& STLVec() { return _v; } + size_t Size() const { return _v.size(); } + double& operator[](int i) { return _v[i]; } + const double& operator[](int i) const { return _v[i]; } + Vec& operator+=(const Vec& b) { + assert(b.Size() == _v.size()); + for (size_t i = 0; i < _v.size(); i++) { + _v[i] += b[i]; + } + return *this; + } + Vec& operator*=(const double c) { + for (size_t i = 0; i < _v.size(); i++) { + _v[i] *= c; + } + return *this; + } + void Project(const Vec& y) { + for (size_t i = 0; i < _v.size(); i++) { + // if (sign(_v[i]) != sign(y[i])) _v[i] = 0; + if (_v[i] * y[i] <= 0) _v[i] = 0; + } + } +}; + +inline double dot_product(const Vec& a, const Vec& b) { + double sum = 0; + for (size_t i = 0; i < a.Size(); i++) { + sum += a[i] * b[i]; + } + return sum; +} + +inline std::ostream& operator<<(std::ostream& s, const Vec& a) { + s << "("; + for (size_t i = 0; i < a.Size(); i++) { + if (i != 0) s << ", "; + s << a[i]; + } + s << ")"; + return s; +} + +inline const Vec operator+(const Vec& a, const Vec& b) { + Vec v(a.Size()); + assert(a.Size() == b.Size()); + for (size_t i = 0; i < a.Size(); i++) { + v[i] = a[i] + b[i]; + } + return v; +} + +inline const Vec operator-(const Vec& a, const Vec& b) { + Vec v(a.Size()); + assert(a.Size() == b.Size()); + for (size_t i = 0; i < a.Size(); i++) { + v[i] = a[i] - b[i]; + } + return v; +} + +inline const Vec operator*(const Vec& a, const double c) { + Vec v(a.Size()); + for (size_t i = 0; i < a.Size(); i++) { + v[i] = a[i] * c; + } + return v; +} + +inline const Vec operator*(const double c, const Vec& a) { return a * c; } // // data format for each sample for training/testing @@ -309,6 +383,7 @@ class ME_Model { static double FunctionGradientWrapper(const std::vector<double>& x, std::vector<double>& grad); }; +} // namespace maxent #endif diff --git a/utils/murmur_hash3.h b/utils/murmur_hash3.h index a125d775..e8a8b10b 100644 --- a/utils/murmur_hash3.h +++ b/utils/murmur_hash3.h @@ -2,8 +2,8 @@ // MurmurHash3 was written by Austin Appleby, and is placed in the public // domain. The author hereby disclaims copyright to this source code. -#ifndef _MURMURHASH3_H_ -#define _MURMURHASH3_H_ +#ifndef MURMURHASH3_H_ +#define MURMURHASH3_H_ //----------------------------------------------------------------------------- // Platform-specific functions and macros diff --git a/utils/owlqn.cpp b/utils/owlqn.cpp deleted file mode 100644 index c3a0f0da..00000000 --- a/utils/owlqn.cpp +++ /dev/null @@ -1,127 +0,0 @@ -#include <vector> -#include <iostream> -#include <cmath> -#include <stdio.h> -#include "mathvec.h" -#include "lbfgs.h" -#include "maxent.h" - -using namespace std; - -const static int M = LBFGS_M; -const static double LINE_SEARCH_ALPHA = 0.1; -const static double LINE_SEARCH_BETA = 0.5; - -// stopping criteria -int OWLQN_MAX_ITER = 300; -const static double MIN_GRAD_NORM = 0.0001; - -Vec approximate_Hg(const int iter, const Vec& grad, const Vec s[], - const Vec y[], const double z[]); - -inline int sign(double x) { - if (x > 0) return 1; - if (x < 0) return -1; - return 0; -}; - -static Vec pseudo_gradient(const Vec& x, const Vec& grad0, const double C) { - Vec grad = grad0; - for (size_t i = 0; i < x.Size(); i++) { - if (x[i] != 0) { - grad[i] += C * sign(x[i]); - continue; - } - const double gm = grad0[i] - C; - if (gm > 0) { - grad[i] = gm; - continue; - } - const double gp = grad0[i] + C; - if (gp < 0) { - grad[i] = gp; - continue; - } - grad[i] = 0; - } - - return grad; -} - -double ME_Model::regularized_func_grad(const double C, const Vec& x, - Vec& grad) { - double f = FunctionGradient(x.STLVec(), grad.STLVec()); - for (size_t i = 0; i < x.Size(); i++) { - f += C * fabs(x[i]); - } - - return f; -} - -double ME_Model::constrained_line_search(double C, const Vec& x0, - const Vec& grad0, const double f0, - const Vec& dx, Vec& x, Vec& grad1) { - // compute the orthant to explore - Vec orthant = x0; - for (size_t i = 0; i < orthant.Size(); i++) { - if (orthant[i] == 0) orthant[i] = -grad0[i]; - } - - double t = 1.0 / LINE_SEARCH_BETA; - - double f; - do { - t *= LINE_SEARCH_BETA; - x = x0 + t * dx; - x.Project(orthant); - // for (size_t i = 0; i < x.Size(); i++) { - // if (x0[i] != 0 && sign(x[i]) != sign(x0[i])) x[i] = 0; - // } - - f = regularized_func_grad(C, x, grad1); - // cout << "*"; - } while (f > f0 + LINE_SEARCH_ALPHA * dot_product(x - x0, grad0)); - - return f; -} - -vector<double> ME_Model::perform_OWLQN(const vector<double>& x0, - const double C) { - const size_t dim = x0.size(); - Vec x = x0; - - Vec grad(dim), dx(dim); - double f = regularized_func_grad(C, x, grad); - - Vec s[M], y[M]; - double z[M]; // rho - - for (int iter = 0; iter < OWLQN_MAX_ITER; iter++) { - Vec pg = pseudo_gradient(x, grad, C); - - fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, -f, _train_error); - if (_nheldout > 0) { - const double heldout_logl = heldout_likelihood(); - fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl, - _heldout_error); - } - fprintf(stderr, "\n"); - - if (sqrt(dot_product(pg, pg)) < MIN_GRAD_NORM) break; - - dx = -1 * approximate_Hg(iter, pg, s, y, z); - if (dot_product(dx, pg) >= 0) dx.Project(-1 * pg); - - Vec x1(dim), grad1(dim); - f = constrained_line_search(C, x, pg, f, dx, x1, grad1); - - s[iter % M] = x1 - x; - y[iter % M] = grad1 - grad; - z[iter % M] = 1.0 / dot_product(y[iter % M], s[iter % M]); - - x = x1; - grad = grad1; - } - - return x.STLVec(); -} diff --git a/utils/perfect_hash.h b/utils/perfect_hash.h index 29ea48a9..8c12c9f0 100644 --- a/utils/perfect_hash.h +++ b/utils/perfect_hash.h @@ -1,5 +1,5 @@ -#ifndef _PERFECT_HASH_MAP_H_ -#define _PERFECT_HASH_MAP_H_ +#ifndef PERFECT_HASH_MAP_H_ +#define PERFECT_HASH_MAP_H_ #include <vector> #include <boost/utility.hpp> diff --git a/utils/prob.h b/utils/prob.h index bc297870..32ba9a86 100644 --- a/utils/prob.h +++ b/utils/prob.h @@ -1,5 +1,5 @@ -#ifndef _PROB_H_ -#define _PROB_H_ +#ifndef PROB_H_ +#define PROB_H_ #include "logval.h" diff --git a/utils/sgd.cpp b/utils/sgd.cpp deleted file mode 100644 index 8613edca..00000000 --- a/utils/sgd.cpp +++ /dev/null @@ -1,193 +0,0 @@ -#include "maxent.h" -#include <cmath> -#include <stdio.h> - -using namespace std; - -// const double SGD_ETA0 = 1; -// const double SGD_ITER = 30; -// const double SGD_ALPHA = 0.85; - -//#define FOLOS_NAIVE -//#define FOLOS_LAZY -#define SGD_CP - -inline void apply_l1_penalty(const int i, const double u, vector<double>& _vl, - vector<double>& q) { - double& w = _vl[i]; - const double z = w; - double& qi = q[i]; - if (w > 0) { - w = max(0.0, w - (u + qi)); - } else if (w < 0) { - w = min(0.0, w + (u - qi)); - } - qi += w - z; -} - -static double l1norm(const vector<double>& v) { - double sum = 0; - for (size_t i = 0; i < v.size(); i++) sum += abs(v[i]); - return sum; -} - -inline void update_folos_lazy(const int iter_sample, const int k, - vector<double>& _vl, - const vector<double>& sum_eta, - vector<int>& last_updated) { - const double penalty = sum_eta[iter_sample] - sum_eta[last_updated[k]]; - double& x = _vl[k]; - if (x > 0) - x = max(0.0, x - penalty); - else - x = min(0.0, x + penalty); - last_updated[k] = iter_sample; -} - -int ME_Model::perform_SGD() { - if (_l2reg > 0) { - cerr << "error: L2 regularization is currently not supported in SGD mode." - << endl; - exit(1); - } - - cerr << "performing SGD" << endl; - - const double l1param = _l1reg; - - const int d = _fb.Size(); - - vector<int> ri(_vs.size()); - for (size_t i = 0; i < ri.size(); i++) ri[i] = i; - - vector<double> grad(d); - int iter_sample = 0; - const double eta0 = SGD_ETA0; - - // cerr << "l1param = " << l1param << endl; - cerr << "eta0 = " << eta0 << " alpha = " << SGD_ALPHA << endl; - - double u = 0; - vector<double> q(d, 0); - vector<int> last_updated(d, 0); - vector<double> sum_eta; - sum_eta.push_back(0); - - for (int iter = 0; iter < SGD_ITER; iter++) { - - random_shuffle(ri.begin(), ri.end()); - - double logl = 0; - int ncorrect = 0, ntotal = 0; - for (size_t i = 0; i < _vs.size(); i++, ntotal++, iter_sample++) { - const Sample& s = _vs[ri[i]]; - -#ifdef FOLOS_LAZY - for (vector<int>::const_iterator j = s.positive_features.begin(); - j != s.positive_features.end(); j++) { - for (vector<int>::const_iterator k = _feature2mef[*j].begin(); - k != _feature2mef[*j].end(); k++) { - update_folos_lazy(iter_sample, *k, _vl, sum_eta, last_updated); - } - } -#endif - - vector<double> membp(_num_classes); - const int max_label = conditional_probability(s, membp); - - const double eta = - eta0 * pow(SGD_ALPHA, - (double)iter_sample / _vs.size()); // exponential decay - // const double eta = eta0 / (1.0 + (double)iter_sample / - // _vs.size()); - - // if (iter_sample % _vs.size() == 0) cerr << "eta = " << eta << - // endl; - u += eta * l1param; - - sum_eta.push_back(sum_eta.back() + eta * l1param); - - logl += log(membp[s.label]); - if (max_label == s.label) ncorrect++; - - // binary features - for (vector<int>::const_iterator j = s.positive_features.begin(); - j != s.positive_features.end(); j++) { - for (vector<int>::const_iterator k = _feature2mef[*j].begin(); - k != _feature2mef[*j].end(); k++) { - const double me = membp[_fb.Feature(*k).label()]; - const double ee = (_fb.Feature(*k).label() == s.label ? 1.0 : 0); - const double grad = (me - ee); - _vl[*k] -= eta * grad; -#ifdef SGD_CP - apply_l1_penalty(*k, u, _vl, q); -#endif - } - } - // real-valued features - for (vector<pair<int, double> >::const_iterator j = s.rvfeatures.begin(); - j != s.rvfeatures.end(); j++) { - for (vector<int>::const_iterator k = _feature2mef[j->first].begin(); - k != _feature2mef[j->first].end(); k++) { - const double me = membp[_fb.Feature(*k).label()]; - const double ee = (_fb.Feature(*k).label() == s.label ? 1.0 : 0); - const double grad = (me - ee) * j->second; - _vl[*k] -= eta * grad; -#ifdef SGD_CP - apply_l1_penalty(*k, u, _vl, q); -#endif - } - } - -#ifdef FOLOS_NAIVE - for (size_t j = 0; j < d; j++) { - double& x = _vl[j]; - if (x > 0) - x = max(0.0, x - eta * l1param); - else - x = min(0.0, x + eta * l1param); - } -#endif - } - logl /= _vs.size(); -// fprintf(stderr, "%4d logl = %8.3f acc = %6.4f ", iter, logl, -// (double)ncorrect / ntotal); - -#ifdef FOLOS_LAZY - if (l1param > 0) { - for (size_t j = 0; j < d; j++) - update_folos_lazy(iter_sample, j, _vl, sum_eta, last_updated); - } -#endif - - double f = logl; - if (l1param > 0) { - const double l1 = - l1norm(_vl); // this is not accurate when lazy update is used - // cerr << "f0 = " << update_model_expectation() - l1param * l1 << " - // "; - f -= l1param * l1; - int nonzero = 0; - for (int j = 0; j < d; j++) - if (_vl[j] != 0) nonzero++; - // cerr << " f = " << f << " l1 = " << l1 << " nonzero_features = " - // << nonzero << endl; - } - // fprintf(stderr, "%4d obj = %7.3f acc = %6.4f", iter+1, f, - // (double)ncorrect/ntotal); - // fprintf(stderr, "%4d obj = %f", iter+1, f); - fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, f, - 1 - (double)ncorrect / ntotal); - - if (_nheldout > 0) { - double heldout_logl = heldout_likelihood(); - // fprintf(stderr, " heldout_logl = %f acc = %6.4f\n", - // heldout_logl, 1 - _heldout_error); - fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl, - _heldout_error); - } - fprintf(stderr, "\n"); - } - - return 0; -} diff --git a/utils/small_vector.h b/utils/small_vector.h index 280ab72c..f16bc898 100644 --- a/utils/small_vector.h +++ b/utils/small_vector.h @@ -1,5 +1,5 @@ -#ifndef _SMALL_VECTOR_H_ -#define _SMALL_VECTOR_H_ +#ifndef SMALL_VECTOR_H_ +#define SMALL_VECTOR_H_ /* REQUIRES that T is POD (can be memcpy). won't work (yet) due to union with SMALL_VECTOR_POD==0 - may be possible to handle movable types that have ctor/dtor, by using explicit allocation, ctor/dtor calls. but for now JUST USE THIS FOR no-meaningful ctor/dtor POD types. @@ -15,6 +15,7 @@ #include <new> #include <stdint.h> #include <boost/functional/hash.hpp> +#include <boost/serialization/map.hpp> //sizeof(T)/sizeof(T*)>1?sizeof(T)/sizeof(T*):1 @@ -297,6 +298,21 @@ public: return hash_range(data_.ptr,data_.ptr+size_); } + template<class Archive> + void save(Archive & ar, const unsigned int) const { + ar & size_; + for (unsigned i = 0; i < size_; ++i) + ar & (*this)[i]; + } + template<class Archive> + void load(Archive & ar, const unsigned int) { + uint16_t s; + ar & s; + this->resize(s); + for (unsigned i = 0; i < size_; ++i) + ar & (*this)[i]; + } + BOOST_SERIALIZATION_SPLIT_MEMBER() private: union StorageType { T vals[SV_MAX]; diff --git a/utils/small_vector_test.cc b/utils/small_vector_test.cc index a4eb89ae..9e1a148d 100644 --- a/utils/small_vector_test.cc +++ b/utils/small_vector_test.cc @@ -3,6 +3,10 @@ #define BOOST_TEST_MODULE svTest #include <boost/test/unit_test.hpp> #include <boost/test/floating_point_comparison.hpp> +#include <boost/archive/text_oarchive.hpp> +#include <boost/archive/text_iarchive.hpp> +#include <string> +#include <sstream> #include <iostream> #include <vector> @@ -128,3 +132,29 @@ BOOST_AUTO_TEST_CASE(Small) { cerr << sizeof(SmallVectorInt) << endl; cerr << sizeof(vector<int>) << endl; } + +BOOST_AUTO_TEST_CASE(Serialize) { + std::string in; + { + SmallVectorInt v; + v.push_back(0); + v.push_back(1); + v.push_back(-2); + ostringstream os; + boost::archive::text_oarchive oa(os); + oa << v; + in = os.str(); + cerr << in; + } + { + istringstream is(in); + boost::archive::text_iarchive ia(is); + SmallVectorInt v; + ia >> v; + BOOST_CHECK_EQUAL(v.size(), 3); + BOOST_CHECK_EQUAL(v[0], 0); + BOOST_CHECK_EQUAL(v[1], 1); + BOOST_CHECK_EQUAL(v[2], -2); + } +} + diff --git a/utils/sparse_vector.h b/utils/sparse_vector.h index 049151f7..13601376 100644 --- a/utils/sparse_vector.h +++ b/utils/sparse_vector.h @@ -1,5 +1,5 @@ -#ifndef _SPARSE_VECTOR_H_ -#define _SPARSE_VECTOR_H_ +#ifndef SPARSE_VECTOR_H_ +#define SPARSE_VECTOR_H_ #include "fast_sparse_vector.h" #define SparseVector FastSparseVector diff --git a/utils/star.h b/utils/star.h index 21977dc9..01433d12 100644 --- a/utils/star.h +++ b/utils/star.h @@ -1,5 +1,5 @@ -#ifndef _STAR_H_ -#define _STAR_H_ +#ifndef STAR_H_ +#define STAR_H_ // star(x) computes the infinite sum x^0 + x^1 + x^2 + ... diff --git a/utils/sv_test.cc b/utils/sv_test.cc index 67df8c57..b006e66d 100644 --- a/utils/sv_test.cc +++ b/utils/sv_test.cc @@ -1,7 +1,12 @@ #define BOOST_TEST_MODULE WeightsTest #include <boost/test/unit_test.hpp> #include <boost/test/floating_point_comparison.hpp> +#include <boost/archive/text_oarchive.hpp> +#include <boost/archive/text_iarchive.hpp> +#include <sstream> +#include <string> #include "sparse_vector.h" +#include "fdict.h" using namespace std; @@ -33,3 +38,29 @@ BOOST_AUTO_TEST_CASE(Division) { x /= -1; BOOST_CHECK(x == y); } + +BOOST_AUTO_TEST_CASE(Serialization) { + string arc; + FD::dict_.clear(); + { + SparseVector<double> x; + x.set_value(FD::Convert("Feature1"), 1.0); + x.set_value(FD::Convert("Pi"), 3.14); + ostringstream os; + boost::archive::text_oarchive oa(os); + oa << x; + arc = os.str(); + } + FD::dict_.clear(); + FD::Convert("SomeNewString"); + { + SparseVector<double> x; + istringstream is(arc); + boost::archive::text_iarchive ia(is); + ia >> x; + cerr << x << endl; + BOOST_CHECK_CLOSE(x.get(FD::Convert("Pi")), 3.14, 1e-9); + BOOST_CHECK_CLOSE(x.get(FD::Convert("Feature1")), 1.0, 1e-9); + } +} + diff --git a/utils/tdict.h b/utils/tdict.h index bb19ecd5..eed33c3a 100644 --- a/utils/tdict.h +++ b/utils/tdict.h @@ -1,5 +1,5 @@ -#ifndef _TDICT_H_ -#define _TDICT_H_ +#ifndef TDICT_H_ +#define TDICT_H_ #include <string> #include <vector> diff --git a/utils/timing_stats.h b/utils/timing_stats.h index 0a9f7656..69a1cf4b 100644 --- a/utils/timing_stats.h +++ b/utils/timing_stats.h @@ -1,5 +1,5 @@ -#ifndef _TIMING_STATS_H_ -#define _TIMING_STATS_H_ +#ifndef TIMING_STATS_H_ +#define TIMING_STATS_H_ #include <string> #include <map> diff --git a/utils/verbose.h b/utils/verbose.h index 73476383..e39e23cb 100644 --- a/utils/verbose.h +++ b/utils/verbose.h @@ -1,5 +1,5 @@ -#ifndef _VERBOSE_H_ -#define _VERBOSE_H_ +#ifndef VERBOSE_H_ +#define VERBOSE_H_ extern bool SILENT; diff --git a/utils/weights.h b/utils/weights.h index 920fdd75..0bd4c2d9 100644 --- a/utils/weights.h +++ b/utils/weights.h @@ -1,5 +1,5 @@ -#ifndef _WEIGHTS_H_ -#define _WEIGHTS_H_ +#ifndef WEIGHTS_H_ +#define WEIGHTS_H_ #include <string> #include <vector> diff --git a/utils/wordid.h b/utils/wordid.h index 714dcd0b..3aa6cc23 100644 --- a/utils/wordid.h +++ b/utils/wordid.h @@ -1,5 +1,5 @@ -#ifndef _WORD_ID_H_ -#define _WORD_ID_H_ +#ifndef WORD_ID_H_ +#define WORD_ID_H_ #include <limits> |