diff options
Diffstat (limited to 'decoder')
96 files changed, 1150 insertions, 1856 deletions
| diff --git a/decoder/JSON_parser.c b/decoder/JSON_parser.c deleted file mode 100644 index 5e392bc6..00000000 --- a/decoder/JSON_parser.c +++ /dev/null @@ -1,1012 +0,0 @@ -/* JSON_parser.c */ - -/* 2007-08-24 */ - -/* -Copyright (c) 2005 JSON.org - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -The Software shall be used for Good, not Evil. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -*/ - -/* -    Callbacks, comments, Unicode handling by Jean Gressmann (jean@0x42.de), 2007-2009. - -    For the added features the license above applies also. - -    Changelog: -        2009-05-17 -            Incorporated benrudiak@googlemail.com fix for UTF16 decoding. - -        2009-05-14 -            Fixed float parsing bug related to a locale being set that didn't -            use '.' as decimal point character (charles@transmissionbt.com). - -        2008-10-14 -            Renamed states.IN to states.IT to avoid name clash which IN macro -            defined in windef.h (alexey.pelykh@gmail.com) - -        2008-07-19 -            Removed some duplicate code & debugging variable (charles@transmissionbt.com) - -        2008-05-28 -            Made JSON_value structure ansi C compliant. This bug was report by -            trisk@acm.jhu.edu - -        2008-05-20 -            Fixed bug reported by charles@transmissionbt.com where the switching -            from static to dynamic parse buffer did not copy the static parse -            buffer's content. -*/ - - - -#include <assert.h> -#include <ctype.h> -#include <float.h> -#include <stddef.h> -#include <stdio.h> -#include <stdlib.h> -#include <string.h> -#include <locale.h> - -#include "JSON_parser.h" - -#ifdef _MSC_VER -#   if _MSC_VER >= 1400 /* Visual Studio 2005 and up */ -#      pragma warning(disable:4996) // unsecure sscanf -#   endif -#endif - - -#define true  1 -#define false 0 -#define __   -1     /* the universal error code */ - -/* values chosen so that the object size is approx equal to one page (4K) */ -#ifndef JSON_PARSER_STACK_SIZE -#   define JSON_PARSER_STACK_SIZE 128 -#endif - -#ifndef JSON_PARSER_PARSE_BUFFER_SIZE -#   define JSON_PARSER_PARSE_BUFFER_SIZE 3500 -#endif - -typedef unsigned short UTF16; - -struct JSON_parser_struct { -    JSON_parser_callback callback; -    void* ctx; -    signed char state, before_comment_state, type, escaped, comment, allow_comments, handle_floats_manually; -    UTF16 utf16_high_surrogate; -    long depth; -    long top; -    signed char* stack; -    long stack_capacity; -    char decimal_point; -    char* parse_buffer; -    size_t parse_buffer_capacity; -    size_t parse_buffer_count; -    size_t comment_begin_offset; -    signed char static_stack[JSON_PARSER_STACK_SIZE]; -    char static_parse_buffer[JSON_PARSER_PARSE_BUFFER_SIZE]; -}; - -#define COUNTOF(x) (sizeof(x)/sizeof(x[0])) - -/* -    Characters are mapped into these character classes. This allows for -    a significant reduction in the size of the state transition table. -*/ - - - -enum classes { -    C_SPACE,  /* space */ -    C_WHITE,  /* other whitespace */ -    C_LCURB,  /* {  */ -    C_RCURB,  /* } */ -    C_LSQRB,  /* [ */ -    C_RSQRB,  /* ] */ -    C_COLON,  /* : */ -    C_COMMA,  /* , */ -    C_QUOTE,  /* " */ -    C_BACKS,  /* \ */ -    C_SLASH,  /* / */ -    C_PLUS,   /* + */ -    C_MINUS,  /* - */ -    C_POINT,  /* . */ -    C_ZERO ,  /* 0 */ -    C_DIGIT,  /* 123456789 */ -    C_LOW_A,  /* a */ -    C_LOW_B,  /* b */ -    C_LOW_C,  /* c */ -    C_LOW_D,  /* d */ -    C_LOW_E,  /* e */ -    C_LOW_F,  /* f */ -    C_LOW_L,  /* l */ -    C_LOW_N,  /* n */ -    C_LOW_R,  /* r */ -    C_LOW_S,  /* s */ -    C_LOW_T,  /* t */ -    C_LOW_U,  /* u */ -    C_ABCDF,  /* ABCDF */ -    C_E,      /* E */ -    C_ETC,    /* everything else */ -    C_STAR,   /* * */ -    NR_CLASSES -}; - -static int ascii_class[128] = { -/* -    This array maps the 128 ASCII characters into character classes. -    The remaining Unicode characters should be mapped to C_ETC. -    Non-whitespace control characters are errors. -*/ -    __,      __,      __,      __,      __,      __,      __,      __, -    __,      C_WHITE, C_WHITE, __,      __,      C_WHITE, __,      __, -    __,      __,      __,      __,      __,      __,      __,      __, -    __,      __,      __,      __,      __,      __,      __,      __, - -    C_SPACE, C_ETC,   C_QUOTE, C_ETC,   C_ETC,   C_ETC,   C_ETC,   C_ETC, -    C_ETC,   C_ETC,   C_STAR,   C_PLUS,  C_COMMA, C_MINUS, C_POINT, C_SLASH, -    C_ZERO,  C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, -    C_DIGIT, C_DIGIT, C_COLON, C_ETC,   C_ETC,   C_ETC,   C_ETC,   C_ETC, - -    C_ETC,   C_ABCDF, C_ABCDF, C_ABCDF, C_ABCDF, C_E,     C_ABCDF, C_ETC, -    C_ETC,   C_ETC,   C_ETC,   C_ETC,   C_ETC,   C_ETC,   C_ETC,   C_ETC, -    C_ETC,   C_ETC,   C_ETC,   C_ETC,   C_ETC,   C_ETC,   C_ETC,   C_ETC, -    C_ETC,   C_ETC,   C_ETC,   C_LSQRB, C_BACKS, C_RSQRB, C_ETC,   C_ETC, - -    C_ETC,   C_LOW_A, C_LOW_B, C_LOW_C, C_LOW_D, C_LOW_E, C_LOW_F, C_ETC, -    C_ETC,   C_ETC,   C_ETC,   C_ETC,   C_LOW_L, C_ETC,   C_LOW_N, C_ETC, -    C_ETC,   C_ETC,   C_LOW_R, C_LOW_S, C_LOW_T, C_LOW_U, C_ETC,   C_ETC, -    C_ETC,   C_ETC,   C_ETC,   C_LCURB, C_ETC,   C_RCURB, C_ETC,   C_ETC -}; - - -/* -    The state codes. -*/ -enum states { -    GO,  /* start    */ -    OK,  /* ok       */ -    OB,  /* object   */ -    KE,  /* key      */ -    CO,  /* colon    */ -    VA,  /* value    */ -    AR,  /* array    */ -    ST,  /* string   */ -    ES,  /* escape   */ -    U1,  /* u1       */ -    U2,  /* u2       */ -    U3,  /* u3       */ -    U4,  /* u4       */ -    MI,  /* minus    */ -    ZE,  /* zero     */ -    IT,  /* integer  */ -    FR,  /* fraction */ -    E1,  /* e        */ -    E2,  /* ex       */ -    E3,  /* exp      */ -    T1,  /* tr       */ -    T2,  /* tru      */ -    T3,  /* true     */ -    F1,  /* fa       */ -    F2,  /* fal      */ -    F3,  /* fals     */ -    F4,  /* false    */ -    N1,  /* nu       */ -    N2,  /* nul      */ -    N3,  /* null     */ -    C1,  /* /        */ -    C2,  /* / *     */ -    C3,  /* *        */ -    FX,  /* *.* *eE* */ -    D1,  /* second UTF-16 character decoding started by \ */ -    D2,  /* second UTF-16 character proceeded by u */ -    NR_STATES -}; - -enum actions -{ -    CB = -10, /* comment begin */ -    CE = -11, /* comment end */ -    FA = -12, /* false */ -    TR = -13, /* false */ -    NU = -14, /* null */ -    DE = -15, /* double detected by exponent e E */ -    DF = -16, /* double detected by fraction . */ -    SB = -17, /* string begin */ -    MX = -18, /* integer detected by minus */ -    ZX = -19, /* integer detected by zero */ -    IX = -20, /* integer detected by 1-9 */ -    EX = -21, /* next char is escaped */ -    UC = -22  /* Unicode character read */ -}; - - -static int state_transition_table[NR_STATES][NR_CLASSES] = { -/* -    The state transition table takes the current state and the current symbol, -    and returns either a new state or an action. An action is represented as a -    negative number. A JSON text is accepted if at the end of the text the -    state is OK and if the mode is MODE_DONE. - -                 white                                      1-9                                   ABCDF  etc -             space |  {  }  [  ]  :  ,  "  \  /  +  -  .  0  |  a  b  c  d  e  f  l  n  r  s  t  u  |  E  |  * */ -/*start  GO*/ {GO,GO,-6,__,-5,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*ok     OK*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*object OB*/ {OB,OB,__,-9,__,__,__,__,SB,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*key    KE*/ {KE,KE,__,__,__,__,__,__,SB,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*colon  CO*/ {CO,CO,__,__,__,__,-2,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*value  VA*/ {VA,VA,-6,__,-5,__,__,__,SB,__,CB,__,MX,__,ZX,IX,__,__,__,__,__,FA,__,NU,__,__,TR,__,__,__,__,__}, -/*array  AR*/ {AR,AR,-6,__,-5,-7,__,__,SB,__,CB,__,MX,__,ZX,IX,__,__,__,__,__,FA,__,NU,__,__,TR,__,__,__,__,__}, -/*string ST*/ {ST,__,ST,ST,ST,ST,ST,ST,-4,EX,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST}, -/*escape ES*/ {__,__,__,__,__,__,__,__,ST,ST,ST,__,__,__,__,__,__,ST,__,__,__,ST,__,ST,ST,__,ST,U1,__,__,__,__}, -/*u1     U1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U2,U2,U2,U2,U2,U2,U2,U2,__,__,__,__,__,__,U2,U2,__,__}, -/*u2     U2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U3,U3,U3,U3,U3,U3,U3,U3,__,__,__,__,__,__,U3,U3,__,__}, -/*u3     U3*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U4,U4,U4,U4,U4,U4,U4,U4,__,__,__,__,__,__,U4,U4,__,__}, -/*u4     U4*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,UC,UC,UC,UC,UC,UC,UC,UC,__,__,__,__,__,__,UC,UC,__,__}, -/*minus  MI*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,ZE,IT,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*zero   ZE*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,DF,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*int    IT*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,DF,IT,IT,__,__,__,__,DE,__,__,__,__,__,__,__,__,DE,__,__}, -/*frac   FR*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,__,FR,FR,__,__,__,__,E1,__,__,__,__,__,__,__,__,E1,__,__}, -/*e      E1*/ {__,__,__,__,__,__,__,__,__,__,__,E2,E2,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*ex     E2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*exp    E3*/ {OK,OK,__,-8,__,-7,__,-3,__,__,__,__,__,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*tr     T1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,T2,__,__,__,__,__,__,__}, -/*tru    T2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,T3,__,__,__,__}, -/*true   T3*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__,__,__}, -/*fa     F1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F2,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*fal    F2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F3,__,__,__,__,__,__,__,__,__}, -/*fals   F3*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F4,__,__,__,__,__,__}, -/*false  F4*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__,__,__}, -/*nu     N1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,N2,__,__,__,__}, -/*nul    N2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,N3,__,__,__,__,__,__,__,__,__}, -/*null   N3*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__}, -/*/      C1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,C2}, -/*/*     C2*/ {C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C3}, -/**      C3*/ {C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,CE,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C3}, -/*_.     FX*/ {OK,OK,__,-8,__,-7,__,-3,__,__,__,__,__,__,FR,FR,__,__,__,__,E1,__,__,__,__,__,__,__,__,E1,__,__}, -/*\      D1*/ {__,__,__,__,__,__,__,__,__,D2,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*\      D2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,U1,__,__,__,__}, -}; - - -/* -    These modes can be pushed on the stack. -*/ -enum modes { -    MODE_ARRAY = 1, -    MODE_DONE = 2, -    MODE_KEY = 3, -    MODE_OBJECT = 4 -}; - -static int -push(JSON_parser jc, int mode) -{ -/* -    Push a mode onto the stack. Return false if there is overflow. -*/ -    jc->top += 1; -    if (jc->depth < 0) { -        if (jc->top >= jc->stack_capacity) { -            size_t bytes_to_allocate; -            jc->stack_capacity *= 2; -            bytes_to_allocate = jc->stack_capacity * sizeof(jc->static_stack[0]); -            if (jc->stack == &jc->static_stack[0]) { -                jc->stack = (signed char*)malloc(bytes_to_allocate); -                memcpy(jc->stack, jc->static_stack, sizeof(jc->static_stack)); -            } else { -                jc->stack = (signed char*)realloc(jc->stack, bytes_to_allocate); -            } -        } -    } else { -        if (jc->top >= jc->depth) { -            return false; -        } -    } - -    jc->stack[jc->top] = mode; -    return true; -} - - -static int -pop(JSON_parser jc, int mode) -{ -/* -    Pop the stack, assuring that the current mode matches the expectation. -    Return false if there is underflow or if the modes mismatch. -*/ -    if (jc->top < 0 || jc->stack[jc->top] != mode) { -        return false; -    } -    jc->top -= 1; -    return true; -} - - -#define parse_buffer_clear(jc) \ -    do {\ -        jc->parse_buffer_count = 0;\ -        jc->parse_buffer[0] = 0;\ -    } while (0) - -#define parse_buffer_pop_back_char(jc)\ -    do {\ -        assert(jc->parse_buffer_count >= 1);\ -        --jc->parse_buffer_count;\ -        jc->parse_buffer[jc->parse_buffer_count] = 0;\ -    } while (0) - -void delete_JSON_parser(JSON_parser jc) -{ -    if (jc) { -        if (jc->stack != &jc->static_stack[0]) { -            free((void*)jc->stack); -        } -        if (jc->parse_buffer != &jc->static_parse_buffer[0]) { -            free((void*)jc->parse_buffer); -        } -        free((void*)jc); -     } -} - - -JSON_parser -new_JSON_parser(JSON_config* config) -{ -/* -    new_JSON_parser starts the checking process by constructing a JSON_parser -    object. It takes a depth parameter that restricts the level of maximum -    nesting. - -    To continue the process, call JSON_parser_char for each character in the -    JSON text, and then call JSON_parser_done to obtain the final result. -    These functions are fully reentrant. -*/ - -    int depth = 0; -    JSON_config default_config; - -    JSON_parser jc = (JSON_parser)malloc(sizeof(struct JSON_parser_struct)); - -    memset(jc, 0, sizeof(*jc)); - - -    /* initialize configuration */ -    init_JSON_config(&default_config); - -    /* set to default configuration if none was provided */ -    if (config == NULL) { -        config = &default_config; -    } - -    depth = config->depth; - -    /* We need to be able to push at least one object */ -    if (depth == 0) { -        depth = 1; -    } - -    jc->state = GO; -    jc->top = -1; - -    /* Do we want non-bound stack? */ -    if (depth > 0) { -        jc->stack_capacity = depth; -        jc->depth = depth; -        if (depth <= (int)COUNTOF(jc->static_stack)) { -            jc->stack = &jc->static_stack[0]; -        } else { -            jc->stack = (signed char*)malloc(jc->stack_capacity * sizeof(jc->static_stack[0])); -        } -    } else { -        jc->stack_capacity = COUNTOF(jc->static_stack); -        jc->depth = -1; -        jc->stack = &jc->static_stack[0]; -    } - -    /* set parser to start */ -    push(jc, MODE_DONE); - -    /* set up the parse buffer */ -    jc->parse_buffer = &jc->static_parse_buffer[0]; -    jc->parse_buffer_capacity = COUNTOF(jc->static_parse_buffer); -    parse_buffer_clear(jc); - -    /* set up callback, comment & float handling */ -    jc->callback = config->callback; -    jc->ctx = config->callback_ctx; -    jc->allow_comments = config->allow_comments != 0; -    jc->handle_floats_manually = config->handle_floats_manually != 0; - -    /* set up decimal point */ -    jc->decimal_point = *localeconv()->decimal_point; - -    return jc; -} - -static void grow_parse_buffer(JSON_parser jc) -{ -    size_t bytes_to_allocate; -    jc->parse_buffer_capacity *= 2; -    bytes_to_allocate = jc->parse_buffer_capacity * sizeof(jc->parse_buffer[0]); -    if (jc->parse_buffer == &jc->static_parse_buffer[0]) { -        jc->parse_buffer = (char*)malloc(bytes_to_allocate); -        memcpy(jc->parse_buffer, jc->static_parse_buffer, jc->parse_buffer_count); -    } else { -        jc->parse_buffer = (char*)realloc(jc->parse_buffer, bytes_to_allocate); -    } -} - -#define parse_buffer_push_back_char(jc, c)\ -    do {\ -        if (jc->parse_buffer_count + 1 >= jc->parse_buffer_capacity) grow_parse_buffer(jc);\ -        jc->parse_buffer[jc->parse_buffer_count++] = c;\ -        jc->parse_buffer[jc->parse_buffer_count]   = 0;\ -    } while (0) - -#define assert_is_non_container_type(jc) \ -    assert( \ -        jc->type == JSON_T_NULL || \ -        jc->type == JSON_T_FALSE || \ -        jc->type == JSON_T_TRUE || \ -        jc->type == JSON_T_FLOAT || \ -        jc->type == JSON_T_INTEGER || \ -        jc->type == JSON_T_STRING) - - -static int parse_parse_buffer(JSON_parser jc) -{ -    if (jc->callback) { -        JSON_value value, *arg = NULL; - -        if (jc->type != JSON_T_NONE) { -            assert_is_non_container_type(jc); - -            switch(jc->type) { -                case JSON_T_FLOAT: -                    arg = &value; -                    if (jc->handle_floats_manually) { -                        value.vu.str.value = jc->parse_buffer; -                        value.vu.str.length = jc->parse_buffer_count; -                    } else { -                        /*sscanf(jc->parse_buffer, "%Lf", &value.vu.float_value);*/ - -                        /* not checking with end pointer b/c there may be trailing ws */ -                        value.vu.float_value = strtod(jc->parse_buffer, NULL); -                    } -                    break; -                case JSON_T_INTEGER: -                    arg = &value; -                    sscanf(jc->parse_buffer, JSON_PARSER_INTEGER_SSCANF_TOKEN, &value.vu.integer_value); -                    break; -                case JSON_T_STRING: -                    arg = &value; -                    value.vu.str.value = jc->parse_buffer; -                    value.vu.str.length = jc->parse_buffer_count; -                    break; -            } - -            if (!(*jc->callback)(jc->ctx, jc->type, arg)) { -                return false; -            } -        } -    } - -    parse_buffer_clear(jc); - -    return true; -} - -#define IS_HIGH_SURROGATE(uc) (((uc) & 0xFC00) == 0xD800) -#define IS_LOW_SURROGATE(uc)  (((uc) & 0xFC00) == 0xDC00) -#define DECODE_SURROGATE_PAIR(hi,lo) ((((hi) & 0x3FF) << 10) + ((lo) & 0x3FF) + 0x10000) -static unsigned char utf8_lead_bits[4] = { 0x00, 0xC0, 0xE0, 0xF0 }; - -static int decode_unicode_char(JSON_parser jc) -{ -    int i; -    unsigned uc = 0; -    char* p; -    int trail_bytes; - -    assert(jc->parse_buffer_count >= 6); - -    p = &jc->parse_buffer[jc->parse_buffer_count - 4]; - -    for (i = 12; i >= 0; i -= 4, ++p) { -        unsigned x = *p; - -        if (x >= 'a') { -            x -= ('a' - 10); -        } else if (x >= 'A') { -            x -= ('A' - 10); -        } else { -            x &= ~0x30u; -        } - -        assert(x < 16); - -        uc |= x << i; -    } - -    /* clear UTF-16 char from buffer */ -    jc->parse_buffer_count -= 6; -    jc->parse_buffer[jc->parse_buffer_count] = 0; - -    /* attempt decoding ... */ -    if (jc->utf16_high_surrogate) { -        if (IS_LOW_SURROGATE(uc)) { -            uc = DECODE_SURROGATE_PAIR(jc->utf16_high_surrogate, uc); -            trail_bytes = 3; -            jc->utf16_high_surrogate = 0; -        } else { -            /* high surrogate without a following low surrogate */ -            return false; -        } -    } else { -        if (uc < 0x80) { -            trail_bytes = 0; -        } else if (uc < 0x800) { -            trail_bytes = 1; -        } else if (IS_HIGH_SURROGATE(uc)) { -            /* save the high surrogate and wait for the low surrogate */ -            jc->utf16_high_surrogate = uc; -            return true; -        } else if (IS_LOW_SURROGATE(uc)) { -            /* low surrogate without a preceding high surrogate */ -            return false; -        } else { -            trail_bytes = 2; -        } -    } - -    jc->parse_buffer[jc->parse_buffer_count++] = (char) ((uc >> (trail_bytes * 6)) | utf8_lead_bits[trail_bytes]); - -    for (i = trail_bytes * 6 - 6; i >= 0; i -= 6) { -        jc->parse_buffer[jc->parse_buffer_count++] = (char) (((uc >> i) & 0x3F) | 0x80); -    } - -    jc->parse_buffer[jc->parse_buffer_count] = 0; - -    return true; -} - -static int add_escaped_char_to_parse_buffer(JSON_parser jc, int next_char) -{ -    jc->escaped = 0; -    /* remove the backslash */ -    parse_buffer_pop_back_char(jc); -    switch(next_char) { -        case 'b': -            parse_buffer_push_back_char(jc, '\b'); -            break; -        case 'f': -            parse_buffer_push_back_char(jc, '\f'); -            break; -        case 'n': -            parse_buffer_push_back_char(jc, '\n'); -            break; -        case 'r': -            parse_buffer_push_back_char(jc, '\r'); -            break; -        case 't': -            parse_buffer_push_back_char(jc, '\t'); -            break; -        case '"': -            parse_buffer_push_back_char(jc, '"'); -            break; -        case '\\': -            parse_buffer_push_back_char(jc, '\\'); -            break; -        case '/': -            parse_buffer_push_back_char(jc, '/'); -            break; -        case 'u': -            parse_buffer_push_back_char(jc, '\\'); -            parse_buffer_push_back_char(jc, 'u'); -            break; -        default: -            return false; -    } - -    return true; -} - -#define add_char_to_parse_buffer(jc, next_char, next_class) \ -    do { \ -        if (jc->escaped) { \ -            if (!add_escaped_char_to_parse_buffer(jc, next_char)) \ -                return false; \ -        } else if (!jc->comment) { \ -            if ((jc->type != JSON_T_NONE) | !((next_class == C_SPACE) | (next_class == C_WHITE)) /* non-white-space */) { \ -                parse_buffer_push_back_char(jc, (char)next_char); \ -            } \ -        } \ -    } while (0) - - -#define assert_type_isnt_string_null_or_bool(jc) \ -    assert(jc->type != JSON_T_FALSE); \ -    assert(jc->type != JSON_T_TRUE); \ -    assert(jc->type != JSON_T_NULL); \ -    assert(jc->type != JSON_T_STRING) - - -int -JSON_parser_char(JSON_parser jc, int next_char) -{ -/* -    After calling new_JSON_parser, call this function for each character (or -    partial character) in your JSON text. It can accept UTF-8, UTF-16, or -    UTF-32. It returns true if things are looking ok so far. If it rejects the -    text, it returns false. -*/ -    int next_class, next_state; - -/* -    Determine the character's class. -*/ -    if (next_char < 0) { -        return false; -    } -    if (next_char >= 128) { -        next_class = C_ETC; -    } else { -        next_class = ascii_class[next_char]; -        if (next_class <= __) { -            return false; -        } -    } - -    add_char_to_parse_buffer(jc, next_char, next_class); - -/* -    Get the next state from the state transition table. -*/ -    next_state = state_transition_table[jc->state][next_class]; -    if (next_state >= 0) { -/* -    Change the state. -*/ -        jc->state = next_state; -    } else { -/* -    Or perform one of the actions. -*/ -        switch (next_state) { -/* Unicode character */ -        case UC: -            if(!decode_unicode_char(jc)) { -                return false; -            } -            /* check if we need to read a second UTF-16 char */ -            if (jc->utf16_high_surrogate) { -                jc->state = D1; -            } else { -                jc->state = ST; -            } -            break; -/* escaped char */ -        case EX: -            jc->escaped = 1; -            jc->state = ES; -            break; -/* integer detected by minus */ -        case MX: -            jc->type = JSON_T_INTEGER; -            jc->state = MI; -            break; -/* integer detected by zero */ -        case ZX: -            jc->type = JSON_T_INTEGER; -            jc->state = ZE; -            break; -/* integer detected by 1-9 */ -        case IX: -            jc->type = JSON_T_INTEGER; -            jc->state = IT; -            break; - -/* floating point number detected by exponent*/ -        case DE: -            assert_type_isnt_string_null_or_bool(jc); -            jc->type = JSON_T_FLOAT; -            jc->state = E1; -            break; - -/* floating point number detected by fraction */ -        case DF: -            assert_type_isnt_string_null_or_bool(jc); -            if (!jc->handle_floats_manually) { -/* -    Some versions of strtod (which underlies sscanf) don't support converting -    C-locale formated floating point values. -*/ -                assert(jc->parse_buffer[jc->parse_buffer_count-1] == '.'); -                jc->parse_buffer[jc->parse_buffer_count-1] = jc->decimal_point; -            } -            jc->type = JSON_T_FLOAT; -            jc->state = FX; -            break; -/* string begin " */ -        case SB: -            parse_buffer_clear(jc); -            assert(jc->type == JSON_T_NONE); -            jc->type = JSON_T_STRING; -            jc->state = ST; -            break; - -/* n */ -        case NU: -            assert(jc->type == JSON_T_NONE); -            jc->type = JSON_T_NULL; -            jc->state = N1; -            break; -/* f */ -        case FA: -            assert(jc->type == JSON_T_NONE); -            jc->type = JSON_T_FALSE; -            jc->state = F1; -            break; -/* t */ -        case TR: -            assert(jc->type == JSON_T_NONE); -            jc->type = JSON_T_TRUE; -            jc->state = T1; -            break; - -/* closing comment */ -        case CE: -            jc->comment = 0; -            assert(jc->parse_buffer_count == 0); -            assert(jc->type == JSON_T_NONE); -            jc->state = jc->before_comment_state; -            break; - -/* opening comment  */ -        case CB: -            if (!jc->allow_comments) { -                return false; -            } -            parse_buffer_pop_back_char(jc); -            if (!parse_parse_buffer(jc)) { -                return false; -            } -            assert(jc->parse_buffer_count == 0); -            assert(jc->type != JSON_T_STRING); -            switch (jc->stack[jc->top]) { -            case MODE_ARRAY: -            case MODE_OBJECT: -                switch(jc->state) { -                case VA: -                case AR: -                    jc->before_comment_state = jc->state; -                    break; -                default: -                    jc->before_comment_state = OK; -                    break; -                } -                break; -            default: -                jc->before_comment_state = jc->state; -                break; -            } -            jc->type = JSON_T_NONE; -            jc->state = C1; -            jc->comment = 1; -            break; -/* empty } */ -        case -9: -            parse_buffer_clear(jc); -            if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_END, NULL)) { -                return false; -            } -            if (!pop(jc, MODE_KEY)) { -                return false; -            } -            jc->state = OK; -            break; - -/* } */ case -8: -            parse_buffer_pop_back_char(jc); -            if (!parse_parse_buffer(jc)) { -                return false; -            } -            if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_END, NULL)) { -                return false; -            } -            if (!pop(jc, MODE_OBJECT)) { -                return false; -            } -            jc->type = JSON_T_NONE; -            jc->state = OK; -            break; - -/* ] */ case -7: -            parse_buffer_pop_back_char(jc); -            if (!parse_parse_buffer(jc)) { -                return false; -            } -            if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_ARRAY_END, NULL)) { -                return false; -            } -            if (!pop(jc, MODE_ARRAY)) { -                return false; -            } - -            jc->type = JSON_T_NONE; -            jc->state = OK; -            break; - -/* { */ case -6: -            parse_buffer_pop_back_char(jc); -            if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_BEGIN, NULL)) { -                return false; -            } -            if (!push(jc, MODE_KEY)) { -                return false; -            } -            assert(jc->type == JSON_T_NONE); -            jc->state = OB; -            break; - -/* [ */ case -5: -            parse_buffer_pop_back_char(jc); -            if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_ARRAY_BEGIN, NULL)) { -                return false; -            } -            if (!push(jc, MODE_ARRAY)) { -                return false; -            } -            assert(jc->type == JSON_T_NONE); -            jc->state = AR; -            break; - -/* string end " */ case -4: -            parse_buffer_pop_back_char(jc); -            switch (jc->stack[jc->top]) { -            case MODE_KEY: -                assert(jc->type == JSON_T_STRING); -                jc->type = JSON_T_NONE; -                jc->state = CO; - -                if (jc->callback) { -                    JSON_value value; -                    value.vu.str.value = jc->parse_buffer; -                    value.vu.str.length = jc->parse_buffer_count; -                    if (!(*jc->callback)(jc->ctx, JSON_T_KEY, &value)) { -                        return false; -                    } -                } -                parse_buffer_clear(jc); -                break; -            case MODE_ARRAY: -            case MODE_OBJECT: -                assert(jc->type == JSON_T_STRING); -                if (!parse_parse_buffer(jc)) { -                    return false; -                } -                jc->type = JSON_T_NONE; -                jc->state = OK; -                break; -            default: -                return false; -            } -            break; - -/* , */ case -3: -            parse_buffer_pop_back_char(jc); -            if (!parse_parse_buffer(jc)) { -                return false; -            } -            switch (jc->stack[jc->top]) { -            case MODE_OBJECT: -/* -    A comma causes a flip from object mode to key mode. -*/ -                if (!pop(jc, MODE_OBJECT) || !push(jc, MODE_KEY)) { -                    return false; -                } -                assert(jc->type != JSON_T_STRING); -                jc->type = JSON_T_NONE; -                jc->state = KE; -                break; -            case MODE_ARRAY: -                assert(jc->type != JSON_T_STRING); -                jc->type = JSON_T_NONE; -                jc->state = VA; -                break; -            default: -                return false; -            } -            break; - -/* : */ case -2: -/* -    A colon causes a flip from key mode to object mode. -*/ -            parse_buffer_pop_back_char(jc); -            if (!pop(jc, MODE_KEY) || !push(jc, MODE_OBJECT)) { -                return false; -            } -            assert(jc->type == JSON_T_NONE); -            jc->state = VA; -            break; -/* -    Bad action. -*/ -        default: -            return false; -        } -    } -    return true; -} - - -int -JSON_parser_done(JSON_parser jc) -{ -    const int result = jc->state == OK && pop(jc, MODE_DONE); - -    return result; -} - - -int JSON_parser_is_legal_white_space_string(const char* s) -{ -    int c, char_class; - -    if (s == NULL) { -        return false; -    } - -    for (; *s; ++s) { -        c = *s; - -        if (c < 0 || c >= 128) { -            return false; -        } - -        char_class = ascii_class[c]; - -        if (char_class != C_SPACE && char_class != C_WHITE) { -            return false; -        } -    } - -    return true; -} - - - -void init_JSON_config(JSON_config* config) -{ -    if (config) { -        memset(config, 0, sizeof(*config)); - -        config->depth = JSON_PARSER_STACK_SIZE - 1; -    } -} diff --git a/decoder/JSON_parser.h b/decoder/JSON_parser.h deleted file mode 100644 index de980072..00000000 --- a/decoder/JSON_parser.h +++ /dev/null @@ -1,152 +0,0 @@ -#ifndef JSON_PARSER_H -#define JSON_PARSER_H - -/* JSON_parser.h */ - - -#include <stddef.h> - -/* Windows DLL stuff */ -#ifdef _WIN32 -#	ifdef JSON_PARSER_DLL_EXPORTS -#		define JSON_PARSER_DLL_API __declspec(dllexport) -#	else -#		define JSON_PARSER_DLL_API __declspec(dllimport) -#   endif -#else -#	define JSON_PARSER_DLL_API -#endif - -/* Determine the integer type use to parse non-floating point numbers */ -#if __STDC_VERSION__ >= 199901L || HAVE_LONG_LONG == 1 -typedef long long JSON_int_t; -#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%lld" -#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%lld" -#else -typedef long JSON_int_t; -#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%ld" -#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%ld" -#endif - - -#ifdef __cplusplus -extern "C" { -#endif - -typedef enum -{ -    JSON_T_NONE = 0, -    JSON_T_ARRAY_BEGIN,  // 1 -    JSON_T_ARRAY_END,    // 2 -    JSON_T_OBJECT_BEGIN, // 3 -    JSON_T_OBJECT_END,   // 4 -    JSON_T_INTEGER,      // 5 -    JSON_T_FLOAT,        // 6 -    JSON_T_NULL,         // 7 -    JSON_T_TRUE,         // 8 -    JSON_T_FALSE,        // 9 -    JSON_T_STRING,       // 10 -    JSON_T_KEY,          // 11 -    JSON_T_MAX           // 12 -} JSON_type; - -typedef struct JSON_value_struct { -    union { -        JSON_int_t integer_value; - -        double float_value; - -        struct { -            const char* value; -            size_t length; -        } str; -    } vu; -} JSON_value; - -typedef struct JSON_parser_struct* JSON_parser; - -/*! \brief JSON parser callback - -    \param ctx The pointer passed to new_JSON_parser. -    \param type An element of JSON_type but not JSON_T_NONE. -    \param value A representation of the parsed value. This parameter is NULL for -        JSON_T_ARRAY_BEGIN, JSON_T_ARRAY_END, JSON_T_OBJECT_BEGIN, JSON_T_OBJECT_END, -        JSON_T_NULL, JSON_T_TRUE, and SON_T_FALSE. String values are always returned -        as zero-terminated C strings. - -    \return Non-zero if parsing should continue, else zero. -*/ -typedef int (*JSON_parser_callback)(void* ctx, int type, const struct JSON_value_struct* value); - - -/*! \brief The structure used to configure a JSON parser object - -    \param depth If negative, the parser can parse arbitrary levels of JSON, otherwise -        the depth is the limit -    \param Pointer to a callback. This parameter may be NULL. In this case the input is merely checked for validity. -    \param Callback context. This parameter may be NULL. -    \param depth. Specifies the levels of nested JSON to allow. Negative numbers yield unlimited nesting. -    \param allowComments. To allow C style comments in JSON, set to non-zero. -    \param handleFloatsManually. To decode floating point numbers manually set this parameter to non-zero. - -    \return The parser object. -*/ -typedef struct { -    JSON_parser_callback     callback; -    void*                    callback_ctx; -    int                      depth; -    int                      allow_comments; -    int                      handle_floats_manually; -} JSON_config; - - -/*! \brief Initializes the JSON parser configuration structure to default values. - -    The default configuration is -    - 127 levels of nested JSON (depends on JSON_PARSER_STACK_SIZE, see json_parser.c) -    - no parsing, just checking for JSON syntax -    - no comments - -    \param config. Used to configure the parser. -*/ -JSON_PARSER_DLL_API void init_JSON_config(JSON_config* config); - -/*! \brief Create a JSON parser object - -    \param config. Used to configure the parser. Set to NULL to use the default configuration. -        See init_JSON_config - -    \return The parser object. -*/ -JSON_PARSER_DLL_API extern JSON_parser new_JSON_parser(JSON_config* config); - -/*! \brief Destroy a previously created JSON parser object. */ -JSON_PARSER_DLL_API extern void delete_JSON_parser(JSON_parser jc); - -/*! \brief Parse a character. - -    \return Non-zero, if all characters passed to this function are part of are valid JSON. -*/ -JSON_PARSER_DLL_API extern int JSON_parser_char(JSON_parser jc, int next_char); - -/*! \brief Finalize parsing. - -    Call this method once after all input characters have been consumed. - -    \return Non-zero, if all parsed characters are valid JSON, zero otherwise. -*/ -JSON_PARSER_DLL_API extern int JSON_parser_done(JSON_parser jc); - -/*! \brief Determine if a given string is valid JSON white space - -    \return Non-zero if the string is valid, zero otherwise. -*/ -JSON_PARSER_DLL_API extern int JSON_parser_is_legal_white_space_string(const char* s); - - -#ifdef __cplusplus -} -#endif - - -#endif /* JSON_PARSER_H */ diff --git a/decoder/Makefile.am b/decoder/Makefile.am index f791b3cc..f9f90cfb 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -33,10 +33,10 @@ noinst_LIBRARIES = libcdec.a  EXTRA_DIST = test_data rule_lexer.ll  libcdec_a_SOURCES = \ -  JSON_parser.h \    aligner.h \    apply_models.h \    bottom_up_parser.h \ +  bottom_up_parser-rs.h \    csplit.h \    decoder.h \    earley_composer.h \ @@ -45,6 +45,7 @@ libcdec_a_SOURCES = \    ff_basic.h \    ff_bleu.h \    ff_charset.h \ +  ff_conll.h \    ff_const_reorder_common.h \    ff_const_reorder.h \    ff_context.h \ @@ -52,7 +53,7 @@ libcdec_a_SOURCES = \    ff_external.h \    ff_factory.h \    ff_klm.h \ -	ff_lexical.h \ +  ff_lexical.h \    ff_lm.h \    ff_ngrams.h \    ff_parse_match.h \ @@ -82,7 +83,6 @@ libcdec_a_SOURCES = \    hg_union.h \    incremental.h \    inside_outside.h \ -  json_parse.h \    kbest.h \    lattice.h \    lexalign.h \ @@ -102,6 +102,7 @@ libcdec_a_SOURCES = \    aligner.cc \    apply_models.cc \    bottom_up_parser.cc \ +  bottom_up_parser-rs.cc \    cdec.cc \    cdec_ff.cc \    csplit.cc \ @@ -112,6 +113,7 @@ libcdec_a_SOURCES = \    ff_basic.cc \    ff_bleu.cc \    ff_charset.cc \ +  ff_conll.cc \    ff_context.cc \    ff_const_reorder.cc \    ff_csplit.cc \ @@ -144,7 +146,6 @@ libcdec_a_SOURCES = \    hg_sampler.cc \    hg_union.cc \    incremental.cc \ -  json_parse.cc \    lattice.cc \    lexalign.cc \    lextrans.cc \ @@ -160,5 +161,4 @@ libcdec_a_SOURCES = \    tagger.cc \    translator.cc \    trule.cc \ -  viterbi.cc \ -  JSON_parser.c +  viterbi.cc diff --git a/decoder/aligner.cc b/decoder/aligner.cc index 232e022a..fd648370 100644 --- a/decoder/aligner.cc +++ b/decoder/aligner.cc @@ -198,13 +198,13 @@ void AlignerTools::WriteAlignment(const Lattice& src_lattice,    }    const Hypergraph* g = &in_g;    HypergraphP new_hg; -  if (!src_lattice.IsSentence() || -      !trg_lattice.IsSentence()) { +  if (!IsSentence(src_lattice) || +      !IsSentence(trg_lattice)) {      if (map_instead_of_viterbi) {        cerr << "  Lattice alignment: using Viterbi instead of MAP alignment\n";      }      map_instead_of_viterbi = false; -    fix_up_src_spans = !src_lattice.IsSentence(); +    fix_up_src_spans = !IsSentence(src_lattice);    }    KBest::KBestDerivations<vector<Hypergraph::Edge const*>, ViterbiPathTraversal> kbest(in_g, k_best); diff --git a/decoder/aligner.h b/decoder/aligner.h index a34795c9..d68ceefc 100644 --- a/decoder/aligner.h +++ b/decoder/aligner.h @@ -1,4 +1,4 @@ -#ifndef _ALIGNER_H_ +#ifndef ALIGNER_H  #include <string>  #include <iostream> diff --git a/decoder/apply_models.h b/decoder/apply_models.h index 19a4c7be..f03c973a 100644 --- a/decoder/apply_models.h +++ b/decoder/apply_models.h @@ -1,5 +1,5 @@ -#ifndef _APPLY_MODELS_H_ -#define _APPLY_MODELS_H_ +#ifndef APPLY_MODELS_H_ +#define APPLY_MODELS_H_  #include <iostream> diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc new file mode 100644 index 00000000..fbde7e24 --- /dev/null +++ b/decoder/bottom_up_parser-rs.cc @@ -0,0 +1,341 @@ +#include "bottom_up_parser-rs.h" + +#include <iostream> +#include <map> + +#include "node_state_hash.h" +#include "nt_span.h" +#include "hg.h" +#include "array2d.h" +#include "tdict.h" +#include "verbose.h" + +using namespace std; + +static WordID kEPS = 0; + +struct RSActiveItem; +class RSChart { + public: +  RSChart(const string& goal, +               const vector<GrammarPtr>& grammars, +               const Lattice& input, +               Hypergraph* forest); +  ~RSChart(); + +  void AddToChart(const RSActiveItem& x, int i, int j); +  void ConsumeTerminal(const RSActiveItem& x, int i, int j, int k); +  void ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k); +  bool Parse(); +  inline bool GoalFound() const { return goal_idx_ >= 0; } +  inline int GetGoalIndex() const { return goal_idx_; } + + private: +  void ApplyRules(const int i, +                  const int j, +                  const RuleBin* rules, +                  const Hypergraph::TailNodeVector& tail, +                  const float lattice_cost); + +  // returns true if a new node was added to the chart +  // false otherwise +  bool ApplyRule(const int i, +                 const int j, +                 const TRulePtr& r, +                 const Hypergraph::TailNodeVector& ant_nodes, +                 const float lattice_cost); + +  void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx); +  void TopoSortUnaries(); + +  const vector<GrammarPtr>& grammars_; +  const Lattice& input_; +  Hypergraph* forest_; +  Array2D<vector<int>> chart_;   // chart_(i,j) is the list of nodes (represented +                                 // by their index in forest_->nodes_) derived spanning i,j +  typedef map<int, int> Cat2NodeMap; +  Array2D<Cat2NodeMap> nodemap_; +  const WordID goal_cat_;    // category that is being searched for at [0,n] +  TRulePtr goal_rule_; +  int goal_idx_;             // index of goal node, if found +  const int lc_fid_; +  vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars + +  static WordID kGOAL;       // [Goal] +}; + +WordID RSChart::kGOAL = 0; + +// "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span" +struct RSActiveItem { +  explicit RSActiveItem(const GrammarIter* g, int i) : +    gptr_(g), ant_nodes_(), lattice_cost(0.0), i_(i) {} +  void ExtendTerminal(int symbol, float src_cost) { +    lattice_cost += src_cost; +    if (symbol != kEPS) +      gptr_ = gptr_->Extend(symbol); +  } +  void ExtendNonTerminal(const Hypergraph* hg, int node_index) { +    gptr_ = gptr_->Extend(hg->nodes_[node_index].cat_); +    ant_nodes_.push_back(node_index); +  } +  // returns false if the extension has failed +  explicit operator bool() const { +    return gptr_; +  } +  const GrammarIter* gptr_; +  Hypergraph::TailNodeVector ant_nodes_; +  float lattice_cost;  // TODO: use SparseVector<double> to encode input features +  short i_; +}; + +// some notes on the implementation +// "X" in Rico's Algorithm 2 roughly looks like it is just a pointer into a grammar +// trie, but it is actually a full "dotted item" since it needs to contain the information +// to build the hypergraph (i.e., it must remember the antecedent nodes and where they are, +// also any information about the path costs). + +RSChart::RSChart(const string& goal, +                           const vector<GrammarPtr>& grammars, +                           const Lattice& input, +                           Hypergraph* forest) : +    grammars_(grammars), +    input_(input), +    forest_(forest), +    chart_(input.size()+1, input.size()+1), +    nodemap_(input.size()+1, input.size()+1), +    goal_cat_(TD::Convert(goal) * -1), +    goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), +    goal_idx_(-1), +    lc_fid_(FD::Convert("LatticeCost")), +    unaries_() { +  for (unsigned i = 0; i < grammars_.size(); ++i) { +    const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules(); +    for (unsigned j = 0; j < u.size(); ++j) +      unaries_.push_back(u[j]); +  } +  TopoSortUnaries(); +  if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; +  if (!SILENT) cerr << "  Goal category: [" << goal << ']' << endl; +} + +static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) { +  if (mark[node] == 1) { +    cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n"; +    return false; // cycle detected +  } else if (mark[node] == 2) { +    return true; // already been  +  } +  mark[node] = 1; +  const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node); +  if (nit != g.end()) { +    const vector<TRulePtr>& edges = nit->second; +    vector<bool> okay(edges.size(), true); +    for (unsigned i = 0; i < edges.size(); ++i) { +      okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark); +      if (!okay[i]) { +        cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl; +      } +   } +    for (unsigned i = 0; i < edges.size(); ++i) { +      if (okay[i]) u.push_back(edges[i]); +      //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl; +    } +  } +  mark[node] = 2; +  return true; +} + +void RSChart::TopoSortUnaries() { +  vector<TRulePtr> u(unaries_.size()); u.clear(); +  map<int, vector<TRulePtr> > g; +  map<int, int> mark; +  //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl; +  mark[goal_cat_] = 2; +  for (unsigned i = 0; i < unaries_.size(); ++i) { +    //cerr << "Adding: " << unaries_[i]->AsString() << endl; +    g[unaries_[i]->f()[0]].push_back(unaries_[i]); +  } +    //m[unaries_[i]->lhs_].push_back(unaries_[i]); +  for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) { +    //cerr << "PROC: " << TD::Convert(-it->first) << endl; +    if (mark[it->first] > 0) { +      //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n"; +    } else { +      TopoSortVisit(it->first, u, g, mark); +    } +  } +  unaries_.clear(); +  for (int i = u.size() - 1; i >= 0; --i) +    unaries_.push_back(u[i]); +} + +bool RSChart::ApplyRule(const int i, +                        const int j, +                        const TRulePtr& r, +                        const Hypergraph::TailNodeVector& ant_nodes, +                        const float lattice_cost) { +  Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); +  //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; +  new_edge->prev_i_ = r->prev_i; +  new_edge->prev_j_ = r->prev_j; +  new_edge->i_ = i; +  new_edge->j_ = j; +  new_edge->feature_values_ = r->GetFeatureValues(); +  if (lattice_cost && lc_fid_) +    new_edge->feature_values_.set_value(lc_fid_, lattice_cost); +  Cat2NodeMap& c2n = nodemap_(i,j); +  const bool is_goal = (r->GetLHS() == kGOAL); +  const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); +  Hypergraph::Node* node = NULL; +  bool added_node = false; +  if (ni == c2n.end()) { +    //cerr << "(" << i << "," << j << ") => " << TD::Convert(-r->GetLHS()) << endl; +    added_node = true; +    node = forest_->AddNode(r->GetLHS()); +    c2n[r->GetLHS()] = node->id_; +    if (is_goal) { +      assert(goal_idx_ == -1); +      goal_idx_ = node->id_; +    } else { +      chart_(i,j).push_back(node->id_); +    } +  } else { +    node = &forest_->nodes_[ni->second]; +  } +  forest_->ConnectEdgeToHeadNode(new_edge, node); +  return added_node; +} + +void RSChart::ApplyRules(const int i, +                         const int j, +                         const RuleBin* rules, +                         const Hypergraph::TailNodeVector& tail, +                         const float lattice_cost) { +  const int n = rules->GetNumRules(); +  //cerr << i << " " << j << ": NUM RULES: " << n << endl; +  for (int k = 0; k < n; ++k) { +    //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl; +    TRulePtr rule = rules->GetIthRule(k); +    // apply rule, and if we create a new node, apply any necessary +    // unary rules +    if (ApplyRule(i, j, rule, tail, lattice_cost)) { +      unsigned nodeidx = nodemap_(i,j)[rule->lhs_]; +      ApplyUnaryRules(i, j, rule->lhs_, nodeidx); +    } +  } +} + +void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx) { +  for (unsigned ri = 0; ri < unaries_.size(); ++ri) { +    //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; +    if (unaries_[ri]->f()[0] == cat) { +      //cerr << "  --MATCH\n"; +      WordID new_lhs = unaries_[ri]->GetLHS(); +      const Hypergraph::TailNodeVector ant(1, nodeidx); +      if (ApplyRule(i, j, unaries_[ri], ant, 0)) { +        //cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl; +        unsigned nodeidx = nodemap_(i,j)[new_lhs]; +        ApplyUnaryRules(i, j, new_lhs, nodeidx); +      } +    } +  } +} + +void RSChart::AddToChart(const RSActiveItem& x, int i, int j) { +  // deal with completed rules +  const RuleBin* rb = x.gptr_->GetRules(); +  if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_cost); + +  //cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n"; +  // continue looking for extensions of the rule to the right +  for (unsigned k = j+1; k <= input_.size(); ++k) { +    ConsumeTerminal(x, i, j, k); +    ConsumeNonTerminal(x, i, j, k); +  } +} + +void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) { +  //cerr << "ConsumeT(" << i << "," << j << "," << k << "):\n"; +   +  const unsigned check_edge_len = k - j; +  // long-term TODO preindex this search so i->len->words is constant time rather than fan out +  for (auto& in_edge : input_[j]) { +    if (in_edge.dist2next == check_edge_len) { +      //cerr << "  Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl; +      RSActiveItem copy = x; +      copy.ExtendTerminal(in_edge.label, in_edge.cost); +      if (copy) AddToChart(copy, i, k); +    } +  } +} + +void RSChart::ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k) { +  //cerr << "ConsumeNT(" << i << "," << j << "," << k << "):\n"; +  for (auto& nodeidx : chart_(j,k)) { +    //cerr << "  Found completed NT in (" << j << "," << k << ") of type " << TD::Convert(-forest_->nodes_[nodeidx].cat_) << endl; +    RSActiveItem copy = x; +    copy.ExtendNonTerminal(forest_, nodeidx); +    if (copy) AddToChart(copy, i, k); +  } +} + +bool RSChart::Parse() { +  size_t in_size_2 = input_.size() * input_.size(); +  forest_->nodes_.reserve(in_size_2 * 2); +  size_t res = min(static_cast<size_t>(2000000), static_cast<size_t>(in_size_2 * 1000)); +  forest_->edges_.reserve(res); +  goal_idx_ = -1; +  const int N = input_.size(); +  for (int i = N - 1; i >= 0; --i) { +    for (int j = i + 1; j <= N; ++j) { +      for (unsigned gi = 0; gi < grammars_.size(); ++gi) { +        RSActiveItem item(grammars_[gi]->GetRoot(), i); +        ConsumeTerminal(item, i, i, j); +      } +      for (unsigned gi = 0; gi < grammars_.size(); ++gi) { +        RSActiveItem item(grammars_[gi]->GetRoot(), i); +        ConsumeNonTerminal(item, i, i, j); +      } +    } +  } + +  // look for goal +  const vector<int>& dh = chart_(0, input_.size()); +  for (unsigned di = 0; di < dh.size(); ++di) { +    const Hypergraph::Node& node = forest_->nodes_[dh[di]]; +    if (node.cat_ == goal_cat_) { +      Hypergraph::TailNodeVector ant(1, node.id_); +      ApplyRule(0, input_.size(), goal_rule_, ant, 0); +    } +  } +  if (!SILENT) cerr << endl; + +  if (GoalFound()) +    forest_->PruneUnreachable(forest_->nodes_.size() - 1); +  return GoalFound(); +} + +RSChart::~RSChart() {} + +RSExhaustiveBottomUpParser::RSExhaustiveBottomUpParser( +    const string& goal_sym, +    const vector<GrammarPtr>& grammars) : +  goal_sym_(goal_sym), +  grammars_(grammars) {} + +bool RSExhaustiveBottomUpParser::Parse(const Lattice& input, +                                     Hypergraph* forest) const { +  kEPS = TD::Convert("*EPS*"); +  RSChart chart(goal_sym_, grammars_, input, forest); +  const bool result = chart.Parse(); + +  if (result) { +    for (auto& node : forest->nodes_) { +      Span prev; +      const Span s = forest->NodeSpan(node.id_, &prev); +      node.node_hash = cdec::HashNode(node.cat_, s.l, s.r, prev.l, prev.r); +    } +  } +  return result; +} diff --git a/decoder/bottom_up_parser-rs.h b/decoder/bottom_up_parser-rs.h new file mode 100644 index 00000000..2e271e99 --- /dev/null +++ b/decoder/bottom_up_parser-rs.h @@ -0,0 +1,29 @@ +#ifndef RSBOTTOM_UP_PARSER_H_ +#define RSBOTTOM_UP_PARSER_H_ + +#include <vector> +#include <string> + +#include "lattice.h" +#include "grammar.h" + +class Hypergraph; + +// implementation of Sennrich (2014) parser +// http://aclweb.org/anthology/W/W14/W14-4011.pdf +class RSExhaustiveBottomUpParser { + public: +  RSExhaustiveBottomUpParser(const std::string& goal_sym, +                             const std::vector<GrammarPtr>& grammars); + +  // returns true if goal reached spanning the full input +  // forest contains the full (i.e., unpruned) parse forest +  bool Parse(const Lattice& input, +             Hypergraph* forest) const; + + private: +  const std::string goal_sym_; +  const std::vector<GrammarPtr> grammars_; +}; + +#endif diff --git a/decoder/bottom_up_parser.h b/decoder/bottom_up_parser.h index 546bfb54..628bb96d 100644 --- a/decoder/bottom_up_parser.h +++ b/decoder/bottom_up_parser.h @@ -1,5 +1,5 @@ -#ifndef _BOTTOM_UP_PARSER_H_ -#define _BOTTOM_UP_PARSER_H_ +#ifndef BOTTOM_UP_PARSER_H_ +#define BOTTOM_UP_PARSER_H_  #include <vector>  #include <string> diff --git a/decoder/csplit.cc b/decoder/csplit.cc index 4a723822..7ee4092e 100644 --- a/decoder/csplit.cc +++ b/decoder/csplit.cc @@ -151,6 +151,7 @@ bool CompoundSplit::TranslateImpl(const string& input,    smeta->SetSourceLength(in.size());  // TODO do utf8 or somethign    for (int i = 0; i < in.size(); ++i)      smeta->src_lattice_.push_back(vector<LatticeArc>(1, LatticeArc(TD::Convert(in[i]), 0.0, 1))); +  smeta->ComputeInputLatticeType();    pimpl_->BuildTrellis(in, forest);    forest->Reweight(weights);    return true; diff --git a/decoder/csplit.h b/decoder/csplit.h index 82ed23fc..83d457b8 100644 --- a/decoder/csplit.h +++ b/decoder/csplit.h @@ -1,5 +1,5 @@ -#ifndef _CSPLIT_H_ -#define _CSPLIT_H_ +#ifndef CSPLIT_H_ +#define CSPLIT_H_  #include "translator.h"  #include "lattice.h" diff --git a/decoder/decoder.cc b/decoder/decoder.cc index c384c33f..1e6c3194 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -17,6 +17,7 @@ namespace std { using std::tr1::unordered_map; }  #include "fdict.h"  #include "timing_stats.h"  #include "verbose.h" +#include "b64featvector.h"  #include "translator.h"  #include "phrasebased_translator.h" @@ -86,7 +87,7 @@ struct ELengthWeightFunction {    }  };  inline void ShowBanner() { -  cerr << "cdec (c) 2009--2014 by Chris Dyer\n"; +  cerr << "cdec (c) 2009--2014 by Chris Dyer" << endl;  }  inline string str(char const* name,po::variables_map const& conf) { @@ -195,7 +196,7 @@ struct DecoderImpl {        }        forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1);        if (!forestname.empty()) forestname=" "+forestname; -      if (!SILENT) {  +      if (!SILENT) {          forest_stats(forest,"  Pruned "+forestname+" forest",false,false);          cerr << "  Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl;        } @@ -261,7 +262,7 @@ struct DecoderImpl {        assert(ref);        LatticeTools::ConvertTextOrPLF(sref, ref);      } -  }  +  }    // used to construct the suffix string to get the name of arguments for multiple passes    // e.g., the "2" in --weights2 @@ -284,7 +285,7 @@ struct DecoderImpl {    boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;    int sample_max_trans;    bool aligner_mode; -  bool graphviz;  +  bool graphviz;    bool joshua_viz;    bool encode_b64;    bool kbest; @@ -301,6 +302,7 @@ struct DecoderImpl {    bool feature_expectations; // TODO Observer    bool output_training_vector; // TODO Observer    bool remove_intersected_rule_annotations; +  bool mr_mira_compat;  // Mr.MIRA compatibility mode.    boost::scoped_ptr<IncrementalBase> incremental; @@ -404,6 +406,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream          ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")          ("extract_rules", po::value<string>(), "Extract the rules used in translation (not de-duped!) to a file in this directory")          ("show_derivations", po::value<string>(), "Directory to print the derivation structures to") +        ("show_derivations_mask", po::value<int>()->default_value(Hypergraph::SPAN|Hypergraph::RULE), "Bit-mask for what to print in derivation structures")          ("graphviz","Show (constrained) translation forest in GraphViz format")          ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")          ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart") @@ -414,7 +417,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream          ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")          ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")          ("forest_output,O",po::value<string>(),"Directory to write forests to") -        ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)"); +        ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)") +        ("mr_mira_compat", "Mr.MIRA compatibility mode (applies weight delta if available; outputs number of lines before k-best)");    // ob.AddOptions(&opts);    po::options_description clo("Command line options"); @@ -665,7 +669,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    unique_kbest = conf.count("unique_k_best");    get_oracle_forest = conf.count("get_oracle_forest");    oracle.show_derivation=conf.count("show_derivations"); +  oracle.show_derivation_mask=conf["show_derivations_mask"].as<int>();    remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations"); +  mr_mira_compat = conf.count("mr_mira_compat");    combine_size = conf["combine_size"].as<int>();    if (combine_size < 1) combine_size = 1; @@ -699,6 +705,24 @@ void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string    static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string);  } +static inline void ApplyWeightDelta(const string &delta_b64, vector<weight_t> *weights) { +  SparseVector<weight_t> delta; +  DecodeFeatureVector(delta_b64, &delta); +  if (delta.empty()) return; +  // Apply updates +  for (SparseVector<weight_t>::iterator dit = delta.begin(); +       dit != delta.end(); ++dit) { +    int feat_id = dit->first; +    union { weight_t weight; unsigned long long repr; } feat_delta; +    feat_delta.weight = dit->second; +    if (!SILENT) +      cerr << "[decoder weight update] " << FD::Convert(feat_id) << " " << feat_delta.weight +           << " = " << hex << feat_delta.repr << endl; +    if (weights->size() <= feat_id) weights->resize(feat_id + 1); +    (*weights)[feat_id] += feat_delta.weight; +  } +} +  bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {    string buf = input;    NgramCache::Clear();   // clear ngram cache for remote LM (if used) @@ -709,6 +733,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {    if (sgml.find("id") != sgml.end())      sent_id = atoi(sgml["id"].c_str()); +  // Add delta from input to weights before decoding +  if (mr_mira_compat) +    ApplyWeightDelta(sgml["delta"], init_weights.get()); +    if (!SILENT) {      cerr << "\nINPUT: ";      if (buf.size() < 100) @@ -928,14 +956,14 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {        Hypergraph new_hg;        {          ReadFile rf(writer.fname_); -        bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); +        bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg);          if (!succeeded) abort();        }        HG::Union(forest, &new_hg); -      bool succeeded = writer.Write(new_hg, false); +      bool succeeded = writer.Write(new_hg);        if (!succeeded) abort();      } else { -      bool succeeded = writer.Write(forest, false); +      bool succeeded = writer.Write(forest);        if (!succeeded) abort();      }    } @@ -947,7 +975,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {      if (kbest && !has_ref) {        //TODO: does this work properly?        const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; -      oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); +      oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname);      } else if (csplit_output_plf) {        cout << HypergraphIO::AsPLF(forest, false) << endl;      } else { @@ -1021,14 +1049,14 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {            Hypergraph new_hg;            {              ReadFile rf(writer.fname_); -            bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); +            bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg);              if (!succeeded) abort();            }            HG::Union(forest, &new_hg); -          bool succeeded = writer.Write(new_hg, false); +          bool succeeded = writer.Write(new_hg);            if (!succeeded) abort();          } else { -          bool succeeded = writer.Write(forest, false); +          bool succeeded = writer.Write(forest);            if (!succeeded) abort();          }        } @@ -1078,7 +1106,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {        if (conf.count("graphviz")) forest.PrintGraphviz();        if (kbest) {          const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; -        oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); +        oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest, mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname);        }        if (conf.count("show_conditional_prob")) {          const prob_t ref_z = Inside<prob_t, EdgeProb>(forest); @@ -1098,4 +1126,3 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {    o->NotifyDecodingComplete(smeta);    return true;  } - diff --git a/decoder/decoder.h b/decoder/decoder.h index 8039a42b..a545206b 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -1,5 +1,5 @@ -#ifndef _DECODER_H_ -#define _DECODER_H_ +#ifndef DECODER_H_ +#define DECODER_H_  #include <iostream>  #include <string> diff --git a/decoder/earley_composer.h b/decoder/earley_composer.h index 9f786bf6..31602f67 100644 --- a/decoder/earley_composer.h +++ b/decoder/earley_composer.h @@ -1,5 +1,5 @@ -#ifndef _EARLEY_COMPOSER_H_ -#define _EARLEY_COMPOSER_H_ +#ifndef EARLEY_COMPOSER_H_ +#define EARLEY_COMPOSER_H_  #include <iostream> diff --git a/decoder/factored_lexicon_helper.h b/decoder/factored_lexicon_helper.h index 7fedc517..460bdebb 100644 --- a/decoder/factored_lexicon_helper.h +++ b/decoder/factored_lexicon_helper.h @@ -1,5 +1,5 @@ -#ifndef _FACTORED_LEXICON_HELPER_ -#define _FACTORED_LEXICON_HELPER_ +#ifndef FACTORED_LEXICON_HELPER_ +#define FACTORED_LEXICON_HELPER_  #include <cassert>  #include <vector> diff --git a/decoder/ff.h b/decoder/ff.h index 647b4834..d6487d97 100644 --- a/decoder/ff.h +++ b/decoder/ff.h @@ -1,5 +1,5 @@ -#ifndef _FF_H_ -#define _FF_H_ +#ifndef FF_H_ +#define FF_H_  #include <string>  #include <vector> diff --git a/decoder/ff_basic.cc b/decoder/ff_basic.cc index f9404d24..f960418a 100644 --- a/decoder/ff_basic.cc +++ b/decoder/ff_basic.cc @@ -49,9 +49,7 @@ void SourceWordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta,    features->set_value(fid_, edge.rule_->FWords() * value_);  } - -ArityPenalty::ArityPenalty(const std::string& param) : -    value_(-1.0 / log(10)) { +ArityPenalty::ArityPenalty(const std::string& param) {    string fname = "Arity_";    unsigned MAX=DEFAULT_MAX_ARITY;    using namespace boost; @@ -61,7 +59,8 @@ ArityPenalty::ArityPenalty(const std::string& param) :      WordID fid=FD::Convert(fname+lexical_cast<string>(i));      fids_.push_back(fid);    } -  while (!fids_.empty() && fids_.back()==0) fids_.pop_back(); // pretty up features vector in case FD was frozen.  doesn't change anything +  // pretty up features vector in case FD was frozen.  doesn't change anything +  while (!fids_.empty() && fids_.back()==0) fids_.pop_back();  }  void ArityPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, @@ -75,6 +74,6 @@ void ArityPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta,    (void) state;    (void) estimated_features;    unsigned a=edge.Arity(); -  features->set_value(a<fids_.size()?fids_[a]:0, value_); +  if (a < fids_.size()) features->set_value(fids_[a], 1.0);  } diff --git a/decoder/ff_basic.h b/decoder/ff_basic.h index 901c0110..c63daf0f 100644 --- a/decoder/ff_basic.h +++ b/decoder/ff_basic.h @@ -1,5 +1,5 @@ -#ifndef _FF_BASIC_H_ -#define _FF_BASIC_H_ +#ifndef FF_BASIC_H_ +#define FF_BASIC_H_  #include "ff.h" @@ -41,7 +41,7 @@ class SourceWordPenalty : public FeatureFunction {    const double value_;  }; -#define DEFAULT_MAX_ARITY 9 +#define DEFAULT_MAX_ARITY 50  #define DEFAULT_MAX_ARITY_STRINGIZE(x) #x  #define DEFAULT_MAX_ARITY_STRINGIZE_EVAL(x) DEFAULT_MAX_ARITY_STRINGIZE(x)  #define DEFAULT_MAX_ARITY_STR DEFAULT_MAX_ARITY_STRINGIZE_EVAL(DEFAULT_MAX_ARITY) @@ -62,7 +62,6 @@ class ArityPenalty : public FeatureFunction {                                       void* context) const;   private:    std::vector<WordID> fids_; -  const double value_;  };  #endif diff --git a/decoder/ff_bleu.h b/decoder/ff_bleu.h index 344dc788..8ca2c095 100644 --- a/decoder/ff_bleu.h +++ b/decoder/ff_bleu.h @@ -1,5 +1,5 @@ -#ifndef _BLEU_FF_H_ -#define _BLEU_FF_H_ +#ifndef BLEU_FF_H_ +#define BLEU_FF_H_  #include <vector>  #include <string> diff --git a/decoder/ff_charset.h b/decoder/ff_charset.h index 267ef65d..e22ece2b 100644 --- a/decoder/ff_charset.h +++ b/decoder/ff_charset.h @@ -1,5 +1,5 @@ -#ifndef _FFCHARSET_H_ -#define _FFCHARSET_H_ +#ifndef FFCHARSET_H_ +#define FFCHARSET_H_  #include <string>  #include <map> diff --git a/decoder/ff_conll.cc b/decoder/ff_conll.cc new file mode 100644 index 00000000..8ded44b7 --- /dev/null +++ b/decoder/ff_conll.cc @@ -0,0 +1,250 @@ +#include "ff_conll.h" + +#include <stdlib.h> +#include <sstream> +#include <cassert> +#include <cmath> +#include <boost/lexical_cast.hpp> + +#include "hg.h" +#include "filelib.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "lattice.h" +#include "fdict.h" +#include "verbose.h" +#include "tdict.h" + +CoNLLFeatures::CoNLLFeatures(const string& param) { +  //  cerr << "initializing CoNLLFeatures with parameters: " << param; +  kSOS = TD::Convert("<s>"); +  kEOS = TD::Convert("</s>"); +  macro_regex = sregex::compile("%([xy])\\[(-[1-9][0-9]*|0|[1-9][1-9]*)]"); +  ParseArgs(param); +} + +string CoNLLFeatures::Escape(const string& x) const { +  string y = x; +  for (int i = 0; i < y.size(); ++i) { +    if (y[i] == '=') y[i]='_'; +    if (y[i] == ';') y[i]='_'; +  } +  return y; +} + +// replace %x[relative_location] or %y[relative_location] with actual_token +// within feature_instance +void CoNLLFeatures::ReplaceMacroWithString( +  string& feature_instance, bool token_vs_label, int relative_location,  +  const string& actual_token) const { + +  stringstream macro; +  if (token_vs_label) { +    macro << "%x["; +  } else { +    macro << "%y["; +  } +  macro << relative_location << "]"; +  int macro_index = feature_instance.find(macro.str()); +  if (macro_index == string::npos) { +    cerr << "Can't find macro " << macro.str() << " in feature template "  +	 << feature_instance; +    abort(); +  } +  feature_instance.replace(macro_index, macro.str().size(), actual_token); +} + +void CoNLLFeatures::ReplaceTokenMacroWithString( +  string& feature_instance, int relative_location,  +  const string& actual_token) const { + +  ReplaceMacroWithString(feature_instance, true, relative_location,  +				actual_token); +} + +void CoNLLFeatures::ReplaceLabelMacroWithString( +  string& feature_instance, int relative_location,  +  const string& actual_token) const { + +  ReplaceMacroWithString(feature_instance, false, relative_location,  +				actual_token); +} + +void CoNLLFeatures::Error(const string& error_message) const { +  cerr << "Error: " << error_message << "\n\n" + +       << "CoNLLFeatures Usage: \n"			      +       << "  feature_function=CoNLLFeatures -t <TEMPLATE>\n\n"  + +       << "Example <TEMPLATE>: U1:%x[-1]_%x[0]|%y[0]\n\n"  + +       << "%x[k] and %y[k] are macros to be instantiated with an input\n"  +       << "token (for x) or a label (for y). k specifies the relative\n"  +       << "location of the input token or label with respect to the current\n" +       << "position. For x, k is an integer value. For y, k must be 0 (to\n"  +       << "be extended).\n\n"; + +  abort(); +} + +void CoNLLFeatures::ParseArgs(const string& in) { +  which_feat = 0; +  vector<string> const& argv = SplitOnWhitespace(in); +  for (vector<string>::const_iterator i = argv.begin(); i != argv.end(); ++i) { +    string const& s = *i; +    if (s[0] == '-') { +      if (s.size() > 2) {  +	stringstream msg; +	msg << s << " is an invalid option for CoNLLFeatures."; +	Error(msg.str()); +      } + +      switch (s[1]) { + +      case 'w': { +        if (++i == argv.end()) { +          Error("Missing parameter to -w"); +        } +        which_feat = boost::lexical_cast<unsigned>(*i); +        break; +      } +      // feature template +      case 't': { +	if (++i == argv.end()) {  +	  Error("Can't find template.");  +	} +	feature_template = *i; +	string::const_iterator start = feature_template.begin(); +	string::const_iterator end = feature_template.end(); +	smatch macro_match; + +        // parse the template +	while (regex_search(start, end, macro_match, macro_regex)) { +	  // get the relative location +	  string relative_location_str(macro_match[2].first,  +				       macro_match[2].second); +	  int relative_location = atoi(relative_location_str.c_str()); +	  // add it to the list of relative locations for token or label +	  // (i.e. x or y) +	  bool valid_location = true; +          if (*macro_match[1].first == 'x') { +	    // add it to token locations +	    token_relative_locations.push_back(relative_location); +	  } else { +	    if (relative_location != 0) { valid_location = false; } +	    // add it to label locations +	    label_relative_locations.push_back(relative_location); +	  } +	  if (!valid_location) { +	    stringstream msg; +	    msg << "Relative location " << relative_location +		<< " in feature template " << feature_template +		<< " is invalid."; +	    Error(msg.str()); +	  } +	  start = macro_match[0].second; +	} +	break; +      } + +      // TODO: arguments to specify kSOS and kEOS + +      default: { +	stringstream msg; +	msg << "Invalid option on CoNLLFeatures: " << s; +	Error(msg.str()); +	break; +      } +      } // end of switch +    } // end of if (token starts with hyphen) +  } // end of for loop (over arguments) + +  // the -t (i.e. template) option is mandatory in this feature function +  if (label_relative_locations.size() == 0 ||  +      token_relative_locations.size() == 0) { +    stringstream msg; +    msg << "Feature template must specify at least one"  +	 << "token macro (e.g. x[-1]) and one label macro (e.g. y[0])."; +    Error(msg.str()); +  } +} + +void CoNLLFeatures::PrepareForInput(const SentenceMetadata& smeta) { +  const Lattice& sl = smeta.GetSourceLattice(); +  current_input.resize(sl.size()); +  for (unsigned i = 0; i < sl.size(); ++i) { +    if (sl[i].size() != 1) { +      stringstream msg; +      msg << "CoNLLFeatures don't support lattice inputs!\nid="  +	   << smeta.GetSentenceId() << endl; +      Error(msg.str()); +    } +    current_input[i] = sl[i][0].label; +  } +  vector<WordID> wids; +  string fn = "feat"; +  fn += boost::lexical_cast<string>(which_feat); +  string feats = smeta.GetSGMLValue(fn); +  if (feats.size() == 0) { +    Error("Can't find " + fn + " in <seg>\n"); +  } +  TD::ConvertSentence(feats, &wids); +  assert(current_input.size() == wids.size()); +  current_input = wids; +} + +void CoNLLFeatures::TraversalFeaturesImpl( +  const SentenceMetadata& smeta, const Hypergraph::Edge& edge, +  const vector<const void*>& ant_contexts, SparseVector<double>* features, +  SparseVector<double>* estimated_features, void* context) const { + +  const TRule& rule = *edge.rule_; +  // arity = 0, no nonterminals +  // size = 1, predicted label is a single token +  if (rule.Arity() != 0 || +      rule.e_.size() != 1) {  +    return;  +  } + +  // replace label macros with actual label strings +  // NOTE: currently, this feature function doesn't allow any label +  // macros except %y[0]. but you can look at as much of the source as you want +  const WordID y0 = rule.e_[0]; +  string y0_str = TD::Convert(y0); + +  // start of the span in the input being labeled +  const int from_src_index = edge.i_;    +  // end of the span in the input +  const int to_src_index = edge.j_; + +  // in the case of tagging the size of the spans being labeled will +  //  always be 1, but in other formalisms, you can have bigger spans +  if (to_src_index - from_src_index != 1) { +    cerr << "CoNLLFeatures doesn't support input spans of length != 1"; +    abort(); +  } + +  string feature_instance = feature_template; +  // replace token macros with actual token strings +  for (unsigned i = 0; i < token_relative_locations.size(); ++i) { +    int loc = token_relative_locations[i]; +    WordID x = loc < 0? kSOS: kEOS; +    if(from_src_index + loc >= 0 &&  +       from_src_index + loc < current_input.size()) { +      x = current_input[from_src_index + loc]; +    } +    string x_str = TD::Convert(x); +    ReplaceTokenMacroWithString(feature_instance, loc, x_str); +  } +     +  ReplaceLabelMacroWithString(feature_instance, 0, y0_str); + +  // pick a real value for this feature +  double fval = 1.0;  + +  // add it to the feature vector +  //   FD::Convert converts the feature string to a feature int +  //   Escape makes sure the feature string doesn't have any bad +  //     symbols that could confuse a parser somewhere +  features->add_value(FD::Convert(Escape(feature_instance)), fval); +} diff --git a/decoder/ff_conll.h b/decoder/ff_conll.h new file mode 100644 index 00000000..b37356d8 --- /dev/null +++ b/decoder/ff_conll.h @@ -0,0 +1,45 @@ +#ifndef FF_CONLL_H_ +#define FF_CONLL_H_ + +#include <vector> +#include <boost/xpressive/xpressive.hpp> +#include "ff.h" + +using namespace boost::xpressive; +using namespace std; + +class CoNLLFeatures : public FeatureFunction { + public: +  CoNLLFeatures(const string& param); + protected: +  virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, +                                     const HG::Edge& edge, +                                     const vector<const void*>& ant_contexts, +                                     SparseVector<double>* features, +                                     SparseVector<double>* estimated_features, +                                     void* context) const; +  virtual void PrepareForInput(const SentenceMetadata& smeta); +  virtual void ParseArgs(const string& in); +  virtual string Escape(const string& x) const; +  virtual void ReplaceMacroWithString(string& feature_instance,  +				      bool token_vs_label,  +				      int relative_location,  +				      const string& actual_token) const; +  virtual void ReplaceTokenMacroWithString(string& feature_instance,  +					   int relative_location,  +					   const string& actual_token) const; +  virtual void ReplaceLabelMacroWithString(string& feature_instance,  +					   int relative_location,  +					   const string& actual_token) const; +  virtual void Error(const string&) const; + + private: +  vector<int> token_relative_locations, label_relative_locations; +  string feature_template; +  vector<WordID> current_input; +  WordID kSOS, kEOS; +  sregex macro_regex; +  unsigned which_feat; +}; + +#endif diff --git a/decoder/ff_context.h b/decoder/ff_context.h index 19198ec3..ed1aea2b 100644 --- a/decoder/ff_context.h +++ b/decoder/ff_context.h @@ -1,6 +1,5 @@ - -#ifndef _FF_CONTEXT_H_ -#define _FF_CONTEXT_H_ +#ifndef FF_CONTEXT_H_ +#define FF_CONTEXT_H_  #include <vector>  #include <boost/xpressive/xpressive.hpp> diff --git a/decoder/ff_csplit.h b/decoder/ff_csplit.h index 79bf2886..227f2a14 100644 --- a/decoder/ff_csplit.h +++ b/decoder/ff_csplit.h @@ -1,5 +1,5 @@ -#ifndef _FF_CSPLIT_H_ -#define _FF_CSPLIT_H_ +#ifndef FF_CSPLIT_H_ +#define FF_CSPLIT_H_  #include <boost/shared_ptr.hpp> diff --git a/decoder/ff_external.h b/decoder/ff_external.h index 3e2bee51..fd12a37c 100644 --- a/decoder/ff_external.h +++ b/decoder/ff_external.h @@ -1,5 +1,5 @@ -#ifndef _FFEXTERNAL_H_ -#define _FFEXTERNAL_H_ +#ifndef FFEXTERNAL_H_ +#define FFEXTERNAL_H_  #include "ff.h" diff --git a/decoder/ff_factory.h b/decoder/ff_factory.h index 1aa8e55f..ba9be9ac 100644 --- a/decoder/ff_factory.h +++ b/decoder/ff_factory.h @@ -1,5 +1,5 @@ -#ifndef _FF_FACTORY_H_ -#define _FF_FACTORY_H_ +#ifndef FF_FACTORY_H_ +#define FF_FACTORY_H_  //TODO: use http://www.boost.org/doc/libs/1_43_0/libs/functional/factory/doc/html/index.html ? diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index db4032f7..c8350623 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -1,5 +1,5 @@ -#ifndef _KLM_FF_H_ -#define _KLM_FF_H_ +#ifndef KLM_FF_H_ +#define KLM_FF_H_  #include <vector>  #include <string> diff --git a/decoder/ff_lm.h b/decoder/ff_lm.h index 85e79704..83a2e186 100644 --- a/decoder/ff_lm.h +++ b/decoder/ff_lm.h @@ -1,5 +1,5 @@ -#ifndef _LM_FF_H_ -#define _LM_FF_H_ +#ifndef LM_FF_H_ +#define LM_FF_H_  #include <vector>  #include <string> diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h index 4965d235..5dea9a7d 100644 --- a/decoder/ff_ngrams.h +++ b/decoder/ff_ngrams.h @@ -1,5 +1,5 @@ -#ifndef _NGRAMS_FF_H_ -#define _NGRAMS_FF_H_ +#ifndef NGRAMS_FF_H_ +#define NGRAMS_FF_H_  #include <vector>  #include <map> diff --git a/decoder/ff_parse_match.h b/decoder/ff_parse_match.h index 7820b418..188c406a 100644 --- a/decoder/ff_parse_match.h +++ b/decoder/ff_parse_match.h @@ -1,5 +1,5 @@ -#ifndef _FF_PARSE_MATCH_H_ -#define _FF_PARSE_MATCH_H_ +#ifndef FF_PARSE_MATCH_H_ +#define FF_PARSE_MATCH_H_  #include "ff.h"  #include "hg.h" diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h index f210dc65..5c4cf45e 100644 --- a/decoder/ff_rules.h +++ b/decoder/ff_rules.h @@ -1,5 +1,5 @@ -#ifndef _FF_RULES_H_ -#define _FF_RULES_H_ +#ifndef FF_RULES_H_ +#define FF_RULES_H_  #include <vector>  #include <map> diff --git a/decoder/ff_ruleshape.h b/decoder/ff_ruleshape.h index 488cfd84..66914f5d 100644 --- a/decoder/ff_ruleshape.h +++ b/decoder/ff_ruleshape.h @@ -1,5 +1,5 @@ -#ifndef _FF_RULESHAPE_H_ -#define _FF_RULESHAPE_H_ +#ifndef FF_RULESHAPE_H_ +#define FF_RULESHAPE_H_  #include <vector>  #include <map> diff --git a/decoder/ff_soft_syntax.h b/decoder/ff_soft_syntax.h index e71825d5..da51df7f 100644 --- a/decoder/ff_soft_syntax.h +++ b/decoder/ff_soft_syntax.h @@ -1,5 +1,5 @@ -#ifndef _FF_SOFT_SYNTAX_H_ -#define _FF_SOFT_SYNTAX_H_ +#ifndef FF_SOFT_SYNTAX_H_ +#define FF_SOFT_SYNTAX_H_  #include "ff.h"  #include "hg.h" diff --git a/decoder/ff_soft_syntax_mindist.h b/decoder/ff_soft_syntax_mindist.h index bf938b38..205eff4b 100644 --- a/decoder/ff_soft_syntax_mindist.h +++ b/decoder/ff_soft_syntax_mindist.h @@ -1,5 +1,5 @@ -#ifndef _FF_SOFT_SYNTAX_MINDIST_H_ -#define _FF_SOFT_SYNTAX_MINDIST_H_ +#ifndef FF_SOFT_SYNTAX_MINDIST_H_ +#define FF_SOFT_SYNTAX_MINDIST_H_  #include "ff.h"  #include "hg.h" diff --git a/decoder/ff_source_path.h b/decoder/ff_source_path.h index 03126412..fc309264 100644 --- a/decoder/ff_source_path.h +++ b/decoder/ff_source_path.h @@ -1,5 +1,5 @@ -#ifndef _FF_SOURCE_PATH_H_ -#define _FF_SOURCE_PATH_H_ +#ifndef FF_SOURCE_PATH_H_ +#define FF_SOURCE_PATH_H_  #include <vector>  #include <map> diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h index bdd638c1..6316e881 100644 --- a/decoder/ff_source_syntax.h +++ b/decoder/ff_source_syntax.h @@ -1,5 +1,5 @@ -#ifndef _FF_SOURCE_SYNTAX_H_ -#define _FF_SOURCE_SYNTAX_H_ +#ifndef FF_SOURCE_SYNTAX_H_ +#define FF_SOURCE_SYNTAX_H_  #include "ff.h"  #include "hg.h" diff --git a/decoder/ff_source_syntax2.h b/decoder/ff_source_syntax2.h index f606c2bf..bbfa9eb6 100644 --- a/decoder/ff_source_syntax2.h +++ b/decoder/ff_source_syntax2.h @@ -1,5 +1,5 @@ -#ifndef _FF_SOURCE_SYNTAX2_H_ -#define _FF_SOURCE_SYNTAX2_H_ +#ifndef FF_SOURCE_SYNTAX2_H_ +#define FF_SOURCE_SYNTAX2_H_  #include "ff.h"  #include "hg.h" diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h index d2f5e84c..e2475491 100644 --- a/decoder/ff_spans.h +++ b/decoder/ff_spans.h @@ -1,5 +1,5 @@ -#ifndef _FF_SPANS_H_ -#define _FF_SPANS_H_ +#ifndef FF_SPANS_H_ +#define FF_SPANS_H_  #include <vector>  #include <map> diff --git a/decoder/ff_tagger.h b/decoder/ff_tagger.h index 46418b0c..0cb8c648 100644 --- a/decoder/ff_tagger.h +++ b/decoder/ff_tagger.h @@ -1,5 +1,5 @@ -#ifndef _FF_TAGGER_H_ -#define _FF_TAGGER_H_ +#ifndef FF_TAGGER_H_ +#define FF_TAGGER_H_  #include <map>  #include <boost/scoped_ptr.hpp> diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h index 0161f603..ec454621 100644 --- a/decoder/ff_wordalign.h +++ b/decoder/ff_wordalign.h @@ -1,5 +1,5 @@ -#ifndef _FF_WORD_ALIGN_H_ -#define _FF_WORD_ALIGN_H_ +#ifndef FF_WORD_ALIGN_H_ +#define FF_WORD_ALIGN_H_  #include "ff.h"  #include "array2d.h" diff --git a/decoder/ff_wordset.h b/decoder/ff_wordset.h index e78cd2fb..94f5ff8a 100644 --- a/decoder/ff_wordset.h +++ b/decoder/ff_wordset.h @@ -1,5 +1,5 @@ -#ifndef _FF_WORDSET_H_ -#define _FF_WORDSET_H_ +#ifndef FF_WORDSET_H_ +#define FF_WORDSET_H_  #include "ff.h"  #include "tdict.h" diff --git a/decoder/ffset.h b/decoder/ffset.h index a69a75fa..84f9fdb9 100644 --- a/decoder/ffset.h +++ b/decoder/ffset.h @@ -1,5 +1,5 @@ -#ifndef _FFSET_H_ -#define _FFSET_H_ +#ifndef FFSET_H_ +#define FFSET_H_  #include <utility>  #include <vector> diff --git a/decoder/forest_writer.cc b/decoder/forest_writer.cc index 6e4cccb3..cc9094d7 100644 --- a/decoder/forest_writer.cc +++ b/decoder/forest_writer.cc @@ -11,13 +11,13 @@  using namespace std;  ForestWriter::ForestWriter(const std::string& path, int num) : -  fname_(path + '/' + boost::lexical_cast<string>(num) + ".json.gz"), used_(false) {} +  fname_(path + '/' + boost::lexical_cast<string>(num) + ".bin.gz"), used_(false) {} -bool ForestWriter::Write(const Hypergraph& forest, bool minimal_rules) { +bool ForestWriter::Write(const Hypergraph& forest) {    assert(!used_);    used_ = true;    cerr << "  Writing forest to " << fname_ << endl;    WriteFile wf(fname_); -  return HypergraphIO::WriteToJSON(forest, minimal_rules, wf.stream()); +  return HypergraphIO::WriteToBinary(forest, wf.stream());  } diff --git a/decoder/forest_writer.h b/decoder/forest_writer.h index 819a8940..54e83470 100644 --- a/decoder/forest_writer.h +++ b/decoder/forest_writer.h @@ -1,5 +1,5 @@ -#ifndef _FOREST_WRITER_H_ -#define _FOREST_WRITER_H_ +#ifndef FOREST_WRITER_H_ +#define FOREST_WRITER_H_  #include <string> @@ -7,7 +7,7 @@ class Hypergraph;  struct ForestWriter {    ForestWriter(const std::string& path, int num); -  bool Write(const Hypergraph& forest, bool minimal_rules); +  bool Write(const Hypergraph& forest);    const std::string fname_;    bool used_; diff --git a/decoder/freqdict.h b/decoder/freqdict.h index 4e03fadd..07d797e2 100644 --- a/decoder/freqdict.h +++ b/decoder/freqdict.h @@ -1,5 +1,5 @@ -#ifndef _FREQDICT_H_ -#define _FREQDICT_H_ +#ifndef FREQDICT_H_ +#define FREQDICT_H_  #include <iostream>  #include <map> diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc index 4253b652..fe28f4c6 100644 --- a/decoder/fst_translator.cc +++ b/decoder/fst_translator.cc @@ -27,11 +27,15 @@ struct FSTTranslatorImpl {                   const vector<double>& weights,                   Hypergraph* forest) {      bool composed = false; -    if (input.find("{\"rules\"") == 0) { +    if (input.find("::forest::") == 0) {        istringstream is(input); +      string header, fname; +      is >> header >> fname; +      ReadFile rf(fname); +      if (!rf) { cerr << "Failed to open " << fname << endl; abort(); }        Hypergraph src_cfg_hg; -      if (!HypergraphIO::ReadFromJSON(&is, &src_cfg_hg)) { -        cerr << "Failed to read HG from JSON.\n"; +      if (!HypergraphIO::ReadFromBinary(rf.stream(), &src_cfg_hg)) { +        cerr << "Failed to read HG.\n";          abort();        }        if (add_pass_through_rules) { @@ -95,6 +99,7 @@ bool FSTTranslator::TranslateImpl(const string& input,                                const vector<double>& weights,                                Hypergraph* minus_lm_forest) {    smeta->SetSourceLength(0);  // don't know how to compute this +  smeta->input_type_ = cdec::kFOREST;    return pimpl_->Translate(input, weights, minus_lm_forest);  } diff --git a/decoder/hg.h b/decoder/hg.h index 4ed27d87..c756012e 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -1,5 +1,5 @@ -#ifndef _HG_H_ -#define _HG_H_ +#ifndef HG_H_ +#define HG_H_  // define USE_INFO_EDGE 1 if you want lots of debug info shown with --show_derivations - otherwise it adds quite a bit of overhead if ffs have their logging enabled (e.g. ff_from_fsa)  #ifndef USE_INFO_EDGE @@ -18,6 +18,7 @@  #include <string>  #include <vector>  #include <boost/shared_ptr.hpp> +#include <boost/serialization/vector.hpp>  #include "feature_vector.h"  #include "small_vector.h" @@ -69,6 +70,18 @@ namespace HG {      short int j_;      short int prev_i_;      short int prev_j_; +    template<class Archive> +    void serialize(Archive & ar, const unsigned int /*version*/) { +      ar & head_node_; +      ar & tail_nodes_; +      ar & rule_; +      ar & feature_values_; +      ar & i_; +      ar & j_; +      ar & prev_i_; +      ar & prev_j_; +      ar & id_; +    }      void show(std::ostream &o,unsigned mask=SPAN|RULE) const {        o<<'{';        if (mask&CATEGORY) @@ -149,6 +162,24 @@ namespace HG {      WordID NT() const { return -cat_; }      EdgesVector in_edges_;   // an in edge is an edge with this node as its head.  (in edges come from the bottom up to us)  indices in edges_      EdgesVector out_edges_;  // an out edge is an edge with this node as its tail.  (out edges leave us up toward the top/goal). indices in edges_ +    template<class Archive> +    void save(Archive & ar, const unsigned int /*version*/) const { +      ar & node_hash; +      ar & id_; +      ar & TD::Convert(-cat_); +      ar & in_edges_; +      ar & out_edges_; +    } +    template<class Archive> +    void load(Archive & ar, const unsigned int /*version*/) { +      ar & node_hash; +      ar & id_; +      std::string cat; ar & cat; +      cat_ = -TD::Convert(cat); +      ar & in_edges_; +      ar & out_edges_; +    } +    BOOST_SERIALIZATION_SPLIT_MEMBER()      void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting        node_hash = o.node_hash;        cat_=o.cat_; @@ -492,6 +523,27 @@ public:    void set_ids(); // resync edge,node .id_    void check_ids() const; // assert that .id_ have been kept in sync +  template<class Archive> +  void save(Archive & ar, const unsigned int /*version*/) const { +    unsigned ns = nodes_.size(); ar & ns; +    unsigned es = edges_.size(); ar & es; +    for (auto& n : nodes_) ar & n; +    for (auto& e : edges_) ar & e; +    int x; +    x = edges_topo_; ar & x; +    x = is_linear_chain_; ar & x; +  } +  template<class Archive> +  void load(Archive & ar, const unsigned int /*version*/) { +    unsigned ns; ar & ns; nodes_.resize(ns); +    unsigned es; ar & es; edges_.resize(es); +    for (auto& n : nodes_) ar & n; +    for (auto& e : edges_) ar & e; +    int x; +    ar & x; edges_topo_ = x; +    ar & x; is_linear_chain_ = x; +  } +  BOOST_SERIALIZATION_SPLIT_MEMBER()  private:    Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {}  }; diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc index 02f5a401..b9381d02 100644 --- a/decoder/hg_intersect.cc +++ b/decoder/hg_intersect.cc @@ -88,7 +88,7 @@ namespace HG {  bool Intersect(const Lattice& target, Hypergraph* hg) {    // there are a number of faster algorithms available for restricted    // classes of hypergraph and/or target. -  if (hg->IsLinearChain() && target.IsSentence()) +  if (hg->IsLinearChain() && IsSentence(target))      return FastLinearIntersect(target, hg);    vector<bool> rem(hg->edges_.size(), false); diff --git a/decoder/hg_intersect.h b/decoder/hg_intersect.h index 29a5ea2a..19c1c177 100644 --- a/decoder/hg_intersect.h +++ b/decoder/hg_intersect.h @@ -1,5 +1,5 @@ -#ifndef _HG_INTERSECT_H_ -#define _HG_INTERSECT_H_ +#ifndef HG_INTERSECT_H_ +#define HG_INTERSECT_H_  #include "lattice.h" diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index eb0be3d4..626b2954 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -6,362 +6,27 @@  #include <sstream>  #include <iostream> +#include <boost/archive/binary_iarchive.hpp> +#include <boost/archive/binary_oarchive.hpp> +#include <boost/serialization/shared_ptr.hpp> +  #include "fast_lexical_cast.hpp"  #include "tdict.h" -#include "json_parse.h"  #include "hg.h"  using namespace std; -struct HGReader : public JSONParser { -  HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; } - -  void CreateNode(const string& cat, const string& shash, const vector<int>& in_edges) { -    WordID c = TD::Convert("X") * -1; -    if (!cat.empty()) c = TD::Convert(cat) * -1; -    Hypergraph::Node* node = hg.AddNode(c); -    char* dend; -    if (shash.size()) -      node->node_hash = strtoull(shash.c_str(), &dend, 16); -    else -      node->node_hash = 0; -    for (int i = 0; i < in_edges.size(); ++i) { -      if (in_edges[i] >= hg.edges_.size()) { -        cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i] -             << ", but hg only has " << hg.edges_.size() << " edges!\n"; -        abort(); -      } -      hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node); -    } -  } -  void CreateEdge(const TRulePtr& rule, SparseVector<double>* feats, const SmallVectorUnsigned& tail) { -    Hypergraph::Edge* edge = hg.AddEdge(rule, tail); -    feats->swap(edge->feature_values_); -    edge->i_ = spans[0]; -    edge->j_ = spans[1]; -    edge->prev_i_ = spans[2]; -    edge->prev_j_ = spans[3]; -  } - -  bool HandleJSONEvent(int type, const JSON_value* value) { -    switch(state) { -    case -1: -      assert(type == JSON_T_OBJECT_BEGIN); -      state = 0; -      break; -    case 0: -      if (type == JSON_T_OBJECT_END) { -        //cerr << "HG created\n";  // TODO, signal some kind of callback -      } else if (type == JSON_T_KEY) { -        string val = value->vu.str.value; -        if (val == "features") { assert(fdict.empty()); state = 1; } -        else if (val == "is_sorted") { state = 3; } -        else if (val == "rules") { assert(rules.empty()); state = 4; } -        else if (val == "node") { state = 8; } -        else if (val == "edges") { state = 13; } -        else { cerr << "Unexpected key: " << val << endl; return false; } -      } -      break; - -    // features -    case 1: -      if(type == JSON_T_NULL) { state = 0; break; } -      assert(type == JSON_T_ARRAY_BEGIN); -      state = 2; -      break; -    case 2: -      if(type == JSON_T_ARRAY_END) { state = 0; break; } -      assert(type == JSON_T_STRING); -      fdict.push_back(FD::Convert(value->vu.str.value)); -      assert(fdict.back() > 0); -      break; - -    // is_sorted -    case 3: -      assert(type == JSON_T_TRUE || type == JSON_T_FALSE); -      is_sorted = (type == JSON_T_TRUE); -      if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; } -      state = 0; -      break; - -    // rules -    case 4: -      if(type == JSON_T_NULL) { state = 0; break; } -      assert(type == JSON_T_ARRAY_BEGIN); -      state = 5; -      break; -    case 5: -      if(type == JSON_T_ARRAY_END) { state = 0; break; } -      assert(type == JSON_T_INTEGER); -      state = 6; -      rule_id = value->vu.integer_value; -      break; -    case 6: -      assert(type == JSON_T_STRING); -      rules[rule_id] = TRulePtr(new TRule(value->vu.str.value)); -      state = 5; -      break; - -    // Nodes -    case 8: -      assert(type == JSON_T_OBJECT_BEGIN); -      ++nodes; -      in_edges.clear(); -      cat.clear(); -      shash.clear(); -      state = 9; break; -    case 9: -      if (type == JSON_T_OBJECT_END) { -        //cerr << "Creating NODE\n"; -        CreateNode(cat, shash, in_edges); -        state = 0; break; -      } -      assert(type == JSON_T_KEY); -      cur_key = value->vu.str.value; -      if (cur_key == "cat") { assert(cat.empty()); state = 10; break; } -      if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; } -      if (cur_key == "node_hash") { assert(shash.empty()); state = 24; break; } -      cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n"; -      return false; -    case 10: -      assert(type == JSON_T_STRING || type == JSON_T_NULL); -      cat = value->vu.str.value; -      state = 9; break; -    case 11: -      if (type == JSON_T_NULL) { state = 9; break; } -      assert(type == JSON_T_ARRAY_BEGIN); -      state = 12; break; -    case 12: -      if (type == JSON_T_ARRAY_END) { state = 9; break; } -      assert(type == JSON_T_INTEGER); -      //cerr << "in_edges: " << value->vu.integer_value << endl; -      in_edges.push_back(value->vu.integer_value); -      break; - -    //   "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12}, -    //         { "tail": null, "feats" : [0,0.87,1,0.02], "spans":[1,2,3,4], "rule": 17}, -    //         { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}] -    case 13: -      assert(type == JSON_T_ARRAY_BEGIN); -      state = 14; -      break; -    case 14: -      if (type == JSON_T_ARRAY_END) { state = 0; break; } -      assert(type == JSON_T_OBJECT_BEGIN); -      //cerr << "New edge\n"; -      ++edges; -      cur_rule.reset(); feats.clear(); tail.clear(); -      state = 15; break; -    case 15: -      if (type == JSON_T_OBJECT_END) { -        CreateEdge(cur_rule, &feats, tail); -        state = 14; break; -      } -      assert(type == JSON_T_KEY); -      cur_key = value->vu.str.value; -      //cerr << "edge key " << cur_key << endl; -      if (cur_key == "rule") { assert(!cur_rule); state = 16; break; } -      if (cur_key == "spans") { assert(!cur_rule); state = 22; break; } -      if (cur_key == "feats") { assert(feats.empty()); state = 17; break; } -      if (cur_key == "tail") { assert(tail.empty()); state = 20; break; } -      cerr << "Unexpected key " << cur_key << " in edge specification\n"; -      return false; -    case 16:    // edge.rule -      if (type == JSON_T_INTEGER) { -        int rule_id = value->vu.integer_value; -        if (rules.find(rule_id) == rules.end()) { -          // rules list must come before the edge definitions! -          cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n"; -          return false; -        } -        cur_rule = rules[rule_id]; -      } else if (type == JSON_T_STRING) { -        cur_rule.reset(new TRule(value->vu.str.value)); -      } else { -        cerr << "Rule must be either a rule id or a rule string" << endl; -        return false; -      } -      // cerr << "Edge: rule=" << cur_rule->AsString() << endl; -      state = 15; -      break; -    case 17:      // edge.feats -      if (type == JSON_T_NULL) { state = 15; break; } -      assert(type == JSON_T_ARRAY_BEGIN); -      state = 18; break; -    case 18: -      if (type == JSON_T_ARRAY_END) { state = 15; break; } -      if (type != JSON_T_INTEGER && type != JSON_T_STRING) { -        cerr << "Unexpected feature id type\n"; return false; -      } -      if (type == JSON_T_INTEGER) { -        fid = value->vu.integer_value; -        assert(fid < fdict.size()); -        fid = fdict[fid]; -      } else if (JSON_T_STRING) { -        fid = FD::Convert(value->vu.str.value); -      } else { abort(); } -      state = 19; -      break; -    case 19: -      { -        assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT); -        double val = (type == JSON_T_INTEGER ? static_cast<double>(value->vu.integer_value) : -	                                       strtod(value->vu.str.value, NULL)); -        feats.set_value(fid, val); -        state = 18; -        break; -      } -    case 20:     // edge.tail -      if (type == JSON_T_NULL) { state = 15; break; } -      assert(type == JSON_T_ARRAY_BEGIN); -      state = 21; break; -    case 21: -      if (type == JSON_T_ARRAY_END) { state = 15; break; } -      assert(type == JSON_T_INTEGER); -      tail.push_back(value->vu.integer_value); -      break; -    case 22:     // edge.spans -      assert(type == JSON_T_ARRAY_BEGIN); -      state = 23; -      spans[0] = spans[1] = spans[2] = spans[3] = -1; -      spanc = 0; -      break; -    case 23: -      if (type == JSON_T_ARRAY_END) { state = 15; break; } -      assert(type == JSON_T_INTEGER); -      assert(spanc < 4); -      spans[spanc] = value->vu.integer_value; -      ++spanc; -      break; -    case 24:  // read node hash -      assert(type == JSON_T_STRING); -      shash = value->vu.str.value; -      state = 9; -      break; -    } -    return true; -  } -  string rp; -  string cat; -  SmallVectorUnsigned tail; -  vector<int> in_edges; -  string shash; -  TRulePtr cur_rule; -  map<int, TRulePtr> rules; -  vector<int> fdict; -  SparseVector<double> feats; -  int state; -  int fid; -  int nodes; -  int edges; -  int spans[4]; -  int spanc; -  string cur_key; -  Hypergraph& hg; -  int rule_id; -  bool nodes_needed; -  bool edges_needed; -  bool is_sorted; -}; - -bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) { +bool HypergraphIO::ReadFromBinary(istream* in, Hypergraph* hg) { +  boost::archive::binary_iarchive oa(*in);    hg->clear(); -  HGReader reader(hg); -  return reader.Parse(in); -} - -static void WriteRule(const TRule& r, ostream* out) { -  if (!r.lhs_) { (*out) << "[X] ||| "; } -  JSONParser::WriteEscapedString(r.AsString(), out); +  oa >> *hg; +  return true;  } -bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) { -  if (hg.empty()) { *out << "{}\n"; return true; } -  map<const TRule*, int> rid; -  ostream& o = *out; -  rid[NULL] = 0; -  o << '{'; -  if (!remove_rules) { -    o << "\"rules\":["; -    for (int i = 0; i < hg.edges_.size(); ++i) { -      const TRule* r = hg.edges_[i].rule_.get(); -      int &id = rid[r]; -      if (!id) { -        id=rid.size() - 1; -        if (id > 1) o << ','; -        o << id << ','; -        WriteRule(*r, &o); -      }; -    } -    o << "],"; -  } -  const bool use_fdict = FD::NumFeats() < 1000; -  if (use_fdict) { -    o << "\"features\":["; -    for (int i = 1; i < FD::NumFeats(); ++i) { -      o << (i==1 ? "":","); -      JSONParser::WriteEscapedString(FD::Convert(i), &o); -    } -    o << "],"; -  } -  vector<int> edgemap(hg.edges_.size(), -1);  // edges may be in non-topo order -  int edge_count = 0; -  for (int i = 0; i < hg.nodes_.size(); ++i) { -    const Hypergraph::Node& node = hg.nodes_[i]; -    if (i > 0) { o << ","; } -    o << "\"edges\":["; -    for (int j = 0; j < node.in_edges_.size(); ++j) { -      const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; -      edgemap[edge.id_] = edge_count; -      ++edge_count; -      o << (j == 0 ? "" : ",") << "{"; - -      o << "\"tail\":["; -      for (int k = 0; k < edge.tail_nodes_.size(); ++k) { -        o << (k > 0 ? "," : "") << edge.tail_nodes_[k]; -      } -      o << "],"; - -      o << "\"spans\":[" << edge.i_ << "," << edge.j_ << "," << edge.prev_i_ << "," << edge.prev_j_ << "],"; - -      o << "\"feats\":["; -      bool first = true; -      for (SparseVector<double>::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) { -        if (!it->second) continue;   // don't write features that have a zero value -        if (!it->first) continue;    // if the feature set was frozen this might happen -        if (!first) o << ','; -        if (use_fdict) -          o << (it->first - 1); -        else { -	  JSONParser::WriteEscapedString(FD::Convert(it->first), &o); -        } -	o << ',' << it->second; -        first = false; -      } -      o << "]"; -      if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; } -      o << "}"; -    } -    o << "],"; - -    o << "\"node\":{\"in_edges\":["; -    for (int j = 0; j < node.in_edges_.size(); ++j) { -      int mapped_edge = edgemap[node.in_edges_[j]]; -      assert(mapped_edge >= 0); -      o << (j == 0 ? "" : ",") << mapped_edge; -    } -    o << "]"; -    if (node.cat_ < 0) { -       o << ",\"cat\":"; -       JSONParser::WriteEscapedString(TD::Convert(node.cat_ * -1), &o); -    } -    char buf[48]; -    sprintf(buf, "%016lX", node.node_hash); -    o << ",\"node_hash\":\"" << buf << "\""; -    o << "}"; -  } -  o << "}\n"; +bool HypergraphIO::WriteToBinary(const Hypergraph& hg, ostream* out) { +  boost::archive::binary_oarchive oa(*out); +  oa << hg;    return true;  } diff --git a/decoder/hg_io.h b/decoder/hg_io.h index 58af8132..93a9e280 100644 --- a/decoder/hg_io.h +++ b/decoder/hg_io.h @@ -1,5 +1,5 @@ -#ifndef _HG_IO_H_ -#define _HG_IO_H_ +#ifndef HG_IO_H_ +#define HG_IO_H_  #include <iostream>  #include <string> @@ -9,19 +9,11 @@ class Hypergraph;  struct HypergraphIO { -  // the format is basically a list of nodes and edges in topological order -  // any edge you read, you must have already read its tail nodes -  // any node you read, you must have already read its incoming edges -  // this may make writing a bit more challenging if your forest is not -  // topologically sorted (but that probably doesn't happen very often), -  // but it makes reading much more memory efficient. -  // see test_data/small.json.gz for an email encoding -  static bool ReadFromJSON(std::istream* in, Hypergraph* out); +  static bool ReadFromBinary(std::istream* in, Hypergraph* out); +  static bool WriteToBinary(const Hypergraph& hg, std::ostream* out);    // if remove_rules is used, the hypergraph is serialized without rule information    // (so it only contains structure and feature information) -  static bool WriteToJSON(const Hypergraph& hg, bool remove_rules, std::ostream* out); -    static void WriteAsCFG(const Hypergraph& hg);    // Write only the target size information in bottom-up order.   diff --git a/decoder/hg_remove_eps.h b/decoder/hg_remove_eps.h index 82f06039..f67fe6e2 100644 --- a/decoder/hg_remove_eps.h +++ b/decoder/hg_remove_eps.h @@ -1,5 +1,5 @@ -#ifndef _HG_REMOVE_EPS_H_ -#define _HG_REMOVE_EPS_H_ +#ifndef HG_REMOVE_EPS_H_ +#define HG_REMOVE_EPS_H_  #include "wordid.h"  class Hypergraph; diff --git a/decoder/hg_sampler.h b/decoder/hg_sampler.h index 6ac39a20..4267b5ec 100644 --- a/decoder/hg_sampler.h +++ b/decoder/hg_sampler.h @@ -1,6 +1,5 @@ -#ifndef _HG_SAMPLER_H_ -#define _HG_SAMPLER_H_ - +#ifndef HG_SAMPLER_H_ +#define HG_SAMPLER_H_  #include <vector>  #include <string> diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 5cb8626a..366b269d 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -1,10 +1,14 @@  #define BOOST_TEST_MODULE hg_test  #include <boost/test/unit_test.hpp>  #include <boost/test/floating_point_comparison.hpp> +#include <boost/archive/text_oarchive.hpp> +#include <boost/archive/text_iarchive.hpp> +#include <boost/serialization/shared_ptr.hpp> +#include <boost/serialization/vector.hpp> +#include <sstream>  #include <iostream>  #include "tdict.h" -#include "json_parse.h"  #include "hg_intersect.h"  #include "hg_union.h"  #include "viterbi.h" @@ -394,16 +398,6 @@ BOOST_AUTO_TEST_CASE(Small) {    BOOST_CHECK_CLOSE(2.1431036, log(c2), 1e-4);  } -BOOST_AUTO_TEST_CASE(JSONTest) { -  std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); -  ostringstream os; -  JSONParser::WriteEscapedString("\"I don't know\", she said.", &os); -  BOOST_CHECK_EQUAL("\"\\\"I don't know\\\", she said.\"", os.str()); -  ostringstream os2; -  JSONParser::WriteEscapedString("yes", &os2); -  BOOST_CHECK_EQUAL("\"yes\"", os2.str()); -} -  BOOST_AUTO_TEST_CASE(TestGenericKBest) {    std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);    Hypergraph hg; @@ -427,19 +421,29 @@ BOOST_AUTO_TEST_CASE(TestGenericKBest) {    }  } -BOOST_AUTO_TEST_CASE(TestReadWriteHG) { +BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) {    std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); -  Hypergraph hg,hg2; -  CreateHG(path, &hg); -  hg.edges_.front().j_ = 23; -  hg.edges_.back().prev_i_ = 99; -  ostringstream os; -  HypergraphIO::WriteToJSON(hg, false, &os); -  istringstream is(os.str()); -  HypergraphIO::ReadFromJSON(&is, &hg2); -  BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths()); -  BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23); -  BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99); +  Hypergraph hg; +  Hypergraph hg2; +  std::string out; +  { +    CreateHG(path, &hg); +    hg.edges_.front().j_ = 23; +    hg.edges_.back().prev_i_ = 99; +    ostringstream os; +    boost::archive::text_oarchive oa(os); +    oa << hg; +    out = os.str(); +  } +  { +    cerr << out << endl; +    istringstream is(out); +    boost::archive::text_iarchive ia(is); +    ia >> hg2; +    BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths()); +    BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23); +    BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99); +  }  }  BOOST_AUTO_TEST_SUITE_END() diff --git a/decoder/hg_test.h b/decoder/hg_test.h index b7bab3c2..575b9c54 100644 --- a/decoder/hg_test.h +++ b/decoder/hg_test.h @@ -12,12 +12,8 @@ namespace {  typedef char const* Name; -Name urdu_json="urdu.json.gz"; -Name urdu_wts="Arity_0 1.70741473606976 Arity_1 1.12426238048012 Arity_2 1.14986187839554 Glue -0.04589037041388 LanguageModel 1.09051 PassThrough -3.66226367902928 PhraseModel_0 -1.94633451863252 PhraseModel_1 -0.1475347695476 PhraseModel_2 -1.614818994946 WordPenalty -3.0 WordPenaltyFsa -0.56028442964748 ShorterThanPrev -10 LongerThanPrev -10"; -Name small_json="small.json.gz"; +Name small_json="small.bin.gz";  Name small_wts="Model_0 -2 Model_1 -.5 Model_2 -1.1 Model_3 -1 Model_4 -1 Model_5 .5 Model_6 .2 Model_7 -.3"; -Name perro_json="perro.json.gz"; -Name perro_wts="SameFirstLetter 1 LongerThanPrev 1 ShorterThanPrev 1 GlueTop 0.0 Glue -1.0 EgivenF -0.5 FgivenE -0.5 LexEgivenF -0.5 LexFgivenE -0.5 LM 1";  } @@ -32,7 +28,7 @@ struct HGSetup {    static void JsonFile(Hypergraph *hg,std::string f) {      ReadFile rf(f); -    HypergraphIO::ReadFromJSON(rf.stream(), hg); +    HypergraphIO::ReadFromBinary(rf.stream(), hg);    }    static void JsonTestFile(Hypergraph *hg,std::string path,std::string n) {      JsonFile(hg,path + "/"+n); @@ -48,35 +44,35 @@ void AddNullEdge(Hypergraph* hg) {  }  void HGSetup::CreateTinyLatticeHG(const std::string& path,Hypergraph* hg) { -  ReadFile rf(path + "/hg_test.tiny_lattice"); -  HypergraphIO::ReadFromJSON(rf.stream(), hg); +  ReadFile rf(path + "/hg_test.tiny_lattice.bin.gz"); +  HypergraphIO::ReadFromBinary(rf.stream(), hg);    AddNullEdge(hg);  }  void HGSetup::CreateLatticeHG(const std::string& path,Hypergraph* hg) { -  ReadFile rf(path + "/hg_test.lattice"); -  HypergraphIO::ReadFromJSON(rf.stream(), hg); +  ReadFile rf(path + "/hg_test.lattice.bin.gz"); +  HypergraphIO::ReadFromBinary(rf.stream(), hg);    AddNullEdge(hg);  }  void HGSetup::CreateHG_tiny(const std::string& path, Hypergraph* hg) { -  ReadFile rf(path + "/hg_test.tiny"); -  HypergraphIO::ReadFromJSON(rf.stream(), hg); +  ReadFile rf(path + "/hg_test.tiny.bin.gz"); +  HypergraphIO::ReadFromBinary(rf.stream(), hg);  }  void HGSetup::CreateHG_int(const std::string& path,Hypergraph* hg) { -  ReadFile rf(path + "/hg_test.hg_int"); -  HypergraphIO::ReadFromJSON(rf.stream(), hg); +  ReadFile rf(path + "/hg_test.hg_int.bin.gz"); +  HypergraphIO::ReadFromBinary(rf.stream(), hg);  }  void HGSetup::CreateHG(const std::string& path,Hypergraph* hg) { -  ReadFile rf(path + "/hg_test.hg"); -  HypergraphIO::ReadFromJSON(rf.stream(), hg); +  ReadFile rf(path + "/hg_test.hg.bin.gz"); +  HypergraphIO::ReadFromBinary(rf.stream(), hg);  }  void HGSetup::CreateHGBalanced(const std::string& path,Hypergraph* hg) { -  ReadFile rf(path + "/hg_test.hg_balanced"); -  HypergraphIO::ReadFromJSON(rf.stream(), hg); +  ReadFile rf(path + "/hg_test.hg_balanced.bin.gz"); +  HypergraphIO::ReadFromBinary(rf.stream(), hg);  }  #endif diff --git a/decoder/hg_union.h b/decoder/hg_union.h index 34624246..bb7e2d09 100644 --- a/decoder/hg_union.h +++ b/decoder/hg_union.h @@ -1,5 +1,5 @@ -#ifndef _HG_UNION_H_ -#define _HG_UNION_H_ +#ifndef HG_UNION_H_ +#define HG_UNION_H_  class Hypergraph;  namespace HG { diff --git a/decoder/incremental.h b/decoder/incremental.h index f791a626..46b4817b 100644 --- a/decoder/incremental.h +++ b/decoder/incremental.h @@ -1,5 +1,5 @@ -#ifndef _INCREMENTAL_H_ -#define _INCREMENTAL_H_ +#ifndef INCREMENTAL_H_ +#define INCREMENTAL_H_  #include "weights.h"  #include <vector> diff --git a/decoder/inside_outside.h b/decoder/inside_outside.h index c0377fe8..d5bda63c 100644 --- a/decoder/inside_outside.h +++ b/decoder/inside_outside.h @@ -1,5 +1,5 @@ -#ifndef _INSIDE_OUTSIDE_H_ -#define _INSIDE_OUTSIDE_H_ +#ifndef INSIDE_OUTSIDE_H_ +#define INSIDE_OUTSIDE_H_  #include <vector>  #include <algorithm> diff --git a/decoder/json_parse.cc b/decoder/json_parse.cc deleted file mode 100644 index f6fdfea8..00000000 --- a/decoder/json_parse.cc +++ /dev/null @@ -1,50 +0,0 @@ -#include "json_parse.h" - -#include <string> -#include <iostream> - -using namespace std; - -static const char *json_hex_chars = "0123456789abcdef"; - -void JSONParser::WriteEscapedString(const string& in, ostream* out) { -  int pos = 0; -  int start_offset = 0; -  unsigned char c = 0; -  (*out) << '"'; -  while(pos < in.size()) { -    c = in[pos]; -    switch(c) { -    case '\b': -    case '\n': -    case '\r': -    case '\t': -    case '"': -    case '\\': -    case '/': -      if(pos - start_offset > 0) -	(*out) << in.substr(start_offset, pos - start_offset); -      if(c == '\b') (*out) << "\\b"; -      else if(c == '\n') (*out) << "\\n"; -      else if(c == '\r') (*out) << "\\r"; -      else if(c == '\t') (*out) << "\\t"; -      else if(c == '"') (*out) << "\\\""; -      else if(c == '\\') (*out) << "\\\\"; -      else if(c == '/') (*out) << "\\/"; -      start_offset = ++pos; -      break; -    default: -      if(c < ' ') { -        cerr << "Warning, bad character (" << static_cast<int>(c) << ") in string\n"; -	if(pos - start_offset > 0) -	  (*out) << in.substr(start_offset, pos - start_offset); -	(*out) << "\\u00" << json_hex_chars[c >> 4] << json_hex_chars[c & 0xf]; -	start_offset = ++pos; -      } else pos++; -    } -  } -  if(pos - start_offset > 0) -    (*out) << in.substr(start_offset, pos - start_offset); -  (*out) << '"'; -} - diff --git a/decoder/json_parse.h b/decoder/json_parse.h deleted file mode 100644 index c3cba954..00000000 --- a/decoder/json_parse.h +++ /dev/null @@ -1,58 +0,0 @@ -#ifndef _JSON_WRAPPER_H_ -#define _JSON_WRAPPER_H_ - -#include <iostream> -#include <cassert> -#include "JSON_parser.h" - -class JSONParser { - public: -  JSONParser() { -    init_JSON_config(&config); -    hack.mf = &JSONParser::Callback; -    config.depth = 10; -    config.callback_ctx = reinterpret_cast<void*>(this); -    config.callback = hack.cb; -    config.allow_comments = 1; -    config.handle_floats_manually = 1; -    jc = new_JSON_parser(&config); -  } -  virtual ~JSONParser() { -    delete_JSON_parser(jc); -  } -  bool Parse(std::istream* in) { -    int count = 0; -    int lc = 1; -    for (; in ; ++count) { -      int next_char = in->get(); -      if (!in->good()) break; -      if (lc == '\n') { ++lc; } -      if (!JSON_parser_char(jc, next_char)) { -        std::cerr << "JSON_parser_char: syntax error, line " << lc << " (byte " << count << ")" << std::endl; -        return false; -      } -    } -    if (!JSON_parser_done(jc)) { -      std::cerr << "JSON_parser_done: syntax error\n"; -      return false; -    } -    return true; -  } -  static void WriteEscapedString(const std::string& in, std::ostream* out); - protected: -  virtual bool HandleJSONEvent(int type, const JSON_value* value) = 0; - private: -  int Callback(int type, const JSON_value* value) { -    if (HandleJSONEvent(type, value)) return 1; -    return 0; -  } -  JSON_parser_struct* jc; -  JSON_config config; -  typedef int (JSONParser::* MF)(int type, const struct JSON_value_struct* value); -  union CBHack { -    JSON_parser_callback cb; -    MF mf; -  } hack; -}; - -#endif diff --git a/decoder/kbest.h b/decoder/kbest.h index c7194c7e..d6b3eb94 100644 --- a/decoder/kbest.h +++ b/decoder/kbest.h @@ -1,5 +1,5 @@ -#ifndef _HG_KBEST_H_ -#define _HG_KBEST_H_ +#ifndef HG_KBEST_H_ +#define HG_KBEST_H_  #include <vector>  #include <utility> diff --git a/decoder/lattice.cc b/decoder/lattice.cc index 89da3cd0..1f97048d 100644 --- a/decoder/lattice.cc +++ b/decoder/lattice.cc @@ -50,7 +50,6 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) {    l.resize(ids.size());    for (int i = 0; i < l.size(); ++i)      l[i].push_back(LatticeArc(ids[i], 0.0, 1)); -  l.is_sentence_ = true;  }  void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) { diff --git a/decoder/lattice.h b/decoder/lattice.h index ad4ca50d..1258d3f5 100644 --- a/decoder/lattice.h +++ b/decoder/lattice.h @@ -1,5 +1,5 @@ -#ifndef __LATTICE_H_ -#define __LATTICE_H_ +#ifndef LATTICE_H_ +#define LATTICE_H_  #include <string>  #include <vector> @@ -25,22 +25,24 @@ class Lattice : public std::vector<std::vector<LatticeArc> > {    friend void LatticeTools::ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl);    friend void LatticeTools::ConvertTextToLattice(const std::string& text, Lattice* pl);   public: -  Lattice() : is_sentence_(false) {} +  Lattice() {}    explicit Lattice(size_t t, const std::vector<LatticeArc>& v = std::vector<LatticeArc>()) : -   std::vector<std::vector<LatticeArc> >(t, v), -   is_sentence_(false) {} +   std::vector<std::vector<LatticeArc>>(t, v) {}    int Distance(int from, int to) const {      if (dist_.empty())        return (to - from);      return dist_(from, to);    } -  // TODO this should actually be computed based on the contents -  // of the lattice -  bool IsSentence() const { return is_sentence_; }   private:    void ComputeDistances();    Array2D<int> dist_; -  bool is_sentence_;  }; +inline bool IsSentence(const Lattice& in) { +  bool res = true; +  for (auto& alt : in) +    if (alt.size() > 1) { res = false; break; } +  return res; +} +  #endif diff --git a/decoder/lexalign.cc b/decoder/lexalign.cc index 11f20de7..dd529311 100644 --- a/decoder/lexalign.cc +++ b/decoder/lexalign.cc @@ -114,10 +114,9 @@ bool LexicalAlign::TranslateImpl(const string& input,                        Hypergraph* forest) {    Lattice& lattice = smeta->src_lattice_;    LatticeTools::ConvertTextOrPLF(input, &lattice); -  if (!lattice.IsSentence()) { -    // lexical models make independence assumptions -    // that don't work with lattices or conf nets -    cerr << "LexicalTrans: cannot deal with lattice source input!\n"; +  smeta->ComputeInputLatticeType(); +  if (smeta->GetInputType() != cdec::kSEQUENCE) { +    cerr << "LexicalTrans: cannot deal with non-sequence input!";      abort();    }    smeta->SetSourceLength(lattice.size()); diff --git a/decoder/lexalign.h b/decoder/lexalign.h index 7ba4fe64..6415f4f9 100644 --- a/decoder/lexalign.h +++ b/decoder/lexalign.h @@ -1,5 +1,5 @@ -#ifndef _LEXALIGN_H_ -#define _LEXALIGN_H_ +#ifndef LEXALIGN_H_ +#define LEXALIGN_H_  #include "translator.h"  #include "lattice.h" diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index 74a18c3f..d13a891a 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -271,10 +271,9 @@ bool LexicalTrans::TranslateImpl(const string& input,                        Hypergraph* forest) {    Lattice& lattice = smeta->src_lattice_;    LatticeTools::ConvertTextOrPLF(input, &lattice); -  if (!lattice.IsSentence()) { -    // lexical models make independence assumptions -    // that don't work with lattices or conf nets -    cerr << "LexicalTrans: cannot deal with lattice source input!\n"; +  smeta->ComputeInputLatticeType(); +  if (smeta->GetInputType() != cdec::kSEQUENCE) { +    cerr << "LexicalTrans: cannot deal with non-sequence inputs\n";      abort();    }    smeta->SetSourceLength(lattice.size()); diff --git a/decoder/lextrans.h b/decoder/lextrans.h index 2d51e7c0..a23a4e0d 100644 --- a/decoder/lextrans.h +++ b/decoder/lextrans.h @@ -1,5 +1,5 @@ -#ifndef _LEXTrans_H_ -#define _LEXTrans_H_ +#ifndef LEXTrans_H_ +#define LEXTrans_H_  #include "translator.h"  #include "lattice.h" diff --git a/decoder/node_state_hash.h b/decoder/node_state_hash.h index 9fc01a09..f380fcb1 100644 --- a/decoder/node_state_hash.h +++ b/decoder/node_state_hash.h @@ -1,5 +1,5 @@ -#ifndef _NODE_STATE_HASH_ -#define _NODE_STATE_HASH_ +#ifndef NODE_STATE_HASH_ +#define NODE_STATE_HASH_  #include <cassert>  #include <cstring> diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index d2c4715c..cd587833 100644 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -21,6 +21,7 @@  #include "kbest.h"  #include "timing_stats.h"  #include "sentences.h" +#include "b64featvector.h"  //TODO: put function impls into .cc  //TODO: move Translation into its own .h and use in cdec @@ -252,19 +253,31 @@ struct OracleBleu {    }    bool show_derivation; +  int show_derivation_mask; +    template <class Filter> -  void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) { +  void kbest(int sent_id, Hypergraph const& forest, int k, bool mr_mira_compat, +             int src_len, std::ostream& kbest_out = std::cout, +             std::ostream& deriv_out = std::cerr) {      using namespace std;      using namespace boost;      typedef KBest::KBestDerivations<Sentence, ESentenceTraversal,Filter> K;      K kbest(forest,k);      //add length (f side) src length of this sentence to the psuedo-doc src length count      float curr_src_length = doc_src_length + tmp_src_length; -    for (int i = 0; i < k; ++i) { +    if (mr_mira_compat) kbest_out << k << "\n"; +    int i = 0; +    for (; i < k; ++i) {        typename K::Derivation *d = kbest.LazyKthBest(forest.nodes_.size() - 1, i);        if (!d) break; -      kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " -                << d->feature_values << " ||| " << log(d->score); +      kbest_out << sent_id << " ||| "; +      if (mr_mira_compat) kbest_out << src_len << " ||| "; +      kbest_out << TD::GetString(d->yield) << " ||| "; +      if (mr_mira_compat) +        kbest_out << EncodeFeatureVector(d->feature_values); +      else +        kbest_out << d->feature_values; +      kbest_out << " ||| " << log(d->score);        if (!refs.empty()) {          ScoreP sentscore = GetScore(d->yield,sent_id);          sentscore->PlusEquals(*doc_score,float(1)); @@ -275,14 +288,21 @@ struct OracleBleu {        if (show_derivation) {          deriv_out<<"\nsent_id="<<sent_id<<"."<<i<<" ||| "; //where i is candidate #/k          deriv_out<<log(d->score)<<"\n"; -        deriv_out<<kbest.derivation_tree(*d,true); +        deriv_out<<kbest.derivation_tree(*d,true, show_derivation_mask);          deriv_out<<"\n"<<flush;        }      } +    if (mr_mira_compat) { +      for (; i < k; ++i) kbest_out << "\n"; +      kbest_out << flush; +    }    }  // TODO decoder output should probably be moved to another file - how about oracle_bleu.h -  void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_, std::string const &deriv_out_filename_) { +  void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, +                 const bool unique, const bool mr_mira_compat, +                 const int src_len, std::string const& kbest_out_filename_, +                 std::string const& deriv_out_filename_) {      WriteFile ko(kbest_out_filename_);      std::cerr << "Output kbest to " << kbest_out_filename_ <<std::endl; @@ -295,9 +315,11 @@ struct OracleBleu {      WriteFile oderiv(sderiv.str());      if (!unique) -      kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),oderiv.get()); +      kbest<KBest::NoFilter<std::vector<WordID> > >( +          sent_id, forest, k, mr_mira_compat, src_len, ko.get(), oderiv.get());      else { -      kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),oderiv.get()); +      kbest<KBest::FilterUnique>(sent_id, forest, k, mr_mira_compat, src_len, +                                 ko.get(), oderiv.get());      }    } @@ -305,7 +327,8 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo    {      std::ostringstream kbest_string_stream;      kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id; -    DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), "-"); +    DumpKBest(sent_id, forest, k, unique, false, -1, kbest_string_stream.str(), +              "-");    }  }; diff --git a/decoder/phrasebased_translator.cc b/decoder/phrasebased_translator.cc index 8048248e..8415353a 100644 --- a/decoder/phrasebased_translator.cc +++ b/decoder/phrasebased_translator.cc @@ -114,6 +114,7 @@ struct PhraseBasedTranslatorImpl {      Lattice lattice;      LatticeTools::ConvertTextOrPLF(input, &lattice);      smeta->SetSourceLength(lattice.size()); +    smeta->ComputeInputLatticeType();      size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion);      minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100);      if (add_pass_through_rules) { diff --git a/decoder/phrasebased_translator.h b/decoder/phrasebased_translator.h index e5e3f8a2..10790d0d 100644 --- a/decoder/phrasebased_translator.h +++ b/decoder/phrasebased_translator.h @@ -1,5 +1,5 @@ -#ifndef _PHRASEBASED_TRANSLATOR_H_ -#define _PHRASEBASED_TRANSLATOR_H_ +#ifndef PHRASEBASED_TRANSLATOR_H_ +#define PHRASEBASED_TRANSLATOR_H_  #include "translator.h" diff --git a/decoder/phrasetable_fst.h b/decoder/phrasetable_fst.h index 477de1f7..966bb14d 100644 --- a/decoder/phrasetable_fst.h +++ b/decoder/phrasetable_fst.h @@ -1,5 +1,5 @@ -#ifndef _PHRASETABLE_FST_H_ -#define _PHRASETABLE_FST_H_ +#ifndef PHRASETABLE_FST_H_ +#define PHRASETABLE_FST_H_  #include <vector>  #include <string> diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc index 10192f7a..2c5fa9c4 100644 --- a/decoder/rescore_translator.cc +++ b/decoder/rescore_translator.cc @@ -3,6 +3,7 @@  #include <sstream>  #include <boost/shared_ptr.hpp> +#include "filelib.h"  #include "sentence_metadata.h"  #include "hg.h"  #include "hg_io.h" @@ -20,16 +21,18 @@ struct RescoreTranslatorImpl {    bool Translate(const string& input,                   const vector<double>& weights,                   Hypergraph* forest) { -    if (input == "{}") return false; -    if (input.find("{\"rules\"") == 0) { -      istringstream is(input); -      Hypergraph src_cfg_hg; -      if (!HypergraphIO::ReadFromJSON(&is, forest)) { -        cerr << "Parse error while reading HG from JSON.\n"; -        abort(); -      } -    } else { -      cerr << "Can only read HG input from JSON: use training/grammar_convert\n"; +    istringstream is(input); +    string header, fname; +    is >> header >> fname; +    if (header != "::forest::") { +      cerr << "RescoreTranslator: expected input lines of form ::forest:: filename.gz\n"; +      abort(); +    } +    ReadFile rf(fname); +    if (!rf) { cerr << "Can't read " << fname << endl; abort(); } +    Hypergraph src_cfg_hg; +    if (!HypergraphIO::ReadFromBinary(rf.stream(), forest)) { +      cerr << "Parse error while reading HG.\n";        abort();      }      Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); @@ -53,6 +56,7 @@ bool RescoreTranslator::TranslateImpl(const string& input,                                const vector<double>& weights,                                Hypergraph* minus_lm_forest) {    smeta->SetSourceLength(0);  // don't know how to compute this +  smeta->input_type_ = cdec::kFOREST;    return pimpl_->Translate(input, weights, minus_lm_forest);  } diff --git a/decoder/rule_lexer.h b/decoder/rule_lexer.h index e15c056d..5267f9ca 100644 --- a/decoder/rule_lexer.h +++ b/decoder/rule_lexer.h @@ -1,5 +1,5 @@ -#ifndef _RULE_LEXER_H_ -#define _RULE_LEXER_H_ +#ifndef RULE_LEXER_H_ +#define RULE_LEXER_H_  #include <iostream>  #include <string> diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll index d4a8d86b..8b48ab7b 100644 --- a/decoder/rule_lexer.ll +++ b/decoder/rule_lexer.ll @@ -356,6 +356,7 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const  void RuleLexer::ReadRule(const std::string& srule, RuleCallback func, bool mono, void* extra) {    init_default_feature_names(); +  scfglex_fname = srule;    lex_mono_rules = mono;    lex_line = 1;    rule_callback_extra = extra; diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 83b65c28..538f82ec 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -195,6 +195,7 @@ struct SCFGTranslatorImpl {      Lattice& lattice = smeta->src_lattice_;      LatticeTools::ConvertTextOrPLF(input, &lattice);      smeta->SetSourceLength(lattice.size()); +    smeta->ComputeInputLatticeType();      if (add_pass_through_rules){        if (!SILENT) cerr << "Adding pass through grammar" << endl;        PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_, num_pt_features); diff --git a/decoder/sentence_metadata.h b/decoder/sentence_metadata.h index f2a779f4..e13c2ca5 100644 --- a/decoder/sentence_metadata.h +++ b/decoder/sentence_metadata.h @@ -1,14 +1,20 @@ -#ifndef _SENTENCE_METADATA_H_ -#define _SENTENCE_METADATA_H_ +#ifndef SENTENCE_METADATA_H_ +#define SENTENCE_METADATA_H_  #include <string>  #include <map>  #include <cassert>  #include "lattice.h" +#include "tree_fragment.h"  struct DocScorer;  // deprecated, will be removed  struct Score;     // deprecated, will be removed +namespace cdec { +enum InputType { kSEQUENCE, kTREE, kLATTICE, kFOREST, kUNKNOWN }; +class TreeFragment; +} +  class SentenceMetadata {   public:    friend class DecoderImpl; @@ -17,7 +23,17 @@ class SentenceMetadata {      src_len_(-1),      has_reference_(ref.size() > 0),      trg_len_(ref.size()), -    ref_(has_reference_ ? &ref : NULL) {} +    ref_(has_reference_ ? &ref : NULL), +    input_type_(cdec::kUNKNOWN) {} + +  // helper function for lattice inputs +  void ComputeInputLatticeType() { +    input_type_ = cdec::kSEQUENCE; +    for (auto& alt : src_lattice_) { +      if (alt.size() > 1) { input_type_ = cdec::kLATTICE; break; } +    } +  } +  cdec::InputType GetInputType() const { return input_type_; }    int GetSentenceId() const { return sent_id_; } @@ -25,6 +41,8 @@ class SentenceMetadata {    // it has parsed the source    void SetSourceLength(int sl) { src_len_ = sl; } +  const cdec::TreeFragment& GetSourceTree() const { return src_tree_; } +    // this should be called if a separate model needs to    // specify how long the target sentence should be    void SetTargetLength(int tl) { @@ -64,12 +82,15 @@ class SentenceMetadata {    const Score* app_score;   public:    Lattice src_lattice_;  // this will only be set if inputs are finite state! +  cdec::TreeFragment src_tree_; // this will be set only if inputs are trees   private:    // you need to be very careful when depending on these values    // they will only be set during training / alignment contexts    const bool has_reference_;    int trg_len_;    const Lattice* const ref_; + public: +  cdec::InputType input_type_;  };  #endif diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 30fb055f..500d2061 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -100,6 +100,8 @@ bool Tagger::TranslateImpl(const string& input,    Lattice& lattice = smeta->src_lattice_;    LatticeTools::ConvertTextToLattice(input, &lattice);    smeta->SetSourceLength(lattice.size()); +  smeta->ComputeInputLatticeType(); +  assert(smeta->GetInputType() == cdec::kSEQUENCE);    vector<WordID> sequence(lattice.size());    for (int i = 0; i < lattice.size(); ++i) {      assert(lattice[i].size() == 1); diff --git a/decoder/tagger.h b/decoder/tagger.h index 9ac820d9..51659d5b 100644 --- a/decoder/tagger.h +++ b/decoder/tagger.h @@ -1,5 +1,5 @@ -#ifndef _TAGGER_H_ -#define _TAGGER_H_ +#ifndef TAGGER_H_ +#define TAGGER_H_  #include "translator.h" diff --git a/decoder/test_data/hg_test.hg.bin.gz b/decoder/test_data/hg_test.hg.bin.gzBinary files differ new file mode 100644 index 00000000..c07fbe8c --- /dev/null +++ b/decoder/test_data/hg_test.hg.bin.gz diff --git a/decoder/test_data/hg_test.hg_balanced.bin.gz b/decoder/test_data/hg_test.hg_balanced.bin.gzBinary files differ new file mode 100644 index 00000000..896d3d60 --- /dev/null +++ b/decoder/test_data/hg_test.hg_balanced.bin.gz diff --git a/decoder/test_data/hg_test.hg_int.bin.gz b/decoder/test_data/hg_test.hg_int.bin.gzBinary files differ new file mode 100644 index 00000000..e0bd6187 --- /dev/null +++ b/decoder/test_data/hg_test.hg_int.bin.gz diff --git a/decoder/test_data/hg_test.lattice.bin.gz b/decoder/test_data/hg_test.lattice.bin.gzBinary files differ new file mode 100644 index 00000000..8a8c05f4 --- /dev/null +++ b/decoder/test_data/hg_test.lattice.bin.gz diff --git a/decoder/test_data/hg_test.tiny.bin.gz b/decoder/test_data/hg_test.tiny.bin.gzBinary files differ new file mode 100644 index 00000000..0e68eb40 --- /dev/null +++ b/decoder/test_data/hg_test.tiny.bin.gz diff --git a/decoder/test_data/hg_test.tiny_lattice.bin.gz b/decoder/test_data/hg_test.tiny_lattice.bin.gzBinary files differ new file mode 100644 index 00000000..97e8dc05 --- /dev/null +++ b/decoder/test_data/hg_test.tiny_lattice.bin.gz diff --git a/decoder/test_data/perro.json.gz b/decoder/test_data/perro.json.gzBinary files differ deleted file mode 100644 index 41de5758..00000000 --- a/decoder/test_data/perro.json.gz +++ /dev/null diff --git a/decoder/test_data/small.bin.gz b/decoder/test_data/small.bin.gzBinary files differ new file mode 100644 index 00000000..1c5a1631 --- /dev/null +++ b/decoder/test_data/small.bin.gz diff --git a/decoder/test_data/small.json.gz b/decoder/test_data/small.json.gzBinary files differ deleted file mode 100644 index f6f37293..00000000 --- a/decoder/test_data/small.json.gz +++ /dev/null diff --git a/decoder/test_data/urdu.json.gz b/decoder/test_data/urdu.json.gzBinary files differ deleted file mode 100644 index 84535402..00000000 --- a/decoder/test_data/urdu.json.gz +++ /dev/null diff --git a/decoder/translator.h b/decoder/translator.h index ba218a0b..096cf191 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -1,5 +1,5 @@ -#ifndef _TRANSLATOR_H_ -#define _TRANSLATOR_H_ +#ifndef TRANSLATOR_H_ +#define TRANSLATOR_H_  #include <string>  #include <vector> diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index bd3b01d0..cdd83ffc 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -32,12 +32,12 @@ static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bo      ++lc;      if (line.size() == 0 || line[0] == '#') continue;      std::vector<StringPiece> fields = TokenizeMultisep(line, " ||| "); -    if (has_multiple_states && fields.size() != 4) { -      cerr << "Expected 4 fields in rule file but line " << lc << " is:\n" << line << endl; +    if (has_multiple_states && fields.size() < 4) { +      cerr << "Expected at least 4 fields in rule file but line " << lc << " is:\n" << line << endl;        abort();      } -    if (!has_multiple_states && fields.size() != 3) { -      cerr << "Expected 3 fields in rule file but line " << lc << " is:\n" << line << endl; +    if (!has_multiple_states && fields.size() < 3) { +      cerr << "Expected at least 3 fields in rule file but line " << lc << " is:\n" << line << endl;        abort();      } @@ -73,6 +73,7 @@ static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bo        cerr << "Not implemented...\n"; abort(); // TODO read in states      } else {        os << " ||| " << fields[1] << " ||| " << fields[2]; +      if (fields.size() > 3) os << " ||| " << fields[3];        rule.reset(new TRule(os.str()));      }      cur->rules.push_back(rule); @@ -266,6 +267,7 @@ struct Tree2StringTranslatorImpl {        for (auto sym : rule_src)          cur = &cur->next[sym];        TRulePtr rule(new TRule(rhse, rhsf, lhs)); +      rule->a_.push_back(AlignmentPoint(0, 0));        rule->ComputeArity();        rule->scores_.set_value(ntfid, 1.0);        rule->scores_.set_value(kFID, 1.0); @@ -287,6 +289,8 @@ struct Tree2StringTranslatorImpl {                   const vector<double>& weights,                   Hypergraph* minus_lm_forest) {      cdec::TreeFragment input_tree(input, false); +    smeta->src_tree_ = input_tree; +    smeta->input_type_ = cdec::kTREE;      if (add_pass_through_rules) CreatePassThroughRules(input_tree);      Hypergraph hg;      hg.ReserveNodes(input_tree.nodes.size()); diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 42f7793a..5f717c5b 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -64,6 +64,13 @@ int TreeFragment::SetupSpansRec(unsigned cur, int left) {    return right;  } +vector<int> TreeFragment::Terminals() const { +  vector<int> terms; +  for (auto& x : *this) +    if (IsTerminal(x)) terms.push_back(x); +  return terms; +} +  // cp is the character index in the tree  // np keeps track of the nodes (nonterminals) that have been built  // symp keeps track of the terminal symbols that have been built diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index 6b4842ee..e19b79fb 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -72,6 +72,8 @@ class TreeFragment {    BreadthFirstIterator bfs_begin(unsigned node_idx) const;    BreadthFirstIterator bfs_end() const; +  std::vector<int> Terminals() const; +   private:    // cp is the character index in the tree    // np keeps track of the nodes (nonterminals) that have been built diff --git a/decoder/trule.h b/decoder/trule.h index 243b0da9..7af46747 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -1,5 +1,5 @@ -#ifndef _RULE_H_ -#define _RULE_H_ +#ifndef TRULE_H_ +#define TRULE_H_  #include <algorithm>  #include <vector> @@ -11,6 +11,7 @@  #include "sparse_vector.h"  #include "wordid.h" +#include "tdict.h"  class TRule;  typedef boost::shared_ptr<TRule> TRulePtr; @@ -26,6 +27,7 @@ struct AlignmentPoint {    short s_;    short t_;  }; +  inline std::ostream& operator<<(std::ostream& os, const AlignmentPoint& p) {    return os << static_cast<int>(p.s_) << '-' << static_cast<int>(p.t_);  } @@ -163,6 +165,66 @@ class TRule {    // optional, shows internal structure of TSG rules    boost::shared_ptr<cdec::TreeFragment> tree_structure; +  friend class boost::serialization::access; +  template<class Archive> +  void save(Archive & ar, const unsigned int /*version*/) const { +    ar & TD::Convert(-lhs_); +    unsigned f_size = f_.size(); +    ar & f_size; +    assert(f_size <= (sizeof(size_t) * 8)); +    size_t f_nt_mask = 0; +    for (int i = f_.size() - 1; i >= 0; --i) { +      f_nt_mask <<= 1; +      f_nt_mask |= (f_[i] <= 0 ? 1 : 0); +    } +    ar & f_nt_mask; +    for (unsigned i = 0; i < f_.size(); ++i) +      ar & TD::Convert(f_[i] < 0 ? -f_[i] : f_[i]); +    unsigned e_size = e_.size(); +    ar & e_size; +    size_t e_nt_mask = 0; +    assert(e_size <= (sizeof(size_t) * 8)); +    for (int i = e_.size() - 1; i >= 0; --i) { +      e_nt_mask <<= 1; +      e_nt_mask |= (e_[i] <= 0 ? 1 : 0); +    } +    ar & e_nt_mask; +    for (unsigned i = 0; i < e_.size(); ++i) +      if (e_[i] <= 0) ar & e_[i]; else ar & TD::Convert(e_[i]); +    ar & arity_; +    ar & scores_; +  } +  template<class Archive> +  void load(Archive & ar, const unsigned int /*version*/) { +    std::string lhs; ar & lhs; lhs_ = -TD::Convert(lhs); +    unsigned f_size; ar & f_size; +    f_.resize(f_size); +    size_t f_nt_mask; ar & f_nt_mask; +    std::string sym; +    for (unsigned i = 0; i < f_size; ++i) { +      bool mask = (f_nt_mask & 1); +      ar & sym; +      f_[i] = TD::Convert(sym) * (mask ? -1 : 1); +      f_nt_mask >>= 1; +    } +    unsigned e_size; ar & e_size; +    e_.resize(e_size); +    size_t e_nt_mask; ar & e_nt_mask; +    for (unsigned i = 0; i < e_size; ++i) { +      bool mask = (e_nt_mask & 1); +      if (mask) { +        ar & e_[i]; +      } else { +        ar & sym; +        e_[i] = TD::Convert(sym); +      } +      e_nt_mask >>= 1; +    } +    ar & arity_; +    ar & scores_; +  } + +  BOOST_SERIALIZATION_SPLIT_MEMBER()   private:    TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {}  }; diff --git a/decoder/trule_test.cc b/decoder/trule_test.cc index 0cb7e2e8..d75c2016 100644 --- a/decoder/trule_test.cc +++ b/decoder/trule_test.cc @@ -4,6 +4,10 @@  #include <boost/test/unit_test.hpp>  #include <boost/test/floating_point_comparison.hpp>  #include <iostream> +#include <boost/archive/text_oarchive.hpp> +#include <boost/archive/text_iarchive.hpp> +#include <boost/serialization/shared_ptr.hpp> +#include <sstream>  #include "tdict.h"  using namespace std; @@ -53,3 +57,35 @@ BOOST_AUTO_TEST_CASE(TestRuleR) {    BOOST_CHECK_EQUAL(t6.e_[3], 0);  } +BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) { +  string str; +  string t7str; +  { +    TRule t7; +    t7.ReadFromString("[X] ||| den [X,1] sah [X,2] . ||| [2] saw the [1] . ||| Feature1=0.12321 Foo=0.23232 Bar=0.121"); +    cerr << t7.AsString() << endl; +    ostringstream os; +    TRulePtr tp1(new TRule("[X] ||| a b c ||| x z y ||| A=1 B=2")); +    TRulePtr tp2 = tp1; +    boost::archive::text_oarchive oa(os); +    oa << t7; +    oa << tp1; +    oa << tp2; +    str = os.str(); +    t7str = t7.AsString(); +  } +  { +    istringstream is(str); +    boost::archive::text_iarchive ia(is); +    TRule t8; +    ia >> t8; +    TRulePtr tp3, tp4; +    ia >> tp3; +    ia >> tp4; +    cerr << t8.AsString() << endl; +    BOOST_CHECK_EQUAL(t7str, t8.AsString()); +    cerr << tp3->AsString() << endl; +    cerr << tp4->AsString() << endl; +  } +} + diff --git a/decoder/viterbi.h b/decoder/viterbi.h index a8a0ea7f..20ee73cc 100644 --- a/decoder/viterbi.h +++ b/decoder/viterbi.h @@ -1,5 +1,5 @@ -#ifndef _VITERBI_H_ -#define _VITERBI_H_ +#ifndef VITERBI_H_ +#define VITERBI_H_  #include <vector>  #include "prob.h" | 
