diff options
Diffstat (limited to 'decoder')
45 files changed, 3870 insertions, 1726 deletions
diff --git a/decoder/JSON_parser.c b/decoder/JSON_parser.c deleted file mode 100644 index 5e392bc6..00000000 --- a/decoder/JSON_parser.c +++ /dev/null @@ -1,1012 +0,0 @@ -/* JSON_parser.c */ - -/* 2007-08-24 */ - -/* -Copyright (c) 2005 JSON.org - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -The Software shall be used for Good, not Evil. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -*/ - -/* - Callbacks, comments, Unicode handling by Jean Gressmann (jean@0x42.de), 2007-2009. - - For the added features the license above applies also. - - Changelog: - 2009-05-17 - Incorporated benrudiak@googlemail.com fix for UTF16 decoding. - - 2009-05-14 - Fixed float parsing bug related to a locale being set that didn't - use '.' as decimal point character (charles@transmissionbt.com). - - 2008-10-14 - Renamed states.IN to states.IT to avoid name clash which IN macro - defined in windef.h (alexey.pelykh@gmail.com) - - 2008-07-19 - Removed some duplicate code & debugging variable (charles@transmissionbt.com) - - 2008-05-28 - Made JSON_value structure ansi C compliant. This bug was report by - trisk@acm.jhu.edu - - 2008-05-20 - Fixed bug reported by charles@transmissionbt.com where the switching - from static to dynamic parse buffer did not copy the static parse - buffer's content. -*/ - - - -#include <assert.h> -#include <ctype.h> -#include <float.h> -#include <stddef.h> -#include <stdio.h> -#include <stdlib.h> -#include <string.h> -#include <locale.h> - -#include "JSON_parser.h" - -#ifdef _MSC_VER -# if _MSC_VER >= 1400 /* Visual Studio 2005 and up */ -# pragma warning(disable:4996) // unsecure sscanf -# endif -#endif - - -#define true 1 -#define false 0 -#define __ -1 /* the universal error code */ - -/* values chosen so that the object size is approx equal to one page (4K) */ -#ifndef JSON_PARSER_STACK_SIZE -# define JSON_PARSER_STACK_SIZE 128 -#endif - -#ifndef JSON_PARSER_PARSE_BUFFER_SIZE -# define JSON_PARSER_PARSE_BUFFER_SIZE 3500 -#endif - -typedef unsigned short UTF16; - -struct JSON_parser_struct { - JSON_parser_callback callback; - void* ctx; - signed char state, before_comment_state, type, escaped, comment, allow_comments, handle_floats_manually; - UTF16 utf16_high_surrogate; - long depth; - long top; - signed char* stack; - long stack_capacity; - char decimal_point; - char* parse_buffer; - size_t parse_buffer_capacity; - size_t parse_buffer_count; - size_t comment_begin_offset; - signed char static_stack[JSON_PARSER_STACK_SIZE]; - char static_parse_buffer[JSON_PARSER_PARSE_BUFFER_SIZE]; -}; - -#define COUNTOF(x) (sizeof(x)/sizeof(x[0])) - -/* - Characters are mapped into these character classes. This allows for - a significant reduction in the size of the state transition table. -*/ - - - -enum classes { - C_SPACE, /* space */ - C_WHITE, /* other whitespace */ - C_LCURB, /* { */ - C_RCURB, /* } */ - C_LSQRB, /* [ */ - C_RSQRB, /* ] */ - C_COLON, /* : */ - C_COMMA, /* , */ - C_QUOTE, /* " */ - C_BACKS, /* \ */ - C_SLASH, /* / */ - C_PLUS, /* + */ - C_MINUS, /* - */ - C_POINT, /* . */ - C_ZERO , /* 0 */ - C_DIGIT, /* 123456789 */ - C_LOW_A, /* a */ - C_LOW_B, /* b */ - C_LOW_C, /* c */ - C_LOW_D, /* d */ - C_LOW_E, /* e */ - C_LOW_F, /* f */ - C_LOW_L, /* l */ - C_LOW_N, /* n */ - C_LOW_R, /* r */ - C_LOW_S, /* s */ - C_LOW_T, /* t */ - C_LOW_U, /* u */ - C_ABCDF, /* ABCDF */ - C_E, /* E */ - C_ETC, /* everything else */ - C_STAR, /* * */ - NR_CLASSES -}; - -static int ascii_class[128] = { -/* - This array maps the 128 ASCII characters into character classes. - The remaining Unicode characters should be mapped to C_ETC. - Non-whitespace control characters are errors. -*/ - __, __, __, __, __, __, __, __, - __, C_WHITE, C_WHITE, __, __, C_WHITE, __, __, - __, __, __, __, __, __, __, __, - __, __, __, __, __, __, __, __, - - C_SPACE, C_ETC, C_QUOTE, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, - C_ETC, C_ETC, C_STAR, C_PLUS, C_COMMA, C_MINUS, C_POINT, C_SLASH, - C_ZERO, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, - C_DIGIT, C_DIGIT, C_COLON, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, - - C_ETC, C_ABCDF, C_ABCDF, C_ABCDF, C_ABCDF, C_E, C_ABCDF, C_ETC, - C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, - C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, - C_ETC, C_ETC, C_ETC, C_LSQRB, C_BACKS, C_RSQRB, C_ETC, C_ETC, - - C_ETC, C_LOW_A, C_LOW_B, C_LOW_C, C_LOW_D, C_LOW_E, C_LOW_F, C_ETC, - C_ETC, C_ETC, C_ETC, C_ETC, C_LOW_L, C_ETC, C_LOW_N, C_ETC, - C_ETC, C_ETC, C_LOW_R, C_LOW_S, C_LOW_T, C_LOW_U, C_ETC, C_ETC, - C_ETC, C_ETC, C_ETC, C_LCURB, C_ETC, C_RCURB, C_ETC, C_ETC -}; - - -/* - The state codes. -*/ -enum states { - GO, /* start */ - OK, /* ok */ - OB, /* object */ - KE, /* key */ - CO, /* colon */ - VA, /* value */ - AR, /* array */ - ST, /* string */ - ES, /* escape */ - U1, /* u1 */ - U2, /* u2 */ - U3, /* u3 */ - U4, /* u4 */ - MI, /* minus */ - ZE, /* zero */ - IT, /* integer */ - FR, /* fraction */ - E1, /* e */ - E2, /* ex */ - E3, /* exp */ - T1, /* tr */ - T2, /* tru */ - T3, /* true */ - F1, /* fa */ - F2, /* fal */ - F3, /* fals */ - F4, /* false */ - N1, /* nu */ - N2, /* nul */ - N3, /* null */ - C1, /* / */ - C2, /* / * */ - C3, /* * */ - FX, /* *.* *eE* */ - D1, /* second UTF-16 character decoding started by \ */ - D2, /* second UTF-16 character proceeded by u */ - NR_STATES -}; - -enum actions -{ - CB = -10, /* comment begin */ - CE = -11, /* comment end */ - FA = -12, /* false */ - TR = -13, /* false */ - NU = -14, /* null */ - DE = -15, /* double detected by exponent e E */ - DF = -16, /* double detected by fraction . */ - SB = -17, /* string begin */ - MX = -18, /* integer detected by minus */ - ZX = -19, /* integer detected by zero */ - IX = -20, /* integer detected by 1-9 */ - EX = -21, /* next char is escaped */ - UC = -22 /* Unicode character read */ -}; - - -static int state_transition_table[NR_STATES][NR_CLASSES] = { -/* - The state transition table takes the current state and the current symbol, - and returns either a new state or an action. An action is represented as a - negative number. A JSON text is accepted if at the end of the text the - state is OK and if the mode is MODE_DONE. - - white 1-9 ABCDF etc - space | { } [ ] : , " \ / + - . 0 | a b c d e f l n r s t u | E | * */ -/*start GO*/ {GO,GO,-6,__,-5,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*ok OK*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*object OB*/ {OB,OB,__,-9,__,__,__,__,SB,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*key KE*/ {KE,KE,__,__,__,__,__,__,SB,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*colon CO*/ {CO,CO,__,__,__,__,-2,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*value VA*/ {VA,VA,-6,__,-5,__,__,__,SB,__,CB,__,MX,__,ZX,IX,__,__,__,__,__,FA,__,NU,__,__,TR,__,__,__,__,__}, -/*array AR*/ {AR,AR,-6,__,-5,-7,__,__,SB,__,CB,__,MX,__,ZX,IX,__,__,__,__,__,FA,__,NU,__,__,TR,__,__,__,__,__}, -/*string ST*/ {ST,__,ST,ST,ST,ST,ST,ST,-4,EX,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST}, -/*escape ES*/ {__,__,__,__,__,__,__,__,ST,ST,ST,__,__,__,__,__,__,ST,__,__,__,ST,__,ST,ST,__,ST,U1,__,__,__,__}, -/*u1 U1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U2,U2,U2,U2,U2,U2,U2,U2,__,__,__,__,__,__,U2,U2,__,__}, -/*u2 U2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U3,U3,U3,U3,U3,U3,U3,U3,__,__,__,__,__,__,U3,U3,__,__}, -/*u3 U3*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U4,U4,U4,U4,U4,U4,U4,U4,__,__,__,__,__,__,U4,U4,__,__}, -/*u4 U4*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,UC,UC,UC,UC,UC,UC,UC,UC,__,__,__,__,__,__,UC,UC,__,__}, -/*minus MI*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,ZE,IT,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*zero ZE*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,DF,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*int IT*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,DF,IT,IT,__,__,__,__,DE,__,__,__,__,__,__,__,__,DE,__,__}, -/*frac FR*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,__,FR,FR,__,__,__,__,E1,__,__,__,__,__,__,__,__,E1,__,__}, -/*e E1*/ {__,__,__,__,__,__,__,__,__,__,__,E2,E2,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*ex E2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*exp E3*/ {OK,OK,__,-8,__,-7,__,-3,__,__,__,__,__,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*tr T1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,T2,__,__,__,__,__,__,__}, -/*tru T2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,T3,__,__,__,__}, -/*true T3*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__,__,__}, -/*fa F1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F2,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*fal F2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F3,__,__,__,__,__,__,__,__,__}, -/*fals F3*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F4,__,__,__,__,__,__}, -/*false F4*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__,__,__}, -/*nu N1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,N2,__,__,__,__}, -/*nul N2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,N3,__,__,__,__,__,__,__,__,__}, -/*null N3*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__}, -/*/ C1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,C2}, -/*/* C2*/ {C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C3}, -/** C3*/ {C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,CE,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C3}, -/*_. FX*/ {OK,OK,__,-8,__,-7,__,-3,__,__,__,__,__,__,FR,FR,__,__,__,__,E1,__,__,__,__,__,__,__,__,E1,__,__}, -/*\ D1*/ {__,__,__,__,__,__,__,__,__,D2,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*\ D2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,U1,__,__,__,__}, -}; - - -/* - These modes can be pushed on the stack. -*/ -enum modes { - MODE_ARRAY = 1, - MODE_DONE = 2, - MODE_KEY = 3, - MODE_OBJECT = 4 -}; - -static int -push(JSON_parser jc, int mode) -{ -/* - Push a mode onto the stack. Return false if there is overflow. -*/ - jc->top += 1; - if (jc->depth < 0) { - if (jc->top >= jc->stack_capacity) { - size_t bytes_to_allocate; - jc->stack_capacity *= 2; - bytes_to_allocate = jc->stack_capacity * sizeof(jc->static_stack[0]); - if (jc->stack == &jc->static_stack[0]) { - jc->stack = (signed char*)malloc(bytes_to_allocate); - memcpy(jc->stack, jc->static_stack, sizeof(jc->static_stack)); - } else { - jc->stack = (signed char*)realloc(jc->stack, bytes_to_allocate); - } - } - } else { - if (jc->top >= jc->depth) { - return false; - } - } - - jc->stack[jc->top] = mode; - return true; -} - - -static int -pop(JSON_parser jc, int mode) -{ -/* - Pop the stack, assuring that the current mode matches the expectation. - Return false if there is underflow or if the modes mismatch. -*/ - if (jc->top < 0 || jc->stack[jc->top] != mode) { - return false; - } - jc->top -= 1; - return true; -} - - -#define parse_buffer_clear(jc) \ - do {\ - jc->parse_buffer_count = 0;\ - jc->parse_buffer[0] = 0;\ - } while (0) - -#define parse_buffer_pop_back_char(jc)\ - do {\ - assert(jc->parse_buffer_count >= 1);\ - --jc->parse_buffer_count;\ - jc->parse_buffer[jc->parse_buffer_count] = 0;\ - } while (0) - -void delete_JSON_parser(JSON_parser jc) -{ - if (jc) { - if (jc->stack != &jc->static_stack[0]) { - free((void*)jc->stack); - } - if (jc->parse_buffer != &jc->static_parse_buffer[0]) { - free((void*)jc->parse_buffer); - } - free((void*)jc); - } -} - - -JSON_parser -new_JSON_parser(JSON_config* config) -{ -/* - new_JSON_parser starts the checking process by constructing a JSON_parser - object. It takes a depth parameter that restricts the level of maximum - nesting. - - To continue the process, call JSON_parser_char for each character in the - JSON text, and then call JSON_parser_done to obtain the final result. - These functions are fully reentrant. -*/ - - int depth = 0; - JSON_config default_config; - - JSON_parser jc = (JSON_parser)malloc(sizeof(struct JSON_parser_struct)); - - memset(jc, 0, sizeof(*jc)); - - - /* initialize configuration */ - init_JSON_config(&default_config); - - /* set to default configuration if none was provided */ - if (config == NULL) { - config = &default_config; - } - - depth = config->depth; - - /* We need to be able to push at least one object */ - if (depth == 0) { - depth = 1; - } - - jc->state = GO; - jc->top = -1; - - /* Do we want non-bound stack? */ - if (depth > 0) { - jc->stack_capacity = depth; - jc->depth = depth; - if (depth <= (int)COUNTOF(jc->static_stack)) { - jc->stack = &jc->static_stack[0]; - } else { - jc->stack = (signed char*)malloc(jc->stack_capacity * sizeof(jc->static_stack[0])); - } - } else { - jc->stack_capacity = COUNTOF(jc->static_stack); - jc->depth = -1; - jc->stack = &jc->static_stack[0]; - } - - /* set parser to start */ - push(jc, MODE_DONE); - - /* set up the parse buffer */ - jc->parse_buffer = &jc->static_parse_buffer[0]; - jc->parse_buffer_capacity = COUNTOF(jc->static_parse_buffer); - parse_buffer_clear(jc); - - /* set up callback, comment & float handling */ - jc->callback = config->callback; - jc->ctx = config->callback_ctx; - jc->allow_comments = config->allow_comments != 0; - jc->handle_floats_manually = config->handle_floats_manually != 0; - - /* set up decimal point */ - jc->decimal_point = *localeconv()->decimal_point; - - return jc; -} - -static void grow_parse_buffer(JSON_parser jc) -{ - size_t bytes_to_allocate; - jc->parse_buffer_capacity *= 2; - bytes_to_allocate = jc->parse_buffer_capacity * sizeof(jc->parse_buffer[0]); - if (jc->parse_buffer == &jc->static_parse_buffer[0]) { - jc->parse_buffer = (char*)malloc(bytes_to_allocate); - memcpy(jc->parse_buffer, jc->static_parse_buffer, jc->parse_buffer_count); - } else { - jc->parse_buffer = (char*)realloc(jc->parse_buffer, bytes_to_allocate); - } -} - -#define parse_buffer_push_back_char(jc, c)\ - do {\ - if (jc->parse_buffer_count + 1 >= jc->parse_buffer_capacity) grow_parse_buffer(jc);\ - jc->parse_buffer[jc->parse_buffer_count++] = c;\ - jc->parse_buffer[jc->parse_buffer_count] = 0;\ - } while (0) - -#define assert_is_non_container_type(jc) \ - assert( \ - jc->type == JSON_T_NULL || \ - jc->type == JSON_T_FALSE || \ - jc->type == JSON_T_TRUE || \ - jc->type == JSON_T_FLOAT || \ - jc->type == JSON_T_INTEGER || \ - jc->type == JSON_T_STRING) - - -static int parse_parse_buffer(JSON_parser jc) -{ - if (jc->callback) { - JSON_value value, *arg = NULL; - - if (jc->type != JSON_T_NONE) { - assert_is_non_container_type(jc); - - switch(jc->type) { - case JSON_T_FLOAT: - arg = &value; - if (jc->handle_floats_manually) { - value.vu.str.value = jc->parse_buffer; - value.vu.str.length = jc->parse_buffer_count; - } else { - /*sscanf(jc->parse_buffer, "%Lf", &value.vu.float_value);*/ - - /* not checking with end pointer b/c there may be trailing ws */ - value.vu.float_value = strtod(jc->parse_buffer, NULL); - } - break; - case JSON_T_INTEGER: - arg = &value; - sscanf(jc->parse_buffer, JSON_PARSER_INTEGER_SSCANF_TOKEN, &value.vu.integer_value); - break; - case JSON_T_STRING: - arg = &value; - value.vu.str.value = jc->parse_buffer; - value.vu.str.length = jc->parse_buffer_count; - break; - } - - if (!(*jc->callback)(jc->ctx, jc->type, arg)) { - return false; - } - } - } - - parse_buffer_clear(jc); - - return true; -} - -#define IS_HIGH_SURROGATE(uc) (((uc) & 0xFC00) == 0xD800) -#define IS_LOW_SURROGATE(uc) (((uc) & 0xFC00) == 0xDC00) -#define DECODE_SURROGATE_PAIR(hi,lo) ((((hi) & 0x3FF) << 10) + ((lo) & 0x3FF) + 0x10000) -static unsigned char utf8_lead_bits[4] = { 0x00, 0xC0, 0xE0, 0xF0 }; - -static int decode_unicode_char(JSON_parser jc) -{ - int i; - unsigned uc = 0; - char* p; - int trail_bytes; - - assert(jc->parse_buffer_count >= 6); - - p = &jc->parse_buffer[jc->parse_buffer_count - 4]; - - for (i = 12; i >= 0; i -= 4, ++p) { - unsigned x = *p; - - if (x >= 'a') { - x -= ('a' - 10); - } else if (x >= 'A') { - x -= ('A' - 10); - } else { - x &= ~0x30u; - } - - assert(x < 16); - - uc |= x << i; - } - - /* clear UTF-16 char from buffer */ - jc->parse_buffer_count -= 6; - jc->parse_buffer[jc->parse_buffer_count] = 0; - - /* attempt decoding ... */ - if (jc->utf16_high_surrogate) { - if (IS_LOW_SURROGATE(uc)) { - uc = DECODE_SURROGATE_PAIR(jc->utf16_high_surrogate, uc); - trail_bytes = 3; - jc->utf16_high_surrogate = 0; - } else { - /* high surrogate without a following low surrogate */ - return false; - } - } else { - if (uc < 0x80) { - trail_bytes = 0; - } else if (uc < 0x800) { - trail_bytes = 1; - } else if (IS_HIGH_SURROGATE(uc)) { - /* save the high surrogate and wait for the low surrogate */ - jc->utf16_high_surrogate = uc; - return true; - } else if (IS_LOW_SURROGATE(uc)) { - /* low surrogate without a preceding high surrogate */ - return false; - } else { - trail_bytes = 2; - } - } - - jc->parse_buffer[jc->parse_buffer_count++] = (char) ((uc >> (trail_bytes * 6)) | utf8_lead_bits[trail_bytes]); - - for (i = trail_bytes * 6 - 6; i >= 0; i -= 6) { - jc->parse_buffer[jc->parse_buffer_count++] = (char) (((uc >> i) & 0x3F) | 0x80); - } - - jc->parse_buffer[jc->parse_buffer_count] = 0; - - return true; -} - -static int add_escaped_char_to_parse_buffer(JSON_parser jc, int next_char) -{ - jc->escaped = 0; - /* remove the backslash */ - parse_buffer_pop_back_char(jc); - switch(next_char) { - case 'b': - parse_buffer_push_back_char(jc, '\b'); - break; - case 'f': - parse_buffer_push_back_char(jc, '\f'); - break; - case 'n': - parse_buffer_push_back_char(jc, '\n'); - break; - case 'r': - parse_buffer_push_back_char(jc, '\r'); - break; - case 't': - parse_buffer_push_back_char(jc, '\t'); - break; - case '"': - parse_buffer_push_back_char(jc, '"'); - break; - case '\\': - parse_buffer_push_back_char(jc, '\\'); - break; - case '/': - parse_buffer_push_back_char(jc, '/'); - break; - case 'u': - parse_buffer_push_back_char(jc, '\\'); - parse_buffer_push_back_char(jc, 'u'); - break; - default: - return false; - } - - return true; -} - -#define add_char_to_parse_buffer(jc, next_char, next_class) \ - do { \ - if (jc->escaped) { \ - if (!add_escaped_char_to_parse_buffer(jc, next_char)) \ - return false; \ - } else if (!jc->comment) { \ - if ((jc->type != JSON_T_NONE) | !((next_class == C_SPACE) | (next_class == C_WHITE)) /* non-white-space */) { \ - parse_buffer_push_back_char(jc, (char)next_char); \ - } \ - } \ - } while (0) - - -#define assert_type_isnt_string_null_or_bool(jc) \ - assert(jc->type != JSON_T_FALSE); \ - assert(jc->type != JSON_T_TRUE); \ - assert(jc->type != JSON_T_NULL); \ - assert(jc->type != JSON_T_STRING) - - -int -JSON_parser_char(JSON_parser jc, int next_char) -{ -/* - After calling new_JSON_parser, call this function for each character (or - partial character) in your JSON text. It can accept UTF-8, UTF-16, or - UTF-32. It returns true if things are looking ok so far. If it rejects the - text, it returns false. -*/ - int next_class, next_state; - -/* - Determine the character's class. -*/ - if (next_char < 0) { - return false; - } - if (next_char >= 128) { - next_class = C_ETC; - } else { - next_class = ascii_class[next_char]; - if (next_class <= __) { - return false; - } - } - - add_char_to_parse_buffer(jc, next_char, next_class); - -/* - Get the next state from the state transition table. -*/ - next_state = state_transition_table[jc->state][next_class]; - if (next_state >= 0) { -/* - Change the state. -*/ - jc->state = next_state; - } else { -/* - Or perform one of the actions. -*/ - switch (next_state) { -/* Unicode character */ - case UC: - if(!decode_unicode_char(jc)) { - return false; - } - /* check if we need to read a second UTF-16 char */ - if (jc->utf16_high_surrogate) { - jc->state = D1; - } else { - jc->state = ST; - } - break; -/* escaped char */ - case EX: - jc->escaped = 1; - jc->state = ES; - break; -/* integer detected by minus */ - case MX: - jc->type = JSON_T_INTEGER; - jc->state = MI; - break; -/* integer detected by zero */ - case ZX: - jc->type = JSON_T_INTEGER; - jc->state = ZE; - break; -/* integer detected by 1-9 */ - case IX: - jc->type = JSON_T_INTEGER; - jc->state = IT; - break; - -/* floating point number detected by exponent*/ - case DE: - assert_type_isnt_string_null_or_bool(jc); - jc->type = JSON_T_FLOAT; - jc->state = E1; - break; - -/* floating point number detected by fraction */ - case DF: - assert_type_isnt_string_null_or_bool(jc); - if (!jc->handle_floats_manually) { -/* - Some versions of strtod (which underlies sscanf) don't support converting - C-locale formated floating point values. -*/ - assert(jc->parse_buffer[jc->parse_buffer_count-1] == '.'); - jc->parse_buffer[jc->parse_buffer_count-1] = jc->decimal_point; - } - jc->type = JSON_T_FLOAT; - jc->state = FX; - break; -/* string begin " */ - case SB: - parse_buffer_clear(jc); - assert(jc->type == JSON_T_NONE); - jc->type = JSON_T_STRING; - jc->state = ST; - break; - -/* n */ - case NU: - assert(jc->type == JSON_T_NONE); - jc->type = JSON_T_NULL; - jc->state = N1; - break; -/* f */ - case FA: - assert(jc->type == JSON_T_NONE); - jc->type = JSON_T_FALSE; - jc->state = F1; - break; -/* t */ - case TR: - assert(jc->type == JSON_T_NONE); - jc->type = JSON_T_TRUE; - jc->state = T1; - break; - -/* closing comment */ - case CE: - jc->comment = 0; - assert(jc->parse_buffer_count == 0); - assert(jc->type == JSON_T_NONE); - jc->state = jc->before_comment_state; - break; - -/* opening comment */ - case CB: - if (!jc->allow_comments) { - return false; - } - parse_buffer_pop_back_char(jc); - if (!parse_parse_buffer(jc)) { - return false; - } - assert(jc->parse_buffer_count == 0); - assert(jc->type != JSON_T_STRING); - switch (jc->stack[jc->top]) { - case MODE_ARRAY: - case MODE_OBJECT: - switch(jc->state) { - case VA: - case AR: - jc->before_comment_state = jc->state; - break; - default: - jc->before_comment_state = OK; - break; - } - break; - default: - jc->before_comment_state = jc->state; - break; - } - jc->type = JSON_T_NONE; - jc->state = C1; - jc->comment = 1; - break; -/* empty } */ - case -9: - parse_buffer_clear(jc); - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_END, NULL)) { - return false; - } - if (!pop(jc, MODE_KEY)) { - return false; - } - jc->state = OK; - break; - -/* } */ case -8: - parse_buffer_pop_back_char(jc); - if (!parse_parse_buffer(jc)) { - return false; - } - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_END, NULL)) { - return false; - } - if (!pop(jc, MODE_OBJECT)) { - return false; - } - jc->type = JSON_T_NONE; - jc->state = OK; - break; - -/* ] */ case -7: - parse_buffer_pop_back_char(jc); - if (!parse_parse_buffer(jc)) { - return false; - } - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_ARRAY_END, NULL)) { - return false; - } - if (!pop(jc, MODE_ARRAY)) { - return false; - } - - jc->type = JSON_T_NONE; - jc->state = OK; - break; - -/* { */ case -6: - parse_buffer_pop_back_char(jc); - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_BEGIN, NULL)) { - return false; - } - if (!push(jc, MODE_KEY)) { - return false; - } - assert(jc->type == JSON_T_NONE); - jc->state = OB; - break; - -/* [ */ case -5: - parse_buffer_pop_back_char(jc); - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_ARRAY_BEGIN, NULL)) { - return false; - } - if (!push(jc, MODE_ARRAY)) { - return false; - } - assert(jc->type == JSON_T_NONE); - jc->state = AR; - break; - -/* string end " */ case -4: - parse_buffer_pop_back_char(jc); - switch (jc->stack[jc->top]) { - case MODE_KEY: - assert(jc->type == JSON_T_STRING); - jc->type = JSON_T_NONE; - jc->state = CO; - - if (jc->callback) { - JSON_value value; - value.vu.str.value = jc->parse_buffer; - value.vu.str.length = jc->parse_buffer_count; - if (!(*jc->callback)(jc->ctx, JSON_T_KEY, &value)) { - return false; - } - } - parse_buffer_clear(jc); - break; - case MODE_ARRAY: - case MODE_OBJECT: - assert(jc->type == JSON_T_STRING); - if (!parse_parse_buffer(jc)) { - return false; - } - jc->type = JSON_T_NONE; - jc->state = OK; - break; - default: - return false; - } - break; - -/* , */ case -3: - parse_buffer_pop_back_char(jc); - if (!parse_parse_buffer(jc)) { - return false; - } - switch (jc->stack[jc->top]) { - case MODE_OBJECT: -/* - A comma causes a flip from object mode to key mode. -*/ - if (!pop(jc, MODE_OBJECT) || !push(jc, MODE_KEY)) { - return false; - } - assert(jc->type != JSON_T_STRING); - jc->type = JSON_T_NONE; - jc->state = KE; - break; - case MODE_ARRAY: - assert(jc->type != JSON_T_STRING); - jc->type = JSON_T_NONE; - jc->state = VA; - break; - default: - return false; - } - break; - -/* : */ case -2: -/* - A colon causes a flip from key mode to object mode. -*/ - parse_buffer_pop_back_char(jc); - if (!pop(jc, MODE_KEY) || !push(jc, MODE_OBJECT)) { - return false; - } - assert(jc->type == JSON_T_NONE); - jc->state = VA; - break; -/* - Bad action. -*/ - default: - return false; - } - } - return true; -} - - -int -JSON_parser_done(JSON_parser jc) -{ - const int result = jc->state == OK && pop(jc, MODE_DONE); - - return result; -} - - -int JSON_parser_is_legal_white_space_string(const char* s) -{ - int c, char_class; - - if (s == NULL) { - return false; - } - - for (; *s; ++s) { - c = *s; - - if (c < 0 || c >= 128) { - return false; - } - - char_class = ascii_class[c]; - - if (char_class != C_SPACE && char_class != C_WHITE) { - return false; - } - } - - return true; -} - - - -void init_JSON_config(JSON_config* config) -{ - if (config) { - memset(config, 0, sizeof(*config)); - - config->depth = JSON_PARSER_STACK_SIZE - 1; - } -} diff --git a/decoder/JSON_parser.h b/decoder/JSON_parser.h deleted file mode 100644 index de980072..00000000 --- a/decoder/JSON_parser.h +++ /dev/null @@ -1,152 +0,0 @@ -#ifndef JSON_PARSER_H -#define JSON_PARSER_H - -/* JSON_parser.h */ - - -#include <stddef.h> - -/* Windows DLL stuff */ -#ifdef _WIN32 -# ifdef JSON_PARSER_DLL_EXPORTS -# define JSON_PARSER_DLL_API __declspec(dllexport) -# else -# define JSON_PARSER_DLL_API __declspec(dllimport) -# endif -#else -# define JSON_PARSER_DLL_API -#endif - -/* Determine the integer type use to parse non-floating point numbers */ -#if __STDC_VERSION__ >= 199901L || HAVE_LONG_LONG == 1 -typedef long long JSON_int_t; -#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%lld" -#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%lld" -#else -typedef long JSON_int_t; -#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%ld" -#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%ld" -#endif - - -#ifdef __cplusplus -extern "C" { -#endif - -typedef enum -{ - JSON_T_NONE = 0, - JSON_T_ARRAY_BEGIN, // 1 - JSON_T_ARRAY_END, // 2 - JSON_T_OBJECT_BEGIN, // 3 - JSON_T_OBJECT_END, // 4 - JSON_T_INTEGER, // 5 - JSON_T_FLOAT, // 6 - JSON_T_NULL, // 7 - JSON_T_TRUE, // 8 - JSON_T_FALSE, // 9 - JSON_T_STRING, // 10 - JSON_T_KEY, // 11 - JSON_T_MAX // 12 -} JSON_type; - -typedef struct JSON_value_struct { - union { - JSON_int_t integer_value; - - double float_value; - - struct { - const char* value; - size_t length; - } str; - } vu; -} JSON_value; - -typedef struct JSON_parser_struct* JSON_parser; - -/*! \brief JSON parser callback - - \param ctx The pointer passed to new_JSON_parser. - \param type An element of JSON_type but not JSON_T_NONE. - \param value A representation of the parsed value. This parameter is NULL for - JSON_T_ARRAY_BEGIN, JSON_T_ARRAY_END, JSON_T_OBJECT_BEGIN, JSON_T_OBJECT_END, - JSON_T_NULL, JSON_T_TRUE, and SON_T_FALSE. String values are always returned - as zero-terminated C strings. - - \return Non-zero if parsing should continue, else zero. -*/ -typedef int (*JSON_parser_callback)(void* ctx, int type, const struct JSON_value_struct* value); - - -/*! \brief The structure used to configure a JSON parser object - - \param depth If negative, the parser can parse arbitrary levels of JSON, otherwise - the depth is the limit - \param Pointer to a callback. This parameter may be NULL. In this case the input is merely checked for validity. - \param Callback context. This parameter may be NULL. - \param depth. Specifies the levels of nested JSON to allow. Negative numbers yield unlimited nesting. - \param allowComments. To allow C style comments in JSON, set to non-zero. - \param handleFloatsManually. To decode floating point numbers manually set this parameter to non-zero. - - \return The parser object. -*/ -typedef struct { - JSON_parser_callback callback; - void* callback_ctx; - int depth; - int allow_comments; - int handle_floats_manually; -} JSON_config; - - -/*! \brief Initializes the JSON parser configuration structure to default values. - - The default configuration is - - 127 levels of nested JSON (depends on JSON_PARSER_STACK_SIZE, see json_parser.c) - - no parsing, just checking for JSON syntax - - no comments - - \param config. Used to configure the parser. -*/ -JSON_PARSER_DLL_API void init_JSON_config(JSON_config* config); - -/*! \brief Create a JSON parser object - - \param config. Used to configure the parser. Set to NULL to use the default configuration. - See init_JSON_config - - \return The parser object. -*/ -JSON_PARSER_DLL_API extern JSON_parser new_JSON_parser(JSON_config* config); - -/*! \brief Destroy a previously created JSON parser object. */ -JSON_PARSER_DLL_API extern void delete_JSON_parser(JSON_parser jc); - -/*! \brief Parse a character. - - \return Non-zero, if all characters passed to this function are part of are valid JSON. -*/ -JSON_PARSER_DLL_API extern int JSON_parser_char(JSON_parser jc, int next_char); - -/*! \brief Finalize parsing. - - Call this method once after all input characters have been consumed. - - \return Non-zero, if all parsed characters are valid JSON, zero otherwise. -*/ -JSON_PARSER_DLL_API extern int JSON_parser_done(JSON_parser jc); - -/*! \brief Determine if a given string is valid JSON white space - - \return Non-zero if the string is valid, zero otherwise. -*/ -JSON_PARSER_DLL_API extern int JSON_parser_is_legal_white_space_string(const char* s); - - -#ifdef __cplusplus -} -#endif - - -#endif /* JSON_PARSER_H */ diff --git a/decoder/Makefile.am b/decoder/Makefile.am index b23bbad4..e313f1f9 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -36,10 +36,10 @@ noinst_LIBRARIES = libcdec.a EXTRA_DIST = test_data rule_lexer.ll libcdec_a_SOURCES = \ - JSON_parser.h \ aligner.h \ apply_models.h \ bottom_up_parser.h \ + bottom_up_parser-rs.h \ csplit.h \ decoder.h \ earley_composer.h \ @@ -48,12 +48,15 @@ libcdec_a_SOURCES = \ ff_basic.h \ ff_bleu.h \ ff_charset.h \ + ff_conll.h \ + ff_const_reorder_common.h \ + ff_const_reorder.h \ ff_context.h \ ff_csplit.h \ ff_external.h \ ff_factory.h \ ff_klm.h \ - ff_lexical.h \ + ff_lexical.h \ ff_lm.h \ ff_ngrams.h \ ff_parse_match.h \ @@ -61,6 +64,7 @@ libcdec_a_SOURCES = \ ff_rules.h \ ff_ruleshape.h \ ff_sample_fsa.h \ + ff_soft_syn.h \ ff_soft_syntax.h \ ff_soft_syntax_mindist.h \ ff_source_path.h \ @@ -83,7 +87,6 @@ libcdec_a_SOURCES = \ hg_union.h \ incremental.h \ inside_outside.h \ - json_parse.h \ kbest.h \ lattice.h \ lexalign.h \ @@ -103,6 +106,7 @@ libcdec_a_SOURCES = \ aligner.cc \ apply_models.cc \ bottom_up_parser.cc \ + bottom_up_parser-rs.cc \ cdec.cc \ cdec_ff.cc \ csplit.cc \ @@ -113,7 +117,9 @@ libcdec_a_SOURCES = \ ff_basic.cc \ ff_bleu.cc \ ff_charset.cc \ + ff_conll.cc \ ff_context.cc \ + ff_const_reorder.cc \ ff_csplit.cc \ ff_external.cc \ ff_factory.cc \ @@ -123,6 +129,7 @@ libcdec_a_SOURCES = \ ff_parse_match.cc \ ff_rules.cc \ ff_ruleshape.cc \ + ff_soft_syn.cc \ ff_soft_syntax.cc \ ff_soft_syntax_mindist.cc \ ff_source_path.cc \ @@ -144,7 +151,6 @@ libcdec_a_SOURCES = \ hg_sampler.cc \ hg_union.cc \ incremental.cc \ - json_parse.cc \ lattice.cc \ lexalign.cc \ lextrans.cc \ @@ -160,5 +166,4 @@ libcdec_a_SOURCES = \ tagger.cc \ translator.cc \ trule.cc \ - viterbi.cc \ - JSON_parser.c + viterbi.cc diff --git a/decoder/aligner.h b/decoder/aligner.h index a34795c9..d68ceefc 100644 --- a/decoder/aligner.h +++ b/decoder/aligner.h @@ -1,4 +1,4 @@ -#ifndef _ALIGNER_H_ +#ifndef ALIGNER_H #include <string> #include <iostream> diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9f8bbead..18c83fd4 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -233,7 +233,20 @@ public: void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) { Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_); new_edge->edge_prob_ = item->out_edge_.edge_prob_; - Candidate*& o_item = (*s2n)[item->state_]; + + Candidate** o_item_ptr = nullptr; + if (item->state_.size() && models.NeedsStateErasure()) { + // When erasure of certain state bytes is needed, we must make a copy of + // the state instead of doing the erasure in-place because future + // candidates may require the information in the bytes to be erased. + FFState state(item->state_); + models.EraseIgnoredBytes(&state); + o_item_ptr = &(*s2n)[state]; + } else { + o_item_ptr = &(*s2n)[item->state_]; + } + Candidate*& o_item = *o_item_ptr; + if (!o_item) o_item = item; int& node_id = o_item->node_index_; @@ -254,7 +267,18 @@ public: // score is the same for all items with a common residual DP // state if (item->vit_prob_ > o_item->vit_prob_) { - assert(o_item->state_ == item->state_); // sanity check! + if (item->state_.size() && models.NeedsStateErasure()) { + // node_states_ should still point to the unerased state. + node_states_[o_item->node_index_] = item->state_; + // sanity check! + FFState item_state(item->state_), o_item_state(o_item->state_); + models.EraseIgnoredBytes(&item_state); + models.EraseIgnoredBytes(&o_item_state); + assert(item_state == o_item_state); + } else { + assert(o_item->state_ == item->state_); // sanity check! + } + o_item->est_prob_ = item->est_prob_; o_item->vit_prob_ = item->vit_prob_; } @@ -599,9 +623,10 @@ void ApplyModelSet(const Hypergraph& in, if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) { NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state ma.Apply(); - } else if (config.algorithm == IntersectionConfiguration::CUBE - || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING - || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) { + } else if (config.algorithm == IntersectionConfiguration::CUBE || + config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING || + config.algorithm == + IntersectionConfiguration::FAST_CUBE_PRUNING_2) { int pl = config.pop_limit; const int max_pl_for_large=50; if (pl > max_pl_for_large && in.nodes_.size() > 80000) { @@ -628,4 +653,3 @@ void ApplyModelSet(const Hypergraph& in, out->is_linear_chain_ = in.is_linear_chain_; // TODO remove when this is computed // automatically } - diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc new file mode 100644 index 00000000..fbde7e24 --- /dev/null +++ b/decoder/bottom_up_parser-rs.cc @@ -0,0 +1,341 @@ +#include "bottom_up_parser-rs.h" + +#include <iostream> +#include <map> + +#include "node_state_hash.h" +#include "nt_span.h" +#include "hg.h" +#include "array2d.h" +#include "tdict.h" +#include "verbose.h" + +using namespace std; + +static WordID kEPS = 0; + +struct RSActiveItem; +class RSChart { + public: + RSChart(const string& goal, + const vector<GrammarPtr>& grammars, + const Lattice& input, + Hypergraph* forest); + ~RSChart(); + + void AddToChart(const RSActiveItem& x, int i, int j); + void ConsumeTerminal(const RSActiveItem& x, int i, int j, int k); + void ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k); + bool Parse(); + inline bool GoalFound() const { return goal_idx_ >= 0; } + inline int GetGoalIndex() const { return goal_idx_; } + + private: + void ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const float lattice_cost); + + // returns true if a new node was added to the chart + // false otherwise + bool ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const float lattice_cost); + + void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx); + void TopoSortUnaries(); + + const vector<GrammarPtr>& grammars_; + const Lattice& input_; + Hypergraph* forest_; + Array2D<vector<int>> chart_; // chart_(i,j) is the list of nodes (represented + // by their index in forest_->nodes_) derived spanning i,j + typedef map<int, int> Cat2NodeMap; + Array2D<Cat2NodeMap> nodemap_; + const WordID goal_cat_; // category that is being searched for at [0,n] + TRulePtr goal_rule_; + int goal_idx_; // index of goal node, if found + const int lc_fid_; + vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars + + static WordID kGOAL; // [Goal] +}; + +WordID RSChart::kGOAL = 0; + +// "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span" +struct RSActiveItem { + explicit RSActiveItem(const GrammarIter* g, int i) : + gptr_(g), ant_nodes_(), lattice_cost(0.0), i_(i) {} + void ExtendTerminal(int symbol, float src_cost) { + lattice_cost += src_cost; + if (symbol != kEPS) + gptr_ = gptr_->Extend(symbol); + } + void ExtendNonTerminal(const Hypergraph* hg, int node_index) { + gptr_ = gptr_->Extend(hg->nodes_[node_index].cat_); + ant_nodes_.push_back(node_index); + } + // returns false if the extension has failed + explicit operator bool() const { + return gptr_; + } + const GrammarIter* gptr_; + Hypergraph::TailNodeVector ant_nodes_; + float lattice_cost; // TODO: use SparseVector<double> to encode input features + short i_; +}; + +// some notes on the implementation +// "X" in Rico's Algorithm 2 roughly looks like it is just a pointer into a grammar +// trie, but it is actually a full "dotted item" since it needs to contain the information +// to build the hypergraph (i.e., it must remember the antecedent nodes and where they are, +// also any information about the path costs). + +RSChart::RSChart(const string& goal, + const vector<GrammarPtr>& grammars, + const Lattice& input, + Hypergraph* forest) : + grammars_(grammars), + input_(input), + forest_(forest), + chart_(input.size()+1, input.size()+1), + nodemap_(input.size()+1, input.size()+1), + goal_cat_(TD::Convert(goal) * -1), + goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), + goal_idx_(-1), + lc_fid_(FD::Convert("LatticeCost")), + unaries_() { + for (unsigned i = 0; i < grammars_.size(); ++i) { + const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules(); + for (unsigned j = 0; j < u.size(); ++j) + unaries_.push_back(u[j]); + } + TopoSortUnaries(); + if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; + if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl; +} + +static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) { + if (mark[node] == 1) { + cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n"; + return false; // cycle detected + } else if (mark[node] == 2) { + return true; // already been + } + mark[node] = 1; + const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node); + if (nit != g.end()) { + const vector<TRulePtr>& edges = nit->second; + vector<bool> okay(edges.size(), true); + for (unsigned i = 0; i < edges.size(); ++i) { + okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark); + if (!okay[i]) { + cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl; + } + } + for (unsigned i = 0; i < edges.size(); ++i) { + if (okay[i]) u.push_back(edges[i]); + //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl; + } + } + mark[node] = 2; + return true; +} + +void RSChart::TopoSortUnaries() { + vector<TRulePtr> u(unaries_.size()); u.clear(); + map<int, vector<TRulePtr> > g; + map<int, int> mark; + //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl; + mark[goal_cat_] = 2; + for (unsigned i = 0; i < unaries_.size(); ++i) { + //cerr << "Adding: " << unaries_[i]->AsString() << endl; + g[unaries_[i]->f()[0]].push_back(unaries_[i]); + } + //m[unaries_[i]->lhs_].push_back(unaries_[i]); + for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) { + //cerr << "PROC: " << TD::Convert(-it->first) << endl; + if (mark[it->first] > 0) { + //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n"; + } else { + TopoSortVisit(it->first, u, g, mark); + } + } + unaries_.clear(); + for (int i = u.size() - 1; i >= 0; --i) + unaries_.push_back(u[i]); +} + +bool RSChart::ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const float lattice_cost) { + Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); + //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; + new_edge->prev_i_ = r->prev_i; + new_edge->prev_j_ = r->prev_j; + new_edge->i_ = i; + new_edge->j_ = j; + new_edge->feature_values_ = r->GetFeatureValues(); + if (lattice_cost && lc_fid_) + new_edge->feature_values_.set_value(lc_fid_, lattice_cost); + Cat2NodeMap& c2n = nodemap_(i,j); + const bool is_goal = (r->GetLHS() == kGOAL); + const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); + Hypergraph::Node* node = NULL; + bool added_node = false; + if (ni == c2n.end()) { + //cerr << "(" << i << "," << j << ") => " << TD::Convert(-r->GetLHS()) << endl; + added_node = true; + node = forest_->AddNode(r->GetLHS()); + c2n[r->GetLHS()] = node->id_; + if (is_goal) { + assert(goal_idx_ == -1); + goal_idx_ = node->id_; + } else { + chart_(i,j).push_back(node->id_); + } + } else { + node = &forest_->nodes_[ni->second]; + } + forest_->ConnectEdgeToHeadNode(new_edge, node); + return added_node; +} + +void RSChart::ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const float lattice_cost) { + const int n = rules->GetNumRules(); + //cerr << i << " " << j << ": NUM RULES: " << n << endl; + for (int k = 0; k < n; ++k) { + //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl; + TRulePtr rule = rules->GetIthRule(k); + // apply rule, and if we create a new node, apply any necessary + // unary rules + if (ApplyRule(i, j, rule, tail, lattice_cost)) { + unsigned nodeidx = nodemap_(i,j)[rule->lhs_]; + ApplyUnaryRules(i, j, rule->lhs_, nodeidx); + } + } +} + +void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx) { + for (unsigned ri = 0; ri < unaries_.size(); ++ri) { + //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; + if (unaries_[ri]->f()[0] == cat) { + //cerr << " --MATCH\n"; + WordID new_lhs = unaries_[ri]->GetLHS(); + const Hypergraph::TailNodeVector ant(1, nodeidx); + if (ApplyRule(i, j, unaries_[ri], ant, 0)) { + //cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl; + unsigned nodeidx = nodemap_(i,j)[new_lhs]; + ApplyUnaryRules(i, j, new_lhs, nodeidx); + } + } + } +} + +void RSChart::AddToChart(const RSActiveItem& x, int i, int j) { + // deal with completed rules + const RuleBin* rb = x.gptr_->GetRules(); + if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_cost); + + //cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n"; + // continue looking for extensions of the rule to the right + for (unsigned k = j+1; k <= input_.size(); ++k) { + ConsumeTerminal(x, i, j, k); + ConsumeNonTerminal(x, i, j, k); + } +} + +void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) { + //cerr << "ConsumeT(" << i << "," << j << "," << k << "):\n"; + + const unsigned check_edge_len = k - j; + // long-term TODO preindex this search so i->len->words is constant time rather than fan out + for (auto& in_edge : input_[j]) { + if (in_edge.dist2next == check_edge_len) { + //cerr << " Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl; + RSActiveItem copy = x; + copy.ExtendTerminal(in_edge.label, in_edge.cost); + if (copy) AddToChart(copy, i, k); + } + } +} + +void RSChart::ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k) { + //cerr << "ConsumeNT(" << i << "," << j << "," << k << "):\n"; + for (auto& nodeidx : chart_(j,k)) { + //cerr << " Found completed NT in (" << j << "," << k << ") of type " << TD::Convert(-forest_->nodes_[nodeidx].cat_) << endl; + RSActiveItem copy = x; + copy.ExtendNonTerminal(forest_, nodeidx); + if (copy) AddToChart(copy, i, k); + } +} + +bool RSChart::Parse() { + size_t in_size_2 = input_.size() * input_.size(); + forest_->nodes_.reserve(in_size_2 * 2); + size_t res = min(static_cast<size_t>(2000000), static_cast<size_t>(in_size_2 * 1000)); + forest_->edges_.reserve(res); + goal_idx_ = -1; + const int N = input_.size(); + for (int i = N - 1; i >= 0; --i) { + for (int j = i + 1; j <= N; ++j) { + for (unsigned gi = 0; gi < grammars_.size(); ++gi) { + RSActiveItem item(grammars_[gi]->GetRoot(), i); + ConsumeTerminal(item, i, i, j); + } + for (unsigned gi = 0; gi < grammars_.size(); ++gi) { + RSActiveItem item(grammars_[gi]->GetRoot(), i); + ConsumeNonTerminal(item, i, i, j); + } + } + } + + // look for goal + const vector<int>& dh = chart_(0, input_.size()); + for (unsigned di = 0; di < dh.size(); ++di) { + const Hypergraph::Node& node = forest_->nodes_[dh[di]]; + if (node.cat_ == goal_cat_) { + Hypergraph::TailNodeVector ant(1, node.id_); + ApplyRule(0, input_.size(), goal_rule_, ant, 0); + } + } + if (!SILENT) cerr << endl; + + if (GoalFound()) + forest_->PruneUnreachable(forest_->nodes_.size() - 1); + return GoalFound(); +} + +RSChart::~RSChart() {} + +RSExhaustiveBottomUpParser::RSExhaustiveBottomUpParser( + const string& goal_sym, + const vector<GrammarPtr>& grammars) : + goal_sym_(goal_sym), + grammars_(grammars) {} + +bool RSExhaustiveBottomUpParser::Parse(const Lattice& input, + Hypergraph* forest) const { + kEPS = TD::Convert("*EPS*"); + RSChart chart(goal_sym_, grammars_, input, forest); + const bool result = chart.Parse(); + + if (result) { + for (auto& node : forest->nodes_) { + Span prev; + const Span s = forest->NodeSpan(node.id_, &prev); + node.node_hash = cdec::HashNode(node.cat_, s.l, s.r, prev.l, prev.r); + } + } + return result; +} diff --git a/decoder/bottom_up_parser-rs.h b/decoder/bottom_up_parser-rs.h new file mode 100644 index 00000000..2e271e99 --- /dev/null +++ b/decoder/bottom_up_parser-rs.h @@ -0,0 +1,29 @@ +#ifndef RSBOTTOM_UP_PARSER_H_ +#define RSBOTTOM_UP_PARSER_H_ + +#include <vector> +#include <string> + +#include "lattice.h" +#include "grammar.h" + +class Hypergraph; + +// implementation of Sennrich (2014) parser +// http://aclweb.org/anthology/W/W14/W14-4011.pdf +class RSExhaustiveBottomUpParser { + public: + RSExhaustiveBottomUpParser(const std::string& goal_sym, + const std::vector<GrammarPtr>& grammars); + + // returns true if goal reached spanning the full input + // forest contains the full (i.e., unpruned) parse forest + bool Parse(const Lattice& input, + Hypergraph* forest) const; + + private: + const std::string goal_sym_; + const std::vector<GrammarPtr> grammars_; +}; + +#endif diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 7f7e075b..973a643a 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -3,6 +3,7 @@ #include "ff.h" #include "ff_basic.h" #include "ff_context.h" +#include "ff_const_reorder.h" #include "ff_spans.h" #include "ff_lm.h" #include "ff_klm.h" @@ -14,6 +15,7 @@ #include "ff_rules.h" #include "ff_ruleshape.h" #include "ff_bleu.h" +#include "ff_soft_syn.h" #include "ff_soft_syntax.h" #include "ff_soft_syntax_mindist.h" #include "ff_source_path.h" @@ -77,6 +79,7 @@ void register_feature_functions() { ff_registry.Register("WordPairFeatures", new FFFactory<WordPairFeatures>); ff_registry.Register("SourcePathFeatures", new FFFactory<SourcePathFeatures>); ff_registry.Register("WordSet", new FFFactory<WordSet>); + ff_registry.Register("ConstReorderFeature", new FFFactory<ConstReorderFeature>); ff_registry.Register("External", new FFFactory<ExternalFeature>); + ff_registry.Register("SoftSynFeature", new SoftSynFeatureFactory()); } - diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 9e8d692a..1e6c3194 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -17,6 +17,7 @@ namespace std { using std::tr1::unordered_map; } #include "fdict.h" #include "timing_stats.h" #include "verbose.h" +#include "b64featvector.h" #include "translator.h" #include "phrasebased_translator.h" @@ -195,7 +196,7 @@ struct DecoderImpl { } forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1); if (!forestname.empty()) forestname=" "+forestname; - if (!SILENT) { + if (!SILENT) { forest_stats(forest," Pruned "+forestname+" forest",false,false); cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl; } @@ -261,7 +262,7 @@ struct DecoderImpl { assert(ref); LatticeTools::ConvertTextOrPLF(sref, ref); } - } + } // used to construct the suffix string to get the name of arguments for multiple passes // e.g., the "2" in --weights2 @@ -284,7 +285,7 @@ struct DecoderImpl { boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng; int sample_max_trans; bool aligner_mode; - bool graphviz; + bool graphviz; bool joshua_viz; bool encode_b64; bool kbest; @@ -301,6 +302,7 @@ struct DecoderImpl { bool feature_expectations; // TODO Observer bool output_training_vector; // TODO Observer bool remove_intersected_rule_annotations; + bool mr_mira_compat; // Mr.MIRA compatibility mode. boost::scoped_ptr<IncrementalBase> incremental; @@ -404,6 +406,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") ("extract_rules", po::value<string>(), "Extract the rules used in translation (not de-duped!) to a file in this directory") ("show_derivations", po::value<string>(), "Directory to print the derivation structures to") + ("show_derivations_mask", po::value<int>()->default_value(Hypergraph::SPAN|Hypergraph::RULE), "Bit-mask for what to print in derivation structures") ("graphviz","Show (constrained) translation forest in GraphViz format") ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart") ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart") @@ -414,7 +417,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") ("forest_output,O",po::value<string>(),"Directory to write forests to") - ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)"); + ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)") + ("mr_mira_compat", "Mr.MIRA compatibility mode (applies weight delta if available; outputs number of lines before k-best)"); // ob.AddOptions(&opts); po::options_description clo("Command line options"); @@ -665,7 +669,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream unique_kbest = conf.count("unique_k_best"); get_oracle_forest = conf.count("get_oracle_forest"); oracle.show_derivation=conf.count("show_derivations"); + oracle.show_derivation_mask=conf["show_derivations_mask"].as<int>(); remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations"); + mr_mira_compat = conf.count("mr_mira_compat"); combine_size = conf["combine_size"].as<int>(); if (combine_size < 1) combine_size = 1; @@ -699,6 +705,24 @@ void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string); } +static inline void ApplyWeightDelta(const string &delta_b64, vector<weight_t> *weights) { + SparseVector<weight_t> delta; + DecodeFeatureVector(delta_b64, &delta); + if (delta.empty()) return; + // Apply updates + for (SparseVector<weight_t>::iterator dit = delta.begin(); + dit != delta.end(); ++dit) { + int feat_id = dit->first; + union { weight_t weight; unsigned long long repr; } feat_delta; + feat_delta.weight = dit->second; + if (!SILENT) + cerr << "[decoder weight update] " << FD::Convert(feat_id) << " " << feat_delta.weight + << " = " << hex << feat_delta.repr << endl; + if (weights->size() <= feat_id) weights->resize(feat_id + 1); + (*weights)[feat_id] += feat_delta.weight; + } +} + bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { string buf = input; NgramCache::Clear(); // clear ngram cache for remote LM (if used) @@ -709,6 +733,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (sgml.find("id") != sgml.end()) sent_id = atoi(sgml["id"].c_str()); + // Add delta from input to weights before decoding + if (mr_mira_compat) + ApplyWeightDelta(sgml["delta"], init_weights.get()); + if (!SILENT) { cerr << "\nINPUT: "; if (buf.size() < 100) @@ -928,14 +956,14 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { Hypergraph new_hg; { ReadFile rf(writer.fname_); - bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); + bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg); if (!succeeded) abort(); } HG::Union(forest, &new_hg); - bool succeeded = writer.Write(new_hg, false); + bool succeeded = writer.Write(new_hg); if (!succeeded) abort(); } else { - bool succeeded = writer.Write(forest, false); + bool succeeded = writer.Write(forest); if (!succeeded) abort(); } } @@ -947,7 +975,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (kbest && !has_ref) { //TODO: does this work properly? const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; - oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); + oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname); } else if (csplit_output_plf) { cout << HypergraphIO::AsPLF(forest, false) << endl; } else { @@ -1021,14 +1049,14 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { Hypergraph new_hg; { ReadFile rf(writer.fname_); - bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); + bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg); if (!succeeded) abort(); } HG::Union(forest, &new_hg); - bool succeeded = writer.Write(new_hg, false); + bool succeeded = writer.Write(new_hg); if (!succeeded) abort(); } else { - bool succeeded = writer.Write(forest, false); + bool succeeded = writer.Write(forest); if (!succeeded) abort(); } } @@ -1078,7 +1106,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (conf.count("graphviz")) forest.PrintGraphviz(); if (kbest) { const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; - oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); + oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest, mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname); } if (conf.count("show_conditional_prob")) { const prob_t ref_z = Inside<prob_t, EdgeProb>(forest); @@ -1098,4 +1126,3 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { o->NotifyDecodingComplete(smeta); return true; } - diff --git a/decoder/ff.h b/decoder/ff.h index eed1e3fb..d6487d97 100644 --- a/decoder/ff.h +++ b/decoder/ff.h @@ -17,11 +17,23 @@ class FeatureFunction { friend class ExternalFeature; public: std::string name_; // set by FF factory using usage() - FeatureFunction() : state_size_() {} - explicit FeatureFunction(int state_size) : state_size_(state_size) {} + FeatureFunction() : state_size_(), ignored_state_size_() {} + explicit FeatureFunction(int state_size, int ignored_state_size = 0) + : state_size_(state_size), ignored_state_size_(ignored_state_size) {} virtual ~FeatureFunction(); bool IsStateful() const { return state_size_ > 0; } int StateSize() const { return state_size_; } + // Returns the number of bytes in the state that should be ignored during + // search. When non-zero, the last N bytes in the state should be ignored when + // splitting a hypernode by the state. This allows the feature function to + // store some side data and later retrieve it via the state bytes. + // + // In general, this should not be necessary and it should always be possible + // to replace this with a more appropriate design of state (if you find + // yourself having to ignore some part of the state, you are most likely + // storing redundant information in the state). Be sure that you + // understand how this affects ApplyModelSet() before using it. + int IgnoredStateSize() const { return ignored_state_size_; } // override this. not virtual because we want to expose this to factory template for help before creating a FF static std::string usage(bool show_params,bool show_details) { @@ -71,12 +83,18 @@ class FeatureFunction { SparseVector<double>* estimated_features, void* context) const; - // !!! ONLY call this from subclass *CONSTRUCTORS* !!! + // !!! ONLY call these from subclass *CONSTRUCTORS* !!! void SetStateSize(size_t state_size) { state_size_ = state_size; } + + // See document of IgnoredStateSize() above. + void SetIgnoredStateSize(size_t ignored_state_size) { + ignored_state_size_ = ignored_state_size; + } + private: - int state_size_; + int state_size_, ignored_state_size_; }; #endif diff --git a/decoder/ff_conll.cc b/decoder/ff_conll.cc new file mode 100644 index 00000000..8ded44b7 --- /dev/null +++ b/decoder/ff_conll.cc @@ -0,0 +1,250 @@ +#include "ff_conll.h" + +#include <stdlib.h> +#include <sstream> +#include <cassert> +#include <cmath> +#include <boost/lexical_cast.hpp> + +#include "hg.h" +#include "filelib.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "lattice.h" +#include "fdict.h" +#include "verbose.h" +#include "tdict.h" + +CoNLLFeatures::CoNLLFeatures(const string& param) { + // cerr << "initializing CoNLLFeatures with parameters: " << param; + kSOS = TD::Convert("<s>"); + kEOS = TD::Convert("</s>"); + macro_regex = sregex::compile("%([xy])\\[(-[1-9][0-9]*|0|[1-9][1-9]*)]"); + ParseArgs(param); +} + +string CoNLLFeatures::Escape(const string& x) const { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; +} + +// replace %x[relative_location] or %y[relative_location] with actual_token +// within feature_instance +void CoNLLFeatures::ReplaceMacroWithString( + string& feature_instance, bool token_vs_label, int relative_location, + const string& actual_token) const { + + stringstream macro; + if (token_vs_label) { + macro << "%x["; + } else { + macro << "%y["; + } + macro << relative_location << "]"; + int macro_index = feature_instance.find(macro.str()); + if (macro_index == string::npos) { + cerr << "Can't find macro " << macro.str() << " in feature template " + << feature_instance; + abort(); + } + feature_instance.replace(macro_index, macro.str().size(), actual_token); +} + +void CoNLLFeatures::ReplaceTokenMacroWithString( + string& feature_instance, int relative_location, + const string& actual_token) const { + + ReplaceMacroWithString(feature_instance, true, relative_location, + actual_token); +} + +void CoNLLFeatures::ReplaceLabelMacroWithString( + string& feature_instance, int relative_location, + const string& actual_token) const { + + ReplaceMacroWithString(feature_instance, false, relative_location, + actual_token); +} + +void CoNLLFeatures::Error(const string& error_message) const { + cerr << "Error: " << error_message << "\n\n" + + << "CoNLLFeatures Usage: \n" + << " feature_function=CoNLLFeatures -t <TEMPLATE>\n\n" + + << "Example <TEMPLATE>: U1:%x[-1]_%x[0]|%y[0]\n\n" + + << "%x[k] and %y[k] are macros to be instantiated with an input\n" + << "token (for x) or a label (for y). k specifies the relative\n" + << "location of the input token or label with respect to the current\n" + << "position. For x, k is an integer value. For y, k must be 0 (to\n" + << "be extended).\n\n"; + + abort(); +} + +void CoNLLFeatures::ParseArgs(const string& in) { + which_feat = 0; + vector<string> const& argv = SplitOnWhitespace(in); + for (vector<string>::const_iterator i = argv.begin(); i != argv.end(); ++i) { + string const& s = *i; + if (s[0] == '-') { + if (s.size() > 2) { + stringstream msg; + msg << s << " is an invalid option for CoNLLFeatures."; + Error(msg.str()); + } + + switch (s[1]) { + + case 'w': { + if (++i == argv.end()) { + Error("Missing parameter to -w"); + } + which_feat = boost::lexical_cast<unsigned>(*i); + break; + } + // feature template + case 't': { + if (++i == argv.end()) { + Error("Can't find template."); + } + feature_template = *i; + string::const_iterator start = feature_template.begin(); + string::const_iterator end = feature_template.end(); + smatch macro_match; + + // parse the template + while (regex_search(start, end, macro_match, macro_regex)) { + // get the relative location + string relative_location_str(macro_match[2].first, + macro_match[2].second); + int relative_location = atoi(relative_location_str.c_str()); + // add it to the list of relative locations for token or label + // (i.e. x or y) + bool valid_location = true; + if (*macro_match[1].first == 'x') { + // add it to token locations + token_relative_locations.push_back(relative_location); + } else { + if (relative_location != 0) { valid_location = false; } + // add it to label locations + label_relative_locations.push_back(relative_location); + } + if (!valid_location) { + stringstream msg; + msg << "Relative location " << relative_location + << " in feature template " << feature_template + << " is invalid."; + Error(msg.str()); + } + start = macro_match[0].second; + } + break; + } + + // TODO: arguments to specify kSOS and kEOS + + default: { + stringstream msg; + msg << "Invalid option on CoNLLFeatures: " << s; + Error(msg.str()); + break; + } + } // end of switch + } // end of if (token starts with hyphen) + } // end of for loop (over arguments) + + // the -t (i.e. template) option is mandatory in this feature function + if (label_relative_locations.size() == 0 || + token_relative_locations.size() == 0) { + stringstream msg; + msg << "Feature template must specify at least one" + << "token macro (e.g. x[-1]) and one label macro (e.g. y[0])."; + Error(msg.str()); + } +} + +void CoNLLFeatures::PrepareForInput(const SentenceMetadata& smeta) { + const Lattice& sl = smeta.GetSourceLattice(); + current_input.resize(sl.size()); + for (unsigned i = 0; i < sl.size(); ++i) { + if (sl[i].size() != 1) { + stringstream msg; + msg << "CoNLLFeatures don't support lattice inputs!\nid=" + << smeta.GetSentenceId() << endl; + Error(msg.str()); + } + current_input[i] = sl[i][0].label; + } + vector<WordID> wids; + string fn = "feat"; + fn += boost::lexical_cast<string>(which_feat); + string feats = smeta.GetSGMLValue(fn); + if (feats.size() == 0) { + Error("Can't find " + fn + " in <seg>\n"); + } + TD::ConvertSentence(feats, &wids); + assert(current_input.size() == wids.size()); + current_input = wids; +} + +void CoNLLFeatures::TraversalFeaturesImpl( + const SentenceMetadata& smeta, const Hypergraph::Edge& edge, + const vector<const void*>& ant_contexts, SparseVector<double>* features, + SparseVector<double>* estimated_features, void* context) const { + + const TRule& rule = *edge.rule_; + // arity = 0, no nonterminals + // size = 1, predicted label is a single token + if (rule.Arity() != 0 || + rule.e_.size() != 1) { + return; + } + + // replace label macros with actual label strings + // NOTE: currently, this feature function doesn't allow any label + // macros except %y[0]. but you can look at as much of the source as you want + const WordID y0 = rule.e_[0]; + string y0_str = TD::Convert(y0); + + // start of the span in the input being labeled + const int from_src_index = edge.i_; + // end of the span in the input + const int to_src_index = edge.j_; + + // in the case of tagging the size of the spans being labeled will + // always be 1, but in other formalisms, you can have bigger spans + if (to_src_index - from_src_index != 1) { + cerr << "CoNLLFeatures doesn't support input spans of length != 1"; + abort(); + } + + string feature_instance = feature_template; + // replace token macros with actual token strings + for (unsigned i = 0; i < token_relative_locations.size(); ++i) { + int loc = token_relative_locations[i]; + WordID x = loc < 0? kSOS: kEOS; + if(from_src_index + loc >= 0 && + from_src_index + loc < current_input.size()) { + x = current_input[from_src_index + loc]; + } + string x_str = TD::Convert(x); + ReplaceTokenMacroWithString(feature_instance, loc, x_str); + } + + ReplaceLabelMacroWithString(feature_instance, 0, y0_str); + + // pick a real value for this feature + double fval = 1.0; + + // add it to the feature vector + // FD::Convert converts the feature string to a feature int + // Escape makes sure the feature string doesn't have any bad + // symbols that could confuse a parser somewhere + features->add_value(FD::Convert(Escape(feature_instance)), fval); +} diff --git a/decoder/ff_conll.h b/decoder/ff_conll.h new file mode 100644 index 00000000..b37356d8 --- /dev/null +++ b/decoder/ff_conll.h @@ -0,0 +1,45 @@ +#ifndef FF_CONLL_H_ +#define FF_CONLL_H_ + +#include <vector> +#include <boost/xpressive/xpressive.hpp> +#include "ff.h" + +using namespace boost::xpressive; +using namespace std; + +class CoNLLFeatures : public FeatureFunction { + public: + CoNLLFeatures(const string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const HG::Edge& edge, + const vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + virtual void ParseArgs(const string& in); + virtual string Escape(const string& x) const; + virtual void ReplaceMacroWithString(string& feature_instance, + bool token_vs_label, + int relative_location, + const string& actual_token) const; + virtual void ReplaceTokenMacroWithString(string& feature_instance, + int relative_location, + const string& actual_token) const; + virtual void ReplaceLabelMacroWithString(string& feature_instance, + int relative_location, + const string& actual_token) const; + virtual void Error(const string&) const; + + private: + vector<int> token_relative_locations, label_relative_locations; + string feature_template; + vector<WordID> current_input; + WordID kSOS, kEOS; + sregex macro_regex; + unsigned which_feat; +}; + +#endif diff --git a/decoder/ff_const_reorder.cc b/decoder/ff_const_reorder.cc new file mode 100644 index 00000000..f1a6f7cb --- /dev/null +++ b/decoder/ff_const_reorder.cc @@ -0,0 +1,1118 @@ +#include "ff_const_reorder.h" + +#include "filelib.h" +#include "stringlib.h" +#include "hg.h" +#include "sentence_metadata.h" +#include "hash.h" +#include "ff_const_reorder_common.h" + +#include <sstream> +#include <string> +#include <vector> +#include <stdio.h> + +using namespace std; +using namespace const_reorder; + +typedef HASH_MAP<std::string, vector<double> > MapClassifier; + +inline bool is_inside(int i, int left, int right) { + if (i < left || i > right) return false; + return true; +} + +/* + * assume i <= j + * [i, j] is inside [left, right] or [i, j] equates to [left, right] + */ +inline bool is_inside(int i, int j, int left, int right) { + if (i >= left && j <= right) return true; + return false; +} + +/* + * assume i <= j + * [i, j] is inside [left, right], but [i, j] not equal to [left, right] + */ +inline bool is_proper_inside(int i, int j, int left, int right) { + if (i >= left && j <= right && right - left > j - i) return true; + return false; +} + +/* + * assume i <= j + * [i, j] is proper proper inside [left, right] + */ +inline bool is_proper_proper_inside(int i, int j, int left, int right) { + if (i > left && j < right) return true; + return false; +} + +inline bool is_overlap(int left1, int right1, int left2, int right2) { + if (is_inside(left1, left2, right2) || is_inside(left2, left1, right1)) + return true; + + return false; +} + +inline void NewAndCopyCharArray(char** p, const char* q) { + if (q != NULL) { + (*p) = new char[strlen(q) + 1]; + strcpy((*p), q); + } else + (*p) = NULL; +} + +// TODO:to make the alignment more efficient +struct TargetTranslation { + TargetTranslation(int begin_pos, int end_pos,int e_num_word) + : begin_pos_(begin_pos), + end_pos_(end_pos), + e_num_words_(e_num_word), + vec_left_most_(end_pos - begin_pos + 1, e_num_word), + vec_right_most_(end_pos - begin_pos + 1, -1), + vec_f_align_bit_array_(end_pos - begin_pos + 1), + vec_e_align_bit_array_(e_num_word) { + int len = end_pos - begin_pos + 1; + align_.reserve(1.5 * len); + } + + void InsertAlignmentPoint(int s, int t) { + int i = s - begin_pos_; + + vector<bool>& b = vec_f_align_bit_array_[i]; + if (b.empty()) b.resize(e_num_words_); + b[t] = 1; + + vector<bool>& a = vec_e_align_bit_array_[t]; + if (a.empty()) a.resize(end_pos_ - begin_pos_ + 1); + a[i] = 1; + + align_.push_back({s, t}); + + if (t > vec_right_most_[i]) vec_right_most_[i] = t; + if (t < vec_left_most_[i]) vec_left_most_[i] = t; + } + + /* + * given a source span [begin, end], whether its target side is continuous, + * return "0": the source span is translated silently + * return "1": there is at least on word inside its target span, this word + * doesn't align to any word inside [begin, end], but outside [begin, end] + * return "2": otherwise + */ + string IsTargetConstinousSpan(int begin, int end) const { + int target_begin, target_end; + FindLeftRightMostTargetSpan(begin, end, target_begin, target_end); + if (target_begin == -1) return "0"; + + for (int i = target_begin; i <= target_end; i++) { + if (vec_e_align_bit_array_[i].empty()) continue; + int j = begin; + for (; j <= end; j++) { + if (vec_e_align_bit_array_[i][j - begin_pos_]) break; + } + if (j == end + 1) // e[i] is aligned, but e[i] doesn't align to any + // source word in [begin_pos, end_pos] + return "1"; + } + return "2"; + } + + string IsTargetConstinousSpan2(int begin, int end) const { + int target_begin, target_end; + FindLeftRightMostTargetSpan(begin, end, target_begin, target_end); + if (target_begin == -1) return "Unaligned"; + + for (int i = target_begin; i <= target_end; i++) { + if (vec_e_align_bit_array_[i].empty()) continue; + int j = begin; + for (; j <= end; j++) { + if (vec_e_align_bit_array_[i][j - begin_pos_]) break; + } + if (j == end + 1) // e[i] is aligned, but e[i] doesn't align to any + // source word in [begin_pos, end_pos] + return "Discon't"; + } + return "Con't"; + } + + void FindLeftRightMostTargetSpan(int begin, int end, int& target_begin, + int& target_end) const { + int b = begin - begin_pos_; + int e = end - begin_pos_ + 1; + + target_begin = vec_left_most_[b]; + target_end = vec_right_most_[b]; + for (int i = b + 1; i < e; i++) { + if (target_begin > vec_left_most_[i]) target_begin = vec_left_most_[i]; + if (target_end < vec_right_most_[i]) target_end = vec_right_most_[i]; + } + if (target_end == -1) target_begin = -1; + return; + + target_begin = e_num_words_; + target_end = -1; + + for (int i = begin - begin_pos_; i < end - begin_pos_ + 1; i++) { + if (vec_f_align_bit_array_[i].empty()) continue; + for (int j = 0; j < target_begin; j++) + if (vec_f_align_bit_array_[i][j]) { + target_begin = j; + break; + } + } + for (int i = end - begin_pos_; i > begin - begin_pos_ - 1; i--) { + if (vec_f_align_bit_array_[i].empty()) continue; + for (int j = e_num_words_ - 1; j > target_end; j--) + if (vec_f_align_bit_array_[i][j]) { + target_end = j; + break; + } + } + + if (target_end == -1) target_begin = -1; + } + + const uint16_t begin_pos_, end_pos_; // the position in input + const uint16_t e_num_words_; + vector<AlignmentPoint> align_; + + private: + vector<short> vec_left_most_; + vector<short> vec_right_most_; + vector<vector<bool> > vec_f_align_bit_array_; + vector<vector<bool> > vec_e_align_bit_array_; +}; + +struct FocusedConstituent { + FocusedConstituent(const SParsedTree* pTree) { + if (pTree == NULL) return; + for (size_t i = 0; i < pTree->m_vecTerminals.size(); i++) { + STreeItem* pParent = pTree->m_vecTerminals[i]->m_ptParent; + + while (pParent != NULL) { + // if (pParent->m_vecChildren.size() > 1 && pParent->m_iEnd - + // pParent->m_iBegin > 5) { + // if (pParent->m_vecChildren.size() > 1) { + if (true) { + + // do constituent reordering for all children of pParent + if (strcmp(pParent->m_pszTerm, "ROOT")) + focus_parents_.push_back(pParent); + } + if (pParent->m_iBrotherIndex != 0) break; + pParent = pParent->m_ptParent; + } + } + } + + ~FocusedConstituent() { // TODO + focus_parents_.clear(); + } + + vector<STreeItem*> focus_parents_; +}; + +typedef SPredicateItem FocusedPredicate; + +struct FocusedSRL { + FocusedSRL(const SSrlSentence* srl) { + if (srl == NULL) return; + for (size_t i = 0; i < srl->m_vecPred.size(); i++) { + if (strcmp(srl->m_pTree->m_vecTerminals[srl->m_vecPred[i]->m_iPosition] + ->m_ptParent->m_pszTerm, + "VA") == 0) + continue; + focus_predicates_.push_back( + new FocusedPredicate(srl->m_pTree, srl->m_vecPred[i])); + } + } + + ~FocusedSRL() { focus_predicates_.clear(); } + + vector<const FocusedPredicate*> focus_predicates_; +}; + +struct ConstReorderFeatureImpl { + ConstReorderFeatureImpl(const std::string& param) { + + b_block_feature_ = false; + b_order_feature_ = false; + b_srl_block_feature_ = false; + b_srl_order_feature_ = false; + + vector<string> terms; + SplitOnWhitespace(param, &terms); + if (terms.size() == 1) { + b_block_feature_ = true; + b_order_feature_ = true; + } else if (terms.size() >= 3) { + if (terms[1].compare("1") == 0) b_block_feature_ = true; + if (terms[2].compare("1") == 0) b_order_feature_ = true; + if (terms.size() == 6) { + if (terms[4].compare("1") == 0) b_srl_block_feature_ = true; + if (terms[5].compare("1") == 0) b_srl_order_feature_ = true; + + assert(b_srl_block_feature_ || b_srl_order_feature_); + } + + } else { + assert("ERROR"); + } + + const_reorder_classifier_left_ = NULL; + const_reorder_classifier_right_ = NULL; + + srl_reorder_classifier_left_ = NULL; + srl_reorder_classifier_right_ = NULL; + + if (b_order_feature_) { + InitializeClassifier((terms[0] + string(".left")).c_str(), + &const_reorder_classifier_left_); + InitializeClassifier((terms[0] + string(".right")).c_str(), + &const_reorder_classifier_right_); + } + + if (b_srl_order_feature_) { + InitializeClassifier((terms[3] + string(".left")).c_str(), + &srl_reorder_classifier_left_); + InitializeClassifier((terms[3] + string(".right")).c_str(), + &srl_reorder_classifier_right_); + } + + parsed_tree_ = NULL; + focused_consts_ = NULL; + + srl_sentence_ = NULL; + focused_srl_ = NULL; + + map_left_ = NULL; + map_right_ = NULL; + + map_srl_left_ = NULL; + map_srl_right_ = NULL; + + dict_block_status_ = new Dict(); + dict_block_status_->Convert("Unaligned", false); + dict_block_status_->Convert("Discon't", false); + dict_block_status_->Convert("Con't", false); + } + + ~ConstReorderFeatureImpl() { + if (const_reorder_classifier_left_) delete const_reorder_classifier_left_; + if (const_reorder_classifier_right_) delete const_reorder_classifier_right_; + if (srl_reorder_classifier_left_) delete srl_reorder_classifier_left_; + if (srl_reorder_classifier_right_) delete srl_reorder_classifier_right_; + FreeSentenceVariables(); + + delete dict_block_status_; + } + + static int ReserveStateSize() { return 1 * sizeof(TargetTranslation*); } + + void InitializeInputSentence(const std::string& parse_file, + const std::string& srl_file) { + FreeSentenceVariables(); + if (b_srl_block_feature_ || b_srl_order_feature_) { + assert(srl_file != ""); + srl_sentence_ = ReadSRLSentence(srl_file); + parsed_tree_ = srl_sentence_->m_pTree; + } else { + assert(parse_file != ""); + srl_sentence_ = NULL; + parsed_tree_ = ReadParseTree(parse_file); + } + + if (b_block_feature_ || b_order_feature_) { + focused_consts_ = new FocusedConstituent(parsed_tree_); + + if (b_order_feature_) { + // we can do the classifier "off-line" + map_left_ = new MapClassifier(); + map_right_ = new MapClassifier(); + InitializeConstReorderClassifierOutput(); + } + } + + if (b_srl_block_feature_ || b_srl_order_feature_) { + focused_srl_ = new FocusedSRL(srl_sentence_); + + if (b_srl_order_feature_) { + map_srl_left_ = new MapClassifier(); + map_srl_right_ = new MapClassifier(); + InitializeSRLReorderClassifierOutput(); + } + } + + if (parsed_tree_ != NULL) { + size_t i = parsed_tree_->m_vecTerminals.size(); + vec_target_tran_.reserve(20 * i * i * i); + } else + vec_target_tran_.reserve(1000000); + } + + void SetConstReorderFeature(const Hypergraph::Edge& edge, + SparseVector<double>* features, + const vector<const void*>& ant_states, + void* state) { + if (parsed_tree_ == NULL) return; + + short int begin = edge.i_, end = edge.j_ - 1; + + typedef TargetTranslation* PtrTargetTranslation; + PtrTargetTranslation* remnant = + reinterpret_cast<PtrTargetTranslation*>(state); + + vector<const TargetTranslation*> vec_node; + vec_node.reserve(edge.tail_nodes_.size()); + for (size_t i = 0; i < edge.tail_nodes_.size(); i++) { + const PtrTargetTranslation* astate = + reinterpret_cast<const PtrTargetTranslation*>(ant_states[i]); + vec_node.push_back(astate[0]); + } + + int e_num_word = edge.rule_->e_.size(); + for (size_t i = 0; i < vec_node.size(); i++) { + e_num_word += vec_node[i]->e_num_words_; + e_num_word--; + } + + remnant[0] = new TargetTranslation(begin, end, e_num_word); + vec_target_tran_.push_back(remnant[0]); + + // reset the alignment + // for the source side, we know its position in source sentence + // for the target side, we always assume its starting position is 0 + unsigned vc = 0; + const TRulePtr rule = edge.rule_; + std::vector<int> f_index(rule->f_.size()); + int index = edge.i_; + for (unsigned i = 0; i < rule->f_.size(); i++) { + f_index[i] = index; + const WordID& c = rule->f_[i]; + if (c < 1) + index = vec_node[vc++]->end_pos_ + 1; + else + index++; + } + assert(vc == vec_node.size()); + assert(index == edge.j_); + + std::vector<int> e_index(rule->e_.size()); + index = 0; + vc = 0; + for (unsigned i = 0; i < rule->e_.size(); i++) { + e_index[i] = index; + const WordID& c = rule->e_[i]; + if (c < 1) { + index += vec_node[-c]->e_num_words_; + vc++; + } else + index++; + } + assert(vc == vec_node.size()); + + size_t nt_pos = 0; + for (size_t i = 0; i < edge.rule_->f_.size(); i++) { + if (edge.rule_->f_[i] > 0) continue; + + // it's an NT + size_t j; + for (j = 0; j < edge.rule_->e_.size(); j++) + if (edge.rule_->e_[j] * -1 == nt_pos) break; + assert(j != edge.rule_->e_.size()); + nt_pos++; + + // i aligns j + int eindex = e_index[j]; + const vector<AlignmentPoint>& align = + vec_node[-1 * edge.rule_->e_[j]]->align_; + for (size_t k = 0; k < align.size(); k++) { + remnant[0]->InsertAlignmentPoint(align[k].s_, eindex + align[k].t_); + } + } + for (size_t i = 0; i < edge.rule_->a_.size(); i++) { + int findex = f_index[edge.rule_->a_[i].s_]; + int eindex = e_index[edge.rule_->a_[i].t_]; + remnant[0]->InsertAlignmentPoint(findex, eindex); + } + + // till now, we finished setting state values + // next, use the state values to calculate constituent reorder feature + SetConstReorderFeature(begin, end, features, remnant[0], + vec_node, f_index); + } + + void SetConstReorderFeature(short int begin, short int end, + SparseVector<double>* features, + const TargetTranslation* target_translation, + const vector<const TargetTranslation*>& vec_node, + std::vector<int>& /*findex*/) { + if (b_srl_block_feature_ || b_srl_order_feature_) { + double logprob_srl_reorder_left = 0.0, logprob_srl_reorder_right = 0.0; + for (size_t i = 0; i < focused_srl_->focus_predicates_.size(); i++) { + const FocusedPredicate* pred = focused_srl_->focus_predicates_[i]; + if (!is_overlap(begin, end, pred->begin_, pred->end_)) + continue; // have no overlap between this predicate (with its + // argument) and the current edge + + size_t j; + for (j = 0; j < vec_node.size(); j++) { + if (is_inside(pred->begin_, pred->end_, vec_node[j]->begin_pos_, + vec_node[j]->end_pos_)) + break; + } + if (j < vec_node.size()) continue; + + vector<int> vecBlockStatus; + vecBlockStatus.reserve(pred->vec_items_.size()); + for (j = 0; j < pred->vec_items_.size(); j++) { + const STreeItem* con1 = pred->vec_items_[j]->tree_item_; + if (con1->m_iBegin < begin || con1->m_iEnd > end) { + vecBlockStatus.push_back(0); + continue; + } // the node is partially outside the current edge + + string type = target_translation->IsTargetConstinousSpan2( + con1->m_iBegin, con1->m_iEnd); + vecBlockStatus.push_back(dict_block_status_->Convert(type, false)); + + if (!b_srl_block_feature_) continue; + // see if the node is covered by an NT + size_t k; + for (k = 0; k < vec_node.size(); k++) { + if (is_inside(con1->m_iBegin, con1->m_iEnd, vec_node[k]->begin_pos_, + vec_node[k]->end_pos_)) + break; + } + if (k < vec_node.size()) continue; + int f_id = FD::Convert(string(pred->vec_items_[j]->role_) + type); + if (f_id) features->add_value(f_id, 1); + } + + if (!b_srl_order_feature_) continue; + + vector<int> vecPosition, vecRelativePosition; + vector<int> vecRightPosition, vecRelativeRightPosition; + vecPosition.reserve(pred->vec_items_.size()); + vecRelativePosition.reserve(pred->vec_items_.size()); + vecRightPosition.reserve(pred->vec_items_.size()); + vecRelativeRightPosition.reserve(pred->vec_items_.size()); + for (j = 0; j < pred->vec_items_.size(); j++) { + const STreeItem* con1 = pred->vec_items_[j]->tree_item_; + if (con1->m_iBegin < begin || con1->m_iEnd > end) { + vecPosition.push_back(-1); + vecRightPosition.push_back(-1); + continue; + } // the node is partially outside the current edge + int left1 = -1, right1 = -1; + target_translation->FindLeftRightMostTargetSpan( + con1->m_iBegin, con1->m_iEnd, left1, right1); + vecPosition.push_back(left1); + vecRightPosition.push_back(right1); + } + fnGetRelativePosition(vecPosition, vecRelativePosition); + fnGetRelativePosition(vecRightPosition, vecRelativeRightPosition); + + for (j = 1; j < pred->vec_items_.size(); j++) { + const STreeItem* con1 = pred->vec_items_[j - 1]->tree_item_; + const STreeItem* con2 = pred->vec_items_[j]->tree_item_; + + if (con1->m_iBegin < begin || con2->m_iEnd > end) + continue; // one of the two nodes is partially outside the current + // edge + + // both con1 and con2 are covered, need to check if they are covered + // by the same NT + size_t k; + for (k = 0; k < vec_node.size(); k++) { + if (is_inside(con1->m_iBegin, con2->m_iEnd, vec_node[k]->begin_pos_, + vec_node[k]->end_pos_)) + break; + } + if (k < vec_node.size()) continue; + + // they are not covered bye the same NT + string outcome; + string key; + GenerateKey(pred->vec_items_[j - 1]->tree_item_, + pred->vec_items_[j]->tree_item_, vecBlockStatus[j - 1], + vecBlockStatus[j], key); + + fnGetOutcome(vecRelativePosition[j - 1], vecRelativePosition[j], + outcome); + double prob = CalculateConstReorderProb(srl_reorder_classifier_left_, + map_srl_left_, key, outcome); + // printf("%s %s %f\n", ostr.str().c_str(), outcome.c_str(), prob); + logprob_srl_reorder_left += log10(prob); + + fnGetOutcome(vecRelativeRightPosition[j - 1], + vecRelativeRightPosition[j], outcome); + prob = CalculateConstReorderProb(srl_reorder_classifier_right_, + map_srl_right_, key, outcome); + logprob_srl_reorder_right += log10(prob); + } + } + + if (b_srl_order_feature_) { + int f_id = FD::Convert("SRLReorderFeatureLeft"); + if (f_id && logprob_srl_reorder_left != 0.0) + features->set_value(f_id, logprob_srl_reorder_left); + f_id = FD::Convert("SRLReorderFeatureRight"); + if (f_id && logprob_srl_reorder_right != 0.0) + features->set_value(f_id, logprob_srl_reorder_right); + } + } + + if (b_block_feature_ || b_order_feature_) { + double logprob_const_reorder_left = 0.0, + logprob_const_reorder_right = 0.0; + + for (size_t i = 0; i < focused_consts_->focus_parents_.size(); i++) { + STreeItem* parent = focused_consts_->focus_parents_[i]; + if (!is_overlap(begin, end, parent->m_iBegin, + parent->m_iEnd)) + continue; // have no overlap between this parent node and the current + // edge + + size_t j; + for (j = 0; j < vec_node.size(); j++) { + if (is_inside(parent->m_iBegin, parent->m_iEnd, + vec_node[j]->begin_pos_, vec_node[j]->end_pos_)) + break; + } + if (j < vec_node.size()) continue; + + if (b_block_feature_) { + if (parent->m_iBegin >= begin && + parent->m_iEnd <= end) { + string type = target_translation->IsTargetConstinousSpan2( + parent->m_iBegin, parent->m_iEnd); + int f_id = FD::Convert(string(parent->m_pszTerm) + type); + if (f_id) features->add_value(f_id, 1); + } + } + + if (parent->m_vecChildren.size() == 1 || !b_order_feature_) continue; + + vector<int> vecChunkBlock; + vecChunkBlock.reserve(parent->m_vecChildren.size()); + + for (j = 0; j < parent->m_vecChildren.size(); j++) { + STreeItem* con1 = parent->m_vecChildren[j]; + if (con1->m_iBegin < begin || con1->m_iEnd > end) { + vecChunkBlock.push_back(0); + continue; + } // the node is partially outside the current edge + + string type = target_translation->IsTargetConstinousSpan2( + con1->m_iBegin, con1->m_iEnd); + vecChunkBlock.push_back(dict_block_status_->Convert(type, false)); + + /*if (!b_block_feature_) continue; + //see if the node is covered by an NT + size_t k; + for (k = 0; k < vec_node.size(); k++) { + if (is_inside(con1->m_iBegin, con1->m_iEnd, + vec_node[k]->begin_pos_, vec_node[k]->end_pos_)) + break; + } + if (k < vec_node.size()) continue; + int f_id = FD::Convert(string(con1->m_pszTerm) + type); + if (f_id) + features->add_value(f_id, 1);*/ + } + + if (!b_order_feature_) continue; + + vector<int> vecPosition, vecRelativePosition; + vector<int> vecRightPosition, vecRelativeRightPosition; + vecPosition.reserve(parent->m_vecChildren.size()); + vecRelativePosition.reserve(parent->m_vecChildren.size()); + vecRightPosition.reserve(parent->m_vecChildren.size()); + vecRelativeRightPosition.reserve(parent->m_vecChildren.size()); + for (j = 0; j < parent->m_vecChildren.size(); j++) { + STreeItem* con1 = parent->m_vecChildren[j]; + if (con1->m_iBegin < begin || con1->m_iEnd > end) { + vecPosition.push_back(-1); + vecRightPosition.push_back(-1); + continue; + } // the node is partially outside the current edge + int left1 = -1, right1 = -1; + target_translation->FindLeftRightMostTargetSpan( + con1->m_iBegin, con1->m_iEnd, left1, right1); + vecPosition.push_back(left1); + vecRightPosition.push_back(right1); + } + fnGetRelativePosition(vecPosition, vecRelativePosition); + fnGetRelativePosition(vecRightPosition, vecRelativeRightPosition); + + for (j = 1; j < parent->m_vecChildren.size(); j++) { + STreeItem* con1 = parent->m_vecChildren[j - 1]; + STreeItem* con2 = parent->m_vecChildren[j]; + + if (con1->m_iBegin < begin || con2->m_iEnd > end) + continue; // one of the two nodes is partially outside the current + // edge + + // both con1 and con2 are covered, need to check if they are covered + // by the same NT + size_t k; + for (k = 0; k < vec_node.size(); k++) { + if (is_inside(con1->m_iBegin, con2->m_iEnd, vec_node[k]->begin_pos_, + vec_node[k]->end_pos_)) + break; + } + if (k < vec_node.size()) continue; + + // they are not covered bye the same NT + string outcome; + string key; + GenerateKey(parent->m_vecChildren[j - 1], parent->m_vecChildren[j], + vecChunkBlock[j - 1], vecChunkBlock[j], key); + + fnGetOutcome(vecRelativePosition[j - 1], vecRelativePosition[j], + outcome); + double prob = CalculateConstReorderProb( + const_reorder_classifier_left_, map_left_, key, outcome); + // printf("%s %s %f\n", ostr.str().c_str(), outcome.c_str(), prob); + logprob_const_reorder_left += log10(prob); + + fnGetOutcome(vecRelativeRightPosition[j - 1], + vecRelativeRightPosition[j], outcome); + prob = CalculateConstReorderProb(const_reorder_classifier_right_, + map_right_, key, outcome); + logprob_const_reorder_right += log10(prob); + } + } + + if (b_order_feature_) { + int f_id = FD::Convert("ConstReorderFeatureLeft"); + if (f_id && logprob_const_reorder_left != 0.0) + features->set_value(f_id, logprob_const_reorder_left); + f_id = FD::Convert("ConstReorderFeatureRight"); + if (f_id && logprob_const_reorder_right != 0.0) + features->set_value(f_id, logprob_const_reorder_right); + } + } + } + + private: + void Byte_to_Char(unsigned char* str, int n) { + str[0] = (n & 255); + str[1] = n / 256; + } + void GenerateKey(const STreeItem* pCon1, const STreeItem* pCon2, + int iBlockStatus1, int iBlockStatus2, string& key) { + assert(iBlockStatus1 != 0); + assert(iBlockStatus2 != 0); + unsigned char szTerm[1001]; + Byte_to_Char(szTerm, pCon1->m_iBegin); + Byte_to_Char(szTerm + 2, pCon2->m_iEnd); + szTerm[4] = (char)iBlockStatus1; + szTerm[5] = (char)iBlockStatus2; + szTerm[6] = '\0'; + // sprintf(szTerm, "%d|%d|%d|%d|%s|%s", pCon1->m_iBegin, pCon1->m_iEnd, + // pCon2->m_iBegin, pCon2->m_iEnd, strBlockStatus1.c_str(), + // strBlockStatus2.c_str()); + key = string(szTerm, szTerm + 6); + } + void InitializeConstReorderClassifierOutput() { + if (!b_order_feature_) return; + int size_block_status = dict_block_status_->max(); + + for (size_t i = 0; i < focused_consts_->focus_parents_.size(); i++) { + STreeItem* parent = focused_consts_->focus_parents_[i]; + + for (size_t j = 1; j < parent->m_vecChildren.size(); j++) { + for (size_t k = 1; k <= size_block_status; k++) { + for (size_t l = 1; l <= size_block_status; l++) { + ostringstream ostr; + GenerateFeature(parsed_tree_, parent, j, + dict_block_status_->Convert(k), + dict_block_status_->Convert(l), ostr); + + string strKey; + GenerateKey(parent->m_vecChildren[j - 1], parent->m_vecChildren[j], + k, l, strKey); + + vector<double> vecOutput; + const_reorder_classifier_left_->fnEval(ostr.str().c_str(), + vecOutput); + (*map_left_)[strKey] = vecOutput; + + const_reorder_classifier_right_->fnEval(ostr.str().c_str(), + vecOutput); + (*map_right_)[strKey] = vecOutput; + } + } + } + } + } + + void InitializeSRLReorderClassifierOutput() { + if (!b_srl_order_feature_) return; + int size_block_status = dict_block_status_->max(); + + for (size_t i = 0; i < focused_srl_->focus_predicates_.size(); i++) { + const FocusedPredicate* pred = focused_srl_->focus_predicates_[i]; + + for (size_t j = 1; j < pred->vec_items_.size(); j++) { + for (size_t k = 1; k <= size_block_status; k++) { + for (size_t l = 1; l <= size_block_status; l++) { + ostringstream ostr; + + SArgumentReorderModel::fnGenerateFeature( + parsed_tree_, pred->pred_, pred, j, + dict_block_status_->Convert(k), dict_block_status_->Convert(l), + ostr); + + string strKey; + GenerateKey(pred->vec_items_[j - 1]->tree_item_, + pred->vec_items_[j]->tree_item_, k, l, strKey); + + vector<double> vecOutput; + srl_reorder_classifier_left_->fnEval(ostr.str().c_str(), vecOutput); + (*map_srl_left_)[strKey] = vecOutput; + + srl_reorder_classifier_right_->fnEval(ostr.str().c_str(), + vecOutput); + (*map_srl_right_)[strKey] = vecOutput; + } + } + } + } + } + + double CalculateConstReorderProb( + const Tsuruoka_Maxent* const_reorder_classifier, const MapClassifier* map, + const string& key, const string& outcome) { + MapClassifier::const_iterator iter = (*map).find(key); + assert(iter != map->end()); + int id = const_reorder_classifier->fnGetClassId(outcome); + return iter->second[id]; + } + + void FreeSentenceVariables() { + if (srl_sentence_ != NULL) { + delete srl_sentence_; + srl_sentence_ = NULL; + } else { + if (parsed_tree_ != NULL) delete parsed_tree_; + parsed_tree_ = NULL; + } + + if (focused_consts_ != NULL) delete focused_consts_; + focused_consts_ = NULL; + + for (size_t i = 0; i < vec_target_tran_.size(); i++) + delete vec_target_tran_[i]; + vec_target_tran_.clear(); + + if (map_left_ != NULL) delete map_left_; + map_left_ = NULL; + if (map_right_ != NULL) delete map_right_; + map_right_ = NULL; + + if (map_srl_left_ != NULL) delete map_srl_left_; + map_srl_left_ = NULL; + if (map_srl_right_ != NULL) delete map_srl_right_; + map_srl_right_ = NULL; + } + + void InitializeClassifier(const char* pszFname, + Tsuruoka_Maxent** ppClassifier) { + (*ppClassifier) = new Tsuruoka_Maxent(pszFname); + } + + void GenerateOutcome(const vector<int>& vecPos, vector<string>& vecOutcome) { + vecOutcome.clear(); + + for (size_t i = 1; i < vecPos.size(); i++) { + if (vecPos[i] == -1 || vecPos[i] == vecPos[i - 1]) { + vecOutcome.push_back("M"); // monotone + continue; + } + + if (vecPos[i - 1] == -1) { + // vecPos[i] is not -1 + size_t j = i - 2; + while (j > -1 && vecPos[j] == -1) j--; + + size_t k; + for (k = 0; k < j; k++) { + if (vecPos[k] > vecPos[j] || vecPos[k] <= vecPos[i]) break; + } + if (k < j) { + vecOutcome.push_back("DM"); + continue; + } + + for (k = i + 1; k < vecPos.size(); k++) + if (vecPos[k] < vecPos[i] && (j == -1 && vecPos[k] >= vecPos[j])) + break; + if (k < vecPos.size()) { + vecOutcome.push_back("DM"); + continue; + } + vecOutcome.push_back("M"); + } else { + // neither of vecPos[i-1] and vecPos[i] is -1 + if (vecPos[i - 1] < vecPos[i]) { + // monotone or discon't monotone + size_t j; + for (j = 0; j < i - 1; j++) + if (vecPos[j] > vecPos[i - 1] && vecPos[j] <= vecPos[i]) break; + if (j < i - 1) { + vecOutcome.push_back("DM"); + continue; + } + for (j = i + 1; j < vecPos.size(); j++) + if (vecPos[j] >= vecPos[i - 1] && vecPos[j] < vecPos[i]) break; + if (j < vecPos.size()) { + vecOutcome.push_back("DM"); + continue; + } + vecOutcome.push_back("M"); + } else { + // swap or discon't swap + size_t j; + for (j = 0; j < i - 1; j++) + if (vecPos[j] > vecPos[i] && vecPos[j] <= vecPos[i - 1]) break; + if (j < i - 1) { + vecOutcome.push_back("DS"); + continue; + } + for (j = i + 1; j < vecPos.size(); j++) + if (vecPos[j] >= vecPos[i] && vecPos[j] < vecPos[i - 1]) break; + if (j < vecPos.size()) { + vecOutcome.push_back("DS"); + continue; + } + vecOutcome.push_back("S"); + } + } + } + + assert(vecOutcome.size() == vecPos.size() - 1); + } + + void fnGetRelativePosition(const vector<int>& vecLeft, + vector<int>& vecPosition) { + vecPosition.clear(); + + vector<float> vec; + vec.reserve(vecLeft.size()); + for (size_t i = 0; i < vecLeft.size(); i++) { + if (vecLeft[i] == -1) { + if (i == 0) + vec.push_back(-1); + else + vec.push_back(vecLeft[i - 1] + 0.1); + } else + vec.push_back(vecLeft[i]); + } + + for (size_t i = 0; i < vecLeft.size(); i++) { + int count = 0; + + for (size_t j = 0; j < vecLeft.size(); j++) { + if (j == i) continue; + if (vec[j] < vec[i]) { + count++; + } else if (vec[j] == vec[i] && j < i) { + count++; + } + } + vecPosition.push_back(count); + } + + for (size_t i = 1; i < vecPosition.size(); i++) { + if (vecPosition[i - 1] == vecPosition[i]) { + for (size_t j = 0; j < vecLeft.size(); j++) cout << vecLeft[j] << " "; + cout << "\n"; + assert(false); + } + } + } + + inline void fnGetOutcome(int i1, int i2, string& strOutcome) { + assert(i1 != i2); + if (i1 < i2) { + if (i2 > i1 + 1) + strOutcome = string("DM"); + else + strOutcome = string("M"); + } else { + if (i1 > i2 + 1) + strOutcome = string("DS"); + else + strOutcome = string("S"); + } + } + + // features in constituent_reorder_model.cc + void GenerateFeature(const SParsedTree* pTree, const STreeItem* pParent, + int iPos, const string& strBlockStatus1, + const string& strBlockStatus2, ostringstream& ostr) { + STreeItem* pCon1, *pCon2; + pCon1 = pParent->m_vecChildren[iPos - 1]; + pCon2 = pParent->m_vecChildren[iPos]; + + string left_label = string(pCon1->m_pszTerm); + string right_label = string(pCon2->m_pszTerm); + string parent_label = string(pParent->m_pszTerm); + + vector<string> vec_other_right_sibling; + for (int i = iPos + 1; i < pParent->m_vecChildren.size(); i++) + vec_other_right_sibling.push_back( + string(pParent->m_vecChildren[i]->m_pszTerm)); + if (vec_other_right_sibling.size() == 0) + vec_other_right_sibling.push_back(string("NULL")); + vector<string> vec_other_left_sibling; + for (int i = 0; i < iPos - 1; i++) + vec_other_left_sibling.push_back( + string(pParent->m_vecChildren[i]->m_pszTerm)); + if (vec_other_left_sibling.size() == 0) + vec_other_left_sibling.push_back(string("NULL")); + + // generate features + // f1 + ostr << "f1=" << left_label << "_" << right_label << "_" << parent_label; + // f2 + for (int i = 0; i < vec_other_right_sibling.size(); i++) + ostr << " f2=" << left_label << "_" << right_label << "_" << parent_label + << "_" << vec_other_right_sibling[i]; + // f3 + for (int i = 0; i < vec_other_left_sibling.size(); i++) + ostr << " f3=" << left_label << "_" << right_label << "_" << parent_label + << "_" << vec_other_left_sibling[i]; + // f4 + ostr << " f4=" << left_label << "_" << right_label << "_" + << pTree->m_vecTerminals[pCon1->m_iHeadWord]->m_ptParent->m_pszTerm; + // f5 + ostr << " f5=" << left_label << "_" << right_label << "_" + << pTree->m_vecTerminals[pCon1->m_iHeadWord]->m_pszTerm; + // f6 + ostr << " f6=" << left_label << "_" << right_label << "_" + << pTree->m_vecTerminals[pCon2->m_iHeadWord]->m_ptParent->m_pszTerm; + // f7 + ostr << " f7=" << left_label << "_" << right_label << "_" + << pTree->m_vecTerminals[pCon2->m_iHeadWord]->m_pszTerm; + // f8 + ostr << " f8=" << left_label << "_" << right_label << "_" + << strBlockStatus1; + // f9 + ostr << " f9=" << left_label << "_" << right_label << "_" + << strBlockStatus2; + + // f10 + ostr << " f10=" << left_label << "_" << parent_label; + // f11 + ostr << " f11=" << right_label << "_" << parent_label; + } + + SParsedTree* ReadParseTree(const std::string& parse_file) { + SParseReader* reader = new SParseReader(parse_file.c_str(), false); + SParsedTree* tree = reader->fnReadNextParseTree(); + // assert(tree != NULL); + delete reader; + return tree; + } + + SSrlSentence* ReadSRLSentence(const std::string& srl_file) { + SSrlSentenceReader* reader = new SSrlSentenceReader(srl_file.c_str()); + SSrlSentence* srl = reader->fnReadNextSrlSentence(); + // assert(srl != NULL); + delete reader; + return srl; + } + + private: + Tsuruoka_Maxent* const_reorder_classifier_left_; + Tsuruoka_Maxent* const_reorder_classifier_right_; + + Tsuruoka_Maxent* srl_reorder_classifier_left_; + Tsuruoka_Maxent* srl_reorder_classifier_right_; + + MapClassifier* map_left_; + MapClassifier* map_right_; + + MapClassifier* map_srl_left_; + MapClassifier* map_srl_right_; + + SParsedTree* parsed_tree_; + FocusedConstituent* focused_consts_; + vector<TargetTranslation*> vec_target_tran_; + + bool b_order_feature_; + bool b_block_feature_; + + bool b_srl_block_feature_; + bool b_srl_order_feature_; + SSrlSentence* srl_sentence_; + FocusedSRL* focused_srl_; + + Dict* dict_block_status_; +}; + +ConstReorderFeature::ConstReorderFeature(const std::string& param) { + pimpl_ = new ConstReorderFeatureImpl(param); + SetStateSize(ConstReorderFeatureImpl::ReserveStateSize()); + SetIgnoredStateSize(ConstReorderFeatureImpl::ReserveStateSize()); + name_ = "ConstReorderFeature"; +} + +ConstReorderFeature::~ConstReorderFeature() { // TODO + delete pimpl_; +} + +void ConstReorderFeature::PrepareForInput(const SentenceMetadata& smeta) { + string parse_file = smeta.GetSGMLValue("parse"); + if (parse_file.empty()) { + parse_file = smeta.GetSGMLValue("src_tree"); + } + string srl_file = smeta.GetSGMLValue("srl"); + assert(!(parse_file == "" && srl_file == "")); + + pimpl_->InitializeInputSentence(parse_file, srl_file); +} + +void ConstReorderFeature::TraversalFeaturesImpl( + const SentenceMetadata& /* smeta */, const Hypergraph::Edge& edge, + const vector<const void*>& ant_states, SparseVector<double>* features, + SparseVector<double>* /*estimated_features*/, void* state) const { + pimpl_->SetConstReorderFeature(edge, features, ant_states, state); +} + +string ConstReorderFeature::usage(bool show_params, bool show_details) { + ostringstream out; + out << "ConstReorderFeature"; + if (show_params) { + out << " model_file_prefix [const_block=1 const_order=1] [srl_block=0 " + "srl_order=0]" + << "\nParameters:\n" + << " const_{block,order}: enable/disable constituency constraints.\n" + << " src_{block,order}: enable/disable semantic role labeling " + "constraints.\n"; + } + if (show_details) { + out << "\n" + << "Soft reordering constraint features from " + "http://www.aclweb.org/anthology/P14-1106. To train the classifers, " + "use utils/const_reorder_model_trainer for constituency reordering " + "constraints and utils/argument_reorder_model_trainer for semantic " + "role labeling reordering constraints.\n" + << "Input segments should provide path to parse tree (resp. SRL parse) " + "as \"parse\" (resp. \"srl\") properties.\n"; + } + return out.str(); +} + +boost::shared_ptr<FeatureFunction> CreateConstReorderModel( + const std::string& param) { + ConstReorderFeature* ret = new ConstReorderFeature(param); + return boost::shared_ptr<FeatureFunction>(ret); +} diff --git a/decoder/ff_const_reorder.h b/decoder/ff_const_reorder.h new file mode 100644 index 00000000..a5be02d0 --- /dev/null +++ b/decoder/ff_const_reorder.h @@ -0,0 +1,43 @@ +/* + * ff_const_reorder.h + * + * Created on: Jul 11, 2013 + * Author: junhuili + */ + +#ifndef FF_CONST_REORDER_H_ +#define FF_CONST_REORDER_H_ + +#include "ff_factory.h" +#include "ff.h" + +struct ConstReorderFeatureImpl; + +// Soft reordering constraint features from +// http://www.aclweb.org/anthology/P14-1106. To train the classifers, +// use utils/const_reorder_model_trainer for constituency reordering +// constraints and utils/argument_reorder_model_trainer for SRL +// reordering constraints. +// +// Input segments should provide path to parse tree (resp. SRL parse) +// as "parse" (resp. "srl") properties. +class ConstReorderFeature : public FeatureFunction { + public: + ConstReorderFeature(const std::string& param); + ~ConstReorderFeature(); + static std::string usage(bool param, bool verbose); + + protected: + virtual void PrepareForInput(const SentenceMetadata& smeta); + + virtual void TraversalFeaturesImpl( + const SentenceMetadata& smeta, const HG::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, SparseVector<double>* estimated_features, + void* out_context) const; + + private: + ConstReorderFeatureImpl* pimpl_; +}; + +#endif /* FF_CONST_REORDER_H_ */ diff --git a/decoder/ff_const_reorder_common.h b/decoder/ff_const_reorder_common.h new file mode 100644 index 00000000..755fd948 --- /dev/null +++ b/decoder/ff_const_reorder_common.h @@ -0,0 +1,1348 @@ +#ifndef _FF_CONST_REORDER_COMMON_H +#define _FF_CONST_REORDER_COMMON_H + +#include <string> +#include <assert.h> +#include <stdio.h> +#include <string.h> +#include <string> +#include <sstream> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "maxent.h" +#include "stringlib.h" + +namespace const_reorder { + +struct STreeItem { + STreeItem(const char *pszTerm) { + m_pszTerm = new char[strlen(pszTerm) + 1]; + strcpy(m_pszTerm, pszTerm); + + m_ptParent = NULL; + m_iBegin = -1; + m_iEnd = -1; + m_iHeadChild = -1; + m_iHeadWord = -1; + m_iBrotherIndex = -1; + } + ~STreeItem() { + delete[] m_pszTerm; + for (size_t i = 0; i < m_vecChildren.size(); i++) delete m_vecChildren[i]; + } + int fnAppend(STreeItem *ptChild) { + m_vecChildren.push_back(ptChild); + ptChild->m_iBrotherIndex = m_vecChildren.size() - 1; + ptChild->m_ptParent = this; + return m_vecChildren.size() - 1; + } + int fnGetChildrenNum() { return m_vecChildren.size(); } + + bool fnIsPreTerminal(void) { + int I; + if (this == NULL || m_vecChildren.size() == 0) return false; + + for (I = 0; I < m_vecChildren.size(); I++) + if (m_vecChildren[I]->m_vecChildren.size() > 0) return false; + + return true; + } + + public: + char *m_pszTerm; + + std::vector<STreeItem *> m_vecChildren; // children items + STreeItem *m_ptParent; // the parent item + + int m_iBegin; + int m_iEnd; // the node span words[m_iBegin, m_iEnd] + int m_iHeadChild; // the index of its head child + int m_iHeadWord; // the index of its head word + int m_iBrotherIndex; // the index in his brothers +}; + +struct SGetHeadWord { + typedef std::vector<std::string> CVectorStr; + SGetHeadWord() {} + ~SGetHeadWord() {} + int fnGetHeadWord(char *pszCFGLeft, CVectorStr vectRight) { + // 0 indicating from right to left while 1 indicating from left to right + char szaHeadLists[201] = "0"; + + /* //head rules for Egnlish + if( strcmp( pszCFGLeft, "ADJP" ) == 0 ) + strcpy( szaHeadLists, "0NNS 0QP 0NN 0$ 0ADVP 0JJ 0VBN 0VBG 0ADJP + 0JJR 0NP 0JJS 0DT 0FW 0RBR 0RBS 0SBAR 0RB 0" ); + else if( strcmp( pszCFGLeft, "ADVP" ) == 0 ) + strcpy( szaHeadLists, "1RB 1RBR 1RBS 1FW 1ADVP 1TO 1CD 1JJR 1JJ 1IN + 1NP 1JJS 1NN 1" ); + else if( strcmp( pszCFGLeft, "CONJP" ) == 0 ) + strcpy( szaHeadLists, "1CC 1RB 1IN 1" ); + else if( strcmp( pszCFGLeft, "FRAG" ) == 0 ) + strcpy( szaHeadLists, "1" ); + else if( strcmp( pszCFGLeft, "INTJ" ) == 0 ) + strcpy( szaHeadLists, "0" ); + else if( strcmp( pszCFGLeft, "LST" ) == 0 ) + strcpy( szaHeadLists, "1LS 1: 1CLN 1" ); + else if( strcmp( pszCFGLeft, "NAC" ) == 0 ) + strcpy( szaHeadLists, "0NN 0NNS 0NNP 0NNPS 0NP 0NAC 0EX 0$ 0CD 0QP + 0PRP 0VBG 0JJ 0JJS 0JJR 0ADJP 0FW 0" ); + else if( strcmp( pszCFGLeft, "PP" ) == 0 ) + strcpy( szaHeadLists, "1IN 1TO 1VBG 1VBN 1RP 1FW 1" ); + else if( strcmp( pszCFGLeft, "PRN" ) == 0 ) + strcpy( szaHeadLists, "1" ); + else if( strcmp( pszCFGLeft, "PRT" ) == 0 ) + strcpy( szaHeadLists, "1RP 1" ); + else if( strcmp( pszCFGLeft, "QP" ) == 0 ) + strcpy( szaHeadLists, "0$ 0IN 0NNS 0NN 0JJ 0RB 0DT 0CD 0NCD 0QP 0JJR + 0JJS 0" ); + else if( strcmp( pszCFGLeft, "RRC" ) == 0 ) + strcpy( szaHeadLists, "1VP 1NP 1ADVP 1ADJP 1PP 1" ); + else if( strcmp( pszCFGLeft, "S" ) == 0 ) + strcpy( szaHeadLists, "0TO 0IN 0VP 0S 0SBAR 0ADJP 0UCP 0NP 0" ); + else if( strcmp( pszCFGLeft, "SBAR" ) == 0 ) + strcpy( szaHeadLists, "0WHNP 0WHPP 0WHADVP 0WHADJP 0IN 0DT 0S 0SQ + 0SINV 0SBAR 0FRAG 0" ); + else if( strcmp( pszCFGLeft, "SBARQ" ) == 0 ) + strcpy( szaHeadLists, "0SQ 0S 0SINV 0SBARQ 0FRAG 0" ); + else if( strcmp( pszCFGLeft, "SINV" ) == 0 ) + strcpy( szaHeadLists, "0VBZ 0VBD 0VBP 0VB 0MD 0VP 0S 0SINV 0ADJP 0NP + 0" ); + else if( strcmp( pszCFGLeft, "SQ" ) == 0 ) + strcpy( szaHeadLists, "0VBZ 0VBD 0VBP 0VB 0MD 0VP 0SQ 0" ); + else if( strcmp( pszCFGLeft, "UCP" ) == 0 ) + strcpy( szaHeadLists, "1" ); + else if( strcmp( pszCFGLeft, "VP" ) == 0 ) + strcpy( szaHeadLists, "0TO 0VBD 0VBN 0MD 0VBZ 0VB 0VBG 0VBP 0VP + 0ADJP 0NN 0NNS 0NP 0" ); + else if( strcmp( pszCFGLeft, "WHADJP" ) == 0 ) + strcpy( szaHeadLists, "0CC 0WRB 0JJ 0ADJP 0" ); + else if( strcmp( pszCFGLeft, "WHADVP" ) == 0 ) + strcpy( szaHeadLists, "1CC 1WRB 1" ); + else if( strcmp( pszCFGLeft, "WHNP" ) == 0 ) + strcpy( szaHeadLists, "0WDT 0WP 0WP$ 0WHADJP 0WHPP 0WHNP 0" ); + else if( strcmp( pszCFGLeft, "WHPP" ) == 0 ) + strcpy( szaHeadLists, "1IN 1TO FW 1" ); + else if( strcmp( pszCFGLeft, "NP" ) == 0 ) + strcpy( szaHeadLists, "0NN NNP NNS NNPS NX POS JJR 0NP 0$ ADJP PRN + 0CD 0JJ JJS RB QP 0" ); + */ + + if (strcmp(pszCFGLeft, "ADJP") == 0) + strcpy(szaHeadLists, "0ADJP JJ 0AD NN CS 0"); + else if (strcmp(pszCFGLeft, "ADVP") == 0) + strcpy(szaHeadLists, "0ADVP AD 0"); + else if (strcmp(pszCFGLeft, "CLP") == 0) + strcpy(szaHeadLists, "0CLP M 0"); + else if (strcmp(pszCFGLeft, "CP") == 0) + strcpy(szaHeadLists, "0DEC SP 1ADVP CS 0CP IP 0"); + else if (strcmp(pszCFGLeft, "DNP") == 0) + strcpy(szaHeadLists, "0DNP DEG 0DEC 0"); + else if (strcmp(pszCFGLeft, "DVP") == 0) + strcpy(szaHeadLists, "0DVP DEV 0"); + else if (strcmp(pszCFGLeft, "DP") == 0) + strcpy(szaHeadLists, "1DP DT 1"); + else if (strcmp(pszCFGLeft, "FRAG") == 0) + strcpy(szaHeadLists, "0VV NR NN 0"); + else if (strcmp(pszCFGLeft, "INTJ") == 0) + strcpy(szaHeadLists, "0INTJ IJ 0"); + else if (strcmp(pszCFGLeft, "LST") == 0) + strcpy(szaHeadLists, "1LST CD OD 1"); + else if (strcmp(pszCFGLeft, "IP") == 0) + strcpy(szaHeadLists, "0IP VP 0VV 0"); + // strcpy( szaHeadLists, "0VP 0VV 1IP 0" ); + else if (strcmp(pszCFGLeft, "LCP") == 0) + strcpy(szaHeadLists, "0LCP LC 0"); + else if (strcmp(pszCFGLeft, "NP") == 0) + strcpy(szaHeadLists, "0NP NN NT NR QP 0"); + else if (strcmp(pszCFGLeft, "PP") == 0) + strcpy(szaHeadLists, "1PP P 1"); + else if (strcmp(pszCFGLeft, "PRN") == 0) + strcpy(szaHeadLists, "0 NP IP VP NT NR NN 0"); + else if (strcmp(pszCFGLeft, "QP") == 0) + strcpy(szaHeadLists, "0QP CLP CD OD 0"); + else if (strcmp(pszCFGLeft, "VP") == 0) + strcpy(szaHeadLists, "1VP VA VC VE VV BA LB VCD VSB VRD VNV VCP 1"); + else if (strcmp(pszCFGLeft, "VCD") == 0) + strcpy(szaHeadLists, "0VCD VV VA VC VE 0"); + if (strcmp(pszCFGLeft, "VRD") == 0) + strcpy(szaHeadLists, "0VRD VV VA VC VE 0"); + else if (strcmp(pszCFGLeft, "VSB") == 0) + strcpy(szaHeadLists, "0VSB VV VA VC VE 0"); + else if (strcmp(pszCFGLeft, "VCP") == 0) + strcpy(szaHeadLists, "0VCP VV VA VC VE 0"); + else if (strcmp(pszCFGLeft, "VNV") == 0) + strcpy(szaHeadLists, "0VNV VV VA VC VE 0"); + else if (strcmp(pszCFGLeft, "VPT") == 0) + strcpy(szaHeadLists, "0VNV VV VA VC VE 0"); + else if (strcmp(pszCFGLeft, "UCP") == 0) + strcpy(szaHeadLists, "0"); + else if (strcmp(pszCFGLeft, "WHNP") == 0) + strcpy(szaHeadLists, "0WHNP NP NN NT NR QP 0"); + else if (strcmp(pszCFGLeft, "WHPP") == 0) + strcpy(szaHeadLists, "1WHPP PP P 1"); + + /* //head rules for GENIA corpus + if( strcmp( pszCFGLeft, "ADJP" ) == 0 ) + strcpy( szaHeadLists, "0NNS 0QP 0NN 0$ 0ADVP 0JJ 0VBN 0VBG 0ADJP + 0JJR 0NP 0JJS 0DT 0FW 0RBR 0RBS 0SBAR 0RB 0" ); + else if( strcmp( pszCFGLeft, "ADVP" ) == 0 ) + strcpy( szaHeadLists, "1RB 1RBR 1RBS 1FW 1ADVP 1TO 1CD 1JJR 1JJ 1IN + 1NP 1JJS 1NN 1" ); + else if( strcmp( pszCFGLeft, "CONJP" ) == 0 ) + strcpy( szaHeadLists, "1CC 1RB 1IN 1" ); + else if( strcmp( pszCFGLeft, "FRAG" ) == 0 ) + strcpy( szaHeadLists, "1" ); + else if( strcmp( pszCFGLeft, "INTJ" ) == 0 ) + strcpy( szaHeadLists, "0" ); + else if( strcmp( pszCFGLeft, "LST" ) == 0 ) + strcpy( szaHeadLists, "1LS 1: 1CLN 1" ); + else if( strcmp( pszCFGLeft, "NAC" ) == 0 ) + strcpy( szaHeadLists, "0NN 0NNS 0NNP 0NNPS 0NP 0NAC 0EX 0$ 0CD 0QP + 0PRP 0VBG 0JJ 0JJS 0JJR 0ADJP 0FW 0" ); + else if( strcmp( pszCFGLeft, "PP" ) == 0 ) + strcpy( szaHeadLists, "1IN 1TO 1VBG 1VBN 1RP 1FW 1" ); + else if( strcmp( pszCFGLeft, "PRN" ) == 0 ) + strcpy( szaHeadLists, "1" ); + else if( strcmp( pszCFGLeft, "PRT" ) == 0 ) + strcpy( szaHeadLists, "1RP 1" ); + else if( strcmp( pszCFGLeft, "QP" ) == 0 ) + strcpy( szaHeadLists, "0$ 0IN 0NNS 0NN 0JJ 0RB 0DT 0CD 0NCD 0QP 0JJR + 0JJS 0" ); + else if( strcmp( pszCFGLeft, "RRC" ) == 0 ) + strcpy( szaHeadLists, "1VP 1NP 1ADVP 1ADJP 1PP 1" ); + else if( strcmp( pszCFGLeft, "S" ) == 0 ) + strcpy( szaHeadLists, "0TO 0IN 0VP 0S 0SBAR 0ADJP 0UCP 0NP 0" ); + else if( strcmp( pszCFGLeft, "SBAR" ) == 0 ) + strcpy( szaHeadLists, "0WHNP 0WHPP 0WHADVP 0WHADJP 0IN 0DT 0S 0SQ + 0SINV 0SBAR 0FRAG 0" ); + else if( strcmp( pszCFGLeft, "SBARQ" ) == 0 ) + strcpy( szaHeadLists, "0SQ 0S 0SINV 0SBARQ 0FRAG 0" ); + else if( strcmp( pszCFGLeft, "SINV" ) == 0 ) + strcpy( szaHeadLists, "0VBZ 0VBD 0VBP 0VB 0MD 0VP 0S 0SINV 0ADJP 0NP + 0" ); + else if( strcmp( pszCFGLeft, "SQ" ) == 0 ) + strcpy( szaHeadLists, "0VBZ 0VBD 0VBP 0VB 0MD 0VP 0SQ 0" ); + else if( strcmp( pszCFGLeft, "UCP" ) == 0 ) + strcpy( szaHeadLists, "1" ); + else if( strcmp( pszCFGLeft, "VP" ) == 0 ) + strcpy( szaHeadLists, "0TO 0VBD 0VBN 0MD 0VBZ 0VB 0VBG 0VBP 0VP + 0ADJP 0NN 0NNS 0NP 0" ); + else if( strcmp( pszCFGLeft, "WHADJP" ) == 0 ) + strcpy( szaHeadLists, "0CC 0WRB 0JJ 0ADJP 0" ); + else if( strcmp( pszCFGLeft, "WHADVP" ) == 0 ) + strcpy( szaHeadLists, "1CC 1WRB 1" ); + else if( strcmp( pszCFGLeft, "WHNP" ) == 0 ) + strcpy( szaHeadLists, "0WDT 0WP 0WP$ 0WHADJP 0WHPP 0WHNP 0" ); + else if( strcmp( pszCFGLeft, "WHPP" ) == 0 ) + strcpy( szaHeadLists, "1IN 1TO FW 1" ); + else if( strcmp( pszCFGLeft, "NP" ) == 0 ) + strcpy( szaHeadLists, "0NN NNP NNS NNPS NX POS JJR 0NP 0$ ADJP PRN + 0CD 0JJ JJS RB QP 0" ); + */ + + return fnMyOwnHeadWordRule(szaHeadLists, vectRight); + } + + private: + int fnMyOwnHeadWordRule(char *pszaHeadLists, CVectorStr vectRight) { + char szHeadList[201], *p; + char szTerm[101]; + int J; + + p = pszaHeadLists; + + int iCountRight; + + iCountRight = vectRight.size(); + + szHeadList[0] = '\0'; + while (1) { + szTerm[0] = '\0'; + sscanf(p, "%s", szTerm); + if (strlen(szHeadList) == 0) { + if (strcmp(szTerm, "0") == 0) { + return iCountRight - 1; + } + if (strcmp(szTerm, "1") == 0) { + return 0; + } + + sprintf(szHeadList, "%c %s ", szTerm[0], szTerm + 1); + p = strstr(p, szTerm); + p += strlen(szTerm); + } else { + if ((szTerm[0] == '0') || (szTerm[0] == '1')) { + if (szHeadList[0] == '0') { + for (J = iCountRight - 1; J >= 0; J--) { + sprintf(szTerm, " %s ", vectRight.at(J).c_str()); + if (strstr(szHeadList, szTerm) != NULL) return J; + } + } else { + for (J = 0; J < iCountRight; J++) { + sprintf(szTerm, " %s ", vectRight.at(J).c_str()); + if (strstr(szHeadList, szTerm) != NULL) return J; + } + } + + szHeadList[0] = '\0'; + } else { + strcat(szHeadList, szTerm); + strcat(szHeadList, " "); + + p = strstr(p, szTerm); + p += strlen(szTerm); + } + } + } + + return 0; + } +}; + +struct SParsedTree { + SParsedTree() { m_ptRoot = NULL; } + ~SParsedTree() { + if (m_ptRoot != NULL) delete m_ptRoot; + } + static SParsedTree *fnConvertFromString(const char *pszStr) { + if (strcmp(pszStr, "(())") == 0) return NULL; + SParsedTree *pTree = new SParsedTree(); + + std::vector<std::string> vecSyn; + fnReadSyntactic(pszStr, vecSyn); + + int iLeft = 1, iRight = 1; //# left/right parenthesis + + STreeItem *pcurrent; + + pTree->m_ptRoot = new STreeItem(vecSyn[1].c_str()); + + pcurrent = pTree->m_ptRoot; + + for (size_t i = 2; i < vecSyn.size() - 1; i++) { + if (strcmp(vecSyn[i].c_str(), "(") == 0) + iLeft++; + else if (strcmp(vecSyn[i].c_str(), ")") == 0) { + iRight++; + if (pcurrent == NULL) { + // error + fprintf(stderr, "ERROR in ConvertFromString\n"); + fprintf(stderr, "%s\n", pszStr); + return NULL; + } + pcurrent = pcurrent->m_ptParent; + } else { + STreeItem *ptNewItem = new STreeItem(vecSyn[i].c_str()); + pcurrent->fnAppend(ptNewItem); + pcurrent = ptNewItem; + + if (strcmp(vecSyn[i - 1].c_str(), "(") != 0 && + strcmp(vecSyn[i - 1].c_str(), ")") != 0) { + pTree->m_vecTerminals.push_back(ptNewItem); + pcurrent = pcurrent->m_ptParent; + } + } + } + + if (iLeft != iRight) { + // error + fprintf(stderr, "the left and right parentheses are not matched!"); + fprintf(stderr, "ERROR in ConvertFromString\n"); + fprintf(stderr, "%s\n", pszStr); + return NULL; + } + + return pTree; + } + + int fnGetNumWord() { return m_vecTerminals.size(); } + + void fnSetSpanInfo() { + int iNextNum = 0; + fnSuffixTraverseSetSpanInfo(m_ptRoot, iNextNum); + } + + void fnSetHeadWord() { + for (size_t i = 0; i < m_vecTerminals.size(); i++) + m_vecTerminals[i]->m_iHeadWord = i; + SGetHeadWord *pGetHeadWord = new SGetHeadWord(); + fnSuffixTraverseSetHeadWord(m_ptRoot, pGetHeadWord); + delete pGetHeadWord; + } + + STreeItem *fnFindNodeForSpan(int iLeft, int iRight, bool bLowest) { + STreeItem *pTreeItem = m_vecTerminals[iLeft]; + + while (pTreeItem->m_iEnd < iRight) { + pTreeItem = pTreeItem->m_ptParent; + if (pTreeItem == NULL) break; + } + if (pTreeItem == NULL) return NULL; + if (pTreeItem->m_iEnd > iRight) return NULL; + + assert(pTreeItem->m_iEnd == iRight); + if (bLowest) return pTreeItem; + + while (pTreeItem->m_ptParent != NULL && + pTreeItem->m_ptParent->fnGetChildrenNum() == 1) + pTreeItem = pTreeItem->m_ptParent; + + return pTreeItem; + } + + private: + void fnSuffixTraverseSetSpanInfo(STreeItem *ptItem, int &iNextNum) { + int I; + int iNumChildren = ptItem->fnGetChildrenNum(); + for (I = 0; I < iNumChildren; I++) + fnSuffixTraverseSetSpanInfo(ptItem->m_vecChildren[I], iNextNum); + + if (I == 0) { + ptItem->m_iBegin = iNextNum; + ptItem->m_iEnd = iNextNum++; + } else { + ptItem->m_iBegin = ptItem->m_vecChildren[0]->m_iBegin; + ptItem->m_iEnd = ptItem->m_vecChildren[I - 1]->m_iEnd; + } + } + + void fnSuffixTraverseSetHeadWord(STreeItem *ptItem, + SGetHeadWord *pGetHeadWord) { + int I, iHeadchild; + + if (ptItem->m_vecChildren.size() == 0) return; + + for (I = 0; I < ptItem->m_vecChildren.size(); I++) + fnSuffixTraverseSetHeadWord(ptItem->m_vecChildren[I], pGetHeadWord); + + std::vector<std::string> vecRight; + + if (ptItem->m_vecChildren.size() == 1) + iHeadchild = 0; + else { + for (I = 0; I < ptItem->m_vecChildren.size(); I++) + vecRight.push_back(std::string(ptItem->m_vecChildren[I]->m_pszTerm)); + + iHeadchild = pGetHeadWord->fnGetHeadWord(ptItem->m_pszTerm, vecRight); + } + + ptItem->m_iHeadChild = iHeadchild; + ptItem->m_iHeadWord = ptItem->m_vecChildren[iHeadchild]->m_iHeadWord; + } + + static void fnReadSyntactic(const char *pszSyn, + std::vector<std::string> &vec) { + char *p; + int I; + + int iLeftNum, iRightNum; + char *pszTmp, *pszTerm; + pszTmp = new char[strlen(pszSyn)]; + pszTerm = new char[strlen(pszSyn)]; + pszTmp[0] = pszTerm[0] = '\0'; + + vec.clear(); + + char *pszLine; + pszLine = new char[strlen(pszSyn) + 1]; + strcpy(pszLine, pszSyn); + + char *pszLine2; + + while (1) { + while ((strlen(pszLine) > 0) && (pszLine[strlen(pszLine) - 1] > 0) && + (pszLine[strlen(pszLine) - 1] <= ' ')) + pszLine[strlen(pszLine) - 1] = '\0'; + + if (strlen(pszLine) == 0) break; + + // printf( "%s\n", pszLine ); + pszLine2 = pszLine; + while (pszLine2[0] <= ' ') pszLine2++; + if (pszLine2[0] == '<') continue; + + sscanf(pszLine2 + 1, "%s", pszTmp); + + if (pszLine2[0] == '(') { + iLeftNum = 0; + iRightNum = 0; + } + + p = pszLine2; + while (1) { + pszTerm[0] = '\0'; + sscanf(p, "%s", pszTerm); + + if (strlen(pszTerm) == 0) break; + p = strstr(p, pszTerm); + p += strlen(pszTerm); + + if ((pszTerm[0] == '(') || (pszTerm[strlen(pszTerm) - 1] == ')')) { + if (pszTerm[0] == '(') { + vec.push_back(std::string("(")); + iLeftNum++; + + I = 1; + while (pszTerm[I] == '(' && pszTerm[I] != '\0') { + vec.push_back(std::string("(")); + iLeftNum++; + + I++; + } + + if (strlen(pszTerm) > 1) vec.push_back(std::string(pszTerm + I)); + } else { + char *pTmp; + pTmp = pszTerm + strlen(pszTerm) - 1; + while ((pTmp[0] == ')') && (pTmp >= pszTerm)) pTmp--; + pTmp[1] = '\0'; + + if (strlen(pszTerm) > 0) vec.push_back(std::string(pszTerm)); + pTmp += 2; + + for (I = 0; I <= (int)strlen(pTmp); I++) { + vec.push_back(std::string(")")); + iRightNum++; + } + } + } else { + char *q; + q = strchr(pszTerm, ')'); + if (q != NULL) { + q[0] = '\0'; + if (pszTerm[0] != '\0') vec.push_back(std::string(pszTerm)); + vec.push_back(std::string(")")); + iRightNum++; + + q++; + while (q[0] == ')') { + vec.push_back(std::string(")")); + q++; + iRightNum++; + } + + while (q[0] == '(') { + vec.push_back(std::string("(")); + q++; + iLeftNum++; + } + + if (q[0] != '\0') vec.push_back(std::string(q)); + } else + vec.push_back(std::string(pszTerm)); + } + } + + if (iLeftNum != iRightNum) { + fprintf(stderr, "%s\n", pszSyn); + assert(iLeftNum == iRightNum); + } + /*if ( iLeftNum != iRightNum ) { + printf( "ERROR: left( and right ) is not matched, %d ( and %d + )\n", iLeftNum, iRightNum ); + return; + }*/ + + if (vec.size() >= 2 && strcmp(vec[1].c_str(), "(") == 0) { + //( (IP..) ) + std::vector<std::string>::iterator it; + it = vec.begin(); + it++; + vec.insert(it, std::string("ROOT")); + } + + break; + } + + delete[] pszLine; + delete[] pszTmp; + delete[] pszTerm; + } + + public: + STreeItem *m_ptRoot; + std::vector<STreeItem *> m_vecTerminals; // the leaf nodes +}; + +struct SParseReader { + SParseReader(const char *pszParse_Fname, bool bFlattened = false) + : m_bFlattened(bFlattened) { + m_fpIn = fopen(pszParse_Fname, "r"); + assert(m_fpIn != NULL); + } + ~SParseReader() { + if (m_fpIn != NULL) fclose(m_fpIn); + } + + SParsedTree *fnReadNextParseTree() { + SParsedTree *pTree = NULL; + char *pszLine = new char[100001]; + int iLen; + + while (fnReadNextSentence(pszLine, &iLen) == true) { + if (iLen == 0) continue; + + pTree = SParsedTree::fnConvertFromString(pszLine); + if (pTree == NULL) break; + if (m_bFlattened) + fnPostProcessingFlattenedParse(pTree); + else { + pTree->fnSetSpanInfo(); + pTree->fnSetHeadWord(); + } + break; + } + + delete[] pszLine; + return pTree; + } + + SParsedTree *fnReadNextParseTreeWithProb(double *pProb) { + SParsedTree *pTree = NULL; + char *pszLine = new char[100001]; + int iLen; + + while (fnReadNextSentence(pszLine, &iLen) == true) { + if (iLen == 0) continue; + + char *p = strchr(pszLine, ' '); + assert(p != NULL); + p[0] = '\0'; + p++; + if (pProb) (*pProb) = atof(pszLine); + + pTree = SParsedTree::fnConvertFromString(p); + if (m_bFlattened) + fnPostProcessingFlattenedParse(pTree); + else { + pTree->fnSetSpanInfo(); + pTree->fnSetHeadWord(); + } + break; + } + + delete[] pszLine; + return pTree; + } + + private: + /* + * since to the parse tree is a flattened tree, use the head mark to identify + * head info. + * the head node will be marked as "*XP*" + */ + void fnSetParseTreeHeadInfo(SParsedTree *pTree) { + for (size_t i = 0; i < pTree->m_vecTerminals.size(); i++) + pTree->m_vecTerminals[i]->m_iHeadWord = i; + fnSuffixTraverseSetHeadWord(pTree->m_ptRoot); + } + + void fnSuffixTraverseSetHeadWord(STreeItem *pTreeItem) { + if (pTreeItem->m_vecChildren.size() == 0) return; + + for (size_t i = 0; i < pTreeItem->m_vecChildren.size(); i++) + fnSuffixTraverseSetHeadWord(pTreeItem->m_vecChildren[i]); + + std::vector<std::string> vecRight; + + int iHeadchild; + + if (pTreeItem->fnIsPreTerminal()) { + iHeadchild = 0; + } else { + size_t i; + for (i = 0; i < pTreeItem->m_vecChildren.size(); i++) { + char *p = pTreeItem->m_vecChildren[i]->m_pszTerm; + if (p[0] == '*' && p[strlen(p) - 1] == '*') { + iHeadchild = i; + p[strlen(p) - 1] = '\0'; + std::string str = p + 1; + strcpy(p, str.c_str()); // erase the "*..*" + break; + } + } + assert(i < pTreeItem->m_vecChildren.size()); + } + + pTreeItem->m_iHeadChild = iHeadchild; + pTreeItem->m_iHeadWord = pTreeItem->m_vecChildren[iHeadchild]->m_iHeadWord; + } + void fnPostProcessingFlattenedParse(SParsedTree *pTree) { + pTree->fnSetSpanInfo(); + fnSetParseTreeHeadInfo(pTree); + } + bool fnReadNextSentence(char *pszLine, int *piLength) { + if (feof(m_fpIn) == true) return false; + + int iLen; + + pszLine[0] = '\0'; + + fgets(pszLine, 10001, m_fpIn); + iLen = strlen(pszLine); + while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) { + pszLine[iLen - 1] = '\0'; + iLen--; + } + + if (piLength != NULL) (*piLength) = iLen; + + return true; + } + + private: + FILE *m_fpIn; + const bool m_bFlattened; +}; + +/* + * Note: + * m_vec_s_align.size() may not be equal to the length of source side + *sentence + * due to the last words may not be aligned + * + */ +struct SAlignment { + typedef std::vector<int> SingleAlign; + SAlignment(const char* pszAlign) { fnInitializeAlignment(pszAlign); } + ~SAlignment() {} + + bool fnIsAligned(int i, bool s) const { + const std::vector<SingleAlign>* palign; + if (s == true) + palign = &m_vec_s_align; + else + palign = &m_vec_t_align; + if ((*palign)[i].size() == 0) return false; + return true; + } + + /* + * return true if [b, e] is aligned phrases on source side (if s==true) or on + * the target side (if s==false); + * return false, otherwise. + */ + bool fnIsAlignedPhrase(int b, int e, bool s, int* pob, int* poe) const { + int ob, oe; //[b, e] on the other side + if (s == true) + fnGetLeftRightMost(b, e, m_vec_s_align, ob, oe); + else + fnGetLeftRightMost(b, e, m_vec_t_align, ob, oe); + + if (ob == -1) { + if (pob != NULL) (*pob) = -1; + if (poe != NULL) (*poe) = -1; + return false; // no aligned word among [b, e] + } + if (pob != NULL) (*pob) = ob; + if (poe != NULL) (*poe) = oe; + + int bb, be; //[b, e] back given [ob, oe] on the other side + if (s == true) + fnGetLeftRightMost(ob, oe, m_vec_t_align, bb, be); + else + fnGetLeftRightMost(ob, oe, m_vec_s_align, bb, be); + + if (bb < b || be > e) return false; + return true; + } + + bool fnIsAlignedTightPhrase(int b, int e, bool s, int* pob, int* poe) const { + const std::vector<SingleAlign>* palign; + if (s == true) + palign = &m_vec_s_align; + else + palign = &m_vec_t_align; + + if ((*palign).size() <= e || (*palign)[b].size() == 0 || + (*palign)[e].size() == 0) + return false; + + return fnIsAlignedPhrase(b, e, s, pob, poe); + } + + void fnGetLeftRightMost(int b, int e, bool s, int& ob, int& oe) const { + if (s == true) + fnGetLeftRightMost(b, e, m_vec_s_align, ob, oe); + else + fnGetLeftRightMost(b, e, m_vec_t_align, ob, oe); + } + + /* + * look the translation of source[b, e] is continuous or not + * 1) return "Unaligned": if the source[b, e] is translated silently; + * 2) return "Con't": if none of target words in target[.., ..] is exclusively + * aligned to any word outside source[b, e] + * 3) return "Discon't": otherwise; + */ + std::string fnIsContinuous(int b, int e) const { + int ob, oe; + fnGetLeftRightMost(b, e, true, ob, oe); + if (ob == -1) return "Unaligned"; + + for (int i = ob; i <= oe; i++) { + if (!fnIsAligned(i, false)) continue; + const SingleAlign& a = m_vec_t_align[i]; + int j; + for (j = 0; j < a.size(); j++) + if (a[j] >= b && a[j] <= e) break; + if (j == a.size()) return "Discon't"; + } + return "Con't"; + } + + const SingleAlign* fnGetSingleWordAlign(int i, bool s) const { + if (s == true) { + if (i >= m_vec_s_align.size()) return NULL; + return &(m_vec_s_align[i]); + } else { + if (i >= m_vec_t_align.size()) return NULL; + return &(m_vec_t_align[i]); + } + } + + private: + void fnGetLeftRightMost(int b, int e, const std::vector<SingleAlign>& align, + int& ob, int& oe) const { + ob = oe = -1; + for (int i = b; i <= e && i < align.size(); i++) { + if (align[i].size() > 0) { + if (align[i][0] < ob || ob == -1) ob = align[i][0]; + if (oe < align[i][align[i].size() - 1]) + oe = align[i][align[i].size() - 1]; + } + } + } + void fnInitializeAlignment(const char* pszAlign) { + m_vec_s_align.clear(); + m_vec_t_align.clear(); + + std::vector<std::string> terms = SplitOnWhitespace(std::string(pszAlign)); + int si, ti; + for (size_t i = 0; i < terms.size(); i++) { + sscanf(terms[i].c_str(), "%d-%d", &si, &ti); + + while (m_vec_s_align.size() <= si) { + SingleAlign sa; + m_vec_s_align.push_back(sa); + } + while (m_vec_t_align.size() <= ti) { + SingleAlign sa; + m_vec_t_align.push_back(sa); + } + + m_vec_s_align[si].push_back(ti); + m_vec_t_align[ti].push_back(si); + } + + // sort + for (size_t i = 0; i < m_vec_s_align.size(); i++) { + std::sort(m_vec_s_align[i].begin(), m_vec_s_align[i].end()); + } + for (size_t i = 0; i < m_vec_t_align.size(); i++) { + std::sort(m_vec_t_align[i].begin(), m_vec_t_align[i].end()); + } + } + + private: + std::vector<SingleAlign> m_vec_s_align; // source side words' alignment + std::vector<SingleAlign> m_vec_t_align; // target side words' alignment +}; + +struct SAlignmentReader { + SAlignmentReader(const char* pszFname) { + m_fpIn = fopen(pszFname, "r"); + assert(m_fpIn != NULL); + } + ~SAlignmentReader() { + if (m_fpIn != NULL) fclose(m_fpIn); + } + SAlignment* fnReadNextAlignment() { + if (feof(m_fpIn) == true) return NULL; + char* pszLine = new char[100001]; + pszLine[0] = '\0'; + fgets(pszLine, 10001, m_fpIn); + int iLen = strlen(pszLine); + if (iLen == 0) return NULL; + while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) { + pszLine[iLen - 1] = '\0'; + iLen--; + } + SAlignment* pAlign = new SAlignment(pszLine); + delete[] pszLine; + return pAlign; + } + + private: + FILE* m_fpIn; +}; + +struct SArgument { + SArgument(const char* pszRole, int iBegin, int iEnd, float fProb) { + m_pszRole = new char[strlen(pszRole) + 1]; + strcpy(m_pszRole, pszRole); + m_iBegin = iBegin; + m_iEnd = iEnd; + m_fProb = fProb; + m_pTreeItem = NULL; + } + ~SArgument() { delete[] m_pszRole; } + + void fnSetTreeItem(STreeItem* pTreeItem) { + m_pTreeItem = pTreeItem; + if (m_pTreeItem != NULL && m_pTreeItem->m_iBegin != -1) { + assert(m_pTreeItem->m_iBegin == m_iBegin); + assert(m_pTreeItem->m_iEnd == m_iEnd); + } + } + + char* m_pszRole; // argument rule, e.g., ARG0, ARGM-TMP + int m_iBegin; + int m_iEnd; // the span of the argument, [m_iBegin, m_iEnd] + float m_fProb; // the probability of this role, + STreeItem* m_pTreeItem; +}; + +struct SPredicate { + SPredicate(const char* pszLemma, int iPosition) { + if (pszLemma != NULL) { + m_pszLemma = new char[strlen(pszLemma) + 1]; + strcpy(m_pszLemma, pszLemma); + } else + m_pszLemma = NULL; + m_iPosition = iPosition; + } + ~SPredicate() { + if (m_pszLemma != NULL) delete[] m_pszLemma; + for (size_t i = 0; i < m_vecArgt.size(); i++) delete m_vecArgt[i]; + } + int fnAppend(const char* pszRole, int iBegin, int iEnd) { + SArgument* pArgt = new SArgument(pszRole, iBegin, iEnd, 1.0); + return fnAppend(pArgt); + } + int fnAppend(SArgument* pArgt) { + m_vecArgt.push_back(pArgt); + int iPosition = m_vecArgt.size() - 1; + return iPosition; + } + + char* m_pszLemma; // lemma of the predicate, for Chinese, it's always as same + // as the predicate itself + int m_iPosition; // the position in sentence + std::vector<SArgument*> m_vecArgt; // arguments associated to the predicate +}; + +struct SSrlSentence { + SSrlSentence() { m_pTree = NULL; } + ~SSrlSentence() { + if (m_pTree != NULL) delete m_pTree; + + for (size_t i = 0; i < m_vecPred.size(); i++) delete m_vecPred[i]; + } + int fnAppend(const char* pszLemma, int iPosition) { + SPredicate* pPred = new SPredicate(pszLemma, iPosition); + return fnAppend(pPred); + } + int fnAppend(SPredicate* pPred) { + m_vecPred.push_back(pPred); + int iPosition = m_vecPred.size() - 1; + return iPosition; + } + int GetPredicateNum() { return m_vecPred.size(); } + + SParsedTree* m_pTree; + std::vector<SPredicate*> m_vecPred; +}; + +struct SSrlSentenceReader { + SSrlSentenceReader(const char* pszSrlFname) { + m_fpIn = fopen(pszSrlFname, "r"); + assert(m_fpIn != NULL); + } + ~SSrlSentenceReader() { + if (m_fpIn != NULL) fclose(m_fpIn); + } + + inline void fnReplaceAll(std::string& str, const std::string& from, + const std::string& to) { + size_t start_pos = 0; + while ((start_pos = str.find(from, start_pos)) != std::string::npos) { + str.replace(start_pos, from.length(), to); + start_pos += to.length(); // In case 'to' contains 'from', like replacing + // 'x' with 'yx' + } + } + + // TODO: here only considers flat predicate-argument structure + // i.e., no overlap among them + SSrlSentence* fnReadNextSrlSentence() { + std::vector<std::vector<std::string> > vecContent; + if (fnReadNextContent(vecContent) == false) return NULL; + + SSrlSentence* pSrlSentence = new SSrlSentence(); + int iSize = vecContent.size(); + // put together syntactic text + std::ostringstream ostr; + for (int i = 0; i < iSize; i++) { + std::string strSynSeg = + vecContent[i][5]; // the 5th column is the syntactic segment + size_t iPosition = strSynSeg.find_first_of('*'); + assert(iPosition != std::string::npos); + std::ostringstream ostrTmp; + ostrTmp << "(" << vecContent[i][2] << " " << vecContent[i][0] + << ")"; // the 2th column is POS-tag, and the 0th column is word + strSynSeg.replace(iPosition, 1, ostrTmp.str()); + fnReplaceAll(strSynSeg, "(", " ("); + ostr << strSynSeg; + } + std::string strSyn = ostr.str(); + pSrlSentence->m_pTree = SParsedTree::fnConvertFromString(strSyn.c_str()); + pSrlSentence->m_pTree->fnSetHeadWord(); + pSrlSentence->m_pTree->fnSetSpanInfo(); + + // read predicate-argument structure + int iNumPred = vecContent[0].size() - 8; + for (int i = 0; i < iNumPred; i++) { + std::vector<std::string> vecRole; + std::vector<int> vecBegin; + std::vector<int> vecEnd; + int iPred = -1; + for (int j = 0; j < iSize; j++) { + const char* p = vecContent[j][i + 8].c_str(); + const char* q; + if (p[0] == '(') { + // starting position of an argument(or predicate) + vecBegin.push_back(j); + q = strchr(p, '*'); + assert(q != NULL); + vecRole.push_back(vecContent[j][i + 8].substr(1, q - p - 1)); + if (vecRole.back().compare("V") == 0) { + assert(iPred == -1); + iPred = vecRole.size() - 1; + } + } + if (p[strlen(p) - 1] == ')') { + // end position of an argument(or predicate) + vecEnd.push_back(j); + assert(vecBegin.size() == vecEnd.size()); + } + } + assert(iPred != -1); + SPredicate* pPred = new SPredicate( + pSrlSentence->m_pTree->m_vecTerminals[vecBegin[iPred]]->m_pszTerm, + vecBegin[iPred]); + pSrlSentence->fnAppend(pPred); + for (size_t j = 0; j < vecBegin.size(); j++) { + if (j == iPred) continue; + pPred->fnAppend(vecRole[j].c_str(), vecBegin[j], vecEnd[j]); + pPred->m_vecArgt.back()->fnSetTreeItem( + pSrlSentence->m_pTree->fnFindNodeForSpan(vecBegin[j], vecEnd[j], + false)); + } + } + return pSrlSentence; + } + + private: + bool fnReadNextContent(std::vector<std::vector<std::string> >& vecContent) { + vecContent.clear(); + if (feof(m_fpIn) == true) return false; + char* pszLine; + pszLine = new char[100001]; + pszLine[0] = '\0'; + int iLen; + while (!feof(m_fpIn)) { + fgets(pszLine, 10001, m_fpIn); + iLen = strlen(pszLine); + while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) { + pszLine[iLen - 1] = '\0'; + iLen--; + } + if (iLen == 0) break; // end of this sentence + + std::vector<std::string> terms = SplitOnWhitespace(std::string(pszLine)); + assert(terms.size() > 7); + vecContent.push_back(terms); + } + delete[] pszLine; + return true; + } + + private: + FILE* m_fpIn; +}; + +typedef std::unordered_map<std::string, int> Map; +typedef std::unordered_map<std::string, int>::iterator Iterator; + +struct Tsuruoka_Maxent { + Tsuruoka_Maxent(const char* pszModelFName) { + if (pszModelFName != NULL) { + m_pModel = new maxent::ME_Model(); + m_pModel->load_from_file(pszModelFName); + } else + m_pModel = NULL; + } + + ~Tsuruoka_Maxent() { + if (m_pModel != NULL) delete m_pModel; + } + + void fnEval(const char* pszContext, std::vector<double>& vecOutput) const { + std::vector<std::string> vecContext; + maxent::ME_Sample* pmes = new maxent::ME_Sample(); + SplitOnWhitespace(std::string(pszContext), &vecContext); + + vecOutput.clear(); + + for (size_t i = 0; i < vecContext.size(); i++) + pmes->add_feature(vecContext[i]); + std::vector<double> vecProb = m_pModel->classify(*pmes); + + for (size_t i = 0; i < vecProb.size(); i++) { + std::string label = m_pModel->get_class_label(i); + vecOutput.push_back(vecProb[i]); + } + delete pmes; + } + int fnGetClassId(const std::string& strLabel) const { + return m_pModel->get_class_id(strLabel); + } + + private: + maxent::ME_Model* m_pModel; +}; + +// an argument item or a predicate item (the verb itself) +struct SSRLItem { + SSRLItem(const STreeItem *tree_item, std::string role) + : tree_item_(tree_item), role_(role) {} + ~SSRLItem() {} + const STreeItem *tree_item_; + const std::string role_; +}; + +struct SPredicateItem { + SPredicateItem(const SParsedTree *tree, const SPredicate *pred) + : pred_(pred) { + vec_items_.reserve(pred->m_vecArgt.size() + 1); + for (int i = 0; i < pred->m_vecArgt.size(); i++) { + vec_items_.push_back( + new SSRLItem(pred->m_vecArgt[i]->m_pTreeItem, + std::string(pred->m_vecArgt[i]->m_pszRole))); + } + vec_items_.push_back( + new SSRLItem(tree->m_vecTerminals[pred->m_iPosition]->m_ptParent, + std::string("Pred"))); + sort(vec_items_.begin(), vec_items_.end(), SortFunction); + + begin_ = vec_items_[0]->tree_item_->m_iBegin; + end_ = vec_items_[vec_items_.size() - 1]->tree_item_->m_iEnd; + } + + ~SPredicateItem() { vec_items_.clear(); } + + static bool SortFunction(SSRLItem *i, SSRLItem *j) { + return (i->tree_item_->m_iBegin < j->tree_item_->m_iBegin); + } + + std::vector<SSRLItem *> vec_items_; + int begin_; + int end_; + const SPredicate *pred_; +}; + +struct SArgumentReorderModel { + public: + static std::string fnGetBlockOutcome(int iBegin, int iEnd, + SAlignment *pAlign) { + return pAlign->fnIsContinuous(iBegin, iEnd); + } + static void fnGetReorderType(SPredicateItem *pPredItem, SAlignment *pAlign, + std::vector<std::string> &vecStrLeftReorder, + std::vector<std::string> &vecStrRightReorder) { + std::vector<int> vecLeft, vecRight; + for (int i = 0; i < pPredItem->vec_items_.size(); i++) { + const STreeItem *pCon1 = pPredItem->vec_items_[i]->tree_item_; + int iLeft1, iRight1; + pAlign->fnGetLeftRightMost(pCon1->m_iBegin, pCon1->m_iEnd, true, iLeft1, + iRight1); + vecLeft.push_back(iLeft1); + vecRight.push_back(iRight1); + } + std::vector<int> vecLeftPosition; + fnGetRelativePosition(vecLeft, vecLeftPosition); + std::vector<int> vecRightPosition; + fnGetRelativePosition(vecRight, vecRightPosition); + + vecStrLeftReorder.clear(); + vecStrRightReorder.clear(); + for (int i = 1; i < vecLeftPosition.size(); i++) { + std::string strOutcome; + fnGetOutcome(vecLeftPosition[i - 1], vecLeftPosition[i], strOutcome); + vecStrLeftReorder.push_back(strOutcome); + fnGetOutcome(vecRightPosition[i - 1], vecRightPosition[i], strOutcome); + vecStrRightReorder.push_back(strOutcome); + } + } + + /* + * features: + * f1: (left_label, right_label, parent_label) + * f2: (left_label, right_label, parent_label, other_right_sibling_label) + * f3: (left_label, right_label, parent_label, other_left_sibling_label) + * f4: (left_label, right_label, left_head_pos) + * f5: (left_label, right_label, left_head_word) + * f6: (left_label, right_label, right_head_pos) + * f7: (left_label, right_label, right_head_word) + * f8: (left_label, right_label, left_chunk_status) + * f9: (left_label, right_label, right_chunk_status) + * f10: (left_label, parent_label) + * f11: (right_label, parent_label) + * + * f1: (left_role, right_role, predicate_term) + * f2: (left_role, right_role, predicate_term, other_right_role) + * f3: (left_role, right_role, predicate_term, other_left_role) + * f4: (left_role, right_role, left_head_pos) + * f5: (left_role, right_role, left_head_word) + * f6: (left_role, right_role, left_syntactic_label) + * f7: (left_role, right_role, right_head_pos) + * f8: (left_role, right_role, right_head_word) + * f8: (left_role, right_role, right_syntactic_label) + * f8: (left_role, right_role, left_chunk_status) + * f9: (left_role, right_role, right_chunk_status) + * f10: (left_role, right_role, left_chunk_status) + * f11: (left_role, right_role, right_chunk_status) + * f12: (left_label, parent_label) + * f13: (right_label, parent_label) + */ + static void fnGenerateFeature(const SParsedTree *pTree, + const SPredicate *pPred, + const SPredicateItem *pPredItem, int iPos, + const std::string &strBlock1, + const std::string &strBlock2, + std::ostringstream &ostr) { + SSRLItem *pSRLItem1 = pPredItem->vec_items_[iPos - 1]; + SSRLItem *pSRLItem2 = pPredItem->vec_items_[iPos]; + const STreeItem *pCon1 = pSRLItem1->tree_item_; + const STreeItem *pCon2 = pSRLItem2->tree_item_; + + std::string left_role = pSRLItem1->role_; + std::string right_role = pSRLItem2->role_; + + std::string predicate_term = + pTree->m_vecTerminals[pPred->m_iPosition]->m_pszTerm; + + std::vector<std::string> vec_other_right_sibling; + for (int i = iPos + 1; i < pPredItem->vec_items_.size(); i++) + vec_other_right_sibling.push_back( + std::string(pPredItem->vec_items_[i]->role_)); + if (vec_other_right_sibling.size() == 0) + vec_other_right_sibling.push_back(std::string("NULL")); + + std::vector<std::string> vec_other_left_sibling; + for (int i = 0; i < iPos - 1; i++) + vec_other_right_sibling.push_back( + std::string(pPredItem->vec_items_[i]->role_)); + if (vec_other_left_sibling.size() == 0) + vec_other_left_sibling.push_back(std::string("NULL")); + + // generate features + // f1 + ostr << "f1=" << left_role << "_" << right_role << "_" << predicate_term; + ostr << "f1=" << left_role << "_" << right_role; + + // f2 + for (int i = 0; i < vec_other_right_sibling.size(); i++) { + ostr << " f2=" << left_role << "_" << right_role << "_" << predicate_term + << "_" << vec_other_right_sibling[i]; + ostr << " f2=" << left_role << "_" << right_role << "_" + << vec_other_right_sibling[i]; + } + // f3 + for (int i = 0; i < vec_other_left_sibling.size(); i++) { + ostr << " f3=" << left_role << "_" << right_role << "_" << predicate_term + << "_" << vec_other_left_sibling[i]; + ostr << " f3=" << left_role << "_" << right_role << "_" + << vec_other_left_sibling[i]; + } + // f4 + ostr << " f4=" << left_role << "_" << right_role << "_" + << pTree->m_vecTerminals[pCon1->m_iHeadWord]->m_ptParent->m_pszTerm; + // f5 + ostr << " f5=" << left_role << "_" << right_role << "_" + << pTree->m_vecTerminals[pCon1->m_iHeadWord]->m_pszTerm; + // f6 + ostr << " f6=" << left_role << "_" << right_role << "_" << pCon2->m_pszTerm; + // f7 + ostr << " f7=" << left_role << "_" << right_role << "_" + << pTree->m_vecTerminals[pCon2->m_iHeadWord]->m_ptParent->m_pszTerm; + // f8 + ostr << " f8=" << left_role << "_" << right_role << "_" + << pTree->m_vecTerminals[pCon2->m_iHeadWord]->m_pszTerm; + // f9 + ostr << " f9=" << left_role << "_" << right_role << "_" << pCon2->m_pszTerm; + // f10 + ostr << " f10=" << left_role << "_" << right_role << "_" << strBlock1; + // f11 + ostr << " f11=" << left_role << "_" << right_role << "_" << strBlock2; + // f12 + ostr << " f12=" << left_role << "_" << predicate_term; + ostr << " f12=" << left_role; + // f13 + ostr << " f13=" << right_role << "_" << predicate_term; + ostr << " f13=" << right_role; + } + + private: + static void fnGetOutcome(int i1, int i2, std::string &strOutcome) { + assert(i1 != i2); + if (i1 < i2) { + if (i2 > i1 + 1) + strOutcome = std::string("DM"); + else + strOutcome = std::string("M"); + } else { + if (i1 > i2 + 1) + strOutcome = std::string("DS"); + else + strOutcome = std::string("S"); + } + } + + static void fnGetRelativePosition(const std::vector<int> &vecLeft, + std::vector<int> &vecPosition) { + vecPosition.clear(); + + std::vector<float> vec; + for (int i = 0; i < vecLeft.size(); i++) { + if (vecLeft[i] == -1) { + if (i == 0) + vec.push_back(-1); + else + vec.push_back(vecLeft[i - 1] + 0.1); + } else + vec.push_back(vecLeft[i]); + } + + for (int i = 0; i < vecLeft.size(); i++) { + int count = 0; + + for (int j = 0; j < vecLeft.size(); j++) { + if (j == i) continue; + if (vec[j] < vec[i]) { + count++; + } else if (vec[j] == vec[i] && j < i) { + count++; + } + } + vecPosition.push_back(count); + } + } +}; +} // namespace const_reorder + +#endif // _FF_CONST_REORDER_COMMON_H diff --git a/decoder/ff_soft_syn.cc b/decoder/ff_soft_syn.cc new file mode 100644 index 00000000..2d4369d1 --- /dev/null +++ b/decoder/ff_soft_syn.cc @@ -0,0 +1,265 @@ +/* + * ff_soft_syn.cc + * + */ +#include "ff_soft_syn.h" + +#include "filelib.h" +#include "stringlib.h" +#include "hg.h" +#include "sentence_metadata.h" +#include "ff_const_reorder_common.h" + +#include <string> +#include <vector> +#include <stdio.h> + +using namespace std; +using namespace const_reorder; + +typedef HASH_MAP<std::string, vector<string> > MapFeatures; + +struct SoftSynFeatureImpl { + SoftSynFeatureImpl(const string& /*params*/) { + parsed_tree_ = NULL; + map_features_ = NULL; + } + + ~SoftSynFeatureImpl() { FreeSentenceVariables(); } + + void InitializeInputSentence(const std::string& parse_file) { + FreeSentenceVariables(); + parsed_tree_ = ReadParseTree(parse_file); + + // we can do the features "off-line" + map_features_ = new MapFeatures(); + InitializeFeatures(map_features_); + } + + /* + * ff_const_reorder.cc::ConstReorderFeatureImpl also defines this function + */ + void FindConsts(const SParsedTree* tree, int begin, int end, + vector<STreeItem*>& consts) { + STreeItem* item; + item = tree->m_vecTerminals[begin]->m_ptParent; + while (true) { + while (item->m_ptParent != NULL && + item->m_ptParent->m_iBegin == item->m_iBegin && + item->m_ptParent->m_iEnd <= end) + item = item->m_ptParent; + + if (item->m_ptParent == NULL && item->m_vecChildren.size() == 1 && + strcmp(item->m_pszTerm, "ROOT") == 0) + item = item->m_vecChildren[0]; // we automatically add a "ROOT" node at + // the top, skip it if necessary. + + consts.push_back(item); + if (item->m_iEnd < end) + item = tree->m_vecTerminals[item->m_iEnd + 1]->m_ptParent; + else + break; + } + } + + /* + * according to Marton & Resnik (2008) + * a span cann't have both X+ style and X= style features + * a constituent XP is crossed only if the span not only covers parts of XP's + *content, but also covers one or more words outside XP + * a span may have X+, Y+ + * + * (note, we refer X* features to X= features in Marton & Resnik (2008)) + */ + void GenerateSoftFeature(int begin, int end, const SParsedTree* tree, + vector<string>& vecFeature) { + vector<STreeItem*> vecNode; + FindConsts(tree, begin, end, vecNode); + + if (vecNode.size() == 1) { + // match to one constituent + string feature_name = string(vecNode[0]->m_pszTerm) + string("*"); + vecFeature.push_back(feature_name); + } else { + // match to multiple constituents, find the lowest common parent (lcp) + STreeItem* lcp = vecNode[0]; + while (lcp->m_iEnd < end) lcp = lcp->m_ptParent; + + for (size_t i = 0; i < vecNode.size(); i++) { + STreeItem* item = vecNode[i]; + + while (item != lcp) { + if (item->m_iBegin < begin || item->m_iEnd > end) { + // item is crossed + string feature_name = string(item->m_pszTerm) + string("+"); + vecFeature.push_back(feature_name); + } + if (item->m_iBrotherIndex > 0 && + item->m_ptParent->m_vecChildren[item->m_iBrotherIndex - 1] + ->m_iBegin >= begin && + item->m_ptParent->m_vecChildren[item->m_iBrotherIndex - 1] + ->m_iEnd <= end) + break; // we don't want to collect crossed constituents twice + item = item->m_ptParent; + } + } + } + } + + void GenerateSoftFeatureFromFlattenedTree(int begin, int end, + const SParsedTree* tree, + vector<string>& vecFeature) { + vector<STreeItem*> vecNode; + FindConsts(tree, begin, end, vecNode); + + if (vecNode.size() == 1) { + // match to one constituent + string feature_name = string(vecNode[0]->m_pszTerm) + string("*"); + vecFeature.push_back(feature_name); + } else { + // match to multiple constituents, see if they have a common parent + size_t i = 0; + for (i = 1; i < vecNode.size(); i++) { + if (vecNode[i]->m_ptParent != vecNode[0]->m_ptParent) break; + } + if (i == vecNode.size()) { + // they share a common parent + string feature_name = + string(vecNode[0]->m_ptParent->m_pszTerm) + string("&"); + vecFeature.push_back(feature_name); + } else { + // they don't share a common parent, find the lowest common parent (lcp) + STreeItem* lcp = vecNode[0]; + while (lcp->m_iEnd < end) lcp = lcp->m_ptParent; + + for (size_t i = 0; i < vecNode.size(); i++) { + STreeItem* item = vecNode[i]; + + while (item != lcp) { + if (item->m_iBegin < begin || item->m_iEnd > end) { + // item is crossed + string feature_name = string(item->m_pszTerm) + string("+"); + vecFeature.push_back(feature_name); + } + if (item->m_iBrotherIndex > 0 && + item->m_ptParent->m_vecChildren[item->m_iBrotherIndex - 1] + ->m_iBegin >= begin && + item->m_ptParent->m_vecChildren[item->m_iBrotherIndex - 1] + ->m_iEnd <= end) + break; // we don't want to collect crossed constituents twice + item = item->m_ptParent; + } + } + } + } + } + + void SetSoftSynFeature(const Hypergraph::Edge& edge, + SparseVector<double>* features) { + if (parsed_tree_ == NULL) return; + + // soft feature for the whole span + const vector<string> vecFeature = + GenerateSoftFeature(edge.i_, edge.j_ - 1, map_features_); + for (size_t i = 0; i < vecFeature.size(); i++) { + int f_id = FD::Convert(vecFeature[i]); + if (f_id) features->set_value(f_id, 1); + } + } + + private: + const vector<string>& GenerateSoftFeature(int begin, int end, + MapFeatures* map_features) { + string key; + GenerateKey(begin, end, key); + MapFeatures::const_iterator iter = (*map_features).find(key); + assert(iter != map_features->end()); + return iter->second; + } + + void Byte_to_Char(unsigned char* str, int n) { + str[0] = (n & 255); + str[1] = n / 256; + } + + void GenerateKey(int begin, int end, string& key) { + unsigned char szTerm[1001]; + Byte_to_Char(szTerm, begin); + Byte_to_Char(szTerm + 2, end); + szTerm[4] = '\0'; + key = string(szTerm, szTerm + 4); + } + + void InitializeFeatures(MapFeatures* map_features) { + if (parsed_tree_ == NULL) return; + + for (size_t i = 0; i < parsed_tree_->m_vecTerminals.size(); i++) + for (size_t j = i; j < parsed_tree_->m_vecTerminals.size(); j++) { + vector<string> vecFeature; + GenerateSoftFeature(i, j, parsed_tree_, vecFeature); + string key; + GenerateKey(i, j, key); + (*map_features)[key] = vecFeature; + } + } + + void FreeSentenceVariables() { + if (parsed_tree_ != NULL) delete parsed_tree_; + if (map_features_ != NULL) delete map_features_; + } + + SParsedTree* ReadParseTree(const std::string& parse_file) { + SParseReader* reader = new SParseReader(parse_file.c_str(), false); + SParsedTree* tree = reader->fnReadNextParseTree(); + // assert(tree != NULL); + delete reader; + return tree; + } + + private: + SParsedTree* parsed_tree_; + + MapFeatures* map_features_; +}; + +SoftSynFeature::SoftSynFeature(std::string param) { + pimpl_ = new SoftSynFeatureImpl(param); + name_ = "SoftSynFeature"; +} + +SoftSynFeature::~SoftSynFeature() { delete pimpl_; } + +void SoftSynFeature::PrepareForInput(const SentenceMetadata& smeta) { + string parse_file = smeta.GetSGMLValue("parse"); + if (parse_file.empty()) { + parse_file = smeta.GetSGMLValue("src_tree"); + } + assert(parse_file.size()); + pimpl_->InitializeInputSentence(parse_file); +} + +void SoftSynFeature::TraversalFeaturesImpl( + const SentenceMetadata& /*smeta*/, const Hypergraph::Edge& edge, + const vector<const void*>& /*ant_states*/, SparseVector<double>* features, + SparseVector<double>* /*estimated_features*/, void* /*state*/) const { + pimpl_->SetSoftSynFeature(edge, features); +} + +string SoftSynFeature::usage(bool /*param*/, bool /*verbose*/) { + return "SoftSynFeature"; +} + +boost::shared_ptr<FeatureFunction> CreateSoftSynFeatureModel( + std::string param) { + SoftSynFeature* ret = new SoftSynFeature(param); + return boost::shared_ptr<FeatureFunction>(ret); +} + +boost::shared_ptr<FeatureFunction> SoftSynFeatureFactory::Create( + std::string param) const { + return CreateSoftSynFeatureModel(param); +} + +std::string SoftSynFeatureFactory::usage(bool params, bool verbose) const { + return SoftSynFeature::usage(params, verbose); +} diff --git a/decoder/ff_soft_syn.h b/decoder/ff_soft_syn.h new file mode 100644 index 00000000..df9a6cc8 --- /dev/null +++ b/decoder/ff_soft_syn.h @@ -0,0 +1,38 @@ +/* + * ff_soft_syn.h + * + */ + +#ifndef FF_SOFT_SYN_H_ +#define FF_SOFT_SYN_H_ + +#include "ff_factory.h" +#include "ff.h" + +struct SoftSynFeatureImpl; + +class SoftSynFeature : public FeatureFunction { + public: + SoftSynFeature(std::string param); + ~SoftSynFeature(); + static std::string usage(bool param, bool verbose); + + protected: + virtual void PrepareForInput(const SentenceMetadata& smeta); + + virtual void TraversalFeaturesImpl( + const SentenceMetadata& smeta, const HG::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, SparseVector<double>* estimated_features, + void* out_context) const; + + private: + SoftSynFeatureImpl* pimpl_; +}; + +struct SoftSynFeatureFactory : public FactoryBase<FeatureFunction> { + FP Create(std::string param) const; + std::string usage(bool params, bool verbose) const; +}; + +#endif /* FF_SOFT_SYN_H_ */ diff --git a/decoder/ffset.cc b/decoder/ffset.cc index 5820f421..8ba70389 100644 --- a/decoder/ffset.cc +++ b/decoder/ffset.cc @@ -14,6 +14,11 @@ ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*> for (int i = 0; i < models_.size(); ++i) { model_state_pos_[i] = state_size_; state_size_ += models_[i]->StateSize(); + int num_ignored_bytes = models_[i]->IgnoredStateSize(); + if (num_ignored_bytes > 0) { + ranges_to_erase_.push_back( + {state_size_ - num_ignored_bytes, state_size_}); + } } } @@ -70,3 +75,13 @@ void ModelSet::AddFinalFeatures(const FFState& state, HG::Edge* edge,SentenceMet edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); } +bool ModelSet::NeedsStateErasure() const { return !ranges_to_erase_.empty(); } + +void ModelSet::EraseIgnoredBytes(FFState* state) const { + // TODO: can we memset? + for (const auto& range : ranges_to_erase_) { + for (int i = range.first; i < range.second; ++i) { + (*state)[i] = 0; + } + } +} diff --git a/decoder/ffset.h b/decoder/ffset.h index b7322ee2..84f9fdb9 100644 --- a/decoder/ffset.h +++ b/decoder/ffset.h @@ -1,6 +1,7 @@ #ifndef FFSET_H_ #define FFSET_H_ +#include <utility> #include <vector> #include "value_array.h" #include "prob.h" @@ -47,11 +48,18 @@ class ModelSet { bool stateless() const { return !state_size_; } + // Part of a feature state may be used for storing some side data for + // calculating feature values but not necessary for splitting hypernodes. Such + // bytes needs to be erased for hypernode splitting. + bool NeedsStateErasure() const; + void EraseIgnoredBytes(FFState* state) const; + private: std::vector<const FeatureFunction*> models_; const std::vector<double>& weights_; int state_size_; std::vector<int> model_state_pos_; + std::vector<std::pair<int, int> > ranges_to_erase_; }; #endif diff --git a/decoder/forest_writer.cc b/decoder/forest_writer.cc index 6e4cccb3..cc9094d7 100644 --- a/decoder/forest_writer.cc +++ b/decoder/forest_writer.cc @@ -11,13 +11,13 @@ using namespace std; ForestWriter::ForestWriter(const std::string& path, int num) : - fname_(path + '/' + boost::lexical_cast<string>(num) + ".json.gz"), used_(false) {} + fname_(path + '/' + boost::lexical_cast<string>(num) + ".bin.gz"), used_(false) {} -bool ForestWriter::Write(const Hypergraph& forest, bool minimal_rules) { +bool ForestWriter::Write(const Hypergraph& forest) { assert(!used_); used_ = true; cerr << " Writing forest to " << fname_ << endl; WriteFile wf(fname_); - return HypergraphIO::WriteToJSON(forest, minimal_rules, wf.stream()); + return HypergraphIO::WriteToBinary(forest, wf.stream()); } diff --git a/decoder/forest_writer.h b/decoder/forest_writer.h index 4d28de77..54e83470 100644 --- a/decoder/forest_writer.h +++ b/decoder/forest_writer.h @@ -7,7 +7,7 @@ class Hypergraph; struct ForestWriter { ForestWriter(const std::string& path, int num); - bool Write(const Hypergraph& forest, bool minimal_rules); + bool Write(const Hypergraph& forest); const std::string fname_; bool used_; diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc index 50e6adcc..fe28f4c6 100644 --- a/decoder/fst_translator.cc +++ b/decoder/fst_translator.cc @@ -27,11 +27,15 @@ struct FSTTranslatorImpl { const vector<double>& weights, Hypergraph* forest) { bool composed = false; - if (input.find("{\"rules\"") == 0) { + if (input.find("::forest::") == 0) { istringstream is(input); + string header, fname; + is >> header >> fname; + ReadFile rf(fname); + if (!rf) { cerr << "Failed to open " << fname << endl; abort(); } Hypergraph src_cfg_hg; - if (!HypergraphIO::ReadFromJSON(&is, &src_cfg_hg)) { - cerr << "Failed to read HG from JSON.\n"; + if (!HypergraphIO::ReadFromBinary(rf.stream(), &src_cfg_hg)) { + cerr << "Failed to read HG.\n"; abort(); } if (add_pass_through_rules) { diff --git a/decoder/hg.h b/decoder/hg.h index 256f650f..c756012e 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -18,6 +18,7 @@ #include <string> #include <vector> #include <boost/shared_ptr.hpp> +#include <boost/serialization/vector.hpp> #include "feature_vector.h" #include "small_vector.h" @@ -69,6 +70,18 @@ namespace HG { short int j_; short int prev_i_; short int prev_j_; + template<class Archive> + void serialize(Archive & ar, const unsigned int /*version*/) { + ar & head_node_; + ar & tail_nodes_; + ar & rule_; + ar & feature_values_; + ar & i_; + ar & j_; + ar & prev_i_; + ar & prev_j_; + ar & id_; + } void show(std::ostream &o,unsigned mask=SPAN|RULE) const { o<<'{'; if (mask&CATEGORY) @@ -149,6 +162,24 @@ namespace HG { WordID NT() const { return -cat_; } EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_ EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_ + template<class Archive> + void save(Archive & ar, const unsigned int /*version*/) const { + ar & node_hash; + ar & id_; + ar & TD::Convert(-cat_); + ar & in_edges_; + ar & out_edges_; + } + template<class Archive> + void load(Archive & ar, const unsigned int /*version*/) { + ar & node_hash; + ar & id_; + std::string cat; ar & cat; + cat_ = -TD::Convert(cat); + ar & in_edges_; + ar & out_edges_; + } + BOOST_SERIALIZATION_SPLIT_MEMBER() void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting node_hash = o.node_hash; cat_=o.cat_; @@ -492,6 +523,27 @@ public: void set_ids(); // resync edge,node .id_ void check_ids() const; // assert that .id_ have been kept in sync + template<class Archive> + void save(Archive & ar, const unsigned int /*version*/) const { + unsigned ns = nodes_.size(); ar & ns; + unsigned es = edges_.size(); ar & es; + for (auto& n : nodes_) ar & n; + for (auto& e : edges_) ar & e; + int x; + x = edges_topo_; ar & x; + x = is_linear_chain_; ar & x; + } + template<class Archive> + void load(Archive & ar, const unsigned int /*version*/) { + unsigned ns; ar & ns; nodes_.resize(ns); + unsigned es; ar & es; edges_.resize(es); + for (auto& n : nodes_) ar & n; + for (auto& e : edges_) ar & e; + int x; + ar & x; edges_topo_ = x; + ar & x; is_linear_chain_ = x; + } + BOOST_SERIALIZATION_SPLIT_MEMBER() private: Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {} }; diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index eb0be3d4..626b2954 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -6,362 +6,27 @@ #include <sstream> #include <iostream> +#include <boost/archive/binary_iarchive.hpp> +#include <boost/archive/binary_oarchive.hpp> +#include <boost/serialization/shared_ptr.hpp> + #include "fast_lexical_cast.hpp" #include "tdict.h" -#include "json_parse.h" #include "hg.h" using namespace std; -struct HGReader : public JSONParser { - HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; } - - void CreateNode(const string& cat, const string& shash, const vector<int>& in_edges) { - WordID c = TD::Convert("X") * -1; - if (!cat.empty()) c = TD::Convert(cat) * -1; - Hypergraph::Node* node = hg.AddNode(c); - char* dend; - if (shash.size()) - node->node_hash = strtoull(shash.c_str(), &dend, 16); - else - node->node_hash = 0; - for (int i = 0; i < in_edges.size(); ++i) { - if (in_edges[i] >= hg.edges_.size()) { - cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i] - << ", but hg only has " << hg.edges_.size() << " edges!\n"; - abort(); - } - hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node); - } - } - void CreateEdge(const TRulePtr& rule, SparseVector<double>* feats, const SmallVectorUnsigned& tail) { - Hypergraph::Edge* edge = hg.AddEdge(rule, tail); - feats->swap(edge->feature_values_); - edge->i_ = spans[0]; - edge->j_ = spans[1]; - edge->prev_i_ = spans[2]; - edge->prev_j_ = spans[3]; - } - - bool HandleJSONEvent(int type, const JSON_value* value) { - switch(state) { - case -1: - assert(type == JSON_T_OBJECT_BEGIN); - state = 0; - break; - case 0: - if (type == JSON_T_OBJECT_END) { - //cerr << "HG created\n"; // TODO, signal some kind of callback - } else if (type == JSON_T_KEY) { - string val = value->vu.str.value; - if (val == "features") { assert(fdict.empty()); state = 1; } - else if (val == "is_sorted") { state = 3; } - else if (val == "rules") { assert(rules.empty()); state = 4; } - else if (val == "node") { state = 8; } - else if (val == "edges") { state = 13; } - else { cerr << "Unexpected key: " << val << endl; return false; } - } - break; - - // features - case 1: - if(type == JSON_T_NULL) { state = 0; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 2; - break; - case 2: - if(type == JSON_T_ARRAY_END) { state = 0; break; } - assert(type == JSON_T_STRING); - fdict.push_back(FD::Convert(value->vu.str.value)); - assert(fdict.back() > 0); - break; - - // is_sorted - case 3: - assert(type == JSON_T_TRUE || type == JSON_T_FALSE); - is_sorted = (type == JSON_T_TRUE); - if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; } - state = 0; - break; - - // rules - case 4: - if(type == JSON_T_NULL) { state = 0; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 5; - break; - case 5: - if(type == JSON_T_ARRAY_END) { state = 0; break; } - assert(type == JSON_T_INTEGER); - state = 6; - rule_id = value->vu.integer_value; - break; - case 6: - assert(type == JSON_T_STRING); - rules[rule_id] = TRulePtr(new TRule(value->vu.str.value)); - state = 5; - break; - - // Nodes - case 8: - assert(type == JSON_T_OBJECT_BEGIN); - ++nodes; - in_edges.clear(); - cat.clear(); - shash.clear(); - state = 9; break; - case 9: - if (type == JSON_T_OBJECT_END) { - //cerr << "Creating NODE\n"; - CreateNode(cat, shash, in_edges); - state = 0; break; - } - assert(type == JSON_T_KEY); - cur_key = value->vu.str.value; - if (cur_key == "cat") { assert(cat.empty()); state = 10; break; } - if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; } - if (cur_key == "node_hash") { assert(shash.empty()); state = 24; break; } - cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n"; - return false; - case 10: - assert(type == JSON_T_STRING || type == JSON_T_NULL); - cat = value->vu.str.value; - state = 9; break; - case 11: - if (type == JSON_T_NULL) { state = 9; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 12; break; - case 12: - if (type == JSON_T_ARRAY_END) { state = 9; break; } - assert(type == JSON_T_INTEGER); - //cerr << "in_edges: " << value->vu.integer_value << endl; - in_edges.push_back(value->vu.integer_value); - break; - - // "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12}, - // { "tail": null, "feats" : [0,0.87,1,0.02], "spans":[1,2,3,4], "rule": 17}, - // { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}] - case 13: - assert(type == JSON_T_ARRAY_BEGIN); - state = 14; - break; - case 14: - if (type == JSON_T_ARRAY_END) { state = 0; break; } - assert(type == JSON_T_OBJECT_BEGIN); - //cerr << "New edge\n"; - ++edges; - cur_rule.reset(); feats.clear(); tail.clear(); - state = 15; break; - case 15: - if (type == JSON_T_OBJECT_END) { - CreateEdge(cur_rule, &feats, tail); - state = 14; break; - } - assert(type == JSON_T_KEY); - cur_key = value->vu.str.value; - //cerr << "edge key " << cur_key << endl; - if (cur_key == "rule") { assert(!cur_rule); state = 16; break; } - if (cur_key == "spans") { assert(!cur_rule); state = 22; break; } - if (cur_key == "feats") { assert(feats.empty()); state = 17; break; } - if (cur_key == "tail") { assert(tail.empty()); state = 20; break; } - cerr << "Unexpected key " << cur_key << " in edge specification\n"; - return false; - case 16: // edge.rule - if (type == JSON_T_INTEGER) { - int rule_id = value->vu.integer_value; - if (rules.find(rule_id) == rules.end()) { - // rules list must come before the edge definitions! - cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n"; - return false; - } - cur_rule = rules[rule_id]; - } else if (type == JSON_T_STRING) { - cur_rule.reset(new TRule(value->vu.str.value)); - } else { - cerr << "Rule must be either a rule id or a rule string" << endl; - return false; - } - // cerr << "Edge: rule=" << cur_rule->AsString() << endl; - state = 15; - break; - case 17: // edge.feats - if (type == JSON_T_NULL) { state = 15; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 18; break; - case 18: - if (type == JSON_T_ARRAY_END) { state = 15; break; } - if (type != JSON_T_INTEGER && type != JSON_T_STRING) { - cerr << "Unexpected feature id type\n"; return false; - } - if (type == JSON_T_INTEGER) { - fid = value->vu.integer_value; - assert(fid < fdict.size()); - fid = fdict[fid]; - } else if (JSON_T_STRING) { - fid = FD::Convert(value->vu.str.value); - } else { abort(); } - state = 19; - break; - case 19: - { - assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT); - double val = (type == JSON_T_INTEGER ? static_cast<double>(value->vu.integer_value) : - strtod(value->vu.str.value, NULL)); - feats.set_value(fid, val); - state = 18; - break; - } - case 20: // edge.tail - if (type == JSON_T_NULL) { state = 15; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 21; break; - case 21: - if (type == JSON_T_ARRAY_END) { state = 15; break; } - assert(type == JSON_T_INTEGER); - tail.push_back(value->vu.integer_value); - break; - case 22: // edge.spans - assert(type == JSON_T_ARRAY_BEGIN); - state = 23; - spans[0] = spans[1] = spans[2] = spans[3] = -1; - spanc = 0; - break; - case 23: - if (type == JSON_T_ARRAY_END) { state = 15; break; } - assert(type == JSON_T_INTEGER); - assert(spanc < 4); - spans[spanc] = value->vu.integer_value; - ++spanc; - break; - case 24: // read node hash - assert(type == JSON_T_STRING); - shash = value->vu.str.value; - state = 9; - break; - } - return true; - } - string rp; - string cat; - SmallVectorUnsigned tail; - vector<int> in_edges; - string shash; - TRulePtr cur_rule; - map<int, TRulePtr> rules; - vector<int> fdict; - SparseVector<double> feats; - int state; - int fid; - int nodes; - int edges; - int spans[4]; - int spanc; - string cur_key; - Hypergraph& hg; - int rule_id; - bool nodes_needed; - bool edges_needed; - bool is_sorted; -}; - -bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) { +bool HypergraphIO::ReadFromBinary(istream* in, Hypergraph* hg) { + boost::archive::binary_iarchive oa(*in); hg->clear(); - HGReader reader(hg); - return reader.Parse(in); -} - -static void WriteRule(const TRule& r, ostream* out) { - if (!r.lhs_) { (*out) << "[X] ||| "; } - JSONParser::WriteEscapedString(r.AsString(), out); + oa >> *hg; + return true; } -bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) { - if (hg.empty()) { *out << "{}\n"; return true; } - map<const TRule*, int> rid; - ostream& o = *out; - rid[NULL] = 0; - o << '{'; - if (!remove_rules) { - o << "\"rules\":["; - for (int i = 0; i < hg.edges_.size(); ++i) { - const TRule* r = hg.edges_[i].rule_.get(); - int &id = rid[r]; - if (!id) { - id=rid.size() - 1; - if (id > 1) o << ','; - o << id << ','; - WriteRule(*r, &o); - }; - } - o << "],"; - } - const bool use_fdict = FD::NumFeats() < 1000; - if (use_fdict) { - o << "\"features\":["; - for (int i = 1; i < FD::NumFeats(); ++i) { - o << (i==1 ? "":","); - JSONParser::WriteEscapedString(FD::Convert(i), &o); - } - o << "],"; - } - vector<int> edgemap(hg.edges_.size(), -1); // edges may be in non-topo order - int edge_count = 0; - for (int i = 0; i < hg.nodes_.size(); ++i) { - const Hypergraph::Node& node = hg.nodes_[i]; - if (i > 0) { o << ","; } - o << "\"edges\":["; - for (int j = 0; j < node.in_edges_.size(); ++j) { - const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; - edgemap[edge.id_] = edge_count; - ++edge_count; - o << (j == 0 ? "" : ",") << "{"; - - o << "\"tail\":["; - for (int k = 0; k < edge.tail_nodes_.size(); ++k) { - o << (k > 0 ? "," : "") << edge.tail_nodes_[k]; - } - o << "],"; - - o << "\"spans\":[" << edge.i_ << "," << edge.j_ << "," << edge.prev_i_ << "," << edge.prev_j_ << "],"; - - o << "\"feats\":["; - bool first = true; - for (SparseVector<double>::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) { - if (!it->second) continue; // don't write features that have a zero value - if (!it->first) continue; // if the feature set was frozen this might happen - if (!first) o << ','; - if (use_fdict) - o << (it->first - 1); - else { - JSONParser::WriteEscapedString(FD::Convert(it->first), &o); - } - o << ',' << it->second; - first = false; - } - o << "]"; - if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; } - o << "}"; - } - o << "],"; - - o << "\"node\":{\"in_edges\":["; - for (int j = 0; j < node.in_edges_.size(); ++j) { - int mapped_edge = edgemap[node.in_edges_[j]]; - assert(mapped_edge >= 0); - o << (j == 0 ? "" : ",") << mapped_edge; - } - o << "]"; - if (node.cat_ < 0) { - o << ",\"cat\":"; - JSONParser::WriteEscapedString(TD::Convert(node.cat_ * -1), &o); - } - char buf[48]; - sprintf(buf, "%016lX", node.node_hash); - o << ",\"node_hash\":\"" << buf << "\""; - o << "}"; - } - o << "}\n"; +bool HypergraphIO::WriteToBinary(const Hypergraph& hg, ostream* out) { + boost::archive::binary_oarchive oa(*out); + oa << hg; return true; } diff --git a/decoder/hg_io.h b/decoder/hg_io.h index 5a2bd808..93a9e280 100644 --- a/decoder/hg_io.h +++ b/decoder/hg_io.h @@ -9,19 +9,11 @@ class Hypergraph; struct HypergraphIO { - // the format is basically a list of nodes and edges in topological order - // any edge you read, you must have already read its tail nodes - // any node you read, you must have already read its incoming edges - // this may make writing a bit more challenging if your forest is not - // topologically sorted (but that probably doesn't happen very often), - // but it makes reading much more memory efficient. - // see test_data/small.json.gz for an email encoding - static bool ReadFromJSON(std::istream* in, Hypergraph* out); + static bool ReadFromBinary(std::istream* in, Hypergraph* out); + static bool WriteToBinary(const Hypergraph& hg, std::ostream* out); // if remove_rules is used, the hypergraph is serialized without rule information // (so it only contains structure and feature information) - static bool WriteToJSON(const Hypergraph& hg, bool remove_rules, std::ostream* out); - static void WriteAsCFG(const Hypergraph& hg); // Write only the target size information in bottom-up order. diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 5cb8626a..366b269d 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -1,10 +1,14 @@ #define BOOST_TEST_MODULE hg_test #include <boost/test/unit_test.hpp> #include <boost/test/floating_point_comparison.hpp> +#include <boost/archive/text_oarchive.hpp> +#include <boost/archive/text_iarchive.hpp> +#include <boost/serialization/shared_ptr.hpp> +#include <boost/serialization/vector.hpp> +#include <sstream> #include <iostream> #include "tdict.h" -#include "json_parse.h" #include "hg_intersect.h" #include "hg_union.h" #include "viterbi.h" @@ -394,16 +398,6 @@ BOOST_AUTO_TEST_CASE(Small) { BOOST_CHECK_CLOSE(2.1431036, log(c2), 1e-4); } -BOOST_AUTO_TEST_CASE(JSONTest) { - std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); - ostringstream os; - JSONParser::WriteEscapedString("\"I don't know\", she said.", &os); - BOOST_CHECK_EQUAL("\"\\\"I don't know\\\", she said.\"", os.str()); - ostringstream os2; - JSONParser::WriteEscapedString("yes", &os2); - BOOST_CHECK_EQUAL("\"yes\"", os2.str()); -} - BOOST_AUTO_TEST_CASE(TestGenericKBest) { std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; @@ -427,19 +421,29 @@ BOOST_AUTO_TEST_CASE(TestGenericKBest) { } } -BOOST_AUTO_TEST_CASE(TestReadWriteHG) { +BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) { std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); - Hypergraph hg,hg2; - CreateHG(path, &hg); - hg.edges_.front().j_ = 23; - hg.edges_.back().prev_i_ = 99; - ostringstream os; - HypergraphIO::WriteToJSON(hg, false, &os); - istringstream is(os.str()); - HypergraphIO::ReadFromJSON(&is, &hg2); - BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths()); - BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23); - BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99); + Hypergraph hg; + Hypergraph hg2; + std::string out; + { + CreateHG(path, &hg); + hg.edges_.front().j_ = 23; + hg.edges_.back().prev_i_ = 99; + ostringstream os; + boost::archive::text_oarchive oa(os); + oa << hg; + out = os.str(); + } + { + cerr << out << endl; + istringstream is(out); + boost::archive::text_iarchive ia(is); + ia >> hg2; + BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths()); + BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23); + BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99); + } } BOOST_AUTO_TEST_SUITE_END() diff --git a/decoder/hg_test.h b/decoder/hg_test.h index b7bab3c2..575b9c54 100644 --- a/decoder/hg_test.h +++ b/decoder/hg_test.h @@ -12,12 +12,8 @@ namespace { typedef char const* Name; -Name urdu_json="urdu.json.gz"; -Name urdu_wts="Arity_0 1.70741473606976 Arity_1 1.12426238048012 Arity_2 1.14986187839554 Glue -0.04589037041388 LanguageModel 1.09051 PassThrough -3.66226367902928 PhraseModel_0 -1.94633451863252 PhraseModel_1 -0.1475347695476 PhraseModel_2 -1.614818994946 WordPenalty -3.0 WordPenaltyFsa -0.56028442964748 ShorterThanPrev -10 LongerThanPrev -10"; -Name small_json="small.json.gz"; +Name small_json="small.bin.gz"; Name small_wts="Model_0 -2 Model_1 -.5 Model_2 -1.1 Model_3 -1 Model_4 -1 Model_5 .5 Model_6 .2 Model_7 -.3"; -Name perro_json="perro.json.gz"; -Name perro_wts="SameFirstLetter 1 LongerThanPrev 1 ShorterThanPrev 1 GlueTop 0.0 Glue -1.0 EgivenF -0.5 FgivenE -0.5 LexEgivenF -0.5 LexFgivenE -0.5 LM 1"; } @@ -32,7 +28,7 @@ struct HGSetup { static void JsonFile(Hypergraph *hg,std::string f) { ReadFile rf(f); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } static void JsonTestFile(Hypergraph *hg,std::string path,std::string n) { JsonFile(hg,path + "/"+n); @@ -48,35 +44,35 @@ void AddNullEdge(Hypergraph* hg) { } void HGSetup::CreateTinyLatticeHG(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.tiny_lattice"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.tiny_lattice.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); AddNullEdge(hg); } void HGSetup::CreateLatticeHG(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.lattice"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.lattice.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); AddNullEdge(hg); } void HGSetup::CreateHG_tiny(const std::string& path, Hypergraph* hg) { - ReadFile rf(path + "/hg_test.tiny"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.tiny.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } void HGSetup::CreateHG_int(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.hg_int"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.hg_int.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } void HGSetup::CreateHG(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.hg"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.hg.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } void HGSetup::CreateHGBalanced(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.hg_balanced"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.hg_balanced.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } #endif diff --git a/decoder/json_parse.cc b/decoder/json_parse.cc deleted file mode 100644 index f6fdfea8..00000000 --- a/decoder/json_parse.cc +++ /dev/null @@ -1,50 +0,0 @@ -#include "json_parse.h" - -#include <string> -#include <iostream> - -using namespace std; - -static const char *json_hex_chars = "0123456789abcdef"; - -void JSONParser::WriteEscapedString(const string& in, ostream* out) { - int pos = 0; - int start_offset = 0; - unsigned char c = 0; - (*out) << '"'; - while(pos < in.size()) { - c = in[pos]; - switch(c) { - case '\b': - case '\n': - case '\r': - case '\t': - case '"': - case '\\': - case '/': - if(pos - start_offset > 0) - (*out) << in.substr(start_offset, pos - start_offset); - if(c == '\b') (*out) << "\\b"; - else if(c == '\n') (*out) << "\\n"; - else if(c == '\r') (*out) << "\\r"; - else if(c == '\t') (*out) << "\\t"; - else if(c == '"') (*out) << "\\\""; - else if(c == '\\') (*out) << "\\\\"; - else if(c == '/') (*out) << "\\/"; - start_offset = ++pos; - break; - default: - if(c < ' ') { - cerr << "Warning, bad character (" << static_cast<int>(c) << ") in string\n"; - if(pos - start_offset > 0) - (*out) << in.substr(start_offset, pos - start_offset); - (*out) << "\\u00" << json_hex_chars[c >> 4] << json_hex_chars[c & 0xf]; - start_offset = ++pos; - } else pos++; - } - } - if(pos - start_offset > 0) - (*out) << in.substr(start_offset, pos - start_offset); - (*out) << '"'; -} - diff --git a/decoder/json_parse.h b/decoder/json_parse.h deleted file mode 100644 index 85e2eff1..00000000 --- a/decoder/json_parse.h +++ /dev/null @@ -1,58 +0,0 @@ -#ifndef JSON_WRAPPER_H_ -#define JSON_WRAPPER_H_ - -#include <iostream> -#include <cassert> -#include "JSON_parser.h" - -class JSONParser { - public: - JSONParser() { - init_JSON_config(&config); - hack.mf = &JSONParser::Callback; - config.depth = 10; - config.callback_ctx = reinterpret_cast<void*>(this); - config.callback = hack.cb; - config.allow_comments = 1; - config.handle_floats_manually = 1; - jc = new_JSON_parser(&config); - } - virtual ~JSONParser() { - delete_JSON_parser(jc); - } - bool Parse(std::istream* in) { - int count = 0; - int lc = 1; - for (; in ; ++count) { - int next_char = in->get(); - if (!in->good()) break; - if (lc == '\n') { ++lc; } - if (!JSON_parser_char(jc, next_char)) { - std::cerr << "JSON_parser_char: syntax error, line " << lc << " (byte " << count << ")" << std::endl; - return false; - } - } - if (!JSON_parser_done(jc)) { - std::cerr << "JSON_parser_done: syntax error\n"; - return false; - } - return true; - } - static void WriteEscapedString(const std::string& in, std::ostream* out); - protected: - virtual bool HandleJSONEvent(int type, const JSON_value* value) = 0; - private: - int Callback(int type, const JSON_value* value) { - if (HandleJSONEvent(type, value)) return 1; - return 0; - } - JSON_parser_struct* jc; - JSON_config config; - typedef int (JSONParser::* MF)(int type, const struct JSON_value_struct* value); - union CBHack { - JSON_parser_callback cb; - MF mf; - } hack; -}; - -#endif diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index d2c4715c..cd587833 100644 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -21,6 +21,7 @@ #include "kbest.h" #include "timing_stats.h" #include "sentences.h" +#include "b64featvector.h" //TODO: put function impls into .cc //TODO: move Translation into its own .h and use in cdec @@ -252,19 +253,31 @@ struct OracleBleu { } bool show_derivation; + int show_derivation_mask; + template <class Filter> - void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) { + void kbest(int sent_id, Hypergraph const& forest, int k, bool mr_mira_compat, + int src_len, std::ostream& kbest_out = std::cout, + std::ostream& deriv_out = std::cerr) { using namespace std; using namespace boost; typedef KBest::KBestDerivations<Sentence, ESentenceTraversal,Filter> K; K kbest(forest,k); //add length (f side) src length of this sentence to the psuedo-doc src length count float curr_src_length = doc_src_length + tmp_src_length; - for (int i = 0; i < k; ++i) { + if (mr_mira_compat) kbest_out << k << "\n"; + int i = 0; + for (; i < k; ++i) { typename K::Derivation *d = kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; - kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " - << d->feature_values << " ||| " << log(d->score); + kbest_out << sent_id << " ||| "; + if (mr_mira_compat) kbest_out << src_len << " ||| "; + kbest_out << TD::GetString(d->yield) << " ||| "; + if (mr_mira_compat) + kbest_out << EncodeFeatureVector(d->feature_values); + else + kbest_out << d->feature_values; + kbest_out << " ||| " << log(d->score); if (!refs.empty()) { ScoreP sentscore = GetScore(d->yield,sent_id); sentscore->PlusEquals(*doc_score,float(1)); @@ -275,14 +288,21 @@ struct OracleBleu { if (show_derivation) { deriv_out<<"\nsent_id="<<sent_id<<"."<<i<<" ||| "; //where i is candidate #/k deriv_out<<log(d->score)<<"\n"; - deriv_out<<kbest.derivation_tree(*d,true); + deriv_out<<kbest.derivation_tree(*d,true, show_derivation_mask); deriv_out<<"\n"<<flush; } } + if (mr_mira_compat) { + for (; i < k; ++i) kbest_out << "\n"; + kbest_out << flush; + } } // TODO decoder output should probably be moved to another file - how about oracle_bleu.h - void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_, std::string const &deriv_out_filename_) { + void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, + const bool unique, const bool mr_mira_compat, + const int src_len, std::string const& kbest_out_filename_, + std::string const& deriv_out_filename_) { WriteFile ko(kbest_out_filename_); std::cerr << "Output kbest to " << kbest_out_filename_ <<std::endl; @@ -295,9 +315,11 @@ struct OracleBleu { WriteFile oderiv(sderiv.str()); if (!unique) - kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),oderiv.get()); + kbest<KBest::NoFilter<std::vector<WordID> > >( + sent_id, forest, k, mr_mira_compat, src_len, ko.get(), oderiv.get()); else { - kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),oderiv.get()); + kbest<KBest::FilterUnique>(sent_id, forest, k, mr_mira_compat, src_len, + ko.get(), oderiv.get()); } } @@ -305,7 +327,8 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo { std::ostringstream kbest_string_stream; kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id; - DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), "-"); + DumpKBest(sent_id, forest, k, unique, false, -1, kbest_string_stream.str(), + "-"); } }; diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc index 18c83c56..2c5fa9c4 100644 --- a/decoder/rescore_translator.cc +++ b/decoder/rescore_translator.cc @@ -3,6 +3,7 @@ #include <sstream> #include <boost/shared_ptr.hpp> +#include "filelib.h" #include "sentence_metadata.h" #include "hg.h" #include "hg_io.h" @@ -20,16 +21,18 @@ struct RescoreTranslatorImpl { bool Translate(const string& input, const vector<double>& weights, Hypergraph* forest) { - if (input == "{}") return false; - if (input.find("{\"rules\"") == 0) { - istringstream is(input); - Hypergraph src_cfg_hg; - if (!HypergraphIO::ReadFromJSON(&is, forest)) { - cerr << "Parse error while reading HG from JSON.\n"; - abort(); - } - } else { - cerr << "Can only read HG input from JSON: use training/grammar_convert\n"; + istringstream is(input); + string header, fname; + is >> header >> fname; + if (header != "::forest::") { + cerr << "RescoreTranslator: expected input lines of form ::forest:: filename.gz\n"; + abort(); + } + ReadFile rf(fname); + if (!rf) { cerr << "Can't read " << fname << endl; abort(); } + Hypergraph src_cfg_hg; + if (!HypergraphIO::ReadFromBinary(rf.stream(), forest)) { + cerr << "Parse error while reading HG.\n"; abort(); } Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll index d4a8d86b..8b48ab7b 100644 --- a/decoder/rule_lexer.ll +++ b/decoder/rule_lexer.ll @@ -356,6 +356,7 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const void RuleLexer::ReadRule(const std::string& srule, RuleCallback func, bool mono, void* extra) { init_default_feature_names(); + scfglex_fname = srule; lex_mono_rules = mono; lex_line = 1; rule_callback_extra = extra; diff --git a/decoder/test_data/hg_test.hg.bin.gz b/decoder/test_data/hg_test.hg.bin.gz Binary files differnew file mode 100644 index 00000000..c07fbe8c --- /dev/null +++ b/decoder/test_data/hg_test.hg.bin.gz diff --git a/decoder/test_data/hg_test.hg_balanced.bin.gz b/decoder/test_data/hg_test.hg_balanced.bin.gz Binary files differnew file mode 100644 index 00000000..896d3d60 --- /dev/null +++ b/decoder/test_data/hg_test.hg_balanced.bin.gz diff --git a/decoder/test_data/hg_test.hg_int.bin.gz b/decoder/test_data/hg_test.hg_int.bin.gz Binary files differnew file mode 100644 index 00000000..e0bd6187 --- /dev/null +++ b/decoder/test_data/hg_test.hg_int.bin.gz diff --git a/decoder/test_data/hg_test.lattice.bin.gz b/decoder/test_data/hg_test.lattice.bin.gz Binary files differnew file mode 100644 index 00000000..8a8c05f4 --- /dev/null +++ b/decoder/test_data/hg_test.lattice.bin.gz diff --git a/decoder/test_data/hg_test.tiny.bin.gz b/decoder/test_data/hg_test.tiny.bin.gz Binary files differnew file mode 100644 index 00000000..0e68eb40 --- /dev/null +++ b/decoder/test_data/hg_test.tiny.bin.gz diff --git a/decoder/test_data/hg_test.tiny_lattice.bin.gz b/decoder/test_data/hg_test.tiny_lattice.bin.gz Binary files differnew file mode 100644 index 00000000..97e8dc05 --- /dev/null +++ b/decoder/test_data/hg_test.tiny_lattice.bin.gz diff --git a/decoder/test_data/perro.json.gz b/decoder/test_data/perro.json.gz Binary files differdeleted file mode 100644 index 41de5758..00000000 --- a/decoder/test_data/perro.json.gz +++ /dev/null diff --git a/decoder/test_data/small.bin.gz b/decoder/test_data/small.bin.gz Binary files differnew file mode 100644 index 00000000..1c5a1631 --- /dev/null +++ b/decoder/test_data/small.bin.gz diff --git a/decoder/test_data/small.json.gz b/decoder/test_data/small.json.gz Binary files differdeleted file mode 100644 index f6f37293..00000000 --- a/decoder/test_data/small.json.gz +++ /dev/null diff --git a/decoder/test_data/urdu.json.gz b/decoder/test_data/urdu.json.gz Binary files differdeleted file mode 100644 index 84535402..00000000 --- a/decoder/test_data/urdu.json.gz +++ /dev/null diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 08dae64c..cdd83ffc 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -267,6 +267,7 @@ struct Tree2StringTranslatorImpl { for (auto sym : rule_src) cur = &cur->next[sym]; TRulePtr rule(new TRule(rhse, rhsf, lhs)); + rule->a_.push_back(AlignmentPoint(0, 0)); rule->ComputeArity(); rule->scores_.set_value(ntfid, 1.0); rule->scores_.set_value(kFID, 1.0); diff --git a/decoder/trule.h b/decoder/trule.h index adef7cc7..7af46747 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -11,6 +11,7 @@ #include "sparse_vector.h" #include "wordid.h" +#include "tdict.h" class TRule; typedef boost::shared_ptr<TRule> TRulePtr; @@ -26,6 +27,7 @@ struct AlignmentPoint { short s_; short t_; }; + inline std::ostream& operator<<(std::ostream& os, const AlignmentPoint& p) { return os << static_cast<int>(p.s_) << '-' << static_cast<int>(p.t_); } @@ -163,6 +165,66 @@ class TRule { // optional, shows internal structure of TSG rules boost::shared_ptr<cdec::TreeFragment> tree_structure; + friend class boost::serialization::access; + template<class Archive> + void save(Archive & ar, const unsigned int /*version*/) const { + ar & TD::Convert(-lhs_); + unsigned f_size = f_.size(); + ar & f_size; + assert(f_size <= (sizeof(size_t) * 8)); + size_t f_nt_mask = 0; + for (int i = f_.size() - 1; i >= 0; --i) { + f_nt_mask <<= 1; + f_nt_mask |= (f_[i] <= 0 ? 1 : 0); + } + ar & f_nt_mask; + for (unsigned i = 0; i < f_.size(); ++i) + ar & TD::Convert(f_[i] < 0 ? -f_[i] : f_[i]); + unsigned e_size = e_.size(); + ar & e_size; + size_t e_nt_mask = 0; + assert(e_size <= (sizeof(size_t) * 8)); + for (int i = e_.size() - 1; i >= 0; --i) { + e_nt_mask <<= 1; + e_nt_mask |= (e_[i] <= 0 ? 1 : 0); + } + ar & e_nt_mask; + for (unsigned i = 0; i < e_.size(); ++i) + if (e_[i] <= 0) ar & e_[i]; else ar & TD::Convert(e_[i]); + ar & arity_; + ar & scores_; + } + template<class Archive> + void load(Archive & ar, const unsigned int /*version*/) { + std::string lhs; ar & lhs; lhs_ = -TD::Convert(lhs); + unsigned f_size; ar & f_size; + f_.resize(f_size); + size_t f_nt_mask; ar & f_nt_mask; + std::string sym; + for (unsigned i = 0; i < f_size; ++i) { + bool mask = (f_nt_mask & 1); + ar & sym; + f_[i] = TD::Convert(sym) * (mask ? -1 : 1); + f_nt_mask >>= 1; + } + unsigned e_size; ar & e_size; + e_.resize(e_size); + size_t e_nt_mask; ar & e_nt_mask; + for (unsigned i = 0; i < e_size; ++i) { + bool mask = (e_nt_mask & 1); + if (mask) { + ar & e_[i]; + } else { + ar & sym; + e_[i] = TD::Convert(sym); + } + e_nt_mask >>= 1; + } + ar & arity_; + ar & scores_; + } + + BOOST_SERIALIZATION_SPLIT_MEMBER() private: TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {} }; diff --git a/decoder/trule_test.cc b/decoder/trule_test.cc index 0cb7e2e8..d75c2016 100644 --- a/decoder/trule_test.cc +++ b/decoder/trule_test.cc @@ -4,6 +4,10 @@ #include <boost/test/unit_test.hpp> #include <boost/test/floating_point_comparison.hpp> #include <iostream> +#include <boost/archive/text_oarchive.hpp> +#include <boost/archive/text_iarchive.hpp> +#include <boost/serialization/shared_ptr.hpp> +#include <sstream> #include "tdict.h" using namespace std; @@ -53,3 +57,35 @@ BOOST_AUTO_TEST_CASE(TestRuleR) { BOOST_CHECK_EQUAL(t6.e_[3], 0); } +BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) { + string str; + string t7str; + { + TRule t7; + t7.ReadFromString("[X] ||| den [X,1] sah [X,2] . ||| [2] saw the [1] . ||| Feature1=0.12321 Foo=0.23232 Bar=0.121"); + cerr << t7.AsString() << endl; + ostringstream os; + TRulePtr tp1(new TRule("[X] ||| a b c ||| x z y ||| A=1 B=2")); + TRulePtr tp2 = tp1; + boost::archive::text_oarchive oa(os); + oa << t7; + oa << tp1; + oa << tp2; + str = os.str(); + t7str = t7.AsString(); + } + { + istringstream is(str); + boost::archive::text_iarchive ia(is); + TRule t8; + ia >> t8; + TRulePtr tp3, tp4; + ia >> tp3; + ia >> tp4; + cerr << t8.AsString() << endl; + BOOST_CHECK_EQUAL(t7str, t8.AsString()); + cerr << tp3->AsString() << endl; + cerr << tp4->AsString() << endl; + } +} + |