summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorWu, Ke <wuke@cs.umd.edu>2014-12-17 16:15:13 -0500
committerWu, Ke <wuke@cs.umd.edu>2014-12-17 16:15:13 -0500
commit17dbb7d5ab1544899b1b9e867d2246a0a93e3aa8 (patch)
tree7fa2a51763a1b67fb325e86b0e3f764dd119cd70 /decoder
parent1983c75c35b7f5dc3f356a2f9a9345d632b87650 (diff)
parent1613f1fc44ca67820afd7e7b21eb54b316c8ce55 (diff)
Merge branch 'const_reorder_2' into softsyn_2
Diffstat (limited to 'decoder')
-rw-r--r--decoder/JSON_parser.c1012
-rw-r--r--decoder/JSON_parser.h152
-rw-r--r--decoder/Makefile.am12
-rw-r--r--decoder/aligner.cc6
-rw-r--r--decoder/aligner.h2
-rw-r--r--decoder/apply_models.h4
-rw-r--r--decoder/bottom_up_parser-rs.cc341
-rw-r--r--decoder/bottom_up_parser-rs.h29
-rw-r--r--decoder/bottom_up_parser.h4
-rw-r--r--decoder/csplit.cc1
-rw-r--r--decoder/csplit.h4
-rw-r--r--decoder/decoder.cc55
-rw-r--r--decoder/decoder.h4
-rw-r--r--decoder/earley_composer.h4
-rw-r--r--decoder/factored_lexicon_helper.h4
-rw-r--r--decoder/ff.h11
-rw-r--r--decoder/ff_basic.cc9
-rw-r--r--decoder/ff_basic.h7
-rw-r--r--decoder/ff_bleu.h4
-rw-r--r--decoder/ff_charset.h4
-rw-r--r--decoder/ff_conll.cc250
-rw-r--r--decoder/ff_conll.h45
-rw-r--r--decoder/ff_const_reorder.cc3
-rw-r--r--decoder/ff_const_reorder_common.h99
-rw-r--r--decoder/ff_context.h5
-rw-r--r--decoder/ff_csplit.h4
-rw-r--r--decoder/ff_external.h4
-rw-r--r--decoder/ff_factory.h4
-rw-r--r--decoder/ff_klm.h4
-rw-r--r--decoder/ff_lm.h4
-rw-r--r--decoder/ff_ngrams.h4
-rw-r--r--decoder/ff_parse_match.h4
-rw-r--r--decoder/ff_rules.h4
-rw-r--r--decoder/ff_ruleshape.h4
-rw-r--r--decoder/ff_soft_syntax.h4
-rw-r--r--decoder/ff_soft_syntax_mindist.h4
-rw-r--r--decoder/ff_source_path.h4
-rw-r--r--decoder/ff_source_syntax.h4
-rw-r--r--decoder/ff_source_syntax2.h4
-rw-r--r--decoder/ff_spans.h4
-rw-r--r--decoder/ff_tagger.h4
-rw-r--r--decoder/ff_wordalign.h4
-rw-r--r--decoder/ff_wordset.h4
-rw-r--r--decoder/ffset.h4
-rw-r--r--decoder/forest_writer.cc6
-rw-r--r--decoder/forest_writer.h6
-rw-r--r--decoder/freqdict.h4
-rw-r--r--decoder/fst_translator.cc11
-rw-r--r--decoder/hg.h56
-rw-r--r--decoder/hg_intersect.cc2
-rw-r--r--decoder/hg_intersect.h4
-rw-r--r--decoder/hg_io.cc357
-rw-r--r--decoder/hg_io.h16
-rw-r--r--decoder/hg_remove_eps.h4
-rw-r--r--decoder/hg_sampler.h5
-rw-r--r--decoder/hg_test.cc50
-rw-r--r--decoder/hg_test.h32
-rw-r--r--decoder/hg_union.h4
-rw-r--r--decoder/incremental.h4
-rw-r--r--decoder/inside_outside.h4
-rw-r--r--decoder/json_parse.cc50
-rw-r--r--decoder/json_parse.h58
-rw-r--r--decoder/kbest.h4
-rw-r--r--decoder/lattice.cc1
-rw-r--r--decoder/lattice.h20
-rw-r--r--decoder/lexalign.cc7
-rw-r--r--decoder/lexalign.h4
-rw-r--r--decoder/lextrans.cc7
-rw-r--r--decoder/lextrans.h4
-rw-r--r--decoder/node_state_hash.h4
-rw-r--r--decoder/oracle_bleu.h41
-rw-r--r--decoder/phrasebased_translator.cc1
-rw-r--r--decoder/phrasebased_translator.h4
-rw-r--r--decoder/phrasetable_fst.h4
-rw-r--r--decoder/rescore_translator.cc24
-rw-r--r--decoder/rule_lexer.h4
-rw-r--r--decoder/rule_lexer.ll1
-rw-r--r--decoder/scfg_translator.cc1
-rw-r--r--decoder/sentence_metadata.h27
-rw-r--r--decoder/tagger.cc2
-rw-r--r--decoder/tagger.h4
-rw-r--r--decoder/test_data/hg_test.hg.bin.gzbin0 -> 340 bytes
-rw-r--r--decoder/test_data/hg_test.hg_balanced.bin.gzbin0 -> 324 bytes
-rw-r--r--decoder/test_data/hg_test.hg_int.bin.gzbin0 -> 184 bytes
-rw-r--r--decoder/test_data/hg_test.lattice.bin.gzbin0 -> 334 bytes
-rw-r--r--decoder/test_data/hg_test.tiny.bin.gzbin0 -> 177 bytes
-rw-r--r--decoder/test_data/hg_test.tiny_lattice.bin.gzbin0 -> 203 bytes
-rw-r--r--decoder/test_data/perro.json.gzbin608 -> 0 bytes
-rw-r--r--decoder/test_data/small.bin.gzbin0 -> 2807 bytes
-rw-r--r--decoder/test_data/small.json.gzbin1733 -> 0 bytes
-rw-r--r--decoder/test_data/urdu.json.gzbin253497 -> 0 bytes
-rw-r--r--decoder/translator.h4
-rw-r--r--decoder/tree2string_translator.cc12
-rw-r--r--decoder/tree_fragment.cc7
-rw-r--r--decoder/tree_fragment.h2
-rw-r--r--decoder/trule.h66
-rw-r--r--decoder/trule_test.cc36
-rw-r--r--decoder/viterbi.h4
98 files changed, 1163 insertions, 1952 deletions
diff --git a/decoder/JSON_parser.c b/decoder/JSON_parser.c
deleted file mode 100644
index 5e392bc6..00000000
--- a/decoder/JSON_parser.c
+++ /dev/null
@@ -1,1012 +0,0 @@
-/* JSON_parser.c */
-
-/* 2007-08-24 */
-
-/*
-Copyright (c) 2005 JSON.org
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-The Software shall be used for Good, not Evil.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-*/
-
-/*
- Callbacks, comments, Unicode handling by Jean Gressmann (jean@0x42.de), 2007-2009.
-
- For the added features the license above applies also.
-
- Changelog:
- 2009-05-17
- Incorporated benrudiak@googlemail.com fix for UTF16 decoding.
-
- 2009-05-14
- Fixed float parsing bug related to a locale being set that didn't
- use '.' as decimal point character (charles@transmissionbt.com).
-
- 2008-10-14
- Renamed states.IN to states.IT to avoid name clash which IN macro
- defined in windef.h (alexey.pelykh@gmail.com)
-
- 2008-07-19
- Removed some duplicate code & debugging variable (charles@transmissionbt.com)
-
- 2008-05-28
- Made JSON_value structure ansi C compliant. This bug was report by
- trisk@acm.jhu.edu
-
- 2008-05-20
- Fixed bug reported by charles@transmissionbt.com where the switching
- from static to dynamic parse buffer did not copy the static parse
- buffer's content.
-*/
-
-
-
-#include <assert.h>
-#include <ctype.h>
-#include <float.h>
-#include <stddef.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <string.h>
-#include <locale.h>
-
-#include "JSON_parser.h"
-
-#ifdef _MSC_VER
-# if _MSC_VER >= 1400 /* Visual Studio 2005 and up */
-# pragma warning(disable:4996) // unsecure sscanf
-# endif
-#endif
-
-
-#define true 1
-#define false 0
-#define __ -1 /* the universal error code */
-
-/* values chosen so that the object size is approx equal to one page (4K) */
-#ifndef JSON_PARSER_STACK_SIZE
-# define JSON_PARSER_STACK_SIZE 128
-#endif
-
-#ifndef JSON_PARSER_PARSE_BUFFER_SIZE
-# define JSON_PARSER_PARSE_BUFFER_SIZE 3500
-#endif
-
-typedef unsigned short UTF16;
-
-struct JSON_parser_struct {
- JSON_parser_callback callback;
- void* ctx;
- signed char state, before_comment_state, type, escaped, comment, allow_comments, handle_floats_manually;
- UTF16 utf16_high_surrogate;
- long depth;
- long top;
- signed char* stack;
- long stack_capacity;
- char decimal_point;
- char* parse_buffer;
- size_t parse_buffer_capacity;
- size_t parse_buffer_count;
- size_t comment_begin_offset;
- signed char static_stack[JSON_PARSER_STACK_SIZE];
- char static_parse_buffer[JSON_PARSER_PARSE_BUFFER_SIZE];
-};
-
-#define COUNTOF(x) (sizeof(x)/sizeof(x[0]))
-
-/*
- Characters are mapped into these character classes. This allows for
- a significant reduction in the size of the state transition table.
-*/
-
-
-
-enum classes {
- C_SPACE, /* space */
- C_WHITE, /* other whitespace */
- C_LCURB, /* { */
- C_RCURB, /* } */
- C_LSQRB, /* [ */
- C_RSQRB, /* ] */
- C_COLON, /* : */
- C_COMMA, /* , */
- C_QUOTE, /* " */
- C_BACKS, /* \ */
- C_SLASH, /* / */
- C_PLUS, /* + */
- C_MINUS, /* - */
- C_POINT, /* . */
- C_ZERO , /* 0 */
- C_DIGIT, /* 123456789 */
- C_LOW_A, /* a */
- C_LOW_B, /* b */
- C_LOW_C, /* c */
- C_LOW_D, /* d */
- C_LOW_E, /* e */
- C_LOW_F, /* f */
- C_LOW_L, /* l */
- C_LOW_N, /* n */
- C_LOW_R, /* r */
- C_LOW_S, /* s */
- C_LOW_T, /* t */
- C_LOW_U, /* u */
- C_ABCDF, /* ABCDF */
- C_E, /* E */
- C_ETC, /* everything else */
- C_STAR, /* * */
- NR_CLASSES
-};
-
-static int ascii_class[128] = {
-/*
- This array maps the 128 ASCII characters into character classes.
- The remaining Unicode characters should be mapped to C_ETC.
- Non-whitespace control characters are errors.
-*/
- __, __, __, __, __, __, __, __,
- __, C_WHITE, C_WHITE, __, __, C_WHITE, __, __,
- __, __, __, __, __, __, __, __,
- __, __, __, __, __, __, __, __,
-
- C_SPACE, C_ETC, C_QUOTE, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC,
- C_ETC, C_ETC, C_STAR, C_PLUS, C_COMMA, C_MINUS, C_POINT, C_SLASH,
- C_ZERO, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT,
- C_DIGIT, C_DIGIT, C_COLON, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC,
-
- C_ETC, C_ABCDF, C_ABCDF, C_ABCDF, C_ABCDF, C_E, C_ABCDF, C_ETC,
- C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC,
- C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC,
- C_ETC, C_ETC, C_ETC, C_LSQRB, C_BACKS, C_RSQRB, C_ETC, C_ETC,
-
- C_ETC, C_LOW_A, C_LOW_B, C_LOW_C, C_LOW_D, C_LOW_E, C_LOW_F, C_ETC,
- C_ETC, C_ETC, C_ETC, C_ETC, C_LOW_L, C_ETC, C_LOW_N, C_ETC,
- C_ETC, C_ETC, C_LOW_R, C_LOW_S, C_LOW_T, C_LOW_U, C_ETC, C_ETC,
- C_ETC, C_ETC, C_ETC, C_LCURB, C_ETC, C_RCURB, C_ETC, C_ETC
-};
-
-
-/*
- The state codes.
-*/
-enum states {
- GO, /* start */
- OK, /* ok */
- OB, /* object */
- KE, /* key */
- CO, /* colon */
- VA, /* value */
- AR, /* array */
- ST, /* string */
- ES, /* escape */
- U1, /* u1 */
- U2, /* u2 */
- U3, /* u3 */
- U4, /* u4 */
- MI, /* minus */
- ZE, /* zero */
- IT, /* integer */
- FR, /* fraction */
- E1, /* e */
- E2, /* ex */
- E3, /* exp */
- T1, /* tr */
- T2, /* tru */
- T3, /* true */
- F1, /* fa */
- F2, /* fal */
- F3, /* fals */
- F4, /* false */
- N1, /* nu */
- N2, /* nul */
- N3, /* null */
- C1, /* / */
- C2, /* / * */
- C3, /* * */
- FX, /* *.* *eE* */
- D1, /* second UTF-16 character decoding started by \ */
- D2, /* second UTF-16 character proceeded by u */
- NR_STATES
-};
-
-enum actions
-{
- CB = -10, /* comment begin */
- CE = -11, /* comment end */
- FA = -12, /* false */
- TR = -13, /* false */
- NU = -14, /* null */
- DE = -15, /* double detected by exponent e E */
- DF = -16, /* double detected by fraction . */
- SB = -17, /* string begin */
- MX = -18, /* integer detected by minus */
- ZX = -19, /* integer detected by zero */
- IX = -20, /* integer detected by 1-9 */
- EX = -21, /* next char is escaped */
- UC = -22 /* Unicode character read */
-};
-
-
-static int state_transition_table[NR_STATES][NR_CLASSES] = {
-/*
- The state transition table takes the current state and the current symbol,
- and returns either a new state or an action. An action is represented as a
- negative number. A JSON text is accepted if at the end of the text the
- state is OK and if the mode is MODE_DONE.
-
- white 1-9 ABCDF etc
- space | { } [ ] : , " \ / + - . 0 | a b c d e f l n r s t u | E | * */
-/*start GO*/ {GO,GO,-6,__,-5,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*ok OK*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*object OB*/ {OB,OB,__,-9,__,__,__,__,SB,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*key KE*/ {KE,KE,__,__,__,__,__,__,SB,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*colon CO*/ {CO,CO,__,__,__,__,-2,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*value VA*/ {VA,VA,-6,__,-5,__,__,__,SB,__,CB,__,MX,__,ZX,IX,__,__,__,__,__,FA,__,NU,__,__,TR,__,__,__,__,__},
-/*array AR*/ {AR,AR,-6,__,-5,-7,__,__,SB,__,CB,__,MX,__,ZX,IX,__,__,__,__,__,FA,__,NU,__,__,TR,__,__,__,__,__},
-/*string ST*/ {ST,__,ST,ST,ST,ST,ST,ST,-4,EX,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST},
-/*escape ES*/ {__,__,__,__,__,__,__,__,ST,ST,ST,__,__,__,__,__,__,ST,__,__,__,ST,__,ST,ST,__,ST,U1,__,__,__,__},
-/*u1 U1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U2,U2,U2,U2,U2,U2,U2,U2,__,__,__,__,__,__,U2,U2,__,__},
-/*u2 U2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U3,U3,U3,U3,U3,U3,U3,U3,__,__,__,__,__,__,U3,U3,__,__},
-/*u3 U3*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U4,U4,U4,U4,U4,U4,U4,U4,__,__,__,__,__,__,U4,U4,__,__},
-/*u4 U4*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,UC,UC,UC,UC,UC,UC,UC,UC,__,__,__,__,__,__,UC,UC,__,__},
-/*minus MI*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,ZE,IT,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*zero ZE*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,DF,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*int IT*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,DF,IT,IT,__,__,__,__,DE,__,__,__,__,__,__,__,__,DE,__,__},
-/*frac FR*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,__,FR,FR,__,__,__,__,E1,__,__,__,__,__,__,__,__,E1,__,__},
-/*e E1*/ {__,__,__,__,__,__,__,__,__,__,__,E2,E2,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*ex E2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*exp E3*/ {OK,OK,__,-8,__,-7,__,-3,__,__,__,__,__,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*tr T1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,T2,__,__,__,__,__,__,__},
-/*tru T2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,T3,__,__,__,__},
-/*true T3*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__,__,__},
-/*fa F1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F2,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*fal F2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F3,__,__,__,__,__,__,__,__,__},
-/*fals F3*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F4,__,__,__,__,__,__},
-/*false F4*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__,__,__},
-/*nu N1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,N2,__,__,__,__},
-/*nul N2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,N3,__,__,__,__,__,__,__,__,__},
-/*null N3*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__},
-/*/ C1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,C2},
-/*/* C2*/ {C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C3},
-/** C3*/ {C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,CE,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C3},
-/*_. FX*/ {OK,OK,__,-8,__,-7,__,-3,__,__,__,__,__,__,FR,FR,__,__,__,__,E1,__,__,__,__,__,__,__,__,E1,__,__},
-/*\ D1*/ {__,__,__,__,__,__,__,__,__,D2,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*\ D2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,U1,__,__,__,__},
-};
-
-
-/*
- These modes can be pushed on the stack.
-*/
-enum modes {
- MODE_ARRAY = 1,
- MODE_DONE = 2,
- MODE_KEY = 3,
- MODE_OBJECT = 4
-};
-
-static int
-push(JSON_parser jc, int mode)
-{
-/*
- Push a mode onto the stack. Return false if there is overflow.
-*/
- jc->top += 1;
- if (jc->depth < 0) {
- if (jc->top >= jc->stack_capacity) {
- size_t bytes_to_allocate;
- jc->stack_capacity *= 2;
- bytes_to_allocate = jc->stack_capacity * sizeof(jc->static_stack[0]);
- if (jc->stack == &jc->static_stack[0]) {
- jc->stack = (signed char*)malloc(bytes_to_allocate);
- memcpy(jc->stack, jc->static_stack, sizeof(jc->static_stack));
- } else {
- jc->stack = (signed char*)realloc(jc->stack, bytes_to_allocate);
- }
- }
- } else {
- if (jc->top >= jc->depth) {
- return false;
- }
- }
-
- jc->stack[jc->top] = mode;
- return true;
-}
-
-
-static int
-pop(JSON_parser jc, int mode)
-{
-/*
- Pop the stack, assuring that the current mode matches the expectation.
- Return false if there is underflow or if the modes mismatch.
-*/
- if (jc->top < 0 || jc->stack[jc->top] != mode) {
- return false;
- }
- jc->top -= 1;
- return true;
-}
-
-
-#define parse_buffer_clear(jc) \
- do {\
- jc->parse_buffer_count = 0;\
- jc->parse_buffer[0] = 0;\
- } while (0)
-
-#define parse_buffer_pop_back_char(jc)\
- do {\
- assert(jc->parse_buffer_count >= 1);\
- --jc->parse_buffer_count;\
- jc->parse_buffer[jc->parse_buffer_count] = 0;\
- } while (0)
-
-void delete_JSON_parser(JSON_parser jc)
-{
- if (jc) {
- if (jc->stack != &jc->static_stack[0]) {
- free((void*)jc->stack);
- }
- if (jc->parse_buffer != &jc->static_parse_buffer[0]) {
- free((void*)jc->parse_buffer);
- }
- free((void*)jc);
- }
-}
-
-
-JSON_parser
-new_JSON_parser(JSON_config* config)
-{
-/*
- new_JSON_parser starts the checking process by constructing a JSON_parser
- object. It takes a depth parameter that restricts the level of maximum
- nesting.
-
- To continue the process, call JSON_parser_char for each character in the
- JSON text, and then call JSON_parser_done to obtain the final result.
- These functions are fully reentrant.
-*/
-
- int depth = 0;
- JSON_config default_config;
-
- JSON_parser jc = (JSON_parser)malloc(sizeof(struct JSON_parser_struct));
-
- memset(jc, 0, sizeof(*jc));
-
-
- /* initialize configuration */
- init_JSON_config(&default_config);
-
- /* set to default configuration if none was provided */
- if (config == NULL) {
- config = &default_config;
- }
-
- depth = config->depth;
-
- /* We need to be able to push at least one object */
- if (depth == 0) {
- depth = 1;
- }
-
- jc->state = GO;
- jc->top = -1;
-
- /* Do we want non-bound stack? */
- if (depth > 0) {
- jc->stack_capacity = depth;
- jc->depth = depth;
- if (depth <= (int)COUNTOF(jc->static_stack)) {
- jc->stack = &jc->static_stack[0];
- } else {
- jc->stack = (signed char*)malloc(jc->stack_capacity * sizeof(jc->static_stack[0]));
- }
- } else {
- jc->stack_capacity = COUNTOF(jc->static_stack);
- jc->depth = -1;
- jc->stack = &jc->static_stack[0];
- }
-
- /* set parser to start */
- push(jc, MODE_DONE);
-
- /* set up the parse buffer */
- jc->parse_buffer = &jc->static_parse_buffer[0];
- jc->parse_buffer_capacity = COUNTOF(jc->static_parse_buffer);
- parse_buffer_clear(jc);
-
- /* set up callback, comment & float handling */
- jc->callback = config->callback;
- jc->ctx = config->callback_ctx;
- jc->allow_comments = config->allow_comments != 0;
- jc->handle_floats_manually = config->handle_floats_manually != 0;
-
- /* set up decimal point */
- jc->decimal_point = *localeconv()->decimal_point;
-
- return jc;
-}
-
-static void grow_parse_buffer(JSON_parser jc)
-{
- size_t bytes_to_allocate;
- jc->parse_buffer_capacity *= 2;
- bytes_to_allocate = jc->parse_buffer_capacity * sizeof(jc->parse_buffer[0]);
- if (jc->parse_buffer == &jc->static_parse_buffer[0]) {
- jc->parse_buffer = (char*)malloc(bytes_to_allocate);
- memcpy(jc->parse_buffer, jc->static_parse_buffer, jc->parse_buffer_count);
- } else {
- jc->parse_buffer = (char*)realloc(jc->parse_buffer, bytes_to_allocate);
- }
-}
-
-#define parse_buffer_push_back_char(jc, c)\
- do {\
- if (jc->parse_buffer_count + 1 >= jc->parse_buffer_capacity) grow_parse_buffer(jc);\
- jc->parse_buffer[jc->parse_buffer_count++] = c;\
- jc->parse_buffer[jc->parse_buffer_count] = 0;\
- } while (0)
-
-#define assert_is_non_container_type(jc) \
- assert( \
- jc->type == JSON_T_NULL || \
- jc->type == JSON_T_FALSE || \
- jc->type == JSON_T_TRUE || \
- jc->type == JSON_T_FLOAT || \
- jc->type == JSON_T_INTEGER || \
- jc->type == JSON_T_STRING)
-
-
-static int parse_parse_buffer(JSON_parser jc)
-{
- if (jc->callback) {
- JSON_value value, *arg = NULL;
-
- if (jc->type != JSON_T_NONE) {
- assert_is_non_container_type(jc);
-
- switch(jc->type) {
- case JSON_T_FLOAT:
- arg = &value;
- if (jc->handle_floats_manually) {
- value.vu.str.value = jc->parse_buffer;
- value.vu.str.length = jc->parse_buffer_count;
- } else {
- /*sscanf(jc->parse_buffer, "%Lf", &value.vu.float_value);*/
-
- /* not checking with end pointer b/c there may be trailing ws */
- value.vu.float_value = strtod(jc->parse_buffer, NULL);
- }
- break;
- case JSON_T_INTEGER:
- arg = &value;
- sscanf(jc->parse_buffer, JSON_PARSER_INTEGER_SSCANF_TOKEN, &value.vu.integer_value);
- break;
- case JSON_T_STRING:
- arg = &value;
- value.vu.str.value = jc->parse_buffer;
- value.vu.str.length = jc->parse_buffer_count;
- break;
- }
-
- if (!(*jc->callback)(jc->ctx, jc->type, arg)) {
- return false;
- }
- }
- }
-
- parse_buffer_clear(jc);
-
- return true;
-}
-
-#define IS_HIGH_SURROGATE(uc) (((uc) & 0xFC00) == 0xD800)
-#define IS_LOW_SURROGATE(uc) (((uc) & 0xFC00) == 0xDC00)
-#define DECODE_SURROGATE_PAIR(hi,lo) ((((hi) & 0x3FF) << 10) + ((lo) & 0x3FF) + 0x10000)
-static unsigned char utf8_lead_bits[4] = { 0x00, 0xC0, 0xE0, 0xF0 };
-
-static int decode_unicode_char(JSON_parser jc)
-{
- int i;
- unsigned uc = 0;
- char* p;
- int trail_bytes;
-
- assert(jc->parse_buffer_count >= 6);
-
- p = &jc->parse_buffer[jc->parse_buffer_count - 4];
-
- for (i = 12; i >= 0; i -= 4, ++p) {
- unsigned x = *p;
-
- if (x >= 'a') {
- x -= ('a' - 10);
- } else if (x >= 'A') {
- x -= ('A' - 10);
- } else {
- x &= ~0x30u;
- }
-
- assert(x < 16);
-
- uc |= x << i;
- }
-
- /* clear UTF-16 char from buffer */
- jc->parse_buffer_count -= 6;
- jc->parse_buffer[jc->parse_buffer_count] = 0;
-
- /* attempt decoding ... */
- if (jc->utf16_high_surrogate) {
- if (IS_LOW_SURROGATE(uc)) {
- uc = DECODE_SURROGATE_PAIR(jc->utf16_high_surrogate, uc);
- trail_bytes = 3;
- jc->utf16_high_surrogate = 0;
- } else {
- /* high surrogate without a following low surrogate */
- return false;
- }
- } else {
- if (uc < 0x80) {
- trail_bytes = 0;
- } else if (uc < 0x800) {
- trail_bytes = 1;
- } else if (IS_HIGH_SURROGATE(uc)) {
- /* save the high surrogate and wait for the low surrogate */
- jc->utf16_high_surrogate = uc;
- return true;
- } else if (IS_LOW_SURROGATE(uc)) {
- /* low surrogate without a preceding high surrogate */
- return false;
- } else {
- trail_bytes = 2;
- }
- }
-
- jc->parse_buffer[jc->parse_buffer_count++] = (char) ((uc >> (trail_bytes * 6)) | utf8_lead_bits[trail_bytes]);
-
- for (i = trail_bytes * 6 - 6; i >= 0; i -= 6) {
- jc->parse_buffer[jc->parse_buffer_count++] = (char) (((uc >> i) & 0x3F) | 0x80);
- }
-
- jc->parse_buffer[jc->parse_buffer_count] = 0;
-
- return true;
-}
-
-static int add_escaped_char_to_parse_buffer(JSON_parser jc, int next_char)
-{
- jc->escaped = 0;
- /* remove the backslash */
- parse_buffer_pop_back_char(jc);
- switch(next_char) {
- case 'b':
- parse_buffer_push_back_char(jc, '\b');
- break;
- case 'f':
- parse_buffer_push_back_char(jc, '\f');
- break;
- case 'n':
- parse_buffer_push_back_char(jc, '\n');
- break;
- case 'r':
- parse_buffer_push_back_char(jc, '\r');
- break;
- case 't':
- parse_buffer_push_back_char(jc, '\t');
- break;
- case '"':
- parse_buffer_push_back_char(jc, '"');
- break;
- case '\\':
- parse_buffer_push_back_char(jc, '\\');
- break;
- case '/':
- parse_buffer_push_back_char(jc, '/');
- break;
- case 'u':
- parse_buffer_push_back_char(jc, '\\');
- parse_buffer_push_back_char(jc, 'u');
- break;
- default:
- return false;
- }
-
- return true;
-}
-
-#define add_char_to_parse_buffer(jc, next_char, next_class) \
- do { \
- if (jc->escaped) { \
- if (!add_escaped_char_to_parse_buffer(jc, next_char)) \
- return false; \
- } else if (!jc->comment) { \
- if ((jc->type != JSON_T_NONE) | !((next_class == C_SPACE) | (next_class == C_WHITE)) /* non-white-space */) { \
- parse_buffer_push_back_char(jc, (char)next_char); \
- } \
- } \
- } while (0)
-
-
-#define assert_type_isnt_string_null_or_bool(jc) \
- assert(jc->type != JSON_T_FALSE); \
- assert(jc->type != JSON_T_TRUE); \
- assert(jc->type != JSON_T_NULL); \
- assert(jc->type != JSON_T_STRING)
-
-
-int
-JSON_parser_char(JSON_parser jc, int next_char)
-{
-/*
- After calling new_JSON_parser, call this function for each character (or
- partial character) in your JSON text. It can accept UTF-8, UTF-16, or
- UTF-32. It returns true if things are looking ok so far. If it rejects the
- text, it returns false.
-*/
- int next_class, next_state;
-
-/*
- Determine the character's class.
-*/
- if (next_char < 0) {
- return false;
- }
- if (next_char >= 128) {
- next_class = C_ETC;
- } else {
- next_class = ascii_class[next_char];
- if (next_class <= __) {
- return false;
- }
- }
-
- add_char_to_parse_buffer(jc, next_char, next_class);
-
-/*
- Get the next state from the state transition table.
-*/
- next_state = state_transition_table[jc->state][next_class];
- if (next_state >= 0) {
-/*
- Change the state.
-*/
- jc->state = next_state;
- } else {
-/*
- Or perform one of the actions.
-*/
- switch (next_state) {
-/* Unicode character */
- case UC:
- if(!decode_unicode_char(jc)) {
- return false;
- }
- /* check if we need to read a second UTF-16 char */
- if (jc->utf16_high_surrogate) {
- jc->state = D1;
- } else {
- jc->state = ST;
- }
- break;
-/* escaped char */
- case EX:
- jc->escaped = 1;
- jc->state = ES;
- break;
-/* integer detected by minus */
- case MX:
- jc->type = JSON_T_INTEGER;
- jc->state = MI;
- break;
-/* integer detected by zero */
- case ZX:
- jc->type = JSON_T_INTEGER;
- jc->state = ZE;
- break;
-/* integer detected by 1-9 */
- case IX:
- jc->type = JSON_T_INTEGER;
- jc->state = IT;
- break;
-
-/* floating point number detected by exponent*/
- case DE:
- assert_type_isnt_string_null_or_bool(jc);
- jc->type = JSON_T_FLOAT;
- jc->state = E1;
- break;
-
-/* floating point number detected by fraction */
- case DF:
- assert_type_isnt_string_null_or_bool(jc);
- if (!jc->handle_floats_manually) {
-/*
- Some versions of strtod (which underlies sscanf) don't support converting
- C-locale formated floating point values.
-*/
- assert(jc->parse_buffer[jc->parse_buffer_count-1] == '.');
- jc->parse_buffer[jc->parse_buffer_count-1] = jc->decimal_point;
- }
- jc->type = JSON_T_FLOAT;
- jc->state = FX;
- break;
-/* string begin " */
- case SB:
- parse_buffer_clear(jc);
- assert(jc->type == JSON_T_NONE);
- jc->type = JSON_T_STRING;
- jc->state = ST;
- break;
-
-/* n */
- case NU:
- assert(jc->type == JSON_T_NONE);
- jc->type = JSON_T_NULL;
- jc->state = N1;
- break;
-/* f */
- case FA:
- assert(jc->type == JSON_T_NONE);
- jc->type = JSON_T_FALSE;
- jc->state = F1;
- break;
-/* t */
- case TR:
- assert(jc->type == JSON_T_NONE);
- jc->type = JSON_T_TRUE;
- jc->state = T1;
- break;
-
-/* closing comment */
- case CE:
- jc->comment = 0;
- assert(jc->parse_buffer_count == 0);
- assert(jc->type == JSON_T_NONE);
- jc->state = jc->before_comment_state;
- break;
-
-/* opening comment */
- case CB:
- if (!jc->allow_comments) {
- return false;
- }
- parse_buffer_pop_back_char(jc);
- if (!parse_parse_buffer(jc)) {
- return false;
- }
- assert(jc->parse_buffer_count == 0);
- assert(jc->type != JSON_T_STRING);
- switch (jc->stack[jc->top]) {
- case MODE_ARRAY:
- case MODE_OBJECT:
- switch(jc->state) {
- case VA:
- case AR:
- jc->before_comment_state = jc->state;
- break;
- default:
- jc->before_comment_state = OK;
- break;
- }
- break;
- default:
- jc->before_comment_state = jc->state;
- break;
- }
- jc->type = JSON_T_NONE;
- jc->state = C1;
- jc->comment = 1;
- break;
-/* empty } */
- case -9:
- parse_buffer_clear(jc);
- if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_END, NULL)) {
- return false;
- }
- if (!pop(jc, MODE_KEY)) {
- return false;
- }
- jc->state = OK;
- break;
-
-/* } */ case -8:
- parse_buffer_pop_back_char(jc);
- if (!parse_parse_buffer(jc)) {
- return false;
- }
- if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_END, NULL)) {
- return false;
- }
- if (!pop(jc, MODE_OBJECT)) {
- return false;
- }
- jc->type = JSON_T_NONE;
- jc->state = OK;
- break;
-
-/* ] */ case -7:
- parse_buffer_pop_back_char(jc);
- if (!parse_parse_buffer(jc)) {
- return false;
- }
- if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_ARRAY_END, NULL)) {
- return false;
- }
- if (!pop(jc, MODE_ARRAY)) {
- return false;
- }
-
- jc->type = JSON_T_NONE;
- jc->state = OK;
- break;
-
-/* { */ case -6:
- parse_buffer_pop_back_char(jc);
- if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_BEGIN, NULL)) {
- return false;
- }
- if (!push(jc, MODE_KEY)) {
- return false;
- }
- assert(jc->type == JSON_T_NONE);
- jc->state = OB;
- break;
-
-/* [ */ case -5:
- parse_buffer_pop_back_char(jc);
- if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_ARRAY_BEGIN, NULL)) {
- return false;
- }
- if (!push(jc, MODE_ARRAY)) {
- return false;
- }
- assert(jc->type == JSON_T_NONE);
- jc->state = AR;
- break;
-
-/* string end " */ case -4:
- parse_buffer_pop_back_char(jc);
- switch (jc->stack[jc->top]) {
- case MODE_KEY:
- assert(jc->type == JSON_T_STRING);
- jc->type = JSON_T_NONE;
- jc->state = CO;
-
- if (jc->callback) {
- JSON_value value;
- value.vu.str.value = jc->parse_buffer;
- value.vu.str.length = jc->parse_buffer_count;
- if (!(*jc->callback)(jc->ctx, JSON_T_KEY, &value)) {
- return false;
- }
- }
- parse_buffer_clear(jc);
- break;
- case MODE_ARRAY:
- case MODE_OBJECT:
- assert(jc->type == JSON_T_STRING);
- if (!parse_parse_buffer(jc)) {
- return false;
- }
- jc->type = JSON_T_NONE;
- jc->state = OK;
- break;
- default:
- return false;
- }
- break;
-
-/* , */ case -3:
- parse_buffer_pop_back_char(jc);
- if (!parse_parse_buffer(jc)) {
- return false;
- }
- switch (jc->stack[jc->top]) {
- case MODE_OBJECT:
-/*
- A comma causes a flip from object mode to key mode.
-*/
- if (!pop(jc, MODE_OBJECT) || !push(jc, MODE_KEY)) {
- return false;
- }
- assert(jc->type != JSON_T_STRING);
- jc->type = JSON_T_NONE;
- jc->state = KE;
- break;
- case MODE_ARRAY:
- assert(jc->type != JSON_T_STRING);
- jc->type = JSON_T_NONE;
- jc->state = VA;
- break;
- default:
- return false;
- }
- break;
-
-/* : */ case -2:
-/*
- A colon causes a flip from key mode to object mode.
-*/
- parse_buffer_pop_back_char(jc);
- if (!pop(jc, MODE_KEY) || !push(jc, MODE_OBJECT)) {
- return false;
- }
- assert(jc->type == JSON_T_NONE);
- jc->state = VA;
- break;
-/*
- Bad action.
-*/
- default:
- return false;
- }
- }
- return true;
-}
-
-
-int
-JSON_parser_done(JSON_parser jc)
-{
- const int result = jc->state == OK && pop(jc, MODE_DONE);
-
- return result;
-}
-
-
-int JSON_parser_is_legal_white_space_string(const char* s)
-{
- int c, char_class;
-
- if (s == NULL) {
- return false;
- }
-
- for (; *s; ++s) {
- c = *s;
-
- if (c < 0 || c >= 128) {
- return false;
- }
-
- char_class = ascii_class[c];
-
- if (char_class != C_SPACE && char_class != C_WHITE) {
- return false;
- }
- }
-
- return true;
-}
-
-
-
-void init_JSON_config(JSON_config* config)
-{
- if (config) {
- memset(config, 0, sizeof(*config));
-
- config->depth = JSON_PARSER_STACK_SIZE - 1;
- }
-}
diff --git a/decoder/JSON_parser.h b/decoder/JSON_parser.h
deleted file mode 100644
index de980072..00000000
--- a/decoder/JSON_parser.h
+++ /dev/null
@@ -1,152 +0,0 @@
-#ifndef JSON_PARSER_H
-#define JSON_PARSER_H
-
-/* JSON_parser.h */
-
-
-#include <stddef.h>
-
-/* Windows DLL stuff */
-#ifdef _WIN32
-# ifdef JSON_PARSER_DLL_EXPORTS
-# define JSON_PARSER_DLL_API __declspec(dllexport)
-# else
-# define JSON_PARSER_DLL_API __declspec(dllimport)
-# endif
-#else
-# define JSON_PARSER_DLL_API
-#endif
-
-/* Determine the integer type use to parse non-floating point numbers */
-#if __STDC_VERSION__ >= 199901L || HAVE_LONG_LONG == 1
-typedef long long JSON_int_t;
-#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%lld"
-#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%lld"
-#else
-typedef long JSON_int_t;
-#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%ld"
-#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%ld"
-#endif
-
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-typedef enum
-{
- JSON_T_NONE = 0,
- JSON_T_ARRAY_BEGIN, // 1
- JSON_T_ARRAY_END, // 2
- JSON_T_OBJECT_BEGIN, // 3
- JSON_T_OBJECT_END, // 4
- JSON_T_INTEGER, // 5
- JSON_T_FLOAT, // 6
- JSON_T_NULL, // 7
- JSON_T_TRUE, // 8
- JSON_T_FALSE, // 9
- JSON_T_STRING, // 10
- JSON_T_KEY, // 11
- JSON_T_MAX // 12
-} JSON_type;
-
-typedef struct JSON_value_struct {
- union {
- JSON_int_t integer_value;
-
- double float_value;
-
- struct {
- const char* value;
- size_t length;
- } str;
- } vu;
-} JSON_value;
-
-typedef struct JSON_parser_struct* JSON_parser;
-
-/*! \brief JSON parser callback
-
- \param ctx The pointer passed to new_JSON_parser.
- \param type An element of JSON_type but not JSON_T_NONE.
- \param value A representation of the parsed value. This parameter is NULL for
- JSON_T_ARRAY_BEGIN, JSON_T_ARRAY_END, JSON_T_OBJECT_BEGIN, JSON_T_OBJECT_END,
- JSON_T_NULL, JSON_T_TRUE, and SON_T_FALSE. String values are always returned
- as zero-terminated C strings.
-
- \return Non-zero if parsing should continue, else zero.
-*/
-typedef int (*JSON_parser_callback)(void* ctx, int type, const struct JSON_value_struct* value);
-
-
-/*! \brief The structure used to configure a JSON parser object
-
- \param depth If negative, the parser can parse arbitrary levels of JSON, otherwise
- the depth is the limit
- \param Pointer to a callback. This parameter may be NULL. In this case the input is merely checked for validity.
- \param Callback context. This parameter may be NULL.
- \param depth. Specifies the levels of nested JSON to allow. Negative numbers yield unlimited nesting.
- \param allowComments. To allow C style comments in JSON, set to non-zero.
- \param handleFloatsManually. To decode floating point numbers manually set this parameter to non-zero.
-
- \return The parser object.
-*/
-typedef struct {
- JSON_parser_callback callback;
- void* callback_ctx;
- int depth;
- int allow_comments;
- int handle_floats_manually;
-} JSON_config;
-
-
-/*! \brief Initializes the JSON parser configuration structure to default values.
-
- The default configuration is
- - 127 levels of nested JSON (depends on JSON_PARSER_STACK_SIZE, see json_parser.c)
- - no parsing, just checking for JSON syntax
- - no comments
-
- \param config. Used to configure the parser.
-*/
-JSON_PARSER_DLL_API void init_JSON_config(JSON_config* config);
-
-/*! \brief Create a JSON parser object
-
- \param config. Used to configure the parser. Set to NULL to use the default configuration.
- See init_JSON_config
-
- \return The parser object.
-*/
-JSON_PARSER_DLL_API extern JSON_parser new_JSON_parser(JSON_config* config);
-
-/*! \brief Destroy a previously created JSON parser object. */
-JSON_PARSER_DLL_API extern void delete_JSON_parser(JSON_parser jc);
-
-/*! \brief Parse a character.
-
- \return Non-zero, if all characters passed to this function are part of are valid JSON.
-*/
-JSON_PARSER_DLL_API extern int JSON_parser_char(JSON_parser jc, int next_char);
-
-/*! \brief Finalize parsing.
-
- Call this method once after all input characters have been consumed.
-
- \return Non-zero, if all parsed characters are valid JSON, zero otherwise.
-*/
-JSON_PARSER_DLL_API extern int JSON_parser_done(JSON_parser jc);
-
-/*! \brief Determine if a given string is valid JSON white space
-
- \return Non-zero if the string is valid, zero otherwise.
-*/
-JSON_PARSER_DLL_API extern int JSON_parser_is_legal_white_space_string(const char* s);
-
-
-#ifdef __cplusplus
-}
-#endif
-
-
-#endif /* JSON_PARSER_H */
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 727e5af5..dbec532e 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -33,10 +33,10 @@ noinst_LIBRARIES = libcdec.a
EXTRA_DIST = test_data rule_lexer.ll
libcdec_a_SOURCES = \
- JSON_parser.h \
aligner.h \
apply_models.h \
bottom_up_parser.h \
+ bottom_up_parser-rs.h \
csplit.h \
decoder.h \
earley_composer.h \
@@ -45,6 +45,7 @@ libcdec_a_SOURCES = \
ff_basic.h \
ff_bleu.h \
ff_charset.h \
+ ff_conll.h \
ff_const_reorder_common.h \
ff_const_reorder.h \
ff_context.h \
@@ -52,7 +53,7 @@ libcdec_a_SOURCES = \
ff_external.h \
ff_factory.h \
ff_klm.h \
- ff_lexical.h \
+ ff_lexical.h \
ff_lm.h \
ff_ngrams.h \
ff_parse_match.h \
@@ -83,7 +84,6 @@ libcdec_a_SOURCES = \
hg_union.h \
incremental.h \
inside_outside.h \
- json_parse.h \
kbest.h \
lattice.h \
lexalign.h \
@@ -103,6 +103,7 @@ libcdec_a_SOURCES = \
aligner.cc \
apply_models.cc \
bottom_up_parser.cc \
+ bottom_up_parser-rs.cc \
cdec.cc \
cdec_ff.cc \
csplit.cc \
@@ -113,6 +114,7 @@ libcdec_a_SOURCES = \
ff_basic.cc \
ff_bleu.cc \
ff_charset.cc \
+ ff_conll.cc \
ff_context.cc \
ff_const_reorder.cc \
ff_csplit.cc \
@@ -146,7 +148,6 @@ libcdec_a_SOURCES = \
hg_sampler.cc \
hg_union.cc \
incremental.cc \
- json_parse.cc \
lattice.cc \
lexalign.cc \
lextrans.cc \
@@ -162,5 +163,4 @@ libcdec_a_SOURCES = \
tagger.cc \
translator.cc \
trule.cc \
- viterbi.cc \
- JSON_parser.c
+ viterbi.cc
diff --git a/decoder/aligner.cc b/decoder/aligner.cc
index 232e022a..fd648370 100644
--- a/decoder/aligner.cc
+++ b/decoder/aligner.cc
@@ -198,13 +198,13 @@ void AlignerTools::WriteAlignment(const Lattice& src_lattice,
}
const Hypergraph* g = &in_g;
HypergraphP new_hg;
- if (!src_lattice.IsSentence() ||
- !trg_lattice.IsSentence()) {
+ if (!IsSentence(src_lattice) ||
+ !IsSentence(trg_lattice)) {
if (map_instead_of_viterbi) {
cerr << " Lattice alignment: using Viterbi instead of MAP alignment\n";
}
map_instead_of_viterbi = false;
- fix_up_src_spans = !src_lattice.IsSentence();
+ fix_up_src_spans = !IsSentence(src_lattice);
}
KBest::KBestDerivations<vector<Hypergraph::Edge const*>, ViterbiPathTraversal> kbest(in_g, k_best);
diff --git a/decoder/aligner.h b/decoder/aligner.h
index a34795c9..d68ceefc 100644
--- a/decoder/aligner.h
+++ b/decoder/aligner.h
@@ -1,4 +1,4 @@
-#ifndef _ALIGNER_H_
+#ifndef ALIGNER_H
#include <string>
#include <iostream>
diff --git a/decoder/apply_models.h b/decoder/apply_models.h
index 19a4c7be..f03c973a 100644
--- a/decoder/apply_models.h
+++ b/decoder/apply_models.h
@@ -1,5 +1,5 @@
-#ifndef _APPLY_MODELS_H_
-#define _APPLY_MODELS_H_
+#ifndef APPLY_MODELS_H_
+#define APPLY_MODELS_H_
#include <iostream>
diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc
new file mode 100644
index 00000000..fbde7e24
--- /dev/null
+++ b/decoder/bottom_up_parser-rs.cc
@@ -0,0 +1,341 @@
+#include "bottom_up_parser-rs.h"
+
+#include <iostream>
+#include <map>
+
+#include "node_state_hash.h"
+#include "nt_span.h"
+#include "hg.h"
+#include "array2d.h"
+#include "tdict.h"
+#include "verbose.h"
+
+using namespace std;
+
+static WordID kEPS = 0;
+
+struct RSActiveItem;
+class RSChart {
+ public:
+ RSChart(const string& goal,
+ const vector<GrammarPtr>& grammars,
+ const Lattice& input,
+ Hypergraph* forest);
+ ~RSChart();
+
+ void AddToChart(const RSActiveItem& x, int i, int j);
+ void ConsumeTerminal(const RSActiveItem& x, int i, int j, int k);
+ void ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k);
+ bool Parse();
+ inline bool GoalFound() const { return goal_idx_ >= 0; }
+ inline int GetGoalIndex() const { return goal_idx_; }
+
+ private:
+ void ApplyRules(const int i,
+ const int j,
+ const RuleBin* rules,
+ const Hypergraph::TailNodeVector& tail,
+ const float lattice_cost);
+
+ // returns true if a new node was added to the chart
+ // false otherwise
+ bool ApplyRule(const int i,
+ const int j,
+ const TRulePtr& r,
+ const Hypergraph::TailNodeVector& ant_nodes,
+ const float lattice_cost);
+
+ void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx);
+ void TopoSortUnaries();
+
+ const vector<GrammarPtr>& grammars_;
+ const Lattice& input_;
+ Hypergraph* forest_;
+ Array2D<vector<int>> chart_; // chart_(i,j) is the list of nodes (represented
+ // by their index in forest_->nodes_) derived spanning i,j
+ typedef map<int, int> Cat2NodeMap;
+ Array2D<Cat2NodeMap> nodemap_;
+ const WordID goal_cat_; // category that is being searched for at [0,n]
+ TRulePtr goal_rule_;
+ int goal_idx_; // index of goal node, if found
+ const int lc_fid_;
+ vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars
+
+ static WordID kGOAL; // [Goal]
+};
+
+WordID RSChart::kGOAL = 0;
+
+// "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span"
+struct RSActiveItem {
+ explicit RSActiveItem(const GrammarIter* g, int i) :
+ gptr_(g), ant_nodes_(), lattice_cost(0.0), i_(i) {}
+ void ExtendTerminal(int symbol, float src_cost) {
+ lattice_cost += src_cost;
+ if (symbol != kEPS)
+ gptr_ = gptr_->Extend(symbol);
+ }
+ void ExtendNonTerminal(const Hypergraph* hg, int node_index) {
+ gptr_ = gptr_->Extend(hg->nodes_[node_index].cat_);
+ ant_nodes_.push_back(node_index);
+ }
+ // returns false if the extension has failed
+ explicit operator bool() const {
+ return gptr_;
+ }
+ const GrammarIter* gptr_;
+ Hypergraph::TailNodeVector ant_nodes_;
+ float lattice_cost; // TODO: use SparseVector<double> to encode input features
+ short i_;
+};
+
+// some notes on the implementation
+// "X" in Rico's Algorithm 2 roughly looks like it is just a pointer into a grammar
+// trie, but it is actually a full "dotted item" since it needs to contain the information
+// to build the hypergraph (i.e., it must remember the antecedent nodes and where they are,
+// also any information about the path costs).
+
+RSChart::RSChart(const string& goal,
+ const vector<GrammarPtr>& grammars,
+ const Lattice& input,
+ Hypergraph* forest) :
+ grammars_(grammars),
+ input_(input),
+ forest_(forest),
+ chart_(input.size()+1, input.size()+1),
+ nodemap_(input.size()+1, input.size()+1),
+ goal_cat_(TD::Convert(goal) * -1),
+ goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")),
+ goal_idx_(-1),
+ lc_fid_(FD::Convert("LatticeCost")),
+ unaries_() {
+ for (unsigned i = 0; i < grammars_.size(); ++i) {
+ const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules();
+ for (unsigned j = 0; j < u.size(); ++j)
+ unaries_.push_back(u[j]);
+ }
+ TopoSortUnaries();
+ if (!kGOAL) kGOAL = TD::Convert("Goal") * -1;
+ if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl;
+}
+
+static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) {
+ if (mark[node] == 1) {
+ cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n";
+ return false; // cycle detected
+ } else if (mark[node] == 2) {
+ return true; // already been
+ }
+ mark[node] = 1;
+ const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node);
+ if (nit != g.end()) {
+ const vector<TRulePtr>& edges = nit->second;
+ vector<bool> okay(edges.size(), true);
+ for (unsigned i = 0; i < edges.size(); ++i) {
+ okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark);
+ if (!okay[i]) {
+ cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl;
+ }
+ }
+ for (unsigned i = 0; i < edges.size(); ++i) {
+ if (okay[i]) u.push_back(edges[i]);
+ //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl;
+ }
+ }
+ mark[node] = 2;
+ return true;
+}
+
+void RSChart::TopoSortUnaries() {
+ vector<TRulePtr> u(unaries_.size()); u.clear();
+ map<int, vector<TRulePtr> > g;
+ map<int, int> mark;
+ //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl;
+ mark[goal_cat_] = 2;
+ for (unsigned i = 0; i < unaries_.size(); ++i) {
+ //cerr << "Adding: " << unaries_[i]->AsString() << endl;
+ g[unaries_[i]->f()[0]].push_back(unaries_[i]);
+ }
+ //m[unaries_[i]->lhs_].push_back(unaries_[i]);
+ for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) {
+ //cerr << "PROC: " << TD::Convert(-it->first) << endl;
+ if (mark[it->first] > 0) {
+ //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n";
+ } else {
+ TopoSortVisit(it->first, u, g, mark);
+ }
+ }
+ unaries_.clear();
+ for (int i = u.size() - 1; i >= 0; --i)
+ unaries_.push_back(u[i]);
+}
+
+bool RSChart::ApplyRule(const int i,
+ const int j,
+ const TRulePtr& r,
+ const Hypergraph::TailNodeVector& ant_nodes,
+ const float lattice_cost) {
+ Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
+ //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
+ new_edge->prev_i_ = r->prev_i;
+ new_edge->prev_j_ = r->prev_j;
+ new_edge->i_ = i;
+ new_edge->j_ = j;
+ new_edge->feature_values_ = r->GetFeatureValues();
+ if (lattice_cost && lc_fid_)
+ new_edge->feature_values_.set_value(lc_fid_, lattice_cost);
+ Cat2NodeMap& c2n = nodemap_(i,j);
+ const bool is_goal = (r->GetLHS() == kGOAL);
+ const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS());
+ Hypergraph::Node* node = NULL;
+ bool added_node = false;
+ if (ni == c2n.end()) {
+ //cerr << "(" << i << "," << j << ") => " << TD::Convert(-r->GetLHS()) << endl;
+ added_node = true;
+ node = forest_->AddNode(r->GetLHS());
+ c2n[r->GetLHS()] = node->id_;
+ if (is_goal) {
+ assert(goal_idx_ == -1);
+ goal_idx_ = node->id_;
+ } else {
+ chart_(i,j).push_back(node->id_);
+ }
+ } else {
+ node = &forest_->nodes_[ni->second];
+ }
+ forest_->ConnectEdgeToHeadNode(new_edge, node);
+ return added_node;
+}
+
+void RSChart::ApplyRules(const int i,
+ const int j,
+ const RuleBin* rules,
+ const Hypergraph::TailNodeVector& tail,
+ const float lattice_cost) {
+ const int n = rules->GetNumRules();
+ //cerr << i << " " << j << ": NUM RULES: " << n << endl;
+ for (int k = 0; k < n; ++k) {
+ //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl;
+ TRulePtr rule = rules->GetIthRule(k);
+ // apply rule, and if we create a new node, apply any necessary
+ // unary rules
+ if (ApplyRule(i, j, rule, tail, lattice_cost)) {
+ unsigned nodeidx = nodemap_(i,j)[rule->lhs_];
+ ApplyUnaryRules(i, j, rule->lhs_, nodeidx);
+ }
+ }
+}
+
+void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx) {
+ for (unsigned ri = 0; ri < unaries_.size(); ++ri) {
+ //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl;
+ if (unaries_[ri]->f()[0] == cat) {
+ //cerr << " --MATCH\n";
+ WordID new_lhs = unaries_[ri]->GetLHS();
+ const Hypergraph::TailNodeVector ant(1, nodeidx);
+ if (ApplyRule(i, j, unaries_[ri], ant, 0)) {
+ //cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl;
+ unsigned nodeidx = nodemap_(i,j)[new_lhs];
+ ApplyUnaryRules(i, j, new_lhs, nodeidx);
+ }
+ }
+ }
+}
+
+void RSChart::AddToChart(const RSActiveItem& x, int i, int j) {
+ // deal with completed rules
+ const RuleBin* rb = x.gptr_->GetRules();
+ if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_cost);
+
+ //cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n";
+ // continue looking for extensions of the rule to the right
+ for (unsigned k = j+1; k <= input_.size(); ++k) {
+ ConsumeTerminal(x, i, j, k);
+ ConsumeNonTerminal(x, i, j, k);
+ }
+}
+
+void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) {
+ //cerr << "ConsumeT(" << i << "," << j << "," << k << "):\n";
+
+ const unsigned check_edge_len = k - j;
+ // long-term TODO preindex this search so i->len->words is constant time rather than fan out
+ for (auto& in_edge : input_[j]) {
+ if (in_edge.dist2next == check_edge_len) {
+ //cerr << " Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl;
+ RSActiveItem copy = x;
+ copy.ExtendTerminal(in_edge.label, in_edge.cost);
+ if (copy) AddToChart(copy, i, k);
+ }
+ }
+}
+
+void RSChart::ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k) {
+ //cerr << "ConsumeNT(" << i << "," << j << "," << k << "):\n";
+ for (auto& nodeidx : chart_(j,k)) {
+ //cerr << " Found completed NT in (" << j << "," << k << ") of type " << TD::Convert(-forest_->nodes_[nodeidx].cat_) << endl;
+ RSActiveItem copy = x;
+ copy.ExtendNonTerminal(forest_, nodeidx);
+ if (copy) AddToChart(copy, i, k);
+ }
+}
+
+bool RSChart::Parse() {
+ size_t in_size_2 = input_.size() * input_.size();
+ forest_->nodes_.reserve(in_size_2 * 2);
+ size_t res = min(static_cast<size_t>(2000000), static_cast<size_t>(in_size_2 * 1000));
+ forest_->edges_.reserve(res);
+ goal_idx_ = -1;
+ const int N = input_.size();
+ for (int i = N - 1; i >= 0; --i) {
+ for (int j = i + 1; j <= N; ++j) {
+ for (unsigned gi = 0; gi < grammars_.size(); ++gi) {
+ RSActiveItem item(grammars_[gi]->GetRoot(), i);
+ ConsumeTerminal(item, i, i, j);
+ }
+ for (unsigned gi = 0; gi < grammars_.size(); ++gi) {
+ RSActiveItem item(grammars_[gi]->GetRoot(), i);
+ ConsumeNonTerminal(item, i, i, j);
+ }
+ }
+ }
+
+ // look for goal
+ const vector<int>& dh = chart_(0, input_.size());
+ for (unsigned di = 0; di < dh.size(); ++di) {
+ const Hypergraph::Node& node = forest_->nodes_[dh[di]];
+ if (node.cat_ == goal_cat_) {
+ Hypergraph::TailNodeVector ant(1, node.id_);
+ ApplyRule(0, input_.size(), goal_rule_, ant, 0);
+ }
+ }
+ if (!SILENT) cerr << endl;
+
+ if (GoalFound())
+ forest_->PruneUnreachable(forest_->nodes_.size() - 1);
+ return GoalFound();
+}
+
+RSChart::~RSChart() {}
+
+RSExhaustiveBottomUpParser::RSExhaustiveBottomUpParser(
+ const string& goal_sym,
+ const vector<GrammarPtr>& grammars) :
+ goal_sym_(goal_sym),
+ grammars_(grammars) {}
+
+bool RSExhaustiveBottomUpParser::Parse(const Lattice& input,
+ Hypergraph* forest) const {
+ kEPS = TD::Convert("*EPS*");
+ RSChart chart(goal_sym_, grammars_, input, forest);
+ const bool result = chart.Parse();
+
+ if (result) {
+ for (auto& node : forest->nodes_) {
+ Span prev;
+ const Span s = forest->NodeSpan(node.id_, &prev);
+ node.node_hash = cdec::HashNode(node.cat_, s.l, s.r, prev.l, prev.r);
+ }
+ }
+ return result;
+}
diff --git a/decoder/bottom_up_parser-rs.h b/decoder/bottom_up_parser-rs.h
new file mode 100644
index 00000000..2e271e99
--- /dev/null
+++ b/decoder/bottom_up_parser-rs.h
@@ -0,0 +1,29 @@
+#ifndef RSBOTTOM_UP_PARSER_H_
+#define RSBOTTOM_UP_PARSER_H_
+
+#include <vector>
+#include <string>
+
+#include "lattice.h"
+#include "grammar.h"
+
+class Hypergraph;
+
+// implementation of Sennrich (2014) parser
+// http://aclweb.org/anthology/W/W14/W14-4011.pdf
+class RSExhaustiveBottomUpParser {
+ public:
+ RSExhaustiveBottomUpParser(const std::string& goal_sym,
+ const std::vector<GrammarPtr>& grammars);
+
+ // returns true if goal reached spanning the full input
+ // forest contains the full (i.e., unpruned) parse forest
+ bool Parse(const Lattice& input,
+ Hypergraph* forest) const;
+
+ private:
+ const std::string goal_sym_;
+ const std::vector<GrammarPtr> grammars_;
+};
+
+#endif
diff --git a/decoder/bottom_up_parser.h b/decoder/bottom_up_parser.h
index 546bfb54..628bb96d 100644
--- a/decoder/bottom_up_parser.h
+++ b/decoder/bottom_up_parser.h
@@ -1,5 +1,5 @@
-#ifndef _BOTTOM_UP_PARSER_H_
-#define _BOTTOM_UP_PARSER_H_
+#ifndef BOTTOM_UP_PARSER_H_
+#define BOTTOM_UP_PARSER_H_
#include <vector>
#include <string>
diff --git a/decoder/csplit.cc b/decoder/csplit.cc
index 4a723822..7ee4092e 100644
--- a/decoder/csplit.cc
+++ b/decoder/csplit.cc
@@ -151,6 +151,7 @@ bool CompoundSplit::TranslateImpl(const string& input,
smeta->SetSourceLength(in.size()); // TODO do utf8 or somethign
for (int i = 0; i < in.size(); ++i)
smeta->src_lattice_.push_back(vector<LatticeArc>(1, LatticeArc(TD::Convert(in[i]), 0.0, 1)));
+ smeta->ComputeInputLatticeType();
pimpl_->BuildTrellis(in, forest);
forest->Reweight(weights);
return true;
diff --git a/decoder/csplit.h b/decoder/csplit.h
index 82ed23fc..83d457b8 100644
--- a/decoder/csplit.h
+++ b/decoder/csplit.h
@@ -1,5 +1,5 @@
-#ifndef _CSPLIT_H_
-#define _CSPLIT_H_
+#ifndef CSPLIT_H_
+#define CSPLIT_H_
#include "translator.h"
#include "lattice.h"
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index c384c33f..1e6c3194 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -17,6 +17,7 @@ namespace std { using std::tr1::unordered_map; }
#include "fdict.h"
#include "timing_stats.h"
#include "verbose.h"
+#include "b64featvector.h"
#include "translator.h"
#include "phrasebased_translator.h"
@@ -86,7 +87,7 @@ struct ELengthWeightFunction {
}
};
inline void ShowBanner() {
- cerr << "cdec (c) 2009--2014 by Chris Dyer\n";
+ cerr << "cdec (c) 2009--2014 by Chris Dyer" << endl;
}
inline string str(char const* name,po::variables_map const& conf) {
@@ -195,7 +196,7 @@ struct DecoderImpl {
}
forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1);
if (!forestname.empty()) forestname=" "+forestname;
- if (!SILENT) {
+ if (!SILENT) {
forest_stats(forest," Pruned "+forestname+" forest",false,false);
cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl;
}
@@ -261,7 +262,7 @@ struct DecoderImpl {
assert(ref);
LatticeTools::ConvertTextOrPLF(sref, ref);
}
- }
+ }
// used to construct the suffix string to get the name of arguments for multiple passes
// e.g., the "2" in --weights2
@@ -284,7 +285,7 @@ struct DecoderImpl {
boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;
int sample_max_trans;
bool aligner_mode;
- bool graphviz;
+ bool graphviz;
bool joshua_viz;
bool encode_b64;
bool kbest;
@@ -301,6 +302,7 @@ struct DecoderImpl {
bool feature_expectations; // TODO Observer
bool output_training_vector; // TODO Observer
bool remove_intersected_rule_annotations;
+ bool mr_mira_compat; // Mr.MIRA compatibility mode.
boost::scoped_ptr<IncrementalBase> incremental;
@@ -404,6 +406,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
("extract_rules", po::value<string>(), "Extract the rules used in translation (not de-duped!) to a file in this directory")
("show_derivations", po::value<string>(), "Directory to print the derivation structures to")
+ ("show_derivations_mask", po::value<int>()->default_value(Hypergraph::SPAN|Hypergraph::RULE), "Bit-mask for what to print in derivation structures")
("graphviz","Show (constrained) translation forest in GraphViz format")
("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
@@ -414,7 +417,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
("forest_output,O",po::value<string>(),"Directory to write forests to")
- ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)");
+ ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)")
+ ("mr_mira_compat", "Mr.MIRA compatibility mode (applies weight delta if available; outputs number of lines before k-best)");
// ob.AddOptions(&opts);
po::options_description clo("Command line options");
@@ -665,7 +669,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
unique_kbest = conf.count("unique_k_best");
get_oracle_forest = conf.count("get_oracle_forest");
oracle.show_derivation=conf.count("show_derivations");
+ oracle.show_derivation_mask=conf["show_derivations_mask"].as<int>();
remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");
+ mr_mira_compat = conf.count("mr_mira_compat");
combine_size = conf["combine_size"].as<int>();
if (combine_size < 1) combine_size = 1;
@@ -699,6 +705,24 @@ void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string
static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string);
}
+static inline void ApplyWeightDelta(const string &delta_b64, vector<weight_t> *weights) {
+ SparseVector<weight_t> delta;
+ DecodeFeatureVector(delta_b64, &delta);
+ if (delta.empty()) return;
+ // Apply updates
+ for (SparseVector<weight_t>::iterator dit = delta.begin();
+ dit != delta.end(); ++dit) {
+ int feat_id = dit->first;
+ union { weight_t weight; unsigned long long repr; } feat_delta;
+ feat_delta.weight = dit->second;
+ if (!SILENT)
+ cerr << "[decoder weight update] " << FD::Convert(feat_id) << " " << feat_delta.weight
+ << " = " << hex << feat_delta.repr << endl;
+ if (weights->size() <= feat_id) weights->resize(feat_id + 1);
+ (*weights)[feat_id] += feat_delta.weight;
+ }
+}
+
bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
string buf = input;
NgramCache::Clear(); // clear ngram cache for remote LM (if used)
@@ -709,6 +733,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (sgml.find("id") != sgml.end())
sent_id = atoi(sgml["id"].c_str());
+ // Add delta from input to weights before decoding
+ if (mr_mira_compat)
+ ApplyWeightDelta(sgml["delta"], init_weights.get());
+
if (!SILENT) {
cerr << "\nINPUT: ";
if (buf.size() < 100)
@@ -928,14 +956,14 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Hypergraph new_hg;
{
ReadFile rf(writer.fname_);
- bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
+ bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg);
if (!succeeded) abort();
}
HG::Union(forest, &new_hg);
- bool succeeded = writer.Write(new_hg, false);
+ bool succeeded = writer.Write(new_hg);
if (!succeeded) abort();
} else {
- bool succeeded = writer.Write(forest, false);
+ bool succeeded = writer.Write(forest);
if (!succeeded) abort();
}
}
@@ -947,7 +975,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (kbest && !has_ref) {
//TODO: does this work properly?
const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname);
} else if (csplit_output_plf) {
cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {
@@ -1021,14 +1049,14 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Hypergraph new_hg;
{
ReadFile rf(writer.fname_);
- bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
+ bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg);
if (!succeeded) abort();
}
HG::Union(forest, &new_hg);
- bool succeeded = writer.Write(new_hg, false);
+ bool succeeded = writer.Write(new_hg);
if (!succeeded) abort();
} else {
- bool succeeded = writer.Write(forest, false);
+ bool succeeded = writer.Write(forest);
if (!succeeded) abort();
}
}
@@ -1078,7 +1106,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (conf.count("graphviz")) forest.PrintGraphviz();
if (kbest) {
const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest, mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname);
}
if (conf.count("show_conditional_prob")) {
const prob_t ref_z = Inside<prob_t, EdgeProb>(forest);
@@ -1098,4 +1126,3 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
o->NotifyDecodingComplete(smeta);
return true;
}
-
diff --git a/decoder/decoder.h b/decoder/decoder.h
index 8039a42b..a545206b 100644
--- a/decoder/decoder.h
+++ b/decoder/decoder.h
@@ -1,5 +1,5 @@
-#ifndef _DECODER_H_
-#define _DECODER_H_
+#ifndef DECODER_H_
+#define DECODER_H_
#include <iostream>
#include <string>
diff --git a/decoder/earley_composer.h b/decoder/earley_composer.h
index 9f786bf6..31602f67 100644
--- a/decoder/earley_composer.h
+++ b/decoder/earley_composer.h
@@ -1,5 +1,5 @@
-#ifndef _EARLEY_COMPOSER_H_
-#define _EARLEY_COMPOSER_H_
+#ifndef EARLEY_COMPOSER_H_
+#define EARLEY_COMPOSER_H_
#include <iostream>
diff --git a/decoder/factored_lexicon_helper.h b/decoder/factored_lexicon_helper.h
index 7fedc517..460bdebb 100644
--- a/decoder/factored_lexicon_helper.h
+++ b/decoder/factored_lexicon_helper.h
@@ -1,5 +1,5 @@
-#ifndef _FACTORED_LEXICON_HELPER_
-#define _FACTORED_LEXICON_HELPER_
+#ifndef FACTORED_LEXICON_HELPER_
+#define FACTORED_LEXICON_HELPER_
#include <cassert>
#include <vector>
diff --git a/decoder/ff.h b/decoder/ff.h
index afa3dbca..d6487d97 100644
--- a/decoder/ff.h
+++ b/decoder/ff.h
@@ -1,5 +1,5 @@
-#ifndef _FF_H_
-#define _FF_H_
+#ifndef FF_H_
+#define FF_H_
#include <string>
#include <vector>
@@ -27,6 +27,12 @@ class FeatureFunction {
// search. When non-zero, the last N bytes in the state should be ignored when
// splitting a hypernode by the state. This allows the feature function to
// store some side data and later retrieve it via the state bytes.
+ //
+ // In general, this should not be necessary and it should always be possible
+ // to replace this with a more appropriate design of state (if you find
+ // yourself having to ignore some part of the state, you are most likely
+ // storing redundant information in the state). Be sure that you
+ // understand how this affects ApplyModelSet() before using it.
int IgnoredStateSize() const { return ignored_state_size_; }
// override this. not virtual because we want to expose this to factory template for help before creating a FF
@@ -82,6 +88,7 @@ class FeatureFunction {
state_size_ = state_size;
}
+ // See document of IgnoredStateSize() above.
void SetIgnoredStateSize(size_t ignored_state_size) {
ignored_state_size_ = ignored_state_size;
}
diff --git a/decoder/ff_basic.cc b/decoder/ff_basic.cc
index f9404d24..f960418a 100644
--- a/decoder/ff_basic.cc
+++ b/decoder/ff_basic.cc
@@ -49,9 +49,7 @@ void SourceWordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta,
features->set_value(fid_, edge.rule_->FWords() * value_);
}
-
-ArityPenalty::ArityPenalty(const std::string& param) :
- value_(-1.0 / log(10)) {
+ArityPenalty::ArityPenalty(const std::string& param) {
string fname = "Arity_";
unsigned MAX=DEFAULT_MAX_ARITY;
using namespace boost;
@@ -61,7 +59,8 @@ ArityPenalty::ArityPenalty(const std::string& param) :
WordID fid=FD::Convert(fname+lexical_cast<string>(i));
fids_.push_back(fid);
}
- while (!fids_.empty() && fids_.back()==0) fids_.pop_back(); // pretty up features vector in case FD was frozen. doesn't change anything
+ // pretty up features vector in case FD was frozen. doesn't change anything
+ while (!fids_.empty() && fids_.back()==0) fids_.pop_back();
}
void ArityPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta,
@@ -75,6 +74,6 @@ void ArityPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta,
(void) state;
(void) estimated_features;
unsigned a=edge.Arity();
- features->set_value(a<fids_.size()?fids_[a]:0, value_);
+ if (a < fids_.size()) features->set_value(fids_[a], 1.0);
}
diff --git a/decoder/ff_basic.h b/decoder/ff_basic.h
index 901c0110..c63daf0f 100644
--- a/decoder/ff_basic.h
+++ b/decoder/ff_basic.h
@@ -1,5 +1,5 @@
-#ifndef _FF_BASIC_H_
-#define _FF_BASIC_H_
+#ifndef FF_BASIC_H_
+#define FF_BASIC_H_
#include "ff.h"
@@ -41,7 +41,7 @@ class SourceWordPenalty : public FeatureFunction {
const double value_;
};
-#define DEFAULT_MAX_ARITY 9
+#define DEFAULT_MAX_ARITY 50
#define DEFAULT_MAX_ARITY_STRINGIZE(x) #x
#define DEFAULT_MAX_ARITY_STRINGIZE_EVAL(x) DEFAULT_MAX_ARITY_STRINGIZE(x)
#define DEFAULT_MAX_ARITY_STR DEFAULT_MAX_ARITY_STRINGIZE_EVAL(DEFAULT_MAX_ARITY)
@@ -62,7 +62,6 @@ class ArityPenalty : public FeatureFunction {
void* context) const;
private:
std::vector<WordID> fids_;
- const double value_;
};
#endif
diff --git a/decoder/ff_bleu.h b/decoder/ff_bleu.h
index 344dc788..8ca2c095 100644
--- a/decoder/ff_bleu.h
+++ b/decoder/ff_bleu.h
@@ -1,5 +1,5 @@
-#ifndef _BLEU_FF_H_
-#define _BLEU_FF_H_
+#ifndef BLEU_FF_H_
+#define BLEU_FF_H_
#include <vector>
#include <string>
diff --git a/decoder/ff_charset.h b/decoder/ff_charset.h
index 267ef65d..e22ece2b 100644
--- a/decoder/ff_charset.h
+++ b/decoder/ff_charset.h
@@ -1,5 +1,5 @@
-#ifndef _FFCHARSET_H_
-#define _FFCHARSET_H_
+#ifndef FFCHARSET_H_
+#define FFCHARSET_H_
#include <string>
#include <map>
diff --git a/decoder/ff_conll.cc b/decoder/ff_conll.cc
new file mode 100644
index 00000000..8ded44b7
--- /dev/null
+++ b/decoder/ff_conll.cc
@@ -0,0 +1,250 @@
+#include "ff_conll.h"
+
+#include <stdlib.h>
+#include <sstream>
+#include <cassert>
+#include <cmath>
+#include <boost/lexical_cast.hpp>
+
+#include "hg.h"
+#include "filelib.h"
+#include "stringlib.h"
+#include "sentence_metadata.h"
+#include "lattice.h"
+#include "fdict.h"
+#include "verbose.h"
+#include "tdict.h"
+
+CoNLLFeatures::CoNLLFeatures(const string& param) {
+ // cerr << "initializing CoNLLFeatures with parameters: " << param;
+ kSOS = TD::Convert("<s>");
+ kEOS = TD::Convert("</s>");
+ macro_regex = sregex::compile("%([xy])\\[(-[1-9][0-9]*|0|[1-9][1-9]*)]");
+ ParseArgs(param);
+}
+
+string CoNLLFeatures::Escape(const string& x) const {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
+ }
+ return y;
+}
+
+// replace %x[relative_location] or %y[relative_location] with actual_token
+// within feature_instance
+void CoNLLFeatures::ReplaceMacroWithString(
+ string& feature_instance, bool token_vs_label, int relative_location,
+ const string& actual_token) const {
+
+ stringstream macro;
+ if (token_vs_label) {
+ macro << "%x[";
+ } else {
+ macro << "%y[";
+ }
+ macro << relative_location << "]";
+ int macro_index = feature_instance.find(macro.str());
+ if (macro_index == string::npos) {
+ cerr << "Can't find macro " << macro.str() << " in feature template "
+ << feature_instance;
+ abort();
+ }
+ feature_instance.replace(macro_index, macro.str().size(), actual_token);
+}
+
+void CoNLLFeatures::ReplaceTokenMacroWithString(
+ string& feature_instance, int relative_location,
+ const string& actual_token) const {
+
+ ReplaceMacroWithString(feature_instance, true, relative_location,
+ actual_token);
+}
+
+void CoNLLFeatures::ReplaceLabelMacroWithString(
+ string& feature_instance, int relative_location,
+ const string& actual_token) const {
+
+ ReplaceMacroWithString(feature_instance, false, relative_location,
+ actual_token);
+}
+
+void CoNLLFeatures::Error(const string& error_message) const {
+ cerr << "Error: " << error_message << "\n\n"
+
+ << "CoNLLFeatures Usage: \n"
+ << " feature_function=CoNLLFeatures -t <TEMPLATE>\n\n"
+
+ << "Example <TEMPLATE>: U1:%x[-1]_%x[0]|%y[0]\n\n"
+
+ << "%x[k] and %y[k] are macros to be instantiated with an input\n"
+ << "token (for x) or a label (for y). k specifies the relative\n"
+ << "location of the input token or label with respect to the current\n"
+ << "position. For x, k is an integer value. For y, k must be 0 (to\n"
+ << "be extended).\n\n";
+
+ abort();
+}
+
+void CoNLLFeatures::ParseArgs(const string& in) {
+ which_feat = 0;
+ vector<string> const& argv = SplitOnWhitespace(in);
+ for (vector<string>::const_iterator i = argv.begin(); i != argv.end(); ++i) {
+ string const& s = *i;
+ if (s[0] == '-') {
+ if (s.size() > 2) {
+ stringstream msg;
+ msg << s << " is an invalid option for CoNLLFeatures.";
+ Error(msg.str());
+ }
+
+ switch (s[1]) {
+
+ case 'w': {
+ if (++i == argv.end()) {
+ Error("Missing parameter to -w");
+ }
+ which_feat = boost::lexical_cast<unsigned>(*i);
+ break;
+ }
+ // feature template
+ case 't': {
+ if (++i == argv.end()) {
+ Error("Can't find template.");
+ }
+ feature_template = *i;
+ string::const_iterator start = feature_template.begin();
+ string::const_iterator end = feature_template.end();
+ smatch macro_match;
+
+ // parse the template
+ while (regex_search(start, end, macro_match, macro_regex)) {
+ // get the relative location
+ string relative_location_str(macro_match[2].first,
+ macro_match[2].second);
+ int relative_location = atoi(relative_location_str.c_str());
+ // add it to the list of relative locations for token or label
+ // (i.e. x or y)
+ bool valid_location = true;
+ if (*macro_match[1].first == 'x') {
+ // add it to token locations
+ token_relative_locations.push_back(relative_location);
+ } else {
+ if (relative_location != 0) { valid_location = false; }
+ // add it to label locations
+ label_relative_locations.push_back(relative_location);
+ }
+ if (!valid_location) {
+ stringstream msg;
+ msg << "Relative location " << relative_location
+ << " in feature template " << feature_template
+ << " is invalid.";
+ Error(msg.str());
+ }
+ start = macro_match[0].second;
+ }
+ break;
+ }
+
+ // TODO: arguments to specify kSOS and kEOS
+
+ default: {
+ stringstream msg;
+ msg << "Invalid option on CoNLLFeatures: " << s;
+ Error(msg.str());
+ break;
+ }
+ } // end of switch
+ } // end of if (token starts with hyphen)
+ } // end of for loop (over arguments)
+
+ // the -t (i.e. template) option is mandatory in this feature function
+ if (label_relative_locations.size() == 0 ||
+ token_relative_locations.size() == 0) {
+ stringstream msg;
+ msg << "Feature template must specify at least one"
+ << "token macro (e.g. x[-1]) and one label macro (e.g. y[0]).";
+ Error(msg.str());
+ }
+}
+
+void CoNLLFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+ const Lattice& sl = smeta.GetSourceLattice();
+ current_input.resize(sl.size());
+ for (unsigned i = 0; i < sl.size(); ++i) {
+ if (sl[i].size() != 1) {
+ stringstream msg;
+ msg << "CoNLLFeatures don't support lattice inputs!\nid="
+ << smeta.GetSentenceId() << endl;
+ Error(msg.str());
+ }
+ current_input[i] = sl[i][0].label;
+ }
+ vector<WordID> wids;
+ string fn = "feat";
+ fn += boost::lexical_cast<string>(which_feat);
+ string feats = smeta.GetSGMLValue(fn);
+ if (feats.size() == 0) {
+ Error("Can't find " + fn + " in <seg>\n");
+ }
+ TD::ConvertSentence(feats, &wids);
+ assert(current_input.size() == wids.size());
+ current_input = wids;
+}
+
+void CoNLLFeatures::TraversalFeaturesImpl(
+ const SentenceMetadata& smeta, const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts, SparseVector<double>* features,
+ SparseVector<double>* estimated_features, void* context) const {
+
+ const TRule& rule = *edge.rule_;
+ // arity = 0, no nonterminals
+ // size = 1, predicted label is a single token
+ if (rule.Arity() != 0 ||
+ rule.e_.size() != 1) {
+ return;
+ }
+
+ // replace label macros with actual label strings
+ // NOTE: currently, this feature function doesn't allow any label
+ // macros except %y[0]. but you can look at as much of the source as you want
+ const WordID y0 = rule.e_[0];
+ string y0_str = TD::Convert(y0);
+
+ // start of the span in the input being labeled
+ const int from_src_index = edge.i_;
+ // end of the span in the input
+ const int to_src_index = edge.j_;
+
+ // in the case of tagging the size of the spans being labeled will
+ // always be 1, but in other formalisms, you can have bigger spans
+ if (to_src_index - from_src_index != 1) {
+ cerr << "CoNLLFeatures doesn't support input spans of length != 1";
+ abort();
+ }
+
+ string feature_instance = feature_template;
+ // replace token macros with actual token strings
+ for (unsigned i = 0; i < token_relative_locations.size(); ++i) {
+ int loc = token_relative_locations[i];
+ WordID x = loc < 0? kSOS: kEOS;
+ if(from_src_index + loc >= 0 &&
+ from_src_index + loc < current_input.size()) {
+ x = current_input[from_src_index + loc];
+ }
+ string x_str = TD::Convert(x);
+ ReplaceTokenMacroWithString(feature_instance, loc, x_str);
+ }
+
+ ReplaceLabelMacroWithString(feature_instance, 0, y0_str);
+
+ // pick a real value for this feature
+ double fval = 1.0;
+
+ // add it to the feature vector
+ // FD::Convert converts the feature string to a feature int
+ // Escape makes sure the feature string doesn't have any bad
+ // symbols that could confuse a parser somewhere
+ features->add_value(FD::Convert(Escape(feature_instance)), fval);
+}
diff --git a/decoder/ff_conll.h b/decoder/ff_conll.h
new file mode 100644
index 00000000..b37356d8
--- /dev/null
+++ b/decoder/ff_conll.h
@@ -0,0 +1,45 @@
+#ifndef FF_CONLL_H_
+#define FF_CONLL_H_
+
+#include <vector>
+#include <boost/xpressive/xpressive.hpp>
+#include "ff.h"
+
+using namespace boost::xpressive;
+using namespace std;
+
+class CoNLLFeatures : public FeatureFunction {
+ public:
+ CoNLLFeatures(const string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const HG::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ virtual void ParseArgs(const string& in);
+ virtual string Escape(const string& x) const;
+ virtual void ReplaceMacroWithString(string& feature_instance,
+ bool token_vs_label,
+ int relative_location,
+ const string& actual_token) const;
+ virtual void ReplaceTokenMacroWithString(string& feature_instance,
+ int relative_location,
+ const string& actual_token) const;
+ virtual void ReplaceLabelMacroWithString(string& feature_instance,
+ int relative_location,
+ const string& actual_token) const;
+ virtual void Error(const string&) const;
+
+ private:
+ vector<int> token_relative_locations, label_relative_locations;
+ string feature_template;
+ vector<WordID> current_input;
+ WordID kSOS, kEOS;
+ sregex macro_regex;
+ unsigned which_feat;
+};
+
+#endif
diff --git a/decoder/ff_const_reorder.cc b/decoder/ff_const_reorder.cc
index 95546793..f1a6f7cb 100644
--- a/decoder/ff_const_reorder.cc
+++ b/decoder/ff_const_reorder.cc
@@ -1071,6 +1071,9 @@ ConstReorderFeature::~ConstReorderFeature() { // TODO
void ConstReorderFeature::PrepareForInput(const SentenceMetadata& smeta) {
string parse_file = smeta.GetSGMLValue("parse");
+ if (parse_file.empty()) {
+ parse_file = smeta.GetSGMLValue("src_tree");
+ }
string srl_file = smeta.GetSGMLValue("srl");
assert(!(parse_file == "" && srl_file == ""));
diff --git a/decoder/ff_const_reorder_common.h b/decoder/ff_const_reorder_common.h
index 7c111de3..755fd948 100644
--- a/decoder/ff_const_reorder_common.h
+++ b/decoder/ff_const_reorder_common.h
@@ -1081,7 +1081,7 @@ typedef std::unordered_map<std::string, int>::iterator Iterator;
struct Tsuruoka_Maxent {
Tsuruoka_Maxent(const char* pszModelFName) {
if (pszModelFName != NULL) {
- m_pModel = new ME_Model();
+ m_pModel = new maxent::ME_Model();
m_pModel->load_from_file(pszModelFName);
} else
m_pModel = NULL;
@@ -1091,102 +1091,9 @@ struct Tsuruoka_Maxent {
if (m_pModel != NULL) delete m_pModel;
}
- void fnTrain(const char* pszInstanceFName, const char* pszAlgorithm,
- const char* pszModelFName, int /*iNumIteration*/) {
- assert(strcmp(pszAlgorithm, "l1") == 0 || strcmp(pszAlgorithm, "l2") == 0 ||
- strcmp(pszAlgorithm, "sgd") == 0 ||
- strcmp(pszAlgorithm, "SGD") == 0);
- FILE* fpIn = fopen(pszInstanceFName, "r");
-
- ME_Model* pModel = new ME_Model();
-
- char* pszLine = new char[100001];
- int iNumInstances = 0;
- int iLen;
- while (!feof(fpIn)) {
- pszLine[0] = '\0';
- fgets(pszLine, 20000, fpIn);
- if (strlen(pszLine) == 0) {
- continue;
- }
-
- iLen = strlen(pszLine);
- while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) {
- pszLine[iLen - 1] = '\0';
- iLen--;
- }
-
- iNumInstances++;
-
- ME_Sample* pmes = new ME_Sample();
-
- char* p = strrchr(pszLine, ' ');
- assert(p != NULL);
- p[0] = '\0';
- p++;
- std::vector<std::string> vecContext;
- SplitOnWhitespace(std::string(pszLine), &vecContext);
-
- pmes->label = std::string(p);
- for (size_t i = 0; i < vecContext.size(); i++)
- pmes->add_feature(vecContext[i]);
- pModel->add_training_sample((*pmes));
- if (iNumInstances % 100000 == 0)
- fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances);
- delete pmes;
- }
- fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances);
- fclose(fpIn);
-
- if (strcmp(pszAlgorithm, "l1") == 0)
- pModel->use_l1_regularizer(1.0);
- else if (strcmp(pszAlgorithm, "l2") == 0)
- pModel->use_l2_regularizer(1.0);
- else
- pModel->use_SGD();
-
- pModel->train();
- pModel->save_to_file(pszModelFName);
-
- delete pModel;
- fprintf(stdout, "......Finished Training\n");
- fprintf(stdout, "......Model saved as %s\n", pszModelFName);
- delete[] pszLine;
- }
-
- double fnEval(const char* pszContext, const char* pszOutcome) const {
- std::vector<std::string> vecContext;
- ME_Sample* pmes = new ME_Sample();
- SplitOnWhitespace(std::string(pszContext), &vecContext);
-
- for (size_t i = 0; i < vecContext.size(); i++)
- pmes->add_feature(vecContext[i]);
- std::vector<double> vecProb = m_pModel->classify(*pmes);
- delete pmes;
- int iLableID = m_pModel->get_class_id(pszOutcome);
- return vecProb[iLableID];
- }
- void fnEval(const char* pszContext,
- std::vector<std::pair<std::string, double> >& vecOutput) const {
- std::vector<std::string> vecContext;
- ME_Sample* pmes = new ME_Sample();
- SplitOnWhitespace(std::string(pszContext), &vecContext);
-
- vecOutput.clear();
-
- for (size_t i = 0; i < vecContext.size(); i++)
- pmes->add_feature(vecContext[i]);
- std::vector<double> vecProb = m_pModel->classify(*pmes);
-
- for (size_t i = 0; i < vecProb.size(); i++) {
- std::string label = m_pModel->get_class_label(i);
- vecOutput.push_back(make_pair(label, vecProb[i]));
- }
- delete pmes;
- }
void fnEval(const char* pszContext, std::vector<double>& vecOutput) const {
std::vector<std::string> vecContext;
- ME_Sample* pmes = new ME_Sample();
+ maxent::ME_Sample* pmes = new maxent::ME_Sample();
SplitOnWhitespace(std::string(pszContext), &vecContext);
vecOutput.clear();
@@ -1206,7 +1113,7 @@ struct Tsuruoka_Maxent {
}
private:
- ME_Model* m_pModel;
+ maxent::ME_Model* m_pModel;
};
// an argument item or a predicate item (the verb itself)
diff --git a/decoder/ff_context.h b/decoder/ff_context.h
index 19198ec3..ed1aea2b 100644
--- a/decoder/ff_context.h
+++ b/decoder/ff_context.h
@@ -1,6 +1,5 @@
-
-#ifndef _FF_CONTEXT_H_
-#define _FF_CONTEXT_H_
+#ifndef FF_CONTEXT_H_
+#define FF_CONTEXT_H_
#include <vector>
#include <boost/xpressive/xpressive.hpp>
diff --git a/decoder/ff_csplit.h b/decoder/ff_csplit.h
index 79bf2886..227f2a14 100644
--- a/decoder/ff_csplit.h
+++ b/decoder/ff_csplit.h
@@ -1,5 +1,5 @@
-#ifndef _FF_CSPLIT_H_
-#define _FF_CSPLIT_H_
+#ifndef FF_CSPLIT_H_
+#define FF_CSPLIT_H_
#include <boost/shared_ptr.hpp>
diff --git a/decoder/ff_external.h b/decoder/ff_external.h
index 3e2bee51..fd12a37c 100644
--- a/decoder/ff_external.h
+++ b/decoder/ff_external.h
@@ -1,5 +1,5 @@
-#ifndef _FFEXTERNAL_H_
-#define _FFEXTERNAL_H_
+#ifndef FFEXTERNAL_H_
+#define FFEXTERNAL_H_
#include "ff.h"
diff --git a/decoder/ff_factory.h b/decoder/ff_factory.h
index 1aa8e55f..ba9be9ac 100644
--- a/decoder/ff_factory.h
+++ b/decoder/ff_factory.h
@@ -1,5 +1,5 @@
-#ifndef _FF_FACTORY_H_
-#define _FF_FACTORY_H_
+#ifndef FF_FACTORY_H_
+#define FF_FACTORY_H_
//TODO: use http://www.boost.org/doc/libs/1_43_0/libs/functional/factory/doc/html/index.html ?
diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h
index db4032f7..c8350623 100644
--- a/decoder/ff_klm.h
+++ b/decoder/ff_klm.h
@@ -1,5 +1,5 @@
-#ifndef _KLM_FF_H_
-#define _KLM_FF_H_
+#ifndef KLM_FF_H_
+#define KLM_FF_H_
#include <vector>
#include <string>
diff --git a/decoder/ff_lm.h b/decoder/ff_lm.h
index 85e79704..83a2e186 100644
--- a/decoder/ff_lm.h
+++ b/decoder/ff_lm.h
@@ -1,5 +1,5 @@
-#ifndef _LM_FF_H_
-#define _LM_FF_H_
+#ifndef LM_FF_H_
+#define LM_FF_H_
#include <vector>
#include <string>
diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h
index 4965d235..5dea9a7d 100644
--- a/decoder/ff_ngrams.h
+++ b/decoder/ff_ngrams.h
@@ -1,5 +1,5 @@
-#ifndef _NGRAMS_FF_H_
-#define _NGRAMS_FF_H_
+#ifndef NGRAMS_FF_H_
+#define NGRAMS_FF_H_
#include <vector>
#include <map>
diff --git a/decoder/ff_parse_match.h b/decoder/ff_parse_match.h
index 7820b418..188c406a 100644
--- a/decoder/ff_parse_match.h
+++ b/decoder/ff_parse_match.h
@@ -1,5 +1,5 @@
-#ifndef _FF_PARSE_MATCH_H_
-#define _FF_PARSE_MATCH_H_
+#ifndef FF_PARSE_MATCH_H_
+#define FF_PARSE_MATCH_H_
#include "ff.h"
#include "hg.h"
diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h
index f210dc65..5c4cf45e 100644
--- a/decoder/ff_rules.h
+++ b/decoder/ff_rules.h
@@ -1,5 +1,5 @@
-#ifndef _FF_RULES_H_
-#define _FF_RULES_H_
+#ifndef FF_RULES_H_
+#define FF_RULES_H_
#include <vector>
#include <map>
diff --git a/decoder/ff_ruleshape.h b/decoder/ff_ruleshape.h
index 488cfd84..66914f5d 100644
--- a/decoder/ff_ruleshape.h
+++ b/decoder/ff_ruleshape.h
@@ -1,5 +1,5 @@
-#ifndef _FF_RULESHAPE_H_
-#define _FF_RULESHAPE_H_
+#ifndef FF_RULESHAPE_H_
+#define FF_RULESHAPE_H_
#include <vector>
#include <map>
diff --git a/decoder/ff_soft_syntax.h b/decoder/ff_soft_syntax.h
index e71825d5..da51df7f 100644
--- a/decoder/ff_soft_syntax.h
+++ b/decoder/ff_soft_syntax.h
@@ -1,5 +1,5 @@
-#ifndef _FF_SOFT_SYNTAX_H_
-#define _FF_SOFT_SYNTAX_H_
+#ifndef FF_SOFT_SYNTAX_H_
+#define FF_SOFT_SYNTAX_H_
#include "ff.h"
#include "hg.h"
diff --git a/decoder/ff_soft_syntax_mindist.h b/decoder/ff_soft_syntax_mindist.h
index bf938b38..205eff4b 100644
--- a/decoder/ff_soft_syntax_mindist.h
+++ b/decoder/ff_soft_syntax_mindist.h
@@ -1,5 +1,5 @@
-#ifndef _FF_SOFT_SYNTAX_MINDIST_H_
-#define _FF_SOFT_SYNTAX_MINDIST_H_
+#ifndef FF_SOFT_SYNTAX_MINDIST_H_
+#define FF_SOFT_SYNTAX_MINDIST_H_
#include "ff.h"
#include "hg.h"
diff --git a/decoder/ff_source_path.h b/decoder/ff_source_path.h
index 03126412..fc309264 100644
--- a/decoder/ff_source_path.h
+++ b/decoder/ff_source_path.h
@@ -1,5 +1,5 @@
-#ifndef _FF_SOURCE_PATH_H_
-#define _FF_SOURCE_PATH_H_
+#ifndef FF_SOURCE_PATH_H_
+#define FF_SOURCE_PATH_H_
#include <vector>
#include <map>
diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h
index bdd638c1..6316e881 100644
--- a/decoder/ff_source_syntax.h
+++ b/decoder/ff_source_syntax.h
@@ -1,5 +1,5 @@
-#ifndef _FF_SOURCE_SYNTAX_H_
-#define _FF_SOURCE_SYNTAX_H_
+#ifndef FF_SOURCE_SYNTAX_H_
+#define FF_SOURCE_SYNTAX_H_
#include "ff.h"
#include "hg.h"
diff --git a/decoder/ff_source_syntax2.h b/decoder/ff_source_syntax2.h
index f606c2bf..bbfa9eb6 100644
--- a/decoder/ff_source_syntax2.h
+++ b/decoder/ff_source_syntax2.h
@@ -1,5 +1,5 @@
-#ifndef _FF_SOURCE_SYNTAX2_H_
-#define _FF_SOURCE_SYNTAX2_H_
+#ifndef FF_SOURCE_SYNTAX2_H_
+#define FF_SOURCE_SYNTAX2_H_
#include "ff.h"
#include "hg.h"
diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h
index d2f5e84c..e2475491 100644
--- a/decoder/ff_spans.h
+++ b/decoder/ff_spans.h
@@ -1,5 +1,5 @@
-#ifndef _FF_SPANS_H_
-#define _FF_SPANS_H_
+#ifndef FF_SPANS_H_
+#define FF_SPANS_H_
#include <vector>
#include <map>
diff --git a/decoder/ff_tagger.h b/decoder/ff_tagger.h
index 46418b0c..0cb8c648 100644
--- a/decoder/ff_tagger.h
+++ b/decoder/ff_tagger.h
@@ -1,5 +1,5 @@
-#ifndef _FF_TAGGER_H_
-#define _FF_TAGGER_H_
+#ifndef FF_TAGGER_H_
+#define FF_TAGGER_H_
#include <map>
#include <boost/scoped_ptr.hpp>
diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h
index 0161f603..ec454621 100644
--- a/decoder/ff_wordalign.h
+++ b/decoder/ff_wordalign.h
@@ -1,5 +1,5 @@
-#ifndef _FF_WORD_ALIGN_H_
-#define _FF_WORD_ALIGN_H_
+#ifndef FF_WORD_ALIGN_H_
+#define FF_WORD_ALIGN_H_
#include "ff.h"
#include "array2d.h"
diff --git a/decoder/ff_wordset.h b/decoder/ff_wordset.h
index e78cd2fb..94f5ff8a 100644
--- a/decoder/ff_wordset.h
+++ b/decoder/ff_wordset.h
@@ -1,5 +1,5 @@
-#ifndef _FF_WORDSET_H_
-#define _FF_WORDSET_H_
+#ifndef FF_WORDSET_H_
+#define FF_WORDSET_H_
#include "ff.h"
#include "tdict.h"
diff --git a/decoder/ffset.h b/decoder/ffset.h
index a69a75fa..84f9fdb9 100644
--- a/decoder/ffset.h
+++ b/decoder/ffset.h
@@ -1,5 +1,5 @@
-#ifndef _FFSET_H_
-#define _FFSET_H_
+#ifndef FFSET_H_
+#define FFSET_H_
#include <utility>
#include <vector>
diff --git a/decoder/forest_writer.cc b/decoder/forest_writer.cc
index 6e4cccb3..cc9094d7 100644
--- a/decoder/forest_writer.cc
+++ b/decoder/forest_writer.cc
@@ -11,13 +11,13 @@
using namespace std;
ForestWriter::ForestWriter(const std::string& path, int num) :
- fname_(path + '/' + boost::lexical_cast<string>(num) + ".json.gz"), used_(false) {}
+ fname_(path + '/' + boost::lexical_cast<string>(num) + ".bin.gz"), used_(false) {}
-bool ForestWriter::Write(const Hypergraph& forest, bool minimal_rules) {
+bool ForestWriter::Write(const Hypergraph& forest) {
assert(!used_);
used_ = true;
cerr << " Writing forest to " << fname_ << endl;
WriteFile wf(fname_);
- return HypergraphIO::WriteToJSON(forest, minimal_rules, wf.stream());
+ return HypergraphIO::WriteToBinary(forest, wf.stream());
}
diff --git a/decoder/forest_writer.h b/decoder/forest_writer.h
index 819a8940..54e83470 100644
--- a/decoder/forest_writer.h
+++ b/decoder/forest_writer.h
@@ -1,5 +1,5 @@
-#ifndef _FOREST_WRITER_H_
-#define _FOREST_WRITER_H_
+#ifndef FOREST_WRITER_H_
+#define FOREST_WRITER_H_
#include <string>
@@ -7,7 +7,7 @@ class Hypergraph;
struct ForestWriter {
ForestWriter(const std::string& path, int num);
- bool Write(const Hypergraph& forest, bool minimal_rules);
+ bool Write(const Hypergraph& forest);
const std::string fname_;
bool used_;
diff --git a/decoder/freqdict.h b/decoder/freqdict.h
index 4e03fadd..07d797e2 100644
--- a/decoder/freqdict.h
+++ b/decoder/freqdict.h
@@ -1,5 +1,5 @@
-#ifndef _FREQDICT_H_
-#define _FREQDICT_H_
+#ifndef FREQDICT_H_
+#define FREQDICT_H_
#include <iostream>
#include <map>
diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc
index 4253b652..fe28f4c6 100644
--- a/decoder/fst_translator.cc
+++ b/decoder/fst_translator.cc
@@ -27,11 +27,15 @@ struct FSTTranslatorImpl {
const vector<double>& weights,
Hypergraph* forest) {
bool composed = false;
- if (input.find("{\"rules\"") == 0) {
+ if (input.find("::forest::") == 0) {
istringstream is(input);
+ string header, fname;
+ is >> header >> fname;
+ ReadFile rf(fname);
+ if (!rf) { cerr << "Failed to open " << fname << endl; abort(); }
Hypergraph src_cfg_hg;
- if (!HypergraphIO::ReadFromJSON(&is, &src_cfg_hg)) {
- cerr << "Failed to read HG from JSON.\n";
+ if (!HypergraphIO::ReadFromBinary(rf.stream(), &src_cfg_hg)) {
+ cerr << "Failed to read HG.\n";
abort();
}
if (add_pass_through_rules) {
@@ -95,6 +99,7 @@ bool FSTTranslator::TranslateImpl(const string& input,
const vector<double>& weights,
Hypergraph* minus_lm_forest) {
smeta->SetSourceLength(0); // don't know how to compute this
+ smeta->input_type_ = cdec::kFOREST;
return pimpl_->Translate(input, weights, minus_lm_forest);
}
diff --git a/decoder/hg.h b/decoder/hg.h
index 4ed27d87..c756012e 100644
--- a/decoder/hg.h
+++ b/decoder/hg.h
@@ -1,5 +1,5 @@
-#ifndef _HG_H_
-#define _HG_H_
+#ifndef HG_H_
+#define HG_H_
// define USE_INFO_EDGE 1 if you want lots of debug info shown with --show_derivations - otherwise it adds quite a bit of overhead if ffs have their logging enabled (e.g. ff_from_fsa)
#ifndef USE_INFO_EDGE
@@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include <boost/shared_ptr.hpp>
+#include <boost/serialization/vector.hpp>
#include "feature_vector.h"
#include "small_vector.h"
@@ -69,6 +70,18 @@ namespace HG {
short int j_;
short int prev_i_;
short int prev_j_;
+ template<class Archive>
+ void serialize(Archive & ar, const unsigned int /*version*/) {
+ ar & head_node_;
+ ar & tail_nodes_;
+ ar & rule_;
+ ar & feature_values_;
+ ar & i_;
+ ar & j_;
+ ar & prev_i_;
+ ar & prev_j_;
+ ar & id_;
+ }
void show(std::ostream &o,unsigned mask=SPAN|RULE) const {
o<<'{';
if (mask&CATEGORY)
@@ -149,6 +162,24 @@ namespace HG {
WordID NT() const { return -cat_; }
EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_
EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_
+ template<class Archive>
+ void save(Archive & ar, const unsigned int /*version*/) const {
+ ar & node_hash;
+ ar & id_;
+ ar & TD::Convert(-cat_);
+ ar & in_edges_;
+ ar & out_edges_;
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int /*version*/) {
+ ar & node_hash;
+ ar & id_;
+ std::string cat; ar & cat;
+ cat_ = -TD::Convert(cat);
+ ar & in_edges_;
+ ar & out_edges_;
+ }
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting
node_hash = o.node_hash;
cat_=o.cat_;
@@ -492,6 +523,27 @@ public:
void set_ids(); // resync edge,node .id_
void check_ids() const; // assert that .id_ have been kept in sync
+ template<class Archive>
+ void save(Archive & ar, const unsigned int /*version*/) const {
+ unsigned ns = nodes_.size(); ar & ns;
+ unsigned es = edges_.size(); ar & es;
+ for (auto& n : nodes_) ar & n;
+ for (auto& e : edges_) ar & e;
+ int x;
+ x = edges_topo_; ar & x;
+ x = is_linear_chain_; ar & x;
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int /*version*/) {
+ unsigned ns; ar & ns; nodes_.resize(ns);
+ unsigned es; ar & es; edges_.resize(es);
+ for (auto& n : nodes_) ar & n;
+ for (auto& e : edges_) ar & e;
+ int x;
+ ar & x; edges_topo_ = x;
+ ar & x; is_linear_chain_ = x;
+ }
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
private:
Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {}
};
diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc
index 02f5a401..b9381d02 100644
--- a/decoder/hg_intersect.cc
+++ b/decoder/hg_intersect.cc
@@ -88,7 +88,7 @@ namespace HG {
bool Intersect(const Lattice& target, Hypergraph* hg) {
// there are a number of faster algorithms available for restricted
// classes of hypergraph and/or target.
- if (hg->IsLinearChain() && target.IsSentence())
+ if (hg->IsLinearChain() && IsSentence(target))
return FastLinearIntersect(target, hg);
vector<bool> rem(hg->edges_.size(), false);
diff --git a/decoder/hg_intersect.h b/decoder/hg_intersect.h
index 29a5ea2a..19c1c177 100644
--- a/decoder/hg_intersect.h
+++ b/decoder/hg_intersect.h
@@ -1,5 +1,5 @@
-#ifndef _HG_INTERSECT_H_
-#define _HG_INTERSECT_H_
+#ifndef HG_INTERSECT_H_
+#define HG_INTERSECT_H_
#include "lattice.h"
diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc
index eb0be3d4..626b2954 100644
--- a/decoder/hg_io.cc
+++ b/decoder/hg_io.cc
@@ -6,362 +6,27 @@
#include <sstream>
#include <iostream>
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/archive/binary_oarchive.hpp>
+#include <boost/serialization/shared_ptr.hpp>
+
#include "fast_lexical_cast.hpp"
#include "tdict.h"
-#include "json_parse.h"
#include "hg.h"
using namespace std;
-struct HGReader : public JSONParser {
- HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; }
-
- void CreateNode(const string& cat, const string& shash, const vector<int>& in_edges) {
- WordID c = TD::Convert("X") * -1;
- if (!cat.empty()) c = TD::Convert(cat) * -1;
- Hypergraph::Node* node = hg.AddNode(c);
- char* dend;
- if (shash.size())
- node->node_hash = strtoull(shash.c_str(), &dend, 16);
- else
- node->node_hash = 0;
- for (int i = 0; i < in_edges.size(); ++i) {
- if (in_edges[i] >= hg.edges_.size()) {
- cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i]
- << ", but hg only has " << hg.edges_.size() << " edges!\n";
- abort();
- }
- hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node);
- }
- }
- void CreateEdge(const TRulePtr& rule, SparseVector<double>* feats, const SmallVectorUnsigned& tail) {
- Hypergraph::Edge* edge = hg.AddEdge(rule, tail);
- feats->swap(edge->feature_values_);
- edge->i_ = spans[0];
- edge->j_ = spans[1];
- edge->prev_i_ = spans[2];
- edge->prev_j_ = spans[3];
- }
-
- bool HandleJSONEvent(int type, const JSON_value* value) {
- switch(state) {
- case -1:
- assert(type == JSON_T_OBJECT_BEGIN);
- state = 0;
- break;
- case 0:
- if (type == JSON_T_OBJECT_END) {
- //cerr << "HG created\n"; // TODO, signal some kind of callback
- } else if (type == JSON_T_KEY) {
- string val = value->vu.str.value;
- if (val == "features") { assert(fdict.empty()); state = 1; }
- else if (val == "is_sorted") { state = 3; }
- else if (val == "rules") { assert(rules.empty()); state = 4; }
- else if (val == "node") { state = 8; }
- else if (val == "edges") { state = 13; }
- else { cerr << "Unexpected key: " << val << endl; return false; }
- }
- break;
-
- // features
- case 1:
- if(type == JSON_T_NULL) { state = 0; break; }
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 2;
- break;
- case 2:
- if(type == JSON_T_ARRAY_END) { state = 0; break; }
- assert(type == JSON_T_STRING);
- fdict.push_back(FD::Convert(value->vu.str.value));
- assert(fdict.back() > 0);
- break;
-
- // is_sorted
- case 3:
- assert(type == JSON_T_TRUE || type == JSON_T_FALSE);
- is_sorted = (type == JSON_T_TRUE);
- if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; }
- state = 0;
- break;
-
- // rules
- case 4:
- if(type == JSON_T_NULL) { state = 0; break; }
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 5;
- break;
- case 5:
- if(type == JSON_T_ARRAY_END) { state = 0; break; }
- assert(type == JSON_T_INTEGER);
- state = 6;
- rule_id = value->vu.integer_value;
- break;
- case 6:
- assert(type == JSON_T_STRING);
- rules[rule_id] = TRulePtr(new TRule(value->vu.str.value));
- state = 5;
- break;
-
- // Nodes
- case 8:
- assert(type == JSON_T_OBJECT_BEGIN);
- ++nodes;
- in_edges.clear();
- cat.clear();
- shash.clear();
- state = 9; break;
- case 9:
- if (type == JSON_T_OBJECT_END) {
- //cerr << "Creating NODE\n";
- CreateNode(cat, shash, in_edges);
- state = 0; break;
- }
- assert(type == JSON_T_KEY);
- cur_key = value->vu.str.value;
- if (cur_key == "cat") { assert(cat.empty()); state = 10; break; }
- if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; }
- if (cur_key == "node_hash") { assert(shash.empty()); state = 24; break; }
- cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n";
- return false;
- case 10:
- assert(type == JSON_T_STRING || type == JSON_T_NULL);
- cat = value->vu.str.value;
- state = 9; break;
- case 11:
- if (type == JSON_T_NULL) { state = 9; break; }
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 12; break;
- case 12:
- if (type == JSON_T_ARRAY_END) { state = 9; break; }
- assert(type == JSON_T_INTEGER);
- //cerr << "in_edges: " << value->vu.integer_value << endl;
- in_edges.push_back(value->vu.integer_value);
- break;
-
- // "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12},
- // { "tail": null, "feats" : [0,0.87,1,0.02], "spans":[1,2,3,4], "rule": 17},
- // { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}]
- case 13:
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 14;
- break;
- case 14:
- if (type == JSON_T_ARRAY_END) { state = 0; break; }
- assert(type == JSON_T_OBJECT_BEGIN);
- //cerr << "New edge\n";
- ++edges;
- cur_rule.reset(); feats.clear(); tail.clear();
- state = 15; break;
- case 15:
- if (type == JSON_T_OBJECT_END) {
- CreateEdge(cur_rule, &feats, tail);
- state = 14; break;
- }
- assert(type == JSON_T_KEY);
- cur_key = value->vu.str.value;
- //cerr << "edge key " << cur_key << endl;
- if (cur_key == "rule") { assert(!cur_rule); state = 16; break; }
- if (cur_key == "spans") { assert(!cur_rule); state = 22; break; }
- if (cur_key == "feats") { assert(feats.empty()); state = 17; break; }
- if (cur_key == "tail") { assert(tail.empty()); state = 20; break; }
- cerr << "Unexpected key " << cur_key << " in edge specification\n";
- return false;
- case 16: // edge.rule
- if (type == JSON_T_INTEGER) {
- int rule_id = value->vu.integer_value;
- if (rules.find(rule_id) == rules.end()) {
- // rules list must come before the edge definitions!
- cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n";
- return false;
- }
- cur_rule = rules[rule_id];
- } else if (type == JSON_T_STRING) {
- cur_rule.reset(new TRule(value->vu.str.value));
- } else {
- cerr << "Rule must be either a rule id or a rule string" << endl;
- return false;
- }
- // cerr << "Edge: rule=" << cur_rule->AsString() << endl;
- state = 15;
- break;
- case 17: // edge.feats
- if (type == JSON_T_NULL) { state = 15; break; }
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 18; break;
- case 18:
- if (type == JSON_T_ARRAY_END) { state = 15; break; }
- if (type != JSON_T_INTEGER && type != JSON_T_STRING) {
- cerr << "Unexpected feature id type\n"; return false;
- }
- if (type == JSON_T_INTEGER) {
- fid = value->vu.integer_value;
- assert(fid < fdict.size());
- fid = fdict[fid];
- } else if (JSON_T_STRING) {
- fid = FD::Convert(value->vu.str.value);
- } else { abort(); }
- state = 19;
- break;
- case 19:
- {
- assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT);
- double val = (type == JSON_T_INTEGER ? static_cast<double>(value->vu.integer_value) :
- strtod(value->vu.str.value, NULL));
- feats.set_value(fid, val);
- state = 18;
- break;
- }
- case 20: // edge.tail
- if (type == JSON_T_NULL) { state = 15; break; }
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 21; break;
- case 21:
- if (type == JSON_T_ARRAY_END) { state = 15; break; }
- assert(type == JSON_T_INTEGER);
- tail.push_back(value->vu.integer_value);
- break;
- case 22: // edge.spans
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 23;
- spans[0] = spans[1] = spans[2] = spans[3] = -1;
- spanc = 0;
- break;
- case 23:
- if (type == JSON_T_ARRAY_END) { state = 15; break; }
- assert(type == JSON_T_INTEGER);
- assert(spanc < 4);
- spans[spanc] = value->vu.integer_value;
- ++spanc;
- break;
- case 24: // read node hash
- assert(type == JSON_T_STRING);
- shash = value->vu.str.value;
- state = 9;
- break;
- }
- return true;
- }
- string rp;
- string cat;
- SmallVectorUnsigned tail;
- vector<int> in_edges;
- string shash;
- TRulePtr cur_rule;
- map<int, TRulePtr> rules;
- vector<int> fdict;
- SparseVector<double> feats;
- int state;
- int fid;
- int nodes;
- int edges;
- int spans[4];
- int spanc;
- string cur_key;
- Hypergraph& hg;
- int rule_id;
- bool nodes_needed;
- bool edges_needed;
- bool is_sorted;
-};
-
-bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) {
+bool HypergraphIO::ReadFromBinary(istream* in, Hypergraph* hg) {
+ boost::archive::binary_iarchive oa(*in);
hg->clear();
- HGReader reader(hg);
- return reader.Parse(in);
-}
-
-static void WriteRule(const TRule& r, ostream* out) {
- if (!r.lhs_) { (*out) << "[X] ||| "; }
- JSONParser::WriteEscapedString(r.AsString(), out);
+ oa >> *hg;
+ return true;
}
-bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) {
- if (hg.empty()) { *out << "{}\n"; return true; }
- map<const TRule*, int> rid;
- ostream& o = *out;
- rid[NULL] = 0;
- o << '{';
- if (!remove_rules) {
- o << "\"rules\":[";
- for (int i = 0; i < hg.edges_.size(); ++i) {
- const TRule* r = hg.edges_[i].rule_.get();
- int &id = rid[r];
- if (!id) {
- id=rid.size() - 1;
- if (id > 1) o << ',';
- o << id << ',';
- WriteRule(*r, &o);
- };
- }
- o << "],";
- }
- const bool use_fdict = FD::NumFeats() < 1000;
- if (use_fdict) {
- o << "\"features\":[";
- for (int i = 1; i < FD::NumFeats(); ++i) {
- o << (i==1 ? "":",");
- JSONParser::WriteEscapedString(FD::Convert(i), &o);
- }
- o << "],";
- }
- vector<int> edgemap(hg.edges_.size(), -1); // edges may be in non-topo order
- int edge_count = 0;
- for (int i = 0; i < hg.nodes_.size(); ++i) {
- const Hypergraph::Node& node = hg.nodes_[i];
- if (i > 0) { o << ","; }
- o << "\"edges\":[";
- for (int j = 0; j < node.in_edges_.size(); ++j) {
- const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]];
- edgemap[edge.id_] = edge_count;
- ++edge_count;
- o << (j == 0 ? "" : ",") << "{";
-
- o << "\"tail\":[";
- for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
- o << (k > 0 ? "," : "") << edge.tail_nodes_[k];
- }
- o << "],";
-
- o << "\"spans\":[" << edge.i_ << "," << edge.j_ << "," << edge.prev_i_ << "," << edge.prev_j_ << "],";
-
- o << "\"feats\":[";
- bool first = true;
- for (SparseVector<double>::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) {
- if (!it->second) continue; // don't write features that have a zero value
- if (!it->first) continue; // if the feature set was frozen this might happen
- if (!first) o << ',';
- if (use_fdict)
- o << (it->first - 1);
- else {
- JSONParser::WriteEscapedString(FD::Convert(it->first), &o);
- }
- o << ',' << it->second;
- first = false;
- }
- o << "]";
- if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; }
- o << "}";
- }
- o << "],";
-
- o << "\"node\":{\"in_edges\":[";
- for (int j = 0; j < node.in_edges_.size(); ++j) {
- int mapped_edge = edgemap[node.in_edges_[j]];
- assert(mapped_edge >= 0);
- o << (j == 0 ? "" : ",") << mapped_edge;
- }
- o << "]";
- if (node.cat_ < 0) {
- o << ",\"cat\":";
- JSONParser::WriteEscapedString(TD::Convert(node.cat_ * -1), &o);
- }
- char buf[48];
- sprintf(buf, "%016lX", node.node_hash);
- o << ",\"node_hash\":\"" << buf << "\"";
- o << "}";
- }
- o << "}\n";
+bool HypergraphIO::WriteToBinary(const Hypergraph& hg, ostream* out) {
+ boost::archive::binary_oarchive oa(*out);
+ oa << hg;
return true;
}
diff --git a/decoder/hg_io.h b/decoder/hg_io.h
index 58af8132..93a9e280 100644
--- a/decoder/hg_io.h
+++ b/decoder/hg_io.h
@@ -1,5 +1,5 @@
-#ifndef _HG_IO_H_
-#define _HG_IO_H_
+#ifndef HG_IO_H_
+#define HG_IO_H_
#include <iostream>
#include <string>
@@ -9,19 +9,11 @@ class Hypergraph;
struct HypergraphIO {
- // the format is basically a list of nodes and edges in topological order
- // any edge you read, you must have already read its tail nodes
- // any node you read, you must have already read its incoming edges
- // this may make writing a bit more challenging if your forest is not
- // topologically sorted (but that probably doesn't happen very often),
- // but it makes reading much more memory efficient.
- // see test_data/small.json.gz for an email encoding
- static bool ReadFromJSON(std::istream* in, Hypergraph* out);
+ static bool ReadFromBinary(std::istream* in, Hypergraph* out);
+ static bool WriteToBinary(const Hypergraph& hg, std::ostream* out);
// if remove_rules is used, the hypergraph is serialized without rule information
// (so it only contains structure and feature information)
- static bool WriteToJSON(const Hypergraph& hg, bool remove_rules, std::ostream* out);
-
static void WriteAsCFG(const Hypergraph& hg);
// Write only the target size information in bottom-up order.
diff --git a/decoder/hg_remove_eps.h b/decoder/hg_remove_eps.h
index 82f06039..f67fe6e2 100644
--- a/decoder/hg_remove_eps.h
+++ b/decoder/hg_remove_eps.h
@@ -1,5 +1,5 @@
-#ifndef _HG_REMOVE_EPS_H_
-#define _HG_REMOVE_EPS_H_
+#ifndef HG_REMOVE_EPS_H_
+#define HG_REMOVE_EPS_H_
#include "wordid.h"
class Hypergraph;
diff --git a/decoder/hg_sampler.h b/decoder/hg_sampler.h
index 6ac39a20..4267b5ec 100644
--- a/decoder/hg_sampler.h
+++ b/decoder/hg_sampler.h
@@ -1,6 +1,5 @@
-#ifndef _HG_SAMPLER_H_
-#define _HG_SAMPLER_H_
-
+#ifndef HG_SAMPLER_H_
+#define HG_SAMPLER_H_
#include <vector>
#include <string>
diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc
index 5cb8626a..366b269d 100644
--- a/decoder/hg_test.cc
+++ b/decoder/hg_test.cc
@@ -1,10 +1,14 @@
#define BOOST_TEST_MODULE hg_test
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/serialization/shared_ptr.hpp>
+#include <boost/serialization/vector.hpp>
+#include <sstream>
#include <iostream>
#include "tdict.h"
-#include "json_parse.h"
#include "hg_intersect.h"
#include "hg_union.h"
#include "viterbi.h"
@@ -394,16 +398,6 @@ BOOST_AUTO_TEST_CASE(Small) {
BOOST_CHECK_CLOSE(2.1431036, log(c2), 1e-4);
}
-BOOST_AUTO_TEST_CASE(JSONTest) {
- std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
- ostringstream os;
- JSONParser::WriteEscapedString("\"I don't know\", she said.", &os);
- BOOST_CHECK_EQUAL("\"\\\"I don't know\\\", she said.\"", os.str());
- ostringstream os2;
- JSONParser::WriteEscapedString("yes", &os2);
- BOOST_CHECK_EQUAL("\"yes\"", os2.str());
-}
-
BOOST_AUTO_TEST_CASE(TestGenericKBest) {
std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
Hypergraph hg;
@@ -427,19 +421,29 @@ BOOST_AUTO_TEST_CASE(TestGenericKBest) {
}
}
-BOOST_AUTO_TEST_CASE(TestReadWriteHG) {
+BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) {
std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
- Hypergraph hg,hg2;
- CreateHG(path, &hg);
- hg.edges_.front().j_ = 23;
- hg.edges_.back().prev_i_ = 99;
- ostringstream os;
- HypergraphIO::WriteToJSON(hg, false, &os);
- istringstream is(os.str());
- HypergraphIO::ReadFromJSON(&is, &hg2);
- BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths());
- BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23);
- BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99);
+ Hypergraph hg;
+ Hypergraph hg2;
+ std::string out;
+ {
+ CreateHG(path, &hg);
+ hg.edges_.front().j_ = 23;
+ hg.edges_.back().prev_i_ = 99;
+ ostringstream os;
+ boost::archive::text_oarchive oa(os);
+ oa << hg;
+ out = os.str();
+ }
+ {
+ cerr << out << endl;
+ istringstream is(out);
+ boost::archive::text_iarchive ia(is);
+ ia >> hg2;
+ BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths());
+ BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23);
+ BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99);
+ }
}
BOOST_AUTO_TEST_SUITE_END()
diff --git a/decoder/hg_test.h b/decoder/hg_test.h
index b7bab3c2..575b9c54 100644
--- a/decoder/hg_test.h
+++ b/decoder/hg_test.h
@@ -12,12 +12,8 @@ namespace {
typedef char const* Name;
-Name urdu_json="urdu.json.gz";
-Name urdu_wts="Arity_0 1.70741473606976 Arity_1 1.12426238048012 Arity_2 1.14986187839554 Glue -0.04589037041388 LanguageModel 1.09051 PassThrough -3.66226367902928 PhraseModel_0 -1.94633451863252 PhraseModel_1 -0.1475347695476 PhraseModel_2 -1.614818994946 WordPenalty -3.0 WordPenaltyFsa -0.56028442964748 ShorterThanPrev -10 LongerThanPrev -10";
-Name small_json="small.json.gz";
+Name small_json="small.bin.gz";
Name small_wts="Model_0 -2 Model_1 -.5 Model_2 -1.1 Model_3 -1 Model_4 -1 Model_5 .5 Model_6 .2 Model_7 -.3";
-Name perro_json="perro.json.gz";
-Name perro_wts="SameFirstLetter 1 LongerThanPrev 1 ShorterThanPrev 1 GlueTop 0.0 Glue -1.0 EgivenF -0.5 FgivenE -0.5 LexEgivenF -0.5 LexFgivenE -0.5 LM 1";
}
@@ -32,7 +28,7 @@ struct HGSetup {
static void JsonFile(Hypergraph *hg,std::string f) {
ReadFile rf(f);
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
}
static void JsonTestFile(Hypergraph *hg,std::string path,std::string n) {
JsonFile(hg,path + "/"+n);
@@ -48,35 +44,35 @@ void AddNullEdge(Hypergraph* hg) {
}
void HGSetup::CreateTinyLatticeHG(const std::string& path,Hypergraph* hg) {
- ReadFile rf(path + "/hg_test.tiny_lattice");
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ ReadFile rf(path + "/hg_test.tiny_lattice.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
AddNullEdge(hg);
}
void HGSetup::CreateLatticeHG(const std::string& path,Hypergraph* hg) {
- ReadFile rf(path + "/hg_test.lattice");
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ ReadFile rf(path + "/hg_test.lattice.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
AddNullEdge(hg);
}
void HGSetup::CreateHG_tiny(const std::string& path, Hypergraph* hg) {
- ReadFile rf(path + "/hg_test.tiny");
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ ReadFile rf(path + "/hg_test.tiny.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
}
void HGSetup::CreateHG_int(const std::string& path,Hypergraph* hg) {
- ReadFile rf(path + "/hg_test.hg_int");
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ ReadFile rf(path + "/hg_test.hg_int.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
}
void HGSetup::CreateHG(const std::string& path,Hypergraph* hg) {
- ReadFile rf(path + "/hg_test.hg");
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ ReadFile rf(path + "/hg_test.hg.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
}
void HGSetup::CreateHGBalanced(const std::string& path,Hypergraph* hg) {
- ReadFile rf(path + "/hg_test.hg_balanced");
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ ReadFile rf(path + "/hg_test.hg_balanced.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
}
#endif
diff --git a/decoder/hg_union.h b/decoder/hg_union.h
index 34624246..bb7e2d09 100644
--- a/decoder/hg_union.h
+++ b/decoder/hg_union.h
@@ -1,5 +1,5 @@
-#ifndef _HG_UNION_H_
-#define _HG_UNION_H_
+#ifndef HG_UNION_H_
+#define HG_UNION_H_
class Hypergraph;
namespace HG {
diff --git a/decoder/incremental.h b/decoder/incremental.h
index f791a626..46b4817b 100644
--- a/decoder/incremental.h
+++ b/decoder/incremental.h
@@ -1,5 +1,5 @@
-#ifndef _INCREMENTAL_H_
-#define _INCREMENTAL_H_
+#ifndef INCREMENTAL_H_
+#define INCREMENTAL_H_
#include "weights.h"
#include <vector>
diff --git a/decoder/inside_outside.h b/decoder/inside_outside.h
index c0377fe8..d5bda63c 100644
--- a/decoder/inside_outside.h
+++ b/decoder/inside_outside.h
@@ -1,5 +1,5 @@
-#ifndef _INSIDE_OUTSIDE_H_
-#define _INSIDE_OUTSIDE_H_
+#ifndef INSIDE_OUTSIDE_H_
+#define INSIDE_OUTSIDE_H_
#include <vector>
#include <algorithm>
diff --git a/decoder/json_parse.cc b/decoder/json_parse.cc
deleted file mode 100644
index f6fdfea8..00000000
--- a/decoder/json_parse.cc
+++ /dev/null
@@ -1,50 +0,0 @@
-#include "json_parse.h"
-
-#include <string>
-#include <iostream>
-
-using namespace std;
-
-static const char *json_hex_chars = "0123456789abcdef";
-
-void JSONParser::WriteEscapedString(const string& in, ostream* out) {
- int pos = 0;
- int start_offset = 0;
- unsigned char c = 0;
- (*out) << '"';
- while(pos < in.size()) {
- c = in[pos];
- switch(c) {
- case '\b':
- case '\n':
- case '\r':
- case '\t':
- case '"':
- case '\\':
- case '/':
- if(pos - start_offset > 0)
- (*out) << in.substr(start_offset, pos - start_offset);
- if(c == '\b') (*out) << "\\b";
- else if(c == '\n') (*out) << "\\n";
- else if(c == '\r') (*out) << "\\r";
- else if(c == '\t') (*out) << "\\t";
- else if(c == '"') (*out) << "\\\"";
- else if(c == '\\') (*out) << "\\\\";
- else if(c == '/') (*out) << "\\/";
- start_offset = ++pos;
- break;
- default:
- if(c < ' ') {
- cerr << "Warning, bad character (" << static_cast<int>(c) << ") in string\n";
- if(pos - start_offset > 0)
- (*out) << in.substr(start_offset, pos - start_offset);
- (*out) << "\\u00" << json_hex_chars[c >> 4] << json_hex_chars[c & 0xf];
- start_offset = ++pos;
- } else pos++;
- }
- }
- if(pos - start_offset > 0)
- (*out) << in.substr(start_offset, pos - start_offset);
- (*out) << '"';
-}
-
diff --git a/decoder/json_parse.h b/decoder/json_parse.h
deleted file mode 100644
index c3cba954..00000000
--- a/decoder/json_parse.h
+++ /dev/null
@@ -1,58 +0,0 @@
-#ifndef _JSON_WRAPPER_H_
-#define _JSON_WRAPPER_H_
-
-#include <iostream>
-#include <cassert>
-#include "JSON_parser.h"
-
-class JSONParser {
- public:
- JSONParser() {
- init_JSON_config(&config);
- hack.mf = &JSONParser::Callback;
- config.depth = 10;
- config.callback_ctx = reinterpret_cast<void*>(this);
- config.callback = hack.cb;
- config.allow_comments = 1;
- config.handle_floats_manually = 1;
- jc = new_JSON_parser(&config);
- }
- virtual ~JSONParser() {
- delete_JSON_parser(jc);
- }
- bool Parse(std::istream* in) {
- int count = 0;
- int lc = 1;
- for (; in ; ++count) {
- int next_char = in->get();
- if (!in->good()) break;
- if (lc == '\n') { ++lc; }
- if (!JSON_parser_char(jc, next_char)) {
- std::cerr << "JSON_parser_char: syntax error, line " << lc << " (byte " << count << ")" << std::endl;
- return false;
- }
- }
- if (!JSON_parser_done(jc)) {
- std::cerr << "JSON_parser_done: syntax error\n";
- return false;
- }
- return true;
- }
- static void WriteEscapedString(const std::string& in, std::ostream* out);
- protected:
- virtual bool HandleJSONEvent(int type, const JSON_value* value) = 0;
- private:
- int Callback(int type, const JSON_value* value) {
- if (HandleJSONEvent(type, value)) return 1;
- return 0;
- }
- JSON_parser_struct* jc;
- JSON_config config;
- typedef int (JSONParser::* MF)(int type, const struct JSON_value_struct* value);
- union CBHack {
- JSON_parser_callback cb;
- MF mf;
- } hack;
-};
-
-#endif
diff --git a/decoder/kbest.h b/decoder/kbest.h
index c7194c7e..d6b3eb94 100644
--- a/decoder/kbest.h
+++ b/decoder/kbest.h
@@ -1,5 +1,5 @@
-#ifndef _HG_KBEST_H_
-#define _HG_KBEST_H_
+#ifndef HG_KBEST_H_
+#define HG_KBEST_H_
#include <vector>
#include <utility>
diff --git a/decoder/lattice.cc b/decoder/lattice.cc
index 89da3cd0..1f97048d 100644
--- a/decoder/lattice.cc
+++ b/decoder/lattice.cc
@@ -50,7 +50,6 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) {
l.resize(ids.size());
for (int i = 0; i < l.size(); ++i)
l[i].push_back(LatticeArc(ids[i], 0.0, 1));
- l.is_sentence_ = true;
}
void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) {
diff --git a/decoder/lattice.h b/decoder/lattice.h
index ad4ca50d..1258d3f5 100644
--- a/decoder/lattice.h
+++ b/decoder/lattice.h
@@ -1,5 +1,5 @@
-#ifndef __LATTICE_H_
-#define __LATTICE_H_
+#ifndef LATTICE_H_
+#define LATTICE_H_
#include <string>
#include <vector>
@@ -25,22 +25,24 @@ class Lattice : public std::vector<std::vector<LatticeArc> > {
friend void LatticeTools::ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl);
friend void LatticeTools::ConvertTextToLattice(const std::string& text, Lattice* pl);
public:
- Lattice() : is_sentence_(false) {}
+ Lattice() {}
explicit Lattice(size_t t, const std::vector<LatticeArc>& v = std::vector<LatticeArc>()) :
- std::vector<std::vector<LatticeArc> >(t, v),
- is_sentence_(false) {}
+ std::vector<std::vector<LatticeArc>>(t, v) {}
int Distance(int from, int to) const {
if (dist_.empty())
return (to - from);
return dist_(from, to);
}
- // TODO this should actually be computed based on the contents
- // of the lattice
- bool IsSentence() const { return is_sentence_; }
private:
void ComputeDistances();
Array2D<int> dist_;
- bool is_sentence_;
};
+inline bool IsSentence(const Lattice& in) {
+ bool res = true;
+ for (auto& alt : in)
+ if (alt.size() > 1) { res = false; break; }
+ return res;
+}
+
#endif
diff --git a/decoder/lexalign.cc b/decoder/lexalign.cc
index 11f20de7..dd529311 100644
--- a/decoder/lexalign.cc
+++ b/decoder/lexalign.cc
@@ -114,10 +114,9 @@ bool LexicalAlign::TranslateImpl(const string& input,
Hypergraph* forest) {
Lattice& lattice = smeta->src_lattice_;
LatticeTools::ConvertTextOrPLF(input, &lattice);
- if (!lattice.IsSentence()) {
- // lexical models make independence assumptions
- // that don't work with lattices or conf nets
- cerr << "LexicalTrans: cannot deal with lattice source input!\n";
+ smeta->ComputeInputLatticeType();
+ if (smeta->GetInputType() != cdec::kSEQUENCE) {
+ cerr << "LexicalTrans: cannot deal with non-sequence input!";
abort();
}
smeta->SetSourceLength(lattice.size());
diff --git a/decoder/lexalign.h b/decoder/lexalign.h
index 7ba4fe64..6415f4f9 100644
--- a/decoder/lexalign.h
+++ b/decoder/lexalign.h
@@ -1,5 +1,5 @@
-#ifndef _LEXALIGN_H_
-#define _LEXALIGN_H_
+#ifndef LEXALIGN_H_
+#define LEXALIGN_H_
#include "translator.h"
#include "lattice.h"
diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc
index 74a18c3f..d13a891a 100644
--- a/decoder/lextrans.cc
+++ b/decoder/lextrans.cc
@@ -271,10 +271,9 @@ bool LexicalTrans::TranslateImpl(const string& input,
Hypergraph* forest) {
Lattice& lattice = smeta->src_lattice_;
LatticeTools::ConvertTextOrPLF(input, &lattice);
- if (!lattice.IsSentence()) {
- // lexical models make independence assumptions
- // that don't work with lattices or conf nets
- cerr << "LexicalTrans: cannot deal with lattice source input!\n";
+ smeta->ComputeInputLatticeType();
+ if (smeta->GetInputType() != cdec::kSEQUENCE) {
+ cerr << "LexicalTrans: cannot deal with non-sequence inputs\n";
abort();
}
smeta->SetSourceLength(lattice.size());
diff --git a/decoder/lextrans.h b/decoder/lextrans.h
index 2d51e7c0..a23a4e0d 100644
--- a/decoder/lextrans.h
+++ b/decoder/lextrans.h
@@ -1,5 +1,5 @@
-#ifndef _LEXTrans_H_
-#define _LEXTrans_H_
+#ifndef LEXTrans_H_
+#define LEXTrans_H_
#include "translator.h"
#include "lattice.h"
diff --git a/decoder/node_state_hash.h b/decoder/node_state_hash.h
index 9fc01a09..f380fcb1 100644
--- a/decoder/node_state_hash.h
+++ b/decoder/node_state_hash.h
@@ -1,5 +1,5 @@
-#ifndef _NODE_STATE_HASH_
-#define _NODE_STATE_HASH_
+#ifndef NODE_STATE_HASH_
+#define NODE_STATE_HASH_
#include <cassert>
#include <cstring>
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index d2c4715c..cd587833 100644
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -21,6 +21,7 @@
#include "kbest.h"
#include "timing_stats.h"
#include "sentences.h"
+#include "b64featvector.h"
//TODO: put function impls into .cc
//TODO: move Translation into its own .h and use in cdec
@@ -252,19 +253,31 @@ struct OracleBleu {
}
bool show_derivation;
+ int show_derivation_mask;
+
template <class Filter>
- void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) {
+ void kbest(int sent_id, Hypergraph const& forest, int k, bool mr_mira_compat,
+ int src_len, std::ostream& kbest_out = std::cout,
+ std::ostream& deriv_out = std::cerr) {
using namespace std;
using namespace boost;
typedef KBest::KBestDerivations<Sentence, ESentenceTraversal,Filter> K;
K kbest(forest,k);
//add length (f side) src length of this sentence to the psuedo-doc src length count
float curr_src_length = doc_src_length + tmp_src_length;
- for (int i = 0; i < k; ++i) {
+ if (mr_mira_compat) kbest_out << k << "\n";
+ int i = 0;
+ for (; i < k; ++i) {
typename K::Derivation *d = kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
- kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- << d->feature_values << " ||| " << log(d->score);
+ kbest_out << sent_id << " ||| ";
+ if (mr_mira_compat) kbest_out << src_len << " ||| ";
+ kbest_out << TD::GetString(d->yield) << " ||| ";
+ if (mr_mira_compat)
+ kbest_out << EncodeFeatureVector(d->feature_values);
+ else
+ kbest_out << d->feature_values;
+ kbest_out << " ||| " << log(d->score);
if (!refs.empty()) {
ScoreP sentscore = GetScore(d->yield,sent_id);
sentscore->PlusEquals(*doc_score,float(1));
@@ -275,14 +288,21 @@ struct OracleBleu {
if (show_derivation) {
deriv_out<<"\nsent_id="<<sent_id<<"."<<i<<" ||| "; //where i is candidate #/k
deriv_out<<log(d->score)<<"\n";
- deriv_out<<kbest.derivation_tree(*d,true);
+ deriv_out<<kbest.derivation_tree(*d,true, show_derivation_mask);
deriv_out<<"\n"<<flush;
}
}
+ if (mr_mira_compat) {
+ for (; i < k; ++i) kbest_out << "\n";
+ kbest_out << flush;
+ }
}
// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
- void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_, std::string const &deriv_out_filename_) {
+ void DumpKBest(const int sent_id, const Hypergraph& forest, const int k,
+ const bool unique, const bool mr_mira_compat,
+ const int src_len, std::string const& kbest_out_filename_,
+ std::string const& deriv_out_filename_) {
WriteFile ko(kbest_out_filename_);
std::cerr << "Output kbest to " << kbest_out_filename_ <<std::endl;
@@ -295,9 +315,11 @@ struct OracleBleu {
WriteFile oderiv(sderiv.str());
if (!unique)
- kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),oderiv.get());
+ kbest<KBest::NoFilter<std::vector<WordID> > >(
+ sent_id, forest, k, mr_mira_compat, src_len, ko.get(), oderiv.get());
else {
- kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),oderiv.get());
+ kbest<KBest::FilterUnique>(sent_id, forest, k, mr_mira_compat, src_len,
+ ko.get(), oderiv.get());
}
}
@@ -305,7 +327,8 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo
{
std::ostringstream kbest_string_stream;
kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id;
- DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), "-");
+ DumpKBest(sent_id, forest, k, unique, false, -1, kbest_string_stream.str(),
+ "-");
}
};
diff --git a/decoder/phrasebased_translator.cc b/decoder/phrasebased_translator.cc
index 8048248e..8415353a 100644
--- a/decoder/phrasebased_translator.cc
+++ b/decoder/phrasebased_translator.cc
@@ -114,6 +114,7 @@ struct PhraseBasedTranslatorImpl {
Lattice lattice;
LatticeTools::ConvertTextOrPLF(input, &lattice);
smeta->SetSourceLength(lattice.size());
+ smeta->ComputeInputLatticeType();
size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion);
minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100);
if (add_pass_through_rules) {
diff --git a/decoder/phrasebased_translator.h b/decoder/phrasebased_translator.h
index e5e3f8a2..10790d0d 100644
--- a/decoder/phrasebased_translator.h
+++ b/decoder/phrasebased_translator.h
@@ -1,5 +1,5 @@
-#ifndef _PHRASEBASED_TRANSLATOR_H_
-#define _PHRASEBASED_TRANSLATOR_H_
+#ifndef PHRASEBASED_TRANSLATOR_H_
+#define PHRASEBASED_TRANSLATOR_H_
#include "translator.h"
diff --git a/decoder/phrasetable_fst.h b/decoder/phrasetable_fst.h
index 477de1f7..966bb14d 100644
--- a/decoder/phrasetable_fst.h
+++ b/decoder/phrasetable_fst.h
@@ -1,5 +1,5 @@
-#ifndef _PHRASETABLE_FST_H_
-#define _PHRASETABLE_FST_H_
+#ifndef PHRASETABLE_FST_H_
+#define PHRASETABLE_FST_H_
#include <vector>
#include <string>
diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc
index 10192f7a..2c5fa9c4 100644
--- a/decoder/rescore_translator.cc
+++ b/decoder/rescore_translator.cc
@@ -3,6 +3,7 @@
#include <sstream>
#include <boost/shared_ptr.hpp>
+#include "filelib.h"
#include "sentence_metadata.h"
#include "hg.h"
#include "hg_io.h"
@@ -20,16 +21,18 @@ struct RescoreTranslatorImpl {
bool Translate(const string& input,
const vector<double>& weights,
Hypergraph* forest) {
- if (input == "{}") return false;
- if (input.find("{\"rules\"") == 0) {
- istringstream is(input);
- Hypergraph src_cfg_hg;
- if (!HypergraphIO::ReadFromJSON(&is, forest)) {
- cerr << "Parse error while reading HG from JSON.\n";
- abort();
- }
- } else {
- cerr << "Can only read HG input from JSON: use training/grammar_convert\n";
+ istringstream is(input);
+ string header, fname;
+ is >> header >> fname;
+ if (header != "::forest::") {
+ cerr << "RescoreTranslator: expected input lines of form ::forest:: filename.gz\n";
+ abort();
+ }
+ ReadFile rf(fname);
+ if (!rf) { cerr << "Can't read " << fname << endl; abort(); }
+ Hypergraph src_cfg_hg;
+ if (!HypergraphIO::ReadFromBinary(rf.stream(), forest)) {
+ cerr << "Parse error while reading HG.\n";
abort();
}
Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
@@ -53,6 +56,7 @@ bool RescoreTranslator::TranslateImpl(const string& input,
const vector<double>& weights,
Hypergraph* minus_lm_forest) {
smeta->SetSourceLength(0); // don't know how to compute this
+ smeta->input_type_ = cdec::kFOREST;
return pimpl_->Translate(input, weights, minus_lm_forest);
}
diff --git a/decoder/rule_lexer.h b/decoder/rule_lexer.h
index e15c056d..5267f9ca 100644
--- a/decoder/rule_lexer.h
+++ b/decoder/rule_lexer.h
@@ -1,5 +1,5 @@
-#ifndef _RULE_LEXER_H_
-#define _RULE_LEXER_H_
+#ifndef RULE_LEXER_H_
+#define RULE_LEXER_H_
#include <iostream>
#include <string>
diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll
index d4a8d86b..8b48ab7b 100644
--- a/decoder/rule_lexer.ll
+++ b/decoder/rule_lexer.ll
@@ -356,6 +356,7 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const
void RuleLexer::ReadRule(const std::string& srule, RuleCallback func, bool mono, void* extra) {
init_default_feature_names();
+ scfglex_fname = srule;
lex_mono_rules = mono;
lex_line = 1;
rule_callback_extra = extra;
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index 83b65c28..538f82ec 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -195,6 +195,7 @@ struct SCFGTranslatorImpl {
Lattice& lattice = smeta->src_lattice_;
LatticeTools::ConvertTextOrPLF(input, &lattice);
smeta->SetSourceLength(lattice.size());
+ smeta->ComputeInputLatticeType();
if (add_pass_through_rules){
if (!SILENT) cerr << "Adding pass through grammar" << endl;
PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_, num_pt_features);
diff --git a/decoder/sentence_metadata.h b/decoder/sentence_metadata.h
index f2a779f4..e13c2ca5 100644
--- a/decoder/sentence_metadata.h
+++ b/decoder/sentence_metadata.h
@@ -1,14 +1,20 @@
-#ifndef _SENTENCE_METADATA_H_
-#define _SENTENCE_METADATA_H_
+#ifndef SENTENCE_METADATA_H_
+#define SENTENCE_METADATA_H_
#include <string>
#include <map>
#include <cassert>
#include "lattice.h"
+#include "tree_fragment.h"
struct DocScorer; // deprecated, will be removed
struct Score; // deprecated, will be removed
+namespace cdec {
+enum InputType { kSEQUENCE, kTREE, kLATTICE, kFOREST, kUNKNOWN };
+class TreeFragment;
+}
+
class SentenceMetadata {
public:
friend class DecoderImpl;
@@ -17,7 +23,17 @@ class SentenceMetadata {
src_len_(-1),
has_reference_(ref.size() > 0),
trg_len_(ref.size()),
- ref_(has_reference_ ? &ref : NULL) {}
+ ref_(has_reference_ ? &ref : NULL),
+ input_type_(cdec::kUNKNOWN) {}
+
+ // helper function for lattice inputs
+ void ComputeInputLatticeType() {
+ input_type_ = cdec::kSEQUENCE;
+ for (auto& alt : src_lattice_) {
+ if (alt.size() > 1) { input_type_ = cdec::kLATTICE; break; }
+ }
+ }
+ cdec::InputType GetInputType() const { return input_type_; }
int GetSentenceId() const { return sent_id_; }
@@ -25,6 +41,8 @@ class SentenceMetadata {
// it has parsed the source
void SetSourceLength(int sl) { src_len_ = sl; }
+ const cdec::TreeFragment& GetSourceTree() const { return src_tree_; }
+
// this should be called if a separate model needs to
// specify how long the target sentence should be
void SetTargetLength(int tl) {
@@ -64,12 +82,15 @@ class SentenceMetadata {
const Score* app_score;
public:
Lattice src_lattice_; // this will only be set if inputs are finite state!
+ cdec::TreeFragment src_tree_; // this will be set only if inputs are trees
private:
// you need to be very careful when depending on these values
// they will only be set during training / alignment contexts
const bool has_reference_;
int trg_len_;
const Lattice* const ref_;
+ public:
+ cdec::InputType input_type_;
};
#endif
diff --git a/decoder/tagger.cc b/decoder/tagger.cc
index 30fb055f..500d2061 100644
--- a/decoder/tagger.cc
+++ b/decoder/tagger.cc
@@ -100,6 +100,8 @@ bool Tagger::TranslateImpl(const string& input,
Lattice& lattice = smeta->src_lattice_;
LatticeTools::ConvertTextToLattice(input, &lattice);
smeta->SetSourceLength(lattice.size());
+ smeta->ComputeInputLatticeType();
+ assert(smeta->GetInputType() == cdec::kSEQUENCE);
vector<WordID> sequence(lattice.size());
for (int i = 0; i < lattice.size(); ++i) {
assert(lattice[i].size() == 1);
diff --git a/decoder/tagger.h b/decoder/tagger.h
index 9ac820d9..51659d5b 100644
--- a/decoder/tagger.h
+++ b/decoder/tagger.h
@@ -1,5 +1,5 @@
-#ifndef _TAGGER_H_
-#define _TAGGER_H_
+#ifndef TAGGER_H_
+#define TAGGER_H_
#include "translator.h"
diff --git a/decoder/test_data/hg_test.hg.bin.gz b/decoder/test_data/hg_test.hg.bin.gz
new file mode 100644
index 00000000..c07fbe8c
--- /dev/null
+++ b/decoder/test_data/hg_test.hg.bin.gz
Binary files differ
diff --git a/decoder/test_data/hg_test.hg_balanced.bin.gz b/decoder/test_data/hg_test.hg_balanced.bin.gz
new file mode 100644
index 00000000..896d3d60
--- /dev/null
+++ b/decoder/test_data/hg_test.hg_balanced.bin.gz
Binary files differ
diff --git a/decoder/test_data/hg_test.hg_int.bin.gz b/decoder/test_data/hg_test.hg_int.bin.gz
new file mode 100644
index 00000000..e0bd6187
--- /dev/null
+++ b/decoder/test_data/hg_test.hg_int.bin.gz
Binary files differ
diff --git a/decoder/test_data/hg_test.lattice.bin.gz b/decoder/test_data/hg_test.lattice.bin.gz
new file mode 100644
index 00000000..8a8c05f4
--- /dev/null
+++ b/decoder/test_data/hg_test.lattice.bin.gz
Binary files differ
diff --git a/decoder/test_data/hg_test.tiny.bin.gz b/decoder/test_data/hg_test.tiny.bin.gz
new file mode 100644
index 00000000..0e68eb40
--- /dev/null
+++ b/decoder/test_data/hg_test.tiny.bin.gz
Binary files differ
diff --git a/decoder/test_data/hg_test.tiny_lattice.bin.gz b/decoder/test_data/hg_test.tiny_lattice.bin.gz
new file mode 100644
index 00000000..97e8dc05
--- /dev/null
+++ b/decoder/test_data/hg_test.tiny_lattice.bin.gz
Binary files differ
diff --git a/decoder/test_data/perro.json.gz b/decoder/test_data/perro.json.gz
deleted file mode 100644
index 41de5758..00000000
--- a/decoder/test_data/perro.json.gz
+++ /dev/null
Binary files differ
diff --git a/decoder/test_data/small.bin.gz b/decoder/test_data/small.bin.gz
new file mode 100644
index 00000000..1c5a1631
--- /dev/null
+++ b/decoder/test_data/small.bin.gz
Binary files differ
diff --git a/decoder/test_data/small.json.gz b/decoder/test_data/small.json.gz
deleted file mode 100644
index f6f37293..00000000
--- a/decoder/test_data/small.json.gz
+++ /dev/null
Binary files differ
diff --git a/decoder/test_data/urdu.json.gz b/decoder/test_data/urdu.json.gz
deleted file mode 100644
index 84535402..00000000
--- a/decoder/test_data/urdu.json.gz
+++ /dev/null
Binary files differ
diff --git a/decoder/translator.h b/decoder/translator.h
index ba218a0b..096cf191 100644
--- a/decoder/translator.h
+++ b/decoder/translator.h
@@ -1,5 +1,5 @@
-#ifndef _TRANSLATOR_H_
-#define _TRANSLATOR_H_
+#ifndef TRANSLATOR_H_
+#define TRANSLATOR_H_
#include <string>
#include <vector>
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index bd3b01d0..cdd83ffc 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -32,12 +32,12 @@ static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bo
++lc;
if (line.size() == 0 || line[0] == '#') continue;
std::vector<StringPiece> fields = TokenizeMultisep(line, " ||| ");
- if (has_multiple_states && fields.size() != 4) {
- cerr << "Expected 4 fields in rule file but line " << lc << " is:\n" << line << endl;
+ if (has_multiple_states && fields.size() < 4) {
+ cerr << "Expected at least 4 fields in rule file but line " << lc << " is:\n" << line << endl;
abort();
}
- if (!has_multiple_states && fields.size() != 3) {
- cerr << "Expected 3 fields in rule file but line " << lc << " is:\n" << line << endl;
+ if (!has_multiple_states && fields.size() < 3) {
+ cerr << "Expected at least 3 fields in rule file but line " << lc << " is:\n" << line << endl;
abort();
}
@@ -73,6 +73,7 @@ static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bo
cerr << "Not implemented...\n"; abort(); // TODO read in states
} else {
os << " ||| " << fields[1] << " ||| " << fields[2];
+ if (fields.size() > 3) os << " ||| " << fields[3];
rule.reset(new TRule(os.str()));
}
cur->rules.push_back(rule);
@@ -266,6 +267,7 @@ struct Tree2StringTranslatorImpl {
for (auto sym : rule_src)
cur = &cur->next[sym];
TRulePtr rule(new TRule(rhse, rhsf, lhs));
+ rule->a_.push_back(AlignmentPoint(0, 0));
rule->ComputeArity();
rule->scores_.set_value(ntfid, 1.0);
rule->scores_.set_value(kFID, 1.0);
@@ -287,6 +289,8 @@ struct Tree2StringTranslatorImpl {
const vector<double>& weights,
Hypergraph* minus_lm_forest) {
cdec::TreeFragment input_tree(input, false);
+ smeta->src_tree_ = input_tree;
+ smeta->input_type_ = cdec::kTREE;
if (add_pass_through_rules) CreatePassThroughRules(input_tree);
Hypergraph hg;
hg.ReserveNodes(input_tree.nodes.size());
diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc
index 42f7793a..5f717c5b 100644
--- a/decoder/tree_fragment.cc
+++ b/decoder/tree_fragment.cc
@@ -64,6 +64,13 @@ int TreeFragment::SetupSpansRec(unsigned cur, int left) {
return right;
}
+vector<int> TreeFragment::Terminals() const {
+ vector<int> terms;
+ for (auto& x : *this)
+ if (IsTerminal(x)) terms.push_back(x);
+ return terms;
+}
+
// cp is the character index in the tree
// np keeps track of the nodes (nonterminals) that have been built
// symp keeps track of the terminal symbols that have been built
diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h
index 6b4842ee..e19b79fb 100644
--- a/decoder/tree_fragment.h
+++ b/decoder/tree_fragment.h
@@ -72,6 +72,8 @@ class TreeFragment {
BreadthFirstIterator bfs_begin(unsigned node_idx) const;
BreadthFirstIterator bfs_end() const;
+ std::vector<int> Terminals() const;
+
private:
// cp is the character index in the tree
// np keeps track of the nodes (nonterminals) that have been built
diff --git a/decoder/trule.h b/decoder/trule.h
index 243b0da9..7af46747 100644
--- a/decoder/trule.h
+++ b/decoder/trule.h
@@ -1,5 +1,5 @@
-#ifndef _RULE_H_
-#define _RULE_H_
+#ifndef TRULE_H_
+#define TRULE_H_
#include <algorithm>
#include <vector>
@@ -11,6 +11,7 @@
#include "sparse_vector.h"
#include "wordid.h"
+#include "tdict.h"
class TRule;
typedef boost::shared_ptr<TRule> TRulePtr;
@@ -26,6 +27,7 @@ struct AlignmentPoint {
short s_;
short t_;
};
+
inline std::ostream& operator<<(std::ostream& os, const AlignmentPoint& p) {
return os << static_cast<int>(p.s_) << '-' << static_cast<int>(p.t_);
}
@@ -163,6 +165,66 @@ class TRule {
// optional, shows internal structure of TSG rules
boost::shared_ptr<cdec::TreeFragment> tree_structure;
+ friend class boost::serialization::access;
+ template<class Archive>
+ void save(Archive & ar, const unsigned int /*version*/) const {
+ ar & TD::Convert(-lhs_);
+ unsigned f_size = f_.size();
+ ar & f_size;
+ assert(f_size <= (sizeof(size_t) * 8));
+ size_t f_nt_mask = 0;
+ for (int i = f_.size() - 1; i >= 0; --i) {
+ f_nt_mask <<= 1;
+ f_nt_mask |= (f_[i] <= 0 ? 1 : 0);
+ }
+ ar & f_nt_mask;
+ for (unsigned i = 0; i < f_.size(); ++i)
+ ar & TD::Convert(f_[i] < 0 ? -f_[i] : f_[i]);
+ unsigned e_size = e_.size();
+ ar & e_size;
+ size_t e_nt_mask = 0;
+ assert(e_size <= (sizeof(size_t) * 8));
+ for (int i = e_.size() - 1; i >= 0; --i) {
+ e_nt_mask <<= 1;
+ e_nt_mask |= (e_[i] <= 0 ? 1 : 0);
+ }
+ ar & e_nt_mask;
+ for (unsigned i = 0; i < e_.size(); ++i)
+ if (e_[i] <= 0) ar & e_[i]; else ar & TD::Convert(e_[i]);
+ ar & arity_;
+ ar & scores_;
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int /*version*/) {
+ std::string lhs; ar & lhs; lhs_ = -TD::Convert(lhs);
+ unsigned f_size; ar & f_size;
+ f_.resize(f_size);
+ size_t f_nt_mask; ar & f_nt_mask;
+ std::string sym;
+ for (unsigned i = 0; i < f_size; ++i) {
+ bool mask = (f_nt_mask & 1);
+ ar & sym;
+ f_[i] = TD::Convert(sym) * (mask ? -1 : 1);
+ f_nt_mask >>= 1;
+ }
+ unsigned e_size; ar & e_size;
+ e_.resize(e_size);
+ size_t e_nt_mask; ar & e_nt_mask;
+ for (unsigned i = 0; i < e_size; ++i) {
+ bool mask = (e_nt_mask & 1);
+ if (mask) {
+ ar & e_[i];
+ } else {
+ ar & sym;
+ e_[i] = TD::Convert(sym);
+ }
+ e_nt_mask >>= 1;
+ }
+ ar & arity_;
+ ar & scores_;
+ }
+
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
private:
TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {}
};
diff --git a/decoder/trule_test.cc b/decoder/trule_test.cc
index 0cb7e2e8..d75c2016 100644
--- a/decoder/trule_test.cc
+++ b/decoder/trule_test.cc
@@ -4,6 +4,10 @@
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
#include <iostream>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/serialization/shared_ptr.hpp>
+#include <sstream>
#include "tdict.h"
using namespace std;
@@ -53,3 +57,35 @@ BOOST_AUTO_TEST_CASE(TestRuleR) {
BOOST_CHECK_EQUAL(t6.e_[3], 0);
}
+BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) {
+ string str;
+ string t7str;
+ {
+ TRule t7;
+ t7.ReadFromString("[X] ||| den [X,1] sah [X,2] . ||| [2] saw the [1] . ||| Feature1=0.12321 Foo=0.23232 Bar=0.121");
+ cerr << t7.AsString() << endl;
+ ostringstream os;
+ TRulePtr tp1(new TRule("[X] ||| a b c ||| x z y ||| A=1 B=2"));
+ TRulePtr tp2 = tp1;
+ boost::archive::text_oarchive oa(os);
+ oa << t7;
+ oa << tp1;
+ oa << tp2;
+ str = os.str();
+ t7str = t7.AsString();
+ }
+ {
+ istringstream is(str);
+ boost::archive::text_iarchive ia(is);
+ TRule t8;
+ ia >> t8;
+ TRulePtr tp3, tp4;
+ ia >> tp3;
+ ia >> tp4;
+ cerr << t8.AsString() << endl;
+ BOOST_CHECK_EQUAL(t7str, t8.AsString());
+ cerr << tp3->AsString() << endl;
+ cerr << tp4->AsString() << endl;
+ }
+}
+
diff --git a/decoder/viterbi.h b/decoder/viterbi.h
index a8a0ea7f..20ee73cc 100644
--- a/decoder/viterbi.h
+++ b/decoder/viterbi.h
@@ -1,5 +1,5 @@
-#ifndef _VITERBI_H_
-#define _VITERBI_H_
+#ifndef VITERBI_H_
+#define VITERBI_H_
#include <vector>
#include "prob.h"