summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--Makefile.am1
-rw-r--r--configure.ac8
-rwxr-xr-xcorpus/conll2cdec.pl42
-rwxr-xr-xcorpus/sample-dev-sets.py74
-rw-r--r--corpus/support/token_list49
-rw-r--r--corpus/support/token_patterns2
-rwxr-xr-xcorpus/support/tokenizer.pl8
-rwxr-xr-xcorpus/utf8-normalize.sh2
-rw-r--r--decoder/JSON_parser.c1012
-rw-r--r--decoder/JSON_parser.h152
-rw-r--r--decoder/Makefile.am17
-rw-r--r--decoder/aligner.h2
-rw-r--r--decoder/apply_models.cc36
-rw-r--r--decoder/bottom_up_parser-rs.cc341
-rw-r--r--decoder/bottom_up_parser-rs.h29
-rw-r--r--decoder/cdec_ff.cc5
-rw-r--r--decoder/decoder.cc53
-rw-r--r--decoder/ff.h26
-rw-r--r--decoder/ff_conll.cc250
-rw-r--r--decoder/ff_conll.h45
-rw-r--r--decoder/ff_const_reorder.cc1118
-rw-r--r--decoder/ff_const_reorder.h43
-rw-r--r--decoder/ff_const_reorder_common.h1348
-rw-r--r--decoder/ff_soft_syn.cc265
-rw-r--r--decoder/ff_soft_syn.h38
-rw-r--r--decoder/ffset.cc15
-rw-r--r--decoder/ffset.h8
-rw-r--r--decoder/forest_writer.cc6
-rw-r--r--decoder/forest_writer.h2
-rw-r--r--decoder/fst_translator.cc10
-rw-r--r--decoder/hg.h52
-rw-r--r--decoder/hg_io.cc357
-rw-r--r--decoder/hg_io.h12
-rw-r--r--decoder/hg_test.cc50
-rw-r--r--decoder/hg_test.h32
-rw-r--r--decoder/json_parse.cc50
-rw-r--r--decoder/json_parse.h58
-rw-r--r--decoder/oracle_bleu.h41
-rw-r--r--decoder/rescore_translator.cc23
-rw-r--r--decoder/rule_lexer.ll1
-rw-r--r--decoder/test_data/hg_test.hg.bin.gzbin0 -> 340 bytes
-rw-r--r--decoder/test_data/hg_test.hg_balanced.bin.gzbin0 -> 324 bytes
-rw-r--r--decoder/test_data/hg_test.hg_int.bin.gzbin0 -> 184 bytes
-rw-r--r--decoder/test_data/hg_test.lattice.bin.gzbin0 -> 334 bytes
-rw-r--r--decoder/test_data/hg_test.tiny.bin.gzbin0 -> 177 bytes
-rw-r--r--decoder/test_data/hg_test.tiny_lattice.bin.gzbin0 -> 203 bytes
-rw-r--r--decoder/test_data/perro.json.gzbin608 -> 0 bytes
-rw-r--r--decoder/test_data/small.bin.gzbin0 -> 2807 bytes
-rw-r--r--decoder/test_data/small.json.gzbin1733 -> 0 bytes
-rw-r--r--decoder/test_data/urdu.json.gzbin253497 -> 0 bytes
-rw-r--r--decoder/tree2string_translator.cc1
-rw-r--r--decoder/trule.h62
-rw-r--r--decoder/trule_test.cc36
-rw-r--r--python/cdec/hypergraph.pxd3
-rw-r--r--python/cdec/sa/compile.py2
-rw-r--r--tests/system_tests/cfg_rescore/input.txt2
-rw-r--r--tests/system_tests/cfg_rescore/input0.hg.bin.gzbin0 -> 403 bytes
-rw-r--r--tests/system_tests/conll/README8
-rw-r--r--tests/system_tests/conll/cdec.ini13
-rw-r--r--tests/system_tests/conll/gold.statistics20
-rw-r--r--tests/system_tests/conll/gold.stdout0
-rw-r--r--tests/system_tests/conll/input.conll13
-rw-r--r--tests/system_tests/conll/input.txt2
-rw-r--r--tests/system_tests/conll/tagset.txt1
-rw-r--r--tests/system_tests/conll/weights64
-rw-r--r--tests/system_tests/ftrans/input.txt2
-rw-r--r--tests/system_tests/ftrans/input0.hg.bin.gzbin0 -> 225 bytes
-rw-r--r--training/Makefile.am4
-rw-r--r--training/const_reorder/Makefile.am8
-rw-r--r--training/const_reorder/argument_reorder_model.cc307
-rw-r--r--training/const_reorder/constituent_reorder_model.cc636
-rw-r--r--training/const_reorder/trainer.cc69
-rw-r--r--training/const_reorder/trainer.h12
-rw-r--r--training/dpmert/lo_test.cc22
-rw-r--r--training/dpmert/mr_dpmert_generate_mapper_input.cc2
-rw-r--r--training/dpmert/mr_dpmert_map.cc4
-rw-r--r--training/dpmert/test_data/0.bin.gzbin0 -> 24904 bytes
-rw-r--r--training/dpmert/test_data/0.json.gzbin13709 -> 0 bytes
-rw-r--r--training/dpmert/test_data/1.bin.gzbin0 -> 339220 bytes
-rw-r--r--training/dpmert/test_data/1.json.gzbin204803 -> 0 bytes
-rw-r--r--training/dpmert/test_data/test-ch-inside.bin.gzbin0 -> 340 bytes
-rw-r--r--training/dpmert/test_data/test-zero-origin.bin.gzbin0 -> 923 bytes
-rw-r--r--training/minrisk/minrisk_optimize.cc2
-rw-r--r--training/pro/mr_pro_map.cc2
-rw-r--r--training/rampion/rampion_cccp.cc2
-rw-r--r--training/utils/grammar_convert.cc9
-rw-r--r--utils/Makefile.am5
-rw-r--r--utils/b64featvector.cc55
-rw-r--r--utils/b64featvector.h12
-rw-r--r--utils/maxent.cpp1127
-rw-r--r--utils/maxent.h477
-rw-r--r--utils/small_vector.h16
-rw-r--r--utils/small_vector_test.cc30
-rw-r--r--utils/sv_test.cc31
95 files changed, 6983 insertions, 1763 deletions
diff --git a/.gitignore b/.gitignore
index 6400b1fc..b8e0da4e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -184,6 +184,8 @@ training/mr_reduce_to_weights
training/optimize_test
training/plftools
training/test_ngram
+training/const_reorder/argument_reorder_model_trainer
+training/const_reorder/const_reorder_model_trainer
utils/atools
utils/bin/
utils/crp_test
diff --git a/Makefile.am b/Makefile.am
index 88327477..a2d2f332 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -21,4 +21,3 @@ EXTRA_DIST = corpus tests python/cdec python/tests python/examples compound-spli
AUTOMAKE_OPTIONS = foreign
ACLOCAL_AMFLAGS = -I m4
AM_CPPFLAGS = -D_GLIBCXX_PARALLEL -march=native -mtune=native -O2 -pipe -fomit-frame-pointer -Wall
-
diff --git a/configure.ac b/configure.ac
index b8e9ef20..36cee5af 100644
--- a/configure.ac
+++ b/configure.ac
@@ -5,9 +5,9 @@ AM_INIT_AUTOMAKE
AC_CONFIG_HEADERS(config.h)
AC_PROG_LIBTOOL
AC_PROG_LEX
-case $LEX in
-:) AC_MSG_ERROR([No lex (Flex, lex, etc.) program found]);;
-esac
+case $LEX in
+:) AC_MSG_ERROR([No lex (Flex, lex, etc.) program found]);;
+esac
OLD_CXXFLAGS=$CXXFLAGS
AC_PROG_CC
AC_PROG_CXX
@@ -236,9 +236,9 @@ AC_CONFIG_FILES([training/minrisk/Makefile])
AC_CONFIG_FILES([training/mira/Makefile])
AC_CONFIG_FILES([training/latent_svm/Makefile])
AC_CONFIG_FILES([training/dtrain/Makefile])
+AC_CONFIG_FILES([training/const_reorder/Makefile])
# external feature function example code
AC_CONFIG_FILES([example_extff/Makefile])
AC_OUTPUT
-
diff --git a/corpus/conll2cdec.pl b/corpus/conll2cdec.pl
new file mode 100755
index 00000000..ee4e07db
--- /dev/null
+++ b/corpus/conll2cdec.pl
@@ -0,0 +1,42 @@
+#!/usr/bin/perl -w
+use strict;
+
+die "Usage: $0 file.conll\n\n Converts a CoNLL formatted labeled sequence into cdec's format.\n\n" unless scalar @ARGV == 1;
+open F, "<$ARGV[0]" or die "Can't read $ARGV[0]: $!\n";
+
+my @xx;
+my @yy;
+my @os;
+my $sec = undef;
+my $i = 0;
+while(<F>) {
+ chomp;
+ if (/^\s*$/) {
+ print "<seg id=\"$i\"";
+ $i++;
+ for (my $j = 0; $j < $sec; $j++) {
+ my @oo = ();
+ for (my $k = 0; $k < scalar @xx; $k++) {
+ my $sym = $os[$k]->[$j];
+ $sym =~ s/"/'/g;
+ push @oo, $sym;
+ }
+ my $zz = $j + 1;
+ print " feat$zz=\"@oo\"";
+ }
+
+ print "> @xx ||| @yy </seg>\n";
+ @xx = ();
+ @yy = ();
+ @os = ();
+ } else {
+ my ($x, @fs) = split /\s+/;
+ my $y = pop @fs;
+ if (!defined $sec) { $sec = scalar @fs; }
+ die unless $sec == scalar @fs;
+ push @xx, $x;
+ push @yy, $y;
+ push @os, \@fs;
+ }
+}
+
diff --git a/corpus/sample-dev-sets.py b/corpus/sample-dev-sets.py
new file mode 100755
index 00000000..3c969bbe
--- /dev/null
+++ b/corpus/sample-dev-sets.py
@@ -0,0 +1,74 @@
+#!/usr/bin/env python
+
+import gzip
+import os
+import sys
+
+HELP = '''Process an input corpus by dividing it into pseudo-documents and uniformly
+sampling train and dev sets (simulate uniform sampling at the document level
+when document boundaries are unknown)
+
+usage: {} in_file out_prefix doc_size docs_per_dev_set dev_sets [-lc]
+recommended: doc_size=20, docs_per_dev_set=100, dev_sets=2 (dev and test)
+'''
+
+def gzopen(f):
+ return gzip.open(f, 'rb') if f.endswith('.gz') else open(f, 'r')
+
+def wc(f):
+ return sum(1 for _ in gzopen(f))
+
+def main(argv):
+
+ if len(argv[1:]) < 5:
+ sys.stderr.write(HELP.format(os.path.basename(argv[0])))
+ sys.exit(2)
+
+ # Args
+ in_file = os.path.abspath(argv[1])
+ out_prefix = os.path.abspath(argv[2])
+ doc_size = int(argv[3])
+ docs_per_dev_set = int(argv[4])
+ dev_sets = int(argv[5])
+ lc = (len(argv[1:]) == 6 and argv[6] == '-lc')
+
+ # Compute sizes
+ corpus_size = wc(in_file)
+ total_docs = corpus_size / doc_size
+ leftover = corpus_size % doc_size
+ train_docs = total_docs - (dev_sets * docs_per_dev_set)
+ train_batch_size = (train_docs / docs_per_dev_set)
+
+ # Report
+ sys.stderr.write('Splitting {} lines ({} documents)\n'.format(corpus_size, total_docs + (1 if leftover else 0)))
+ sys.stderr.write('Train: {} ({})\n'.format((train_docs * doc_size) + leftover, train_docs + (1 if leftover else 0)))
+ sys.stderr.write('Dev: {} x {} ({})\n'.format(dev_sets, docs_per_dev_set * doc_size, docs_per_dev_set))
+
+ inp = gzopen(in_file)
+ train_out = open('{}.train'.format(out_prefix), 'w')
+ dev_out = [open('{}.dev.{}'.format(out_prefix, i + 1), 'w') for i in range(dev_sets)]
+ i = 0
+
+ # For each set of documents
+ for _ in range(docs_per_dev_set):
+ # Write several documents to train
+ for _ in range(train_batch_size):
+ for _ in range(doc_size):
+ i += 1
+ train_out.write('{} ||| {}'.format(i, inp.readline()) if lc else inp.readline())
+ # Write a document to each dev
+ for out in dev_out:
+ for _ in range(doc_size):
+ i += 1
+ out.write('{} ||| {}'.format(i, inp.readline()) if lc else inp.readline())
+ # Write leftover lines to train
+ for line in inp:
+ i += 1
+ train_out.write('{} ||| {}'.format(i, line) if lc else line)
+
+ train_out.close()
+ for out in dev_out:
+ out.close()
+
+if __name__ == '__main__':
+ main(sys.argv)
diff --git a/corpus/support/token_list b/corpus/support/token_list
index d38638cf..00daa82b 100644
--- a/corpus/support/token_list
+++ b/corpus/support/token_list
@@ -1,6 +1,55 @@
##################### hyphenated words added by Fei since 3/7/05
##X-ray
+# Finnish
+eaa.
+ap.
+arv.
+ay.
+eKr.
+em.
+engl.
+esim.
+fil.
+lis.
+fil.
+maist.
+fil.toht.
+harv.
+ilt.
+jatk.
+jKr.
+jms.
+jne.
+joht.
+klo
+ko.
+ks.
+leht.
+lv.
+lyh.
+mm.
+mon.
+nim.
+nro.
+ns.
+nti.
+os.
+oy.
+pj.
+pnä.
+puh.
+pvm.
+rva.
+tms.
+ts.
+vars.
+vrt.
+ym.
+yms.
+yo.
+>>>>>>> 8646b68e5b124f612fd65b51ea40624f65a2f3d6
+
# hindi abbreviation patterns
जन.
फर.
diff --git a/corpus/support/token_patterns b/corpus/support/token_patterns
index de64fb2a..12558cdd 100644
--- a/corpus/support/token_patterns
+++ b/corpus/support/token_patterns
@@ -1,4 +1,6 @@
/^(al|el|ul|e)\-[a-z]+$/
+/\.(fi|fr|es|co\.uk|de)$/
+/:[a-zä]+$/
/^((а|А)(ль|ш)|уль)-\p{Cyrillic}+$/
/^\p{Cyrillic}\.\p{Cyrillic}\.$/
/^(\d|\d\d|\d\d\d)\.$/
diff --git a/corpus/support/tokenizer.pl b/corpus/support/tokenizer.pl
index aa285be4..718d78cc 100755
--- a/corpus/support/tokenizer.pl
+++ b/corpus/support/tokenizer.pl
@@ -415,7 +415,7 @@ sub deep_proc_token {
}
## remove the ending periods that follow number etc.
- if($line =~ /^(.*(\d|\~|\^|\&|\:|\,|\#|\*|\%|\-|\_|\/|\\|\$|\'))(\.+)$/){
+ if($line =~ /^(.*(\d|\~|\^|\&|\:|\,|\#|\*|\%|€|\-|\_|\/|\\|\$|\'))(\.+)$/){
## 12~13. => 12~13 .
my $t1 = $1;
my $t3 = $3;
@@ -600,12 +600,12 @@ sub deep_proc_token {
## deal with "%"
- if(($line =~ /\%/) && ($Split_On_PercentSign > 0)){
+ if(($line =~ /\%|€/) && ($Split_On_PercentSign > 0)){
my $suc = 0;
if($Split_On_PercentSign >= 2){
- $suc += ($line =~ s/(\D)(\%+)/$1 $2/g);
+ $suc += ($line =~ s/(\D)(\%+|€+)/$1 $2/g);
}else{
- $suc += ($line =~ s/(\%+)/ $1 /g);
+ $suc += ($line =~ s/(\%+|€+)/ $1 /g);
}
if($suc){
diff --git a/corpus/utf8-normalize.sh b/corpus/utf8-normalize.sh
index dcf8bc59..7c0db611 100755
--- a/corpus/utf8-normalize.sh
+++ b/corpus/utf8-normalize.sh
@@ -7,7 +7,7 @@
if which uconv > /dev/null
then
- CMD="uconv -f utf8 -t utf8 -x Any-NFKC --callback skip"
+ CMD="uconv -f utf8 -t utf8 -x Any-NFKC --callback skip --remove-signature"
else
echo "Cannot find ICU uconv (http://site.icu-project.org/) ... falling back to iconv. Normalization NOT taking place." 1>&2
CMD="iconv -f utf8 -t utf8 -c"
diff --git a/decoder/JSON_parser.c b/decoder/JSON_parser.c
deleted file mode 100644
index 5e392bc6..00000000
--- a/decoder/JSON_parser.c
+++ /dev/null
@@ -1,1012 +0,0 @@
-/* JSON_parser.c */
-
-/* 2007-08-24 */
-
-/*
-Copyright (c) 2005 JSON.org
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-The Software shall be used for Good, not Evil.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-*/
-
-/*
- Callbacks, comments, Unicode handling by Jean Gressmann (jean@0x42.de), 2007-2009.
-
- For the added features the license above applies also.
-
- Changelog:
- 2009-05-17
- Incorporated benrudiak@googlemail.com fix for UTF16 decoding.
-
- 2009-05-14
- Fixed float parsing bug related to a locale being set that didn't
- use '.' as decimal point character (charles@transmissionbt.com).
-
- 2008-10-14
- Renamed states.IN to states.IT to avoid name clash which IN macro
- defined in windef.h (alexey.pelykh@gmail.com)
-
- 2008-07-19
- Removed some duplicate code & debugging variable (charles@transmissionbt.com)
-
- 2008-05-28
- Made JSON_value structure ansi C compliant. This bug was report by
- trisk@acm.jhu.edu
-
- 2008-05-20
- Fixed bug reported by charles@transmissionbt.com where the switching
- from static to dynamic parse buffer did not copy the static parse
- buffer's content.
-*/
-
-
-
-#include <assert.h>
-#include <ctype.h>
-#include <float.h>
-#include <stddef.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <string.h>
-#include <locale.h>
-
-#include "JSON_parser.h"
-
-#ifdef _MSC_VER
-# if _MSC_VER >= 1400 /* Visual Studio 2005 and up */
-# pragma warning(disable:4996) // unsecure sscanf
-# endif
-#endif
-
-
-#define true 1
-#define false 0
-#define __ -1 /* the universal error code */
-
-/* values chosen so that the object size is approx equal to one page (4K) */
-#ifndef JSON_PARSER_STACK_SIZE
-# define JSON_PARSER_STACK_SIZE 128
-#endif
-
-#ifndef JSON_PARSER_PARSE_BUFFER_SIZE
-# define JSON_PARSER_PARSE_BUFFER_SIZE 3500
-#endif
-
-typedef unsigned short UTF16;
-
-struct JSON_parser_struct {
- JSON_parser_callback callback;
- void* ctx;
- signed char state, before_comment_state, type, escaped, comment, allow_comments, handle_floats_manually;
- UTF16 utf16_high_surrogate;
- long depth;
- long top;
- signed char* stack;
- long stack_capacity;
- char decimal_point;
- char* parse_buffer;
- size_t parse_buffer_capacity;
- size_t parse_buffer_count;
- size_t comment_begin_offset;
- signed char static_stack[JSON_PARSER_STACK_SIZE];
- char static_parse_buffer[JSON_PARSER_PARSE_BUFFER_SIZE];
-};
-
-#define COUNTOF(x) (sizeof(x)/sizeof(x[0]))
-
-/*
- Characters are mapped into these character classes. This allows for
- a significant reduction in the size of the state transition table.
-*/
-
-
-
-enum classes {
- C_SPACE, /* space */
- C_WHITE, /* other whitespace */
- C_LCURB, /* { */
- C_RCURB, /* } */
- C_LSQRB, /* [ */
- C_RSQRB, /* ] */
- C_COLON, /* : */
- C_COMMA, /* , */
- C_QUOTE, /* " */
- C_BACKS, /* \ */
- C_SLASH, /* / */
- C_PLUS, /* + */
- C_MINUS, /* - */
- C_POINT, /* . */
- C_ZERO , /* 0 */
- C_DIGIT, /* 123456789 */
- C_LOW_A, /* a */
- C_LOW_B, /* b */
- C_LOW_C, /* c */
- C_LOW_D, /* d */
- C_LOW_E, /* e */
- C_LOW_F, /* f */
- C_LOW_L, /* l */
- C_LOW_N, /* n */
- C_LOW_R, /* r */
- C_LOW_S, /* s */
- C_LOW_T, /* t */
- C_LOW_U, /* u */
- C_ABCDF, /* ABCDF */
- C_E, /* E */
- C_ETC, /* everything else */
- C_STAR, /* * */
- NR_CLASSES
-};
-
-static int ascii_class[128] = {
-/*
- This array maps the 128 ASCII characters into character classes.
- The remaining Unicode characters should be mapped to C_ETC.
- Non-whitespace control characters are errors.
-*/
- __, __, __, __, __, __, __, __,
- __, C_WHITE, C_WHITE, __, __, C_WHITE, __, __,
- __, __, __, __, __, __, __, __,
- __, __, __, __, __, __, __, __,
-
- C_SPACE, C_ETC, C_QUOTE, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC,
- C_ETC, C_ETC, C_STAR, C_PLUS, C_COMMA, C_MINUS, C_POINT, C_SLASH,
- C_ZERO, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT,
- C_DIGIT, C_DIGIT, C_COLON, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC,
-
- C_ETC, C_ABCDF, C_ABCDF, C_ABCDF, C_ABCDF, C_E, C_ABCDF, C_ETC,
- C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC,
- C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC,
- C_ETC, C_ETC, C_ETC, C_LSQRB, C_BACKS, C_RSQRB, C_ETC, C_ETC,
-
- C_ETC, C_LOW_A, C_LOW_B, C_LOW_C, C_LOW_D, C_LOW_E, C_LOW_F, C_ETC,
- C_ETC, C_ETC, C_ETC, C_ETC, C_LOW_L, C_ETC, C_LOW_N, C_ETC,
- C_ETC, C_ETC, C_LOW_R, C_LOW_S, C_LOW_T, C_LOW_U, C_ETC, C_ETC,
- C_ETC, C_ETC, C_ETC, C_LCURB, C_ETC, C_RCURB, C_ETC, C_ETC
-};
-
-
-/*
- The state codes.
-*/
-enum states {
- GO, /* start */
- OK, /* ok */
- OB, /* object */
- KE, /* key */
- CO, /* colon */
- VA, /* value */
- AR, /* array */
- ST, /* string */
- ES, /* escape */
- U1, /* u1 */
- U2, /* u2 */
- U3, /* u3 */
- U4, /* u4 */
- MI, /* minus */
- ZE, /* zero */
- IT, /* integer */
- FR, /* fraction */
- E1, /* e */
- E2, /* ex */
- E3, /* exp */
- T1, /* tr */
- T2, /* tru */
- T3, /* true */
- F1, /* fa */
- F2, /* fal */
- F3, /* fals */
- F4, /* false */
- N1, /* nu */
- N2, /* nul */
- N3, /* null */
- C1, /* / */
- C2, /* / * */
- C3, /* * */
- FX, /* *.* *eE* */
- D1, /* second UTF-16 character decoding started by \ */
- D2, /* second UTF-16 character proceeded by u */
- NR_STATES
-};
-
-enum actions
-{
- CB = -10, /* comment begin */
- CE = -11, /* comment end */
- FA = -12, /* false */
- TR = -13, /* false */
- NU = -14, /* null */
- DE = -15, /* double detected by exponent e E */
- DF = -16, /* double detected by fraction . */
- SB = -17, /* string begin */
- MX = -18, /* integer detected by minus */
- ZX = -19, /* integer detected by zero */
- IX = -20, /* integer detected by 1-9 */
- EX = -21, /* next char is escaped */
- UC = -22 /* Unicode character read */
-};
-
-
-static int state_transition_table[NR_STATES][NR_CLASSES] = {
-/*
- The state transition table takes the current state and the current symbol,
- and returns either a new state or an action. An action is represented as a
- negative number. A JSON text is accepted if at the end of the text the
- state is OK and if the mode is MODE_DONE.
-
- white 1-9 ABCDF etc
- space | { } [ ] : , " \ / + - . 0 | a b c d e f l n r s t u | E | * */
-/*start GO*/ {GO,GO,-6,__,-5,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*ok OK*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*object OB*/ {OB,OB,__,-9,__,__,__,__,SB,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*key KE*/ {KE,KE,__,__,__,__,__,__,SB,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*colon CO*/ {CO,CO,__,__,__,__,-2,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*value VA*/ {VA,VA,-6,__,-5,__,__,__,SB,__,CB,__,MX,__,ZX,IX,__,__,__,__,__,FA,__,NU,__,__,TR,__,__,__,__,__},
-/*array AR*/ {AR,AR,-6,__,-5,-7,__,__,SB,__,CB,__,MX,__,ZX,IX,__,__,__,__,__,FA,__,NU,__,__,TR,__,__,__,__,__},
-/*string ST*/ {ST,__,ST,ST,ST,ST,ST,ST,-4,EX,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST},
-/*escape ES*/ {__,__,__,__,__,__,__,__,ST,ST,ST,__,__,__,__,__,__,ST,__,__,__,ST,__,ST,ST,__,ST,U1,__,__,__,__},
-/*u1 U1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U2,U2,U2,U2,U2,U2,U2,U2,__,__,__,__,__,__,U2,U2,__,__},
-/*u2 U2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U3,U3,U3,U3,U3,U3,U3,U3,__,__,__,__,__,__,U3,U3,__,__},
-/*u3 U3*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U4,U4,U4,U4,U4,U4,U4,U4,__,__,__,__,__,__,U4,U4,__,__},
-/*u4 U4*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,UC,UC,UC,UC,UC,UC,UC,UC,__,__,__,__,__,__,UC,UC,__,__},
-/*minus MI*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,ZE,IT,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*zero ZE*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,DF,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*int IT*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,DF,IT,IT,__,__,__,__,DE,__,__,__,__,__,__,__,__,DE,__,__},
-/*frac FR*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,__,FR,FR,__,__,__,__,E1,__,__,__,__,__,__,__,__,E1,__,__},
-/*e E1*/ {__,__,__,__,__,__,__,__,__,__,__,E2,E2,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*ex E2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*exp E3*/ {OK,OK,__,-8,__,-7,__,-3,__,__,__,__,__,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*tr T1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,T2,__,__,__,__,__,__,__},
-/*tru T2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,T3,__,__,__,__},
-/*true T3*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__,__,__},
-/*fa F1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F2,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*fal F2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F3,__,__,__,__,__,__,__,__,__},
-/*fals F3*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F4,__,__,__,__,__,__},
-/*false F4*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__,__,__},
-/*nu N1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,N2,__,__,__,__},
-/*nul N2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,N3,__,__,__,__,__,__,__,__,__},
-/*null N3*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__},
-/*/ C1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,C2},
-/*/* C2*/ {C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C3},
-/** C3*/ {C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,CE,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C3},
-/*_. FX*/ {OK,OK,__,-8,__,-7,__,-3,__,__,__,__,__,__,FR,FR,__,__,__,__,E1,__,__,__,__,__,__,__,__,E1,__,__},
-/*\ D1*/ {__,__,__,__,__,__,__,__,__,D2,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__},
-/*\ D2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,U1,__,__,__,__},
-};
-
-
-/*
- These modes can be pushed on the stack.
-*/
-enum modes {
- MODE_ARRAY = 1,
- MODE_DONE = 2,
- MODE_KEY = 3,
- MODE_OBJECT = 4
-};
-
-static int
-push(JSON_parser jc, int mode)
-{
-/*
- Push a mode onto the stack. Return false if there is overflow.
-*/
- jc->top += 1;
- if (jc->depth < 0) {
- if (jc->top >= jc->stack_capacity) {
- size_t bytes_to_allocate;
- jc->stack_capacity *= 2;
- bytes_to_allocate = jc->stack_capacity * sizeof(jc->static_stack[0]);
- if (jc->stack == &jc->static_stack[0]) {
- jc->stack = (signed char*)malloc(bytes_to_allocate);
- memcpy(jc->stack, jc->static_stack, sizeof(jc->static_stack));
- } else {
- jc->stack = (signed char*)realloc(jc->stack, bytes_to_allocate);
- }
- }
- } else {
- if (jc->top >= jc->depth) {
- return false;
- }
- }
-
- jc->stack[jc->top] = mode;
- return true;
-}
-
-
-static int
-pop(JSON_parser jc, int mode)
-{
-/*
- Pop the stack, assuring that the current mode matches the expectation.
- Return false if there is underflow or if the modes mismatch.
-*/
- if (jc->top < 0 || jc->stack[jc->top] != mode) {
- return false;
- }
- jc->top -= 1;
- return true;
-}
-
-
-#define parse_buffer_clear(jc) \
- do {\
- jc->parse_buffer_count = 0;\
- jc->parse_buffer[0] = 0;\
- } while (0)
-
-#define parse_buffer_pop_back_char(jc)\
- do {\
- assert(jc->parse_buffer_count >= 1);\
- --jc->parse_buffer_count;\
- jc->parse_buffer[jc->parse_buffer_count] = 0;\
- } while (0)
-
-void delete_JSON_parser(JSON_parser jc)
-{
- if (jc) {
- if (jc->stack != &jc->static_stack[0]) {
- free((void*)jc->stack);
- }
- if (jc->parse_buffer != &jc->static_parse_buffer[0]) {
- free((void*)jc->parse_buffer);
- }
- free((void*)jc);
- }
-}
-
-
-JSON_parser
-new_JSON_parser(JSON_config* config)
-{
-/*
- new_JSON_parser starts the checking process by constructing a JSON_parser
- object. It takes a depth parameter that restricts the level of maximum
- nesting.
-
- To continue the process, call JSON_parser_char for each character in the
- JSON text, and then call JSON_parser_done to obtain the final result.
- These functions are fully reentrant.
-*/
-
- int depth = 0;
- JSON_config default_config;
-
- JSON_parser jc = (JSON_parser)malloc(sizeof(struct JSON_parser_struct));
-
- memset(jc, 0, sizeof(*jc));
-
-
- /* initialize configuration */
- init_JSON_config(&default_config);
-
- /* set to default configuration if none was provided */
- if (config == NULL) {
- config = &default_config;
- }
-
- depth = config->depth;
-
- /* We need to be able to push at least one object */
- if (depth == 0) {
- depth = 1;
- }
-
- jc->state = GO;
- jc->top = -1;
-
- /* Do we want non-bound stack? */
- if (depth > 0) {
- jc->stack_capacity = depth;
- jc->depth = depth;
- if (depth <= (int)COUNTOF(jc->static_stack)) {
- jc->stack = &jc->static_stack[0];
- } else {
- jc->stack = (signed char*)malloc(jc->stack_capacity * sizeof(jc->static_stack[0]));
- }
- } else {
- jc->stack_capacity = COUNTOF(jc->static_stack);
- jc->depth = -1;
- jc->stack = &jc->static_stack[0];
- }
-
- /* set parser to start */
- push(jc, MODE_DONE);
-
- /* set up the parse buffer */
- jc->parse_buffer = &jc->static_parse_buffer[0];
- jc->parse_buffer_capacity = COUNTOF(jc->static_parse_buffer);
- parse_buffer_clear(jc);
-
- /* set up callback, comment & float handling */
- jc->callback = config->callback;
- jc->ctx = config->callback_ctx;
- jc->allow_comments = config->allow_comments != 0;
- jc->handle_floats_manually = config->handle_floats_manually != 0;
-
- /* set up decimal point */
- jc->decimal_point = *localeconv()->decimal_point;
-
- return jc;
-}
-
-static void grow_parse_buffer(JSON_parser jc)
-{
- size_t bytes_to_allocate;
- jc->parse_buffer_capacity *= 2;
- bytes_to_allocate = jc->parse_buffer_capacity * sizeof(jc->parse_buffer[0]);
- if (jc->parse_buffer == &jc->static_parse_buffer[0]) {
- jc->parse_buffer = (char*)malloc(bytes_to_allocate);
- memcpy(jc->parse_buffer, jc->static_parse_buffer, jc->parse_buffer_count);
- } else {
- jc->parse_buffer = (char*)realloc(jc->parse_buffer, bytes_to_allocate);
- }
-}
-
-#define parse_buffer_push_back_char(jc, c)\
- do {\
- if (jc->parse_buffer_count + 1 >= jc->parse_buffer_capacity) grow_parse_buffer(jc);\
- jc->parse_buffer[jc->parse_buffer_count++] = c;\
- jc->parse_buffer[jc->parse_buffer_count] = 0;\
- } while (0)
-
-#define assert_is_non_container_type(jc) \
- assert( \
- jc->type == JSON_T_NULL || \
- jc->type == JSON_T_FALSE || \
- jc->type == JSON_T_TRUE || \
- jc->type == JSON_T_FLOAT || \
- jc->type == JSON_T_INTEGER || \
- jc->type == JSON_T_STRING)
-
-
-static int parse_parse_buffer(JSON_parser jc)
-{
- if (jc->callback) {
- JSON_value value, *arg = NULL;
-
- if (jc->type != JSON_T_NONE) {
- assert_is_non_container_type(jc);
-
- switch(jc->type) {
- case JSON_T_FLOAT:
- arg = &value;
- if (jc->handle_floats_manually) {
- value.vu.str.value = jc->parse_buffer;
- value.vu.str.length = jc->parse_buffer_count;
- } else {
- /*sscanf(jc->parse_buffer, "%Lf", &value.vu.float_value);*/
-
- /* not checking with end pointer b/c there may be trailing ws */
- value.vu.float_value = strtod(jc->parse_buffer, NULL);
- }
- break;
- case JSON_T_INTEGER:
- arg = &value;
- sscanf(jc->parse_buffer, JSON_PARSER_INTEGER_SSCANF_TOKEN, &value.vu.integer_value);
- break;
- case JSON_T_STRING:
- arg = &value;
- value.vu.str.value = jc->parse_buffer;
- value.vu.str.length = jc->parse_buffer_count;
- break;
- }
-
- if (!(*jc->callback)(jc->ctx, jc->type, arg)) {
- return false;
- }
- }
- }
-
- parse_buffer_clear(jc);
-
- return true;
-}
-
-#define IS_HIGH_SURROGATE(uc) (((uc) & 0xFC00) == 0xD800)
-#define IS_LOW_SURROGATE(uc) (((uc) & 0xFC00) == 0xDC00)
-#define DECODE_SURROGATE_PAIR(hi,lo) ((((hi) & 0x3FF) << 10) + ((lo) & 0x3FF) + 0x10000)
-static unsigned char utf8_lead_bits[4] = { 0x00, 0xC0, 0xE0, 0xF0 };
-
-static int decode_unicode_char(JSON_parser jc)
-{
- int i;
- unsigned uc = 0;
- char* p;
- int trail_bytes;
-
- assert(jc->parse_buffer_count >= 6);
-
- p = &jc->parse_buffer[jc->parse_buffer_count - 4];
-
- for (i = 12; i >= 0; i -= 4, ++p) {
- unsigned x = *p;
-
- if (x >= 'a') {
- x -= ('a' - 10);
- } else if (x >= 'A') {
- x -= ('A' - 10);
- } else {
- x &= ~0x30u;
- }
-
- assert(x < 16);
-
- uc |= x << i;
- }
-
- /* clear UTF-16 char from buffer */
- jc->parse_buffer_count -= 6;
- jc->parse_buffer[jc->parse_buffer_count] = 0;
-
- /* attempt decoding ... */
- if (jc->utf16_high_surrogate) {
- if (IS_LOW_SURROGATE(uc)) {
- uc = DECODE_SURROGATE_PAIR(jc->utf16_high_surrogate, uc);
- trail_bytes = 3;
- jc->utf16_high_surrogate = 0;
- } else {
- /* high surrogate without a following low surrogate */
- return false;
- }
- } else {
- if (uc < 0x80) {
- trail_bytes = 0;
- } else if (uc < 0x800) {
- trail_bytes = 1;
- } else if (IS_HIGH_SURROGATE(uc)) {
- /* save the high surrogate and wait for the low surrogate */
- jc->utf16_high_surrogate = uc;
- return true;
- } else if (IS_LOW_SURROGATE(uc)) {
- /* low surrogate without a preceding high surrogate */
- return false;
- } else {
- trail_bytes = 2;
- }
- }
-
- jc->parse_buffer[jc->parse_buffer_count++] = (char) ((uc >> (trail_bytes * 6)) | utf8_lead_bits[trail_bytes]);
-
- for (i = trail_bytes * 6 - 6; i >= 0; i -= 6) {
- jc->parse_buffer[jc->parse_buffer_count++] = (char) (((uc >> i) & 0x3F) | 0x80);
- }
-
- jc->parse_buffer[jc->parse_buffer_count] = 0;
-
- return true;
-}
-
-static int add_escaped_char_to_parse_buffer(JSON_parser jc, int next_char)
-{
- jc->escaped = 0;
- /* remove the backslash */
- parse_buffer_pop_back_char(jc);
- switch(next_char) {
- case 'b':
- parse_buffer_push_back_char(jc, '\b');
- break;
- case 'f':
- parse_buffer_push_back_char(jc, '\f');
- break;
- case 'n':
- parse_buffer_push_back_char(jc, '\n');
- break;
- case 'r':
- parse_buffer_push_back_char(jc, '\r');
- break;
- case 't':
- parse_buffer_push_back_char(jc, '\t');
- break;
- case '"':
- parse_buffer_push_back_char(jc, '"');
- break;
- case '\\':
- parse_buffer_push_back_char(jc, '\\');
- break;
- case '/':
- parse_buffer_push_back_char(jc, '/');
- break;
- case 'u':
- parse_buffer_push_back_char(jc, '\\');
- parse_buffer_push_back_char(jc, 'u');
- break;
- default:
- return false;
- }
-
- return true;
-}
-
-#define add_char_to_parse_buffer(jc, next_char, next_class) \
- do { \
- if (jc->escaped) { \
- if (!add_escaped_char_to_parse_buffer(jc, next_char)) \
- return false; \
- } else if (!jc->comment) { \
- if ((jc->type != JSON_T_NONE) | !((next_class == C_SPACE) | (next_class == C_WHITE)) /* non-white-space */) { \
- parse_buffer_push_back_char(jc, (char)next_char); \
- } \
- } \
- } while (0)
-
-
-#define assert_type_isnt_string_null_or_bool(jc) \
- assert(jc->type != JSON_T_FALSE); \
- assert(jc->type != JSON_T_TRUE); \
- assert(jc->type != JSON_T_NULL); \
- assert(jc->type != JSON_T_STRING)
-
-
-int
-JSON_parser_char(JSON_parser jc, int next_char)
-{
-/*
- After calling new_JSON_parser, call this function for each character (or
- partial character) in your JSON text. It can accept UTF-8, UTF-16, or
- UTF-32. It returns true if things are looking ok so far. If it rejects the
- text, it returns false.
-*/
- int next_class, next_state;
-
-/*
- Determine the character's class.
-*/
- if (next_char < 0) {
- return false;
- }
- if (next_char >= 128) {
- next_class = C_ETC;
- } else {
- next_class = ascii_class[next_char];
- if (next_class <= __) {
- return false;
- }
- }
-
- add_char_to_parse_buffer(jc, next_char, next_class);
-
-/*
- Get the next state from the state transition table.
-*/
- next_state = state_transition_table[jc->state][next_class];
- if (next_state >= 0) {
-/*
- Change the state.
-*/
- jc->state = next_state;
- } else {
-/*
- Or perform one of the actions.
-*/
- switch (next_state) {
-/* Unicode character */
- case UC:
- if(!decode_unicode_char(jc)) {
- return false;
- }
- /* check if we need to read a second UTF-16 char */
- if (jc->utf16_high_surrogate) {
- jc->state = D1;
- } else {
- jc->state = ST;
- }
- break;
-/* escaped char */
- case EX:
- jc->escaped = 1;
- jc->state = ES;
- break;
-/* integer detected by minus */
- case MX:
- jc->type = JSON_T_INTEGER;
- jc->state = MI;
- break;
-/* integer detected by zero */
- case ZX:
- jc->type = JSON_T_INTEGER;
- jc->state = ZE;
- break;
-/* integer detected by 1-9 */
- case IX:
- jc->type = JSON_T_INTEGER;
- jc->state = IT;
- break;
-
-/* floating point number detected by exponent*/
- case DE:
- assert_type_isnt_string_null_or_bool(jc);
- jc->type = JSON_T_FLOAT;
- jc->state = E1;
- break;
-
-/* floating point number detected by fraction */
- case DF:
- assert_type_isnt_string_null_or_bool(jc);
- if (!jc->handle_floats_manually) {
-/*
- Some versions of strtod (which underlies sscanf) don't support converting
- C-locale formated floating point values.
-*/
- assert(jc->parse_buffer[jc->parse_buffer_count-1] == '.');
- jc->parse_buffer[jc->parse_buffer_count-1] = jc->decimal_point;
- }
- jc->type = JSON_T_FLOAT;
- jc->state = FX;
- break;
-/* string begin " */
- case SB:
- parse_buffer_clear(jc);
- assert(jc->type == JSON_T_NONE);
- jc->type = JSON_T_STRING;
- jc->state = ST;
- break;
-
-/* n */
- case NU:
- assert(jc->type == JSON_T_NONE);
- jc->type = JSON_T_NULL;
- jc->state = N1;
- break;
-/* f */
- case FA:
- assert(jc->type == JSON_T_NONE);
- jc->type = JSON_T_FALSE;
- jc->state = F1;
- break;
-/* t */
- case TR:
- assert(jc->type == JSON_T_NONE);
- jc->type = JSON_T_TRUE;
- jc->state = T1;
- break;
-
-/* closing comment */
- case CE:
- jc->comment = 0;
- assert(jc->parse_buffer_count == 0);
- assert(jc->type == JSON_T_NONE);
- jc->state = jc->before_comment_state;
- break;
-
-/* opening comment */
- case CB:
- if (!jc->allow_comments) {
- return false;
- }
- parse_buffer_pop_back_char(jc);
- if (!parse_parse_buffer(jc)) {
- return false;
- }
- assert(jc->parse_buffer_count == 0);
- assert(jc->type != JSON_T_STRING);
- switch (jc->stack[jc->top]) {
- case MODE_ARRAY:
- case MODE_OBJECT:
- switch(jc->state) {
- case VA:
- case AR:
- jc->before_comment_state = jc->state;
- break;
- default:
- jc->before_comment_state = OK;
- break;
- }
- break;
- default:
- jc->before_comment_state = jc->state;
- break;
- }
- jc->type = JSON_T_NONE;
- jc->state = C1;
- jc->comment = 1;
- break;
-/* empty } */
- case -9:
- parse_buffer_clear(jc);
- if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_END, NULL)) {
- return false;
- }
- if (!pop(jc, MODE_KEY)) {
- return false;
- }
- jc->state = OK;
- break;
-
-/* } */ case -8:
- parse_buffer_pop_back_char(jc);
- if (!parse_parse_buffer(jc)) {
- return false;
- }
- if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_END, NULL)) {
- return false;
- }
- if (!pop(jc, MODE_OBJECT)) {
- return false;
- }
- jc->type = JSON_T_NONE;
- jc->state = OK;
- break;
-
-/* ] */ case -7:
- parse_buffer_pop_back_char(jc);
- if (!parse_parse_buffer(jc)) {
- return false;
- }
- if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_ARRAY_END, NULL)) {
- return false;
- }
- if (!pop(jc, MODE_ARRAY)) {
- return false;
- }
-
- jc->type = JSON_T_NONE;
- jc->state = OK;
- break;
-
-/* { */ case -6:
- parse_buffer_pop_back_char(jc);
- if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_BEGIN, NULL)) {
- return false;
- }
- if (!push(jc, MODE_KEY)) {
- return false;
- }
- assert(jc->type == JSON_T_NONE);
- jc->state = OB;
- break;
-
-/* [ */ case -5:
- parse_buffer_pop_back_char(jc);
- if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_ARRAY_BEGIN, NULL)) {
- return false;
- }
- if (!push(jc, MODE_ARRAY)) {
- return false;
- }
- assert(jc->type == JSON_T_NONE);
- jc->state = AR;
- break;
-
-/* string end " */ case -4:
- parse_buffer_pop_back_char(jc);
- switch (jc->stack[jc->top]) {
- case MODE_KEY:
- assert(jc->type == JSON_T_STRING);
- jc->type = JSON_T_NONE;
- jc->state = CO;
-
- if (jc->callback) {
- JSON_value value;
- value.vu.str.value = jc->parse_buffer;
- value.vu.str.length = jc->parse_buffer_count;
- if (!(*jc->callback)(jc->ctx, JSON_T_KEY, &value)) {
- return false;
- }
- }
- parse_buffer_clear(jc);
- break;
- case MODE_ARRAY:
- case MODE_OBJECT:
- assert(jc->type == JSON_T_STRING);
- if (!parse_parse_buffer(jc)) {
- return false;
- }
- jc->type = JSON_T_NONE;
- jc->state = OK;
- break;
- default:
- return false;
- }
- break;
-
-/* , */ case -3:
- parse_buffer_pop_back_char(jc);
- if (!parse_parse_buffer(jc)) {
- return false;
- }
- switch (jc->stack[jc->top]) {
- case MODE_OBJECT:
-/*
- A comma causes a flip from object mode to key mode.
-*/
- if (!pop(jc, MODE_OBJECT) || !push(jc, MODE_KEY)) {
- return false;
- }
- assert(jc->type != JSON_T_STRING);
- jc->type = JSON_T_NONE;
- jc->state = KE;
- break;
- case MODE_ARRAY:
- assert(jc->type != JSON_T_STRING);
- jc->type = JSON_T_NONE;
- jc->state = VA;
- break;
- default:
- return false;
- }
- break;
-
-/* : */ case -2:
-/*
- A colon causes a flip from key mode to object mode.
-*/
- parse_buffer_pop_back_char(jc);
- if (!pop(jc, MODE_KEY) || !push(jc, MODE_OBJECT)) {
- return false;
- }
- assert(jc->type == JSON_T_NONE);
- jc->state = VA;
- break;
-/*
- Bad action.
-*/
- default:
- return false;
- }
- }
- return true;
-}
-
-
-int
-JSON_parser_done(JSON_parser jc)
-{
- const int result = jc->state == OK && pop(jc, MODE_DONE);
-
- return result;
-}
-
-
-int JSON_parser_is_legal_white_space_string(const char* s)
-{
- int c, char_class;
-
- if (s == NULL) {
- return false;
- }
-
- for (; *s; ++s) {
- c = *s;
-
- if (c < 0 || c >= 128) {
- return false;
- }
-
- char_class = ascii_class[c];
-
- if (char_class != C_SPACE && char_class != C_WHITE) {
- return false;
- }
- }
-
- return true;
-}
-
-
-
-void init_JSON_config(JSON_config* config)
-{
- if (config) {
- memset(config, 0, sizeof(*config));
-
- config->depth = JSON_PARSER_STACK_SIZE - 1;
- }
-}
diff --git a/decoder/JSON_parser.h b/decoder/JSON_parser.h
deleted file mode 100644
index de980072..00000000
--- a/decoder/JSON_parser.h
+++ /dev/null
@@ -1,152 +0,0 @@
-#ifndef JSON_PARSER_H
-#define JSON_PARSER_H
-
-/* JSON_parser.h */
-
-
-#include <stddef.h>
-
-/* Windows DLL stuff */
-#ifdef _WIN32
-# ifdef JSON_PARSER_DLL_EXPORTS
-# define JSON_PARSER_DLL_API __declspec(dllexport)
-# else
-# define JSON_PARSER_DLL_API __declspec(dllimport)
-# endif
-#else
-# define JSON_PARSER_DLL_API
-#endif
-
-/* Determine the integer type use to parse non-floating point numbers */
-#if __STDC_VERSION__ >= 199901L || HAVE_LONG_LONG == 1
-typedef long long JSON_int_t;
-#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%lld"
-#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%lld"
-#else
-typedef long JSON_int_t;
-#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%ld"
-#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%ld"
-#endif
-
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-typedef enum
-{
- JSON_T_NONE = 0,
- JSON_T_ARRAY_BEGIN, // 1
- JSON_T_ARRAY_END, // 2
- JSON_T_OBJECT_BEGIN, // 3
- JSON_T_OBJECT_END, // 4
- JSON_T_INTEGER, // 5
- JSON_T_FLOAT, // 6
- JSON_T_NULL, // 7
- JSON_T_TRUE, // 8
- JSON_T_FALSE, // 9
- JSON_T_STRING, // 10
- JSON_T_KEY, // 11
- JSON_T_MAX // 12
-} JSON_type;
-
-typedef struct JSON_value_struct {
- union {
- JSON_int_t integer_value;
-
- double float_value;
-
- struct {
- const char* value;
- size_t length;
- } str;
- } vu;
-} JSON_value;
-
-typedef struct JSON_parser_struct* JSON_parser;
-
-/*! \brief JSON parser callback
-
- \param ctx The pointer passed to new_JSON_parser.
- \param type An element of JSON_type but not JSON_T_NONE.
- \param value A representation of the parsed value. This parameter is NULL for
- JSON_T_ARRAY_BEGIN, JSON_T_ARRAY_END, JSON_T_OBJECT_BEGIN, JSON_T_OBJECT_END,
- JSON_T_NULL, JSON_T_TRUE, and SON_T_FALSE. String values are always returned
- as zero-terminated C strings.
-
- \return Non-zero if parsing should continue, else zero.
-*/
-typedef int (*JSON_parser_callback)(void* ctx, int type, const struct JSON_value_struct* value);
-
-
-/*! \brief The structure used to configure a JSON parser object
-
- \param depth If negative, the parser can parse arbitrary levels of JSON, otherwise
- the depth is the limit
- \param Pointer to a callback. This parameter may be NULL. In this case the input is merely checked for validity.
- \param Callback context. This parameter may be NULL.
- \param depth. Specifies the levels of nested JSON to allow. Negative numbers yield unlimited nesting.
- \param allowComments. To allow C style comments in JSON, set to non-zero.
- \param handleFloatsManually. To decode floating point numbers manually set this parameter to non-zero.
-
- \return The parser object.
-*/
-typedef struct {
- JSON_parser_callback callback;
- void* callback_ctx;
- int depth;
- int allow_comments;
- int handle_floats_manually;
-} JSON_config;
-
-
-/*! \brief Initializes the JSON parser configuration structure to default values.
-
- The default configuration is
- - 127 levels of nested JSON (depends on JSON_PARSER_STACK_SIZE, see json_parser.c)
- - no parsing, just checking for JSON syntax
- - no comments
-
- \param config. Used to configure the parser.
-*/
-JSON_PARSER_DLL_API void init_JSON_config(JSON_config* config);
-
-/*! \brief Create a JSON parser object
-
- \param config. Used to configure the parser. Set to NULL to use the default configuration.
- See init_JSON_config
-
- \return The parser object.
-*/
-JSON_PARSER_DLL_API extern JSON_parser new_JSON_parser(JSON_config* config);
-
-/*! \brief Destroy a previously created JSON parser object. */
-JSON_PARSER_DLL_API extern void delete_JSON_parser(JSON_parser jc);
-
-/*! \brief Parse a character.
-
- \return Non-zero, if all characters passed to this function are part of are valid JSON.
-*/
-JSON_PARSER_DLL_API extern int JSON_parser_char(JSON_parser jc, int next_char);
-
-/*! \brief Finalize parsing.
-
- Call this method once after all input characters have been consumed.
-
- \return Non-zero, if all parsed characters are valid JSON, zero otherwise.
-*/
-JSON_PARSER_DLL_API extern int JSON_parser_done(JSON_parser jc);
-
-/*! \brief Determine if a given string is valid JSON white space
-
- \return Non-zero if the string is valid, zero otherwise.
-*/
-JSON_PARSER_DLL_API extern int JSON_parser_is_legal_white_space_string(const char* s);
-
-
-#ifdef __cplusplus
-}
-#endif
-
-
-#endif /* JSON_PARSER_H */
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index b23bbad4..e313f1f9 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -36,10 +36,10 @@ noinst_LIBRARIES = libcdec.a
EXTRA_DIST = test_data rule_lexer.ll
libcdec_a_SOURCES = \
- JSON_parser.h \
aligner.h \
apply_models.h \
bottom_up_parser.h \
+ bottom_up_parser-rs.h \
csplit.h \
decoder.h \
earley_composer.h \
@@ -48,12 +48,15 @@ libcdec_a_SOURCES = \
ff_basic.h \
ff_bleu.h \
ff_charset.h \
+ ff_conll.h \
+ ff_const_reorder_common.h \
+ ff_const_reorder.h \
ff_context.h \
ff_csplit.h \
ff_external.h \
ff_factory.h \
ff_klm.h \
- ff_lexical.h \
+ ff_lexical.h \
ff_lm.h \
ff_ngrams.h \
ff_parse_match.h \
@@ -61,6 +64,7 @@ libcdec_a_SOURCES = \
ff_rules.h \
ff_ruleshape.h \
ff_sample_fsa.h \
+ ff_soft_syn.h \
ff_soft_syntax.h \
ff_soft_syntax_mindist.h \
ff_source_path.h \
@@ -83,7 +87,6 @@ libcdec_a_SOURCES = \
hg_union.h \
incremental.h \
inside_outside.h \
- json_parse.h \
kbest.h \
lattice.h \
lexalign.h \
@@ -103,6 +106,7 @@ libcdec_a_SOURCES = \
aligner.cc \
apply_models.cc \
bottom_up_parser.cc \
+ bottom_up_parser-rs.cc \
cdec.cc \
cdec_ff.cc \
csplit.cc \
@@ -113,7 +117,9 @@ libcdec_a_SOURCES = \
ff_basic.cc \
ff_bleu.cc \
ff_charset.cc \
+ ff_conll.cc \
ff_context.cc \
+ ff_const_reorder.cc \
ff_csplit.cc \
ff_external.cc \
ff_factory.cc \
@@ -123,6 +129,7 @@ libcdec_a_SOURCES = \
ff_parse_match.cc \
ff_rules.cc \
ff_ruleshape.cc \
+ ff_soft_syn.cc \
ff_soft_syntax.cc \
ff_soft_syntax_mindist.cc \
ff_source_path.cc \
@@ -144,7 +151,6 @@ libcdec_a_SOURCES = \
hg_sampler.cc \
hg_union.cc \
incremental.cc \
- json_parse.cc \
lattice.cc \
lexalign.cc \
lextrans.cc \
@@ -160,5 +166,4 @@ libcdec_a_SOURCES = \
tagger.cc \
translator.cc \
trule.cc \
- viterbi.cc \
- JSON_parser.c
+ viterbi.cc
diff --git a/decoder/aligner.h b/decoder/aligner.h
index a34795c9..d68ceefc 100644
--- a/decoder/aligner.h
+++ b/decoder/aligner.h
@@ -1,4 +1,4 @@
-#ifndef _ALIGNER_H_
+#ifndef ALIGNER_H
#include <string>
#include <iostream>
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index 9f8bbead..18c83fd4 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -233,7 +233,20 @@ public:
void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) {
Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_);
new_edge->edge_prob_ = item->out_edge_.edge_prob_;
- Candidate*& o_item = (*s2n)[item->state_];
+
+ Candidate** o_item_ptr = nullptr;
+ if (item->state_.size() && models.NeedsStateErasure()) {
+ // When erasure of certain state bytes is needed, we must make a copy of
+ // the state instead of doing the erasure in-place because future
+ // candidates may require the information in the bytes to be erased.
+ FFState state(item->state_);
+ models.EraseIgnoredBytes(&state);
+ o_item_ptr = &(*s2n)[state];
+ } else {
+ o_item_ptr = &(*s2n)[item->state_];
+ }
+ Candidate*& o_item = *o_item_ptr;
+
if (!o_item) o_item = item;
int& node_id = o_item->node_index_;
@@ -254,7 +267,18 @@ public:
// score is the same for all items with a common residual DP
// state
if (item->vit_prob_ > o_item->vit_prob_) {
- assert(o_item->state_ == item->state_); // sanity check!
+ if (item->state_.size() && models.NeedsStateErasure()) {
+ // node_states_ should still point to the unerased state.
+ node_states_[o_item->node_index_] = item->state_;
+ // sanity check!
+ FFState item_state(item->state_), o_item_state(o_item->state_);
+ models.EraseIgnoredBytes(&item_state);
+ models.EraseIgnoredBytes(&o_item_state);
+ assert(item_state == o_item_state);
+ } else {
+ assert(o_item->state_ == item->state_); // sanity check!
+ }
+
o_item->est_prob_ = item->est_prob_;
o_item->vit_prob_ = item->vit_prob_;
}
@@ -599,9 +623,10 @@ void ApplyModelSet(const Hypergraph& in,
if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) {
NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state
ma.Apply();
- } else if (config.algorithm == IntersectionConfiguration::CUBE
- || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING
- || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) {
+ } else if (config.algorithm == IntersectionConfiguration::CUBE ||
+ config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING ||
+ config.algorithm ==
+ IntersectionConfiguration::FAST_CUBE_PRUNING_2) {
int pl = config.pop_limit;
const int max_pl_for_large=50;
if (pl > max_pl_for_large && in.nodes_.size() > 80000) {
@@ -628,4 +653,3 @@ void ApplyModelSet(const Hypergraph& in,
out->is_linear_chain_ = in.is_linear_chain_; // TODO remove when this is computed
// automatically
}
-
diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc
new file mode 100644
index 00000000..fbde7e24
--- /dev/null
+++ b/decoder/bottom_up_parser-rs.cc
@@ -0,0 +1,341 @@
+#include "bottom_up_parser-rs.h"
+
+#include <iostream>
+#include <map>
+
+#include "node_state_hash.h"
+#include "nt_span.h"
+#include "hg.h"
+#include "array2d.h"
+#include "tdict.h"
+#include "verbose.h"
+
+using namespace std;
+
+static WordID kEPS = 0;
+
+struct RSActiveItem;
+class RSChart {
+ public:
+ RSChart(const string& goal,
+ const vector<GrammarPtr>& grammars,
+ const Lattice& input,
+ Hypergraph* forest);
+ ~RSChart();
+
+ void AddToChart(const RSActiveItem& x, int i, int j);
+ void ConsumeTerminal(const RSActiveItem& x, int i, int j, int k);
+ void ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k);
+ bool Parse();
+ inline bool GoalFound() const { return goal_idx_ >= 0; }
+ inline int GetGoalIndex() const { return goal_idx_; }
+
+ private:
+ void ApplyRules(const int i,
+ const int j,
+ const RuleBin* rules,
+ const Hypergraph::TailNodeVector& tail,
+ const float lattice_cost);
+
+ // returns true if a new node was added to the chart
+ // false otherwise
+ bool ApplyRule(const int i,
+ const int j,
+ const TRulePtr& r,
+ const Hypergraph::TailNodeVector& ant_nodes,
+ const float lattice_cost);
+
+ void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx);
+ void TopoSortUnaries();
+
+ const vector<GrammarPtr>& grammars_;
+ const Lattice& input_;
+ Hypergraph* forest_;
+ Array2D<vector<int>> chart_; // chart_(i,j) is the list of nodes (represented
+ // by their index in forest_->nodes_) derived spanning i,j
+ typedef map<int, int> Cat2NodeMap;
+ Array2D<Cat2NodeMap> nodemap_;
+ const WordID goal_cat_; // category that is being searched for at [0,n]
+ TRulePtr goal_rule_;
+ int goal_idx_; // index of goal node, if found
+ const int lc_fid_;
+ vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars
+
+ static WordID kGOAL; // [Goal]
+};
+
+WordID RSChart::kGOAL = 0;
+
+// "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span"
+struct RSActiveItem {
+ explicit RSActiveItem(const GrammarIter* g, int i) :
+ gptr_(g), ant_nodes_(), lattice_cost(0.0), i_(i) {}
+ void ExtendTerminal(int symbol, float src_cost) {
+ lattice_cost += src_cost;
+ if (symbol != kEPS)
+ gptr_ = gptr_->Extend(symbol);
+ }
+ void ExtendNonTerminal(const Hypergraph* hg, int node_index) {
+ gptr_ = gptr_->Extend(hg->nodes_[node_index].cat_);
+ ant_nodes_.push_back(node_index);
+ }
+ // returns false if the extension has failed
+ explicit operator bool() const {
+ return gptr_;
+ }
+ const GrammarIter* gptr_;
+ Hypergraph::TailNodeVector ant_nodes_;
+ float lattice_cost; // TODO: use SparseVector<double> to encode input features
+ short i_;
+};
+
+// some notes on the implementation
+// "X" in Rico's Algorithm 2 roughly looks like it is just a pointer into a grammar
+// trie, but it is actually a full "dotted item" since it needs to contain the information
+// to build the hypergraph (i.e., it must remember the antecedent nodes and where they are,
+// also any information about the path costs).
+
+RSChart::RSChart(const string& goal,
+ const vector<GrammarPtr>& grammars,
+ const Lattice& input,
+ Hypergraph* forest) :
+ grammars_(grammars),
+ input_(input),
+ forest_(forest),
+ chart_(input.size()+1, input.size()+1),
+ nodemap_(input.size()+1, input.size()+1),
+ goal_cat_(TD::Convert(goal) * -1),
+ goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")),
+ goal_idx_(-1),
+ lc_fid_(FD::Convert("LatticeCost")),
+ unaries_() {
+ for (unsigned i = 0; i < grammars_.size(); ++i) {
+ const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules();
+ for (unsigned j = 0; j < u.size(); ++j)
+ unaries_.push_back(u[j]);
+ }
+ TopoSortUnaries();
+ if (!kGOAL) kGOAL = TD::Convert("Goal") * -1;
+ if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl;
+}
+
+static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) {
+ if (mark[node] == 1) {
+ cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n";
+ return false; // cycle detected
+ } else if (mark[node] == 2) {
+ return true; // already been
+ }
+ mark[node] = 1;
+ const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node);
+ if (nit != g.end()) {
+ const vector<TRulePtr>& edges = nit->second;
+ vector<bool> okay(edges.size(), true);
+ for (unsigned i = 0; i < edges.size(); ++i) {
+ okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark);
+ if (!okay[i]) {
+ cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl;
+ }
+ }
+ for (unsigned i = 0; i < edges.size(); ++i) {
+ if (okay[i]) u.push_back(edges[i]);
+ //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl;
+ }
+ }
+ mark[node] = 2;
+ return true;
+}
+
+void RSChart::TopoSortUnaries() {
+ vector<TRulePtr> u(unaries_.size()); u.clear();
+ map<int, vector<TRulePtr> > g;
+ map<int, int> mark;
+ //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl;
+ mark[goal_cat_] = 2;
+ for (unsigned i = 0; i < unaries_.size(); ++i) {
+ //cerr << "Adding: " << unaries_[i]->AsString() << endl;
+ g[unaries_[i]->f()[0]].push_back(unaries_[i]);
+ }
+ //m[unaries_[i]->lhs_].push_back(unaries_[i]);
+ for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) {
+ //cerr << "PROC: " << TD::Convert(-it->first) << endl;
+ if (mark[it->first] > 0) {
+ //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n";
+ } else {
+ TopoSortVisit(it->first, u, g, mark);
+ }
+ }
+ unaries_.clear();
+ for (int i = u.size() - 1; i >= 0; --i)
+ unaries_.push_back(u[i]);
+}
+
+bool RSChart::ApplyRule(const int i,
+ const int j,
+ const TRulePtr& r,
+ const Hypergraph::TailNodeVector& ant_nodes,
+ const float lattice_cost) {
+ Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
+ //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
+ new_edge->prev_i_ = r->prev_i;
+ new_edge->prev_j_ = r->prev_j;
+ new_edge->i_ = i;
+ new_edge->j_ = j;
+ new_edge->feature_values_ = r->GetFeatureValues();
+ if (lattice_cost && lc_fid_)
+ new_edge->feature_values_.set_value(lc_fid_, lattice_cost);
+ Cat2NodeMap& c2n = nodemap_(i,j);
+ const bool is_goal = (r->GetLHS() == kGOAL);
+ const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS());
+ Hypergraph::Node* node = NULL;
+ bool added_node = false;
+ if (ni == c2n.end()) {
+ //cerr << "(" << i << "," << j << ") => " << TD::Convert(-r->GetLHS()) << endl;
+ added_node = true;
+ node = forest_->AddNode(r->GetLHS());
+ c2n[r->GetLHS()] = node->id_;
+ if (is_goal) {
+ assert(goal_idx_ == -1);
+ goal_idx_ = node->id_;
+ } else {
+ chart_(i,j).push_back(node->id_);
+ }
+ } else {
+ node = &forest_->nodes_[ni->second];
+ }
+ forest_->ConnectEdgeToHeadNode(new_edge, node);
+ return added_node;
+}
+
+void RSChart::ApplyRules(const int i,
+ const int j,
+ const RuleBin* rules,
+ const Hypergraph::TailNodeVector& tail,
+ const float lattice_cost) {
+ const int n = rules->GetNumRules();
+ //cerr << i << " " << j << ": NUM RULES: " << n << endl;
+ for (int k = 0; k < n; ++k) {
+ //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl;
+ TRulePtr rule = rules->GetIthRule(k);
+ // apply rule, and if we create a new node, apply any necessary
+ // unary rules
+ if (ApplyRule(i, j, rule, tail, lattice_cost)) {
+ unsigned nodeidx = nodemap_(i,j)[rule->lhs_];
+ ApplyUnaryRules(i, j, rule->lhs_, nodeidx);
+ }
+ }
+}
+
+void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx) {
+ for (unsigned ri = 0; ri < unaries_.size(); ++ri) {
+ //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl;
+ if (unaries_[ri]->f()[0] == cat) {
+ //cerr << " --MATCH\n";
+ WordID new_lhs = unaries_[ri]->GetLHS();
+ const Hypergraph::TailNodeVector ant(1, nodeidx);
+ if (ApplyRule(i, j, unaries_[ri], ant, 0)) {
+ //cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl;
+ unsigned nodeidx = nodemap_(i,j)[new_lhs];
+ ApplyUnaryRules(i, j, new_lhs, nodeidx);
+ }
+ }
+ }
+}
+
+void RSChart::AddToChart(const RSActiveItem& x, int i, int j) {
+ // deal with completed rules
+ const RuleBin* rb = x.gptr_->GetRules();
+ if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_cost);
+
+ //cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n";
+ // continue looking for extensions of the rule to the right
+ for (unsigned k = j+1; k <= input_.size(); ++k) {
+ ConsumeTerminal(x, i, j, k);
+ ConsumeNonTerminal(x, i, j, k);
+ }
+}
+
+void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) {
+ //cerr << "ConsumeT(" << i << "," << j << "," << k << "):\n";
+
+ const unsigned check_edge_len = k - j;
+ // long-term TODO preindex this search so i->len->words is constant time rather than fan out
+ for (auto& in_edge : input_[j]) {
+ if (in_edge.dist2next == check_edge_len) {
+ //cerr << " Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl;
+ RSActiveItem copy = x;
+ copy.ExtendTerminal(in_edge.label, in_edge.cost);
+ if (copy) AddToChart(copy, i, k);
+ }
+ }
+}
+
+void RSChart::ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k) {
+ //cerr << "ConsumeNT(" << i << "," << j << "," << k << "):\n";
+ for (auto& nodeidx : chart_(j,k)) {
+ //cerr << " Found completed NT in (" << j << "," << k << ") of type " << TD::Convert(-forest_->nodes_[nodeidx].cat_) << endl;
+ RSActiveItem copy = x;
+ copy.ExtendNonTerminal(forest_, nodeidx);
+ if (copy) AddToChart(copy, i, k);
+ }
+}
+
+bool RSChart::Parse() {
+ size_t in_size_2 = input_.size() * input_.size();
+ forest_->nodes_.reserve(in_size_2 * 2);
+ size_t res = min(static_cast<size_t>(2000000), static_cast<size_t>(in_size_2 * 1000));
+ forest_->edges_.reserve(res);
+ goal_idx_ = -1;
+ const int N = input_.size();
+ for (int i = N - 1; i >= 0; --i) {
+ for (int j = i + 1; j <= N; ++j) {
+ for (unsigned gi = 0; gi < grammars_.size(); ++gi) {
+ RSActiveItem item(grammars_[gi]->GetRoot(), i);
+ ConsumeTerminal(item, i, i, j);
+ }
+ for (unsigned gi = 0; gi < grammars_.size(); ++gi) {
+ RSActiveItem item(grammars_[gi]->GetRoot(), i);
+ ConsumeNonTerminal(item, i, i, j);
+ }
+ }
+ }
+
+ // look for goal
+ const vector<int>& dh = chart_(0, input_.size());
+ for (unsigned di = 0; di < dh.size(); ++di) {
+ const Hypergraph::Node& node = forest_->nodes_[dh[di]];
+ if (node.cat_ == goal_cat_) {
+ Hypergraph::TailNodeVector ant(1, node.id_);
+ ApplyRule(0, input_.size(), goal_rule_, ant, 0);
+ }
+ }
+ if (!SILENT) cerr << endl;
+
+ if (GoalFound())
+ forest_->PruneUnreachable(forest_->nodes_.size() - 1);
+ return GoalFound();
+}
+
+RSChart::~RSChart() {}
+
+RSExhaustiveBottomUpParser::RSExhaustiveBottomUpParser(
+ const string& goal_sym,
+ const vector<GrammarPtr>& grammars) :
+ goal_sym_(goal_sym),
+ grammars_(grammars) {}
+
+bool RSExhaustiveBottomUpParser::Parse(const Lattice& input,
+ Hypergraph* forest) const {
+ kEPS = TD::Convert("*EPS*");
+ RSChart chart(goal_sym_, grammars_, input, forest);
+ const bool result = chart.Parse();
+
+ if (result) {
+ for (auto& node : forest->nodes_) {
+ Span prev;
+ const Span s = forest->NodeSpan(node.id_, &prev);
+ node.node_hash = cdec::HashNode(node.cat_, s.l, s.r, prev.l, prev.r);
+ }
+ }
+ return result;
+}
diff --git a/decoder/bottom_up_parser-rs.h b/decoder/bottom_up_parser-rs.h
new file mode 100644
index 00000000..2e271e99
--- /dev/null
+++ b/decoder/bottom_up_parser-rs.h
@@ -0,0 +1,29 @@
+#ifndef RSBOTTOM_UP_PARSER_H_
+#define RSBOTTOM_UP_PARSER_H_
+
+#include <vector>
+#include <string>
+
+#include "lattice.h"
+#include "grammar.h"
+
+class Hypergraph;
+
+// implementation of Sennrich (2014) parser
+// http://aclweb.org/anthology/W/W14/W14-4011.pdf
+class RSExhaustiveBottomUpParser {
+ public:
+ RSExhaustiveBottomUpParser(const std::string& goal_sym,
+ const std::vector<GrammarPtr>& grammars);
+
+ // returns true if goal reached spanning the full input
+ // forest contains the full (i.e., unpruned) parse forest
+ bool Parse(const Lattice& input,
+ Hypergraph* forest) const;
+
+ private:
+ const std::string goal_sym_;
+ const std::vector<GrammarPtr> grammars_;
+};
+
+#endif
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 7f7e075b..973a643a 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -3,6 +3,7 @@
#include "ff.h"
#include "ff_basic.h"
#include "ff_context.h"
+#include "ff_const_reorder.h"
#include "ff_spans.h"
#include "ff_lm.h"
#include "ff_klm.h"
@@ -14,6 +15,7 @@
#include "ff_rules.h"
#include "ff_ruleshape.h"
#include "ff_bleu.h"
+#include "ff_soft_syn.h"
#include "ff_soft_syntax.h"
#include "ff_soft_syntax_mindist.h"
#include "ff_source_path.h"
@@ -77,6 +79,7 @@ void register_feature_functions() {
ff_registry.Register("WordPairFeatures", new FFFactory<WordPairFeatures>);
ff_registry.Register("SourcePathFeatures", new FFFactory<SourcePathFeatures>);
ff_registry.Register("WordSet", new FFFactory<WordSet>);
+ ff_registry.Register("ConstReorderFeature", new FFFactory<ConstReorderFeature>);
ff_registry.Register("External", new FFFactory<ExternalFeature>);
+ ff_registry.Register("SoftSynFeature", new SoftSynFeatureFactory());
}
-
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 9e8d692a..1e6c3194 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -17,6 +17,7 @@ namespace std { using std::tr1::unordered_map; }
#include "fdict.h"
#include "timing_stats.h"
#include "verbose.h"
+#include "b64featvector.h"
#include "translator.h"
#include "phrasebased_translator.h"
@@ -195,7 +196,7 @@ struct DecoderImpl {
}
forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1);
if (!forestname.empty()) forestname=" "+forestname;
- if (!SILENT) {
+ if (!SILENT) {
forest_stats(forest," Pruned "+forestname+" forest",false,false);
cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl;
}
@@ -261,7 +262,7 @@ struct DecoderImpl {
assert(ref);
LatticeTools::ConvertTextOrPLF(sref, ref);
}
- }
+ }
// used to construct the suffix string to get the name of arguments for multiple passes
// e.g., the "2" in --weights2
@@ -284,7 +285,7 @@ struct DecoderImpl {
boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;
int sample_max_trans;
bool aligner_mode;
- bool graphviz;
+ bool graphviz;
bool joshua_viz;
bool encode_b64;
bool kbest;
@@ -301,6 +302,7 @@ struct DecoderImpl {
bool feature_expectations; // TODO Observer
bool output_training_vector; // TODO Observer
bool remove_intersected_rule_annotations;
+ bool mr_mira_compat; // Mr.MIRA compatibility mode.
boost::scoped_ptr<IncrementalBase> incremental;
@@ -404,6 +406,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
("extract_rules", po::value<string>(), "Extract the rules used in translation (not de-duped!) to a file in this directory")
("show_derivations", po::value<string>(), "Directory to print the derivation structures to")
+ ("show_derivations_mask", po::value<int>()->default_value(Hypergraph::SPAN|Hypergraph::RULE), "Bit-mask for what to print in derivation structures")
("graphviz","Show (constrained) translation forest in GraphViz format")
("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
@@ -414,7 +417,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
("forest_output,O",po::value<string>(),"Directory to write forests to")
- ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)");
+ ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)")
+ ("mr_mira_compat", "Mr.MIRA compatibility mode (applies weight delta if available; outputs number of lines before k-best)");
// ob.AddOptions(&opts);
po::options_description clo("Command line options");
@@ -665,7 +669,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
unique_kbest = conf.count("unique_k_best");
get_oracle_forest = conf.count("get_oracle_forest");
oracle.show_derivation=conf.count("show_derivations");
+ oracle.show_derivation_mask=conf["show_derivations_mask"].as<int>();
remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");
+ mr_mira_compat = conf.count("mr_mira_compat");
combine_size = conf["combine_size"].as<int>();
if (combine_size < 1) combine_size = 1;
@@ -699,6 +705,24 @@ void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string
static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string);
}
+static inline void ApplyWeightDelta(const string &delta_b64, vector<weight_t> *weights) {
+ SparseVector<weight_t> delta;
+ DecodeFeatureVector(delta_b64, &delta);
+ if (delta.empty()) return;
+ // Apply updates
+ for (SparseVector<weight_t>::iterator dit = delta.begin();
+ dit != delta.end(); ++dit) {
+ int feat_id = dit->first;
+ union { weight_t weight; unsigned long long repr; } feat_delta;
+ feat_delta.weight = dit->second;
+ if (!SILENT)
+ cerr << "[decoder weight update] " << FD::Convert(feat_id) << " " << feat_delta.weight
+ << " = " << hex << feat_delta.repr << endl;
+ if (weights->size() <= feat_id) weights->resize(feat_id + 1);
+ (*weights)[feat_id] += feat_delta.weight;
+ }
+}
+
bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
string buf = input;
NgramCache::Clear(); // clear ngram cache for remote LM (if used)
@@ -709,6 +733,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (sgml.find("id") != sgml.end())
sent_id = atoi(sgml["id"].c_str());
+ // Add delta from input to weights before decoding
+ if (mr_mira_compat)
+ ApplyWeightDelta(sgml["delta"], init_weights.get());
+
if (!SILENT) {
cerr << "\nINPUT: ";
if (buf.size() < 100)
@@ -928,14 +956,14 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Hypergraph new_hg;
{
ReadFile rf(writer.fname_);
- bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
+ bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg);
if (!succeeded) abort();
}
HG::Union(forest, &new_hg);
- bool succeeded = writer.Write(new_hg, false);
+ bool succeeded = writer.Write(new_hg);
if (!succeeded) abort();
} else {
- bool succeeded = writer.Write(forest, false);
+ bool succeeded = writer.Write(forest);
if (!succeeded) abort();
}
}
@@ -947,7 +975,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (kbest && !has_ref) {
//TODO: does this work properly?
const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname);
} else if (csplit_output_plf) {
cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {
@@ -1021,14 +1049,14 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Hypergraph new_hg;
{
ReadFile rf(writer.fname_);
- bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
+ bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg);
if (!succeeded) abort();
}
HG::Union(forest, &new_hg);
- bool succeeded = writer.Write(new_hg, false);
+ bool succeeded = writer.Write(new_hg);
if (!succeeded) abort();
} else {
- bool succeeded = writer.Write(forest, false);
+ bool succeeded = writer.Write(forest);
if (!succeeded) abort();
}
}
@@ -1078,7 +1106,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (conf.count("graphviz")) forest.PrintGraphviz();
if (kbest) {
const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest, mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname);
}
if (conf.count("show_conditional_prob")) {
const prob_t ref_z = Inside<prob_t, EdgeProb>(forest);
@@ -1098,4 +1126,3 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
o->NotifyDecodingComplete(smeta);
return true;
}
-
diff --git a/decoder/ff.h b/decoder/ff.h
index eed1e3fb..d6487d97 100644
--- a/decoder/ff.h
+++ b/decoder/ff.h
@@ -17,11 +17,23 @@ class FeatureFunction {
friend class ExternalFeature;
public:
std::string name_; // set by FF factory using usage()
- FeatureFunction() : state_size_() {}
- explicit FeatureFunction(int state_size) : state_size_(state_size) {}
+ FeatureFunction() : state_size_(), ignored_state_size_() {}
+ explicit FeatureFunction(int state_size, int ignored_state_size = 0)
+ : state_size_(state_size), ignored_state_size_(ignored_state_size) {}
virtual ~FeatureFunction();
bool IsStateful() const { return state_size_ > 0; }
int StateSize() const { return state_size_; }
+ // Returns the number of bytes in the state that should be ignored during
+ // search. When non-zero, the last N bytes in the state should be ignored when
+ // splitting a hypernode by the state. This allows the feature function to
+ // store some side data and later retrieve it via the state bytes.
+ //
+ // In general, this should not be necessary and it should always be possible
+ // to replace this with a more appropriate design of state (if you find
+ // yourself having to ignore some part of the state, you are most likely
+ // storing redundant information in the state). Be sure that you
+ // understand how this affects ApplyModelSet() before using it.
+ int IgnoredStateSize() const { return ignored_state_size_; }
// override this. not virtual because we want to expose this to factory template for help before creating a FF
static std::string usage(bool show_params,bool show_details) {
@@ -71,12 +83,18 @@ class FeatureFunction {
SparseVector<double>* estimated_features,
void* context) const;
- // !!! ONLY call this from subclass *CONSTRUCTORS* !!!
+ // !!! ONLY call these from subclass *CONSTRUCTORS* !!!
void SetStateSize(size_t state_size) {
state_size_ = state_size;
}
+
+ // See document of IgnoredStateSize() above.
+ void SetIgnoredStateSize(size_t ignored_state_size) {
+ ignored_state_size_ = ignored_state_size;
+ }
+
private:
- int state_size_;
+ int state_size_, ignored_state_size_;
};
#endif
diff --git a/decoder/ff_conll.cc b/decoder/ff_conll.cc
new file mode 100644
index 00000000..8ded44b7
--- /dev/null
+++ b/decoder/ff_conll.cc
@@ -0,0 +1,250 @@
+#include "ff_conll.h"
+
+#include <stdlib.h>
+#include <sstream>
+#include <cassert>
+#include <cmath>
+#include <boost/lexical_cast.hpp>
+
+#include "hg.h"
+#include "filelib.h"
+#include "stringlib.h"
+#include "sentence_metadata.h"
+#include "lattice.h"
+#include "fdict.h"
+#include "verbose.h"
+#include "tdict.h"
+
+CoNLLFeatures::CoNLLFeatures(const string& param) {
+ // cerr << "initializing CoNLLFeatures with parameters: " << param;
+ kSOS = TD::Convert("<s>");
+ kEOS = TD::Convert("</s>");
+ macro_regex = sregex::compile("%([xy])\\[(-[1-9][0-9]*|0|[1-9][1-9]*)]");
+ ParseArgs(param);
+}
+
+string CoNLLFeatures::Escape(const string& x) const {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
+ }
+ return y;
+}
+
+// replace %x[relative_location] or %y[relative_location] with actual_token
+// within feature_instance
+void CoNLLFeatures::ReplaceMacroWithString(
+ string& feature_instance, bool token_vs_label, int relative_location,
+ const string& actual_token) const {
+
+ stringstream macro;
+ if (token_vs_label) {
+ macro << "%x[";
+ } else {
+ macro << "%y[";
+ }
+ macro << relative_location << "]";
+ int macro_index = feature_instance.find(macro.str());
+ if (macro_index == string::npos) {
+ cerr << "Can't find macro " << macro.str() << " in feature template "
+ << feature_instance;
+ abort();
+ }
+ feature_instance.replace(macro_index, macro.str().size(), actual_token);
+}
+
+void CoNLLFeatures::ReplaceTokenMacroWithString(
+ string& feature_instance, int relative_location,
+ const string& actual_token) const {
+
+ ReplaceMacroWithString(feature_instance, true, relative_location,
+ actual_token);
+}
+
+void CoNLLFeatures::ReplaceLabelMacroWithString(
+ string& feature_instance, int relative_location,
+ const string& actual_token) const {
+
+ ReplaceMacroWithString(feature_instance, false, relative_location,
+ actual_token);
+}
+
+void CoNLLFeatures::Error(const string& error_message) const {
+ cerr << "Error: " << error_message << "\n\n"
+
+ << "CoNLLFeatures Usage: \n"
+ << " feature_function=CoNLLFeatures -t <TEMPLATE>\n\n"
+
+ << "Example <TEMPLATE>: U1:%x[-1]_%x[0]|%y[0]\n\n"
+
+ << "%x[k] and %y[k] are macros to be instantiated with an input\n"
+ << "token (for x) or a label (for y). k specifies the relative\n"
+ << "location of the input token or label with respect to the current\n"
+ << "position. For x, k is an integer value. For y, k must be 0 (to\n"
+ << "be extended).\n\n";
+
+ abort();
+}
+
+void CoNLLFeatures::ParseArgs(const string& in) {
+ which_feat = 0;
+ vector<string> const& argv = SplitOnWhitespace(in);
+ for (vector<string>::const_iterator i = argv.begin(); i != argv.end(); ++i) {
+ string const& s = *i;
+ if (s[0] == '-') {
+ if (s.size() > 2) {
+ stringstream msg;
+ msg << s << " is an invalid option for CoNLLFeatures.";
+ Error(msg.str());
+ }
+
+ switch (s[1]) {
+
+ case 'w': {
+ if (++i == argv.end()) {
+ Error("Missing parameter to -w");
+ }
+ which_feat = boost::lexical_cast<unsigned>(*i);
+ break;
+ }
+ // feature template
+ case 't': {
+ if (++i == argv.end()) {
+ Error("Can't find template.");
+ }
+ feature_template = *i;
+ string::const_iterator start = feature_template.begin();
+ string::const_iterator end = feature_template.end();
+ smatch macro_match;
+
+ // parse the template
+ while (regex_search(start, end, macro_match, macro_regex)) {
+ // get the relative location
+ string relative_location_str(macro_match[2].first,
+ macro_match[2].second);
+ int relative_location = atoi(relative_location_str.c_str());
+ // add it to the list of relative locations for token or label
+ // (i.e. x or y)
+ bool valid_location = true;
+ if (*macro_match[1].first == 'x') {
+ // add it to token locations
+ token_relative_locations.push_back(relative_location);
+ } else {
+ if (relative_location != 0) { valid_location = false; }
+ // add it to label locations
+ label_relative_locations.push_back(relative_location);
+ }
+ if (!valid_location) {
+ stringstream msg;
+ msg << "Relative location " << relative_location
+ << " in feature template " << feature_template
+ << " is invalid.";
+ Error(msg.str());
+ }
+ start = macro_match[0].second;
+ }
+ break;
+ }
+
+ // TODO: arguments to specify kSOS and kEOS
+
+ default: {
+ stringstream msg;
+ msg << "Invalid option on CoNLLFeatures: " << s;
+ Error(msg.str());
+ break;
+ }
+ } // end of switch
+ } // end of if (token starts with hyphen)
+ } // end of for loop (over arguments)
+
+ // the -t (i.e. template) option is mandatory in this feature function
+ if (label_relative_locations.size() == 0 ||
+ token_relative_locations.size() == 0) {
+ stringstream msg;
+ msg << "Feature template must specify at least one"
+ << "token macro (e.g. x[-1]) and one label macro (e.g. y[0]).";
+ Error(msg.str());
+ }
+}
+
+void CoNLLFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+ const Lattice& sl = smeta.GetSourceLattice();
+ current_input.resize(sl.size());
+ for (unsigned i = 0; i < sl.size(); ++i) {
+ if (sl[i].size() != 1) {
+ stringstream msg;
+ msg << "CoNLLFeatures don't support lattice inputs!\nid="
+ << smeta.GetSentenceId() << endl;
+ Error(msg.str());
+ }
+ current_input[i] = sl[i][0].label;
+ }
+ vector<WordID> wids;
+ string fn = "feat";
+ fn += boost::lexical_cast<string>(which_feat);
+ string feats = smeta.GetSGMLValue(fn);
+ if (feats.size() == 0) {
+ Error("Can't find " + fn + " in <seg>\n");
+ }
+ TD::ConvertSentence(feats, &wids);
+ assert(current_input.size() == wids.size());
+ current_input = wids;
+}
+
+void CoNLLFeatures::TraversalFeaturesImpl(
+ const SentenceMetadata& smeta, const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts, SparseVector<double>* features,
+ SparseVector<double>* estimated_features, void* context) const {
+
+ const TRule& rule = *edge.rule_;
+ // arity = 0, no nonterminals
+ // size = 1, predicted label is a single token
+ if (rule.Arity() != 0 ||
+ rule.e_.size() != 1) {
+ return;
+ }
+
+ // replace label macros with actual label strings
+ // NOTE: currently, this feature function doesn't allow any label
+ // macros except %y[0]. but you can look at as much of the source as you want
+ const WordID y0 = rule.e_[0];
+ string y0_str = TD::Convert(y0);
+
+ // start of the span in the input being labeled
+ const int from_src_index = edge.i_;
+ // end of the span in the input
+ const int to_src_index = edge.j_;
+
+ // in the case of tagging the size of the spans being labeled will
+ // always be 1, but in other formalisms, you can have bigger spans
+ if (to_src_index - from_src_index != 1) {
+ cerr << "CoNLLFeatures doesn't support input spans of length != 1";
+ abort();
+ }
+
+ string feature_instance = feature_template;
+ // replace token macros with actual token strings
+ for (unsigned i = 0; i < token_relative_locations.size(); ++i) {
+ int loc = token_relative_locations[i];
+ WordID x = loc < 0? kSOS: kEOS;
+ if(from_src_index + loc >= 0 &&
+ from_src_index + loc < current_input.size()) {
+ x = current_input[from_src_index + loc];
+ }
+ string x_str = TD::Convert(x);
+ ReplaceTokenMacroWithString(feature_instance, loc, x_str);
+ }
+
+ ReplaceLabelMacroWithString(feature_instance, 0, y0_str);
+
+ // pick a real value for this feature
+ double fval = 1.0;
+
+ // add it to the feature vector
+ // FD::Convert converts the feature string to a feature int
+ // Escape makes sure the feature string doesn't have any bad
+ // symbols that could confuse a parser somewhere
+ features->add_value(FD::Convert(Escape(feature_instance)), fval);
+}
diff --git a/decoder/ff_conll.h b/decoder/ff_conll.h
new file mode 100644
index 00000000..b37356d8
--- /dev/null
+++ b/decoder/ff_conll.h
@@ -0,0 +1,45 @@
+#ifndef FF_CONLL_H_
+#define FF_CONLL_H_
+
+#include <vector>
+#include <boost/xpressive/xpressive.hpp>
+#include "ff.h"
+
+using namespace boost::xpressive;
+using namespace std;
+
+class CoNLLFeatures : public FeatureFunction {
+ public:
+ CoNLLFeatures(const string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const HG::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ virtual void ParseArgs(const string& in);
+ virtual string Escape(const string& x) const;
+ virtual void ReplaceMacroWithString(string& feature_instance,
+ bool token_vs_label,
+ int relative_location,
+ const string& actual_token) const;
+ virtual void ReplaceTokenMacroWithString(string& feature_instance,
+ int relative_location,
+ const string& actual_token) const;
+ virtual void ReplaceLabelMacroWithString(string& feature_instance,
+ int relative_location,
+ const string& actual_token) const;
+ virtual void Error(const string&) const;
+
+ private:
+ vector<int> token_relative_locations, label_relative_locations;
+ string feature_template;
+ vector<WordID> current_input;
+ WordID kSOS, kEOS;
+ sregex macro_regex;
+ unsigned which_feat;
+};
+
+#endif
diff --git a/decoder/ff_const_reorder.cc b/decoder/ff_const_reorder.cc
new file mode 100644
index 00000000..f1a6f7cb
--- /dev/null
+++ b/decoder/ff_const_reorder.cc
@@ -0,0 +1,1118 @@
+#include "ff_const_reorder.h"
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "hg.h"
+#include "sentence_metadata.h"
+#include "hash.h"
+#include "ff_const_reorder_common.h"
+
+#include <sstream>
+#include <string>
+#include <vector>
+#include <stdio.h>
+
+using namespace std;
+using namespace const_reorder;
+
+typedef HASH_MAP<std::string, vector<double> > MapClassifier;
+
+inline bool is_inside(int i, int left, int right) {
+ if (i < left || i > right) return false;
+ return true;
+}
+
+/*
+ * assume i <= j
+ * [i, j] is inside [left, right] or [i, j] equates to [left, right]
+ */
+inline bool is_inside(int i, int j, int left, int right) {
+ if (i >= left && j <= right) return true;
+ return false;
+}
+
+/*
+ * assume i <= j
+ * [i, j] is inside [left, right], but [i, j] not equal to [left, right]
+ */
+inline bool is_proper_inside(int i, int j, int left, int right) {
+ if (i >= left && j <= right && right - left > j - i) return true;
+ return false;
+}
+
+/*
+ * assume i <= j
+ * [i, j] is proper proper inside [left, right]
+ */
+inline bool is_proper_proper_inside(int i, int j, int left, int right) {
+ if (i > left && j < right) return true;
+ return false;
+}
+
+inline bool is_overlap(int left1, int right1, int left2, int right2) {
+ if (is_inside(left1, left2, right2) || is_inside(left2, left1, right1))
+ return true;
+
+ return false;
+}
+
+inline void NewAndCopyCharArray(char** p, const char* q) {
+ if (q != NULL) {
+ (*p) = new char[strlen(q) + 1];
+ strcpy((*p), q);
+ } else
+ (*p) = NULL;
+}
+
+// TODO:to make the alignment more efficient
+struct TargetTranslation {
+ TargetTranslation(int begin_pos, int end_pos,int e_num_word)
+ : begin_pos_(begin_pos),
+ end_pos_(end_pos),
+ e_num_words_(e_num_word),
+ vec_left_most_(end_pos - begin_pos + 1, e_num_word),
+ vec_right_most_(end_pos - begin_pos + 1, -1),
+ vec_f_align_bit_array_(end_pos - begin_pos + 1),
+ vec_e_align_bit_array_(e_num_word) {
+ int len = end_pos - begin_pos + 1;
+ align_.reserve(1.5 * len);
+ }
+
+ void InsertAlignmentPoint(int s, int t) {
+ int i = s - begin_pos_;
+
+ vector<bool>& b = vec_f_align_bit_array_[i];
+ if (b.empty()) b.resize(e_num_words_);
+ b[t] = 1;
+
+ vector<bool>& a = vec_e_align_bit_array_[t];
+ if (a.empty()) a.resize(end_pos_ - begin_pos_ + 1);
+ a[i] = 1;
+
+ align_.push_back({s, t});
+
+ if (t > vec_right_most_[i]) vec_right_most_[i] = t;
+ if (t < vec_left_most_[i]) vec_left_most_[i] = t;
+ }
+
+ /*
+ * given a source span [begin, end], whether its target side is continuous,
+ * return "0": the source span is translated silently
+ * return "1": there is at least on word inside its target span, this word
+ * doesn't align to any word inside [begin, end], but outside [begin, end]
+ * return "2": otherwise
+ */
+ string IsTargetConstinousSpan(int begin, int end) const {
+ int target_begin, target_end;
+ FindLeftRightMostTargetSpan(begin, end, target_begin, target_end);
+ if (target_begin == -1) return "0";
+
+ for (int i = target_begin; i <= target_end; i++) {
+ if (vec_e_align_bit_array_[i].empty()) continue;
+ int j = begin;
+ for (; j <= end; j++) {
+ if (vec_e_align_bit_array_[i][j - begin_pos_]) break;
+ }
+ if (j == end + 1) // e[i] is aligned, but e[i] doesn't align to any
+ // source word in [begin_pos, end_pos]
+ return "1";
+ }
+ return "2";
+ }
+
+ string IsTargetConstinousSpan2(int begin, int end) const {
+ int target_begin, target_end;
+ FindLeftRightMostTargetSpan(begin, end, target_begin, target_end);
+ if (target_begin == -1) return "Unaligned";
+
+ for (int i = target_begin; i <= target_end; i++) {
+ if (vec_e_align_bit_array_[i].empty()) continue;
+ int j = begin;
+ for (; j <= end; j++) {
+ if (vec_e_align_bit_array_[i][j - begin_pos_]) break;
+ }
+ if (j == end + 1) // e[i] is aligned, but e[i] doesn't align to any
+ // source word in [begin_pos, end_pos]
+ return "Discon't";
+ }
+ return "Con't";
+ }
+
+ void FindLeftRightMostTargetSpan(int begin, int end, int& target_begin,
+ int& target_end) const {
+ int b = begin - begin_pos_;
+ int e = end - begin_pos_ + 1;
+
+ target_begin = vec_left_most_[b];
+ target_end = vec_right_most_[b];
+ for (int i = b + 1; i < e; i++) {
+ if (target_begin > vec_left_most_[i]) target_begin = vec_left_most_[i];
+ if (target_end < vec_right_most_[i]) target_end = vec_right_most_[i];
+ }
+ if (target_end == -1) target_begin = -1;
+ return;
+
+ target_begin = e_num_words_;
+ target_end = -1;
+
+ for (int i = begin - begin_pos_; i < end - begin_pos_ + 1; i++) {
+ if (vec_f_align_bit_array_[i].empty()) continue;
+ for (int j = 0; j < target_begin; j++)
+ if (vec_f_align_bit_array_[i][j]) {
+ target_begin = j;
+ break;
+ }
+ }
+ for (int i = end - begin_pos_; i > begin - begin_pos_ - 1; i--) {
+ if (vec_f_align_bit_array_[i].empty()) continue;
+ for (int j = e_num_words_ - 1; j > target_end; j--)
+ if (vec_f_align_bit_array_[i][j]) {
+ target_end = j;
+ break;
+ }
+ }
+
+ if (target_end == -1) target_begin = -1;
+ }
+
+ const uint16_t begin_pos_, end_pos_; // the position in input
+ const uint16_t e_num_words_;
+ vector<AlignmentPoint> align_;
+
+ private:
+ vector<short> vec_left_most_;
+ vector<short> vec_right_most_;
+ vector<vector<bool> > vec_f_align_bit_array_;
+ vector<vector<bool> > vec_e_align_bit_array_;
+};
+
+struct FocusedConstituent {
+ FocusedConstituent(const SParsedTree* pTree) {
+ if (pTree == NULL) return;
+ for (size_t i = 0; i < pTree->m_vecTerminals.size(); i++) {
+ STreeItem* pParent = pTree->m_vecTerminals[i]->m_ptParent;
+
+ while (pParent != NULL) {
+ // if (pParent->m_vecChildren.size() > 1 && pParent->m_iEnd -
+ // pParent->m_iBegin > 5) {
+ // if (pParent->m_vecChildren.size() > 1) {
+ if (true) {
+
+ // do constituent reordering for all children of pParent
+ if (strcmp(pParent->m_pszTerm, "ROOT"))
+ focus_parents_.push_back(pParent);
+ }
+ if (pParent->m_iBrotherIndex != 0) break;
+ pParent = pParent->m_ptParent;
+ }
+ }
+ }
+
+ ~FocusedConstituent() { // TODO
+ focus_parents_.clear();
+ }
+
+ vector<STreeItem*> focus_parents_;
+};
+
+typedef SPredicateItem FocusedPredicate;
+
+struct FocusedSRL {
+ FocusedSRL(const SSrlSentence* srl) {
+ if (srl == NULL) return;
+ for (size_t i = 0; i < srl->m_vecPred.size(); i++) {
+ if (strcmp(srl->m_pTree->m_vecTerminals[srl->m_vecPred[i]->m_iPosition]
+ ->m_ptParent->m_pszTerm,
+ "VA") == 0)
+ continue;
+ focus_predicates_.push_back(
+ new FocusedPredicate(srl->m_pTree, srl->m_vecPred[i]));
+ }
+ }
+
+ ~FocusedSRL() { focus_predicates_.clear(); }
+
+ vector<const FocusedPredicate*> focus_predicates_;
+};
+
+struct ConstReorderFeatureImpl {
+ ConstReorderFeatureImpl(const std::string& param) {
+
+ b_block_feature_ = false;
+ b_order_feature_ = false;
+ b_srl_block_feature_ = false;
+ b_srl_order_feature_ = false;
+
+ vector<string> terms;
+ SplitOnWhitespace(param, &terms);
+ if (terms.size() == 1) {
+ b_block_feature_ = true;
+ b_order_feature_ = true;
+ } else if (terms.size() >= 3) {
+ if (terms[1].compare("1") == 0) b_block_feature_ = true;
+ if (terms[2].compare("1") == 0) b_order_feature_ = true;
+ if (terms.size() == 6) {
+ if (terms[4].compare("1") == 0) b_srl_block_feature_ = true;
+ if (terms[5].compare("1") == 0) b_srl_order_feature_ = true;
+
+ assert(b_srl_block_feature_ || b_srl_order_feature_);
+ }
+
+ } else {
+ assert("ERROR");
+ }
+
+ const_reorder_classifier_left_ = NULL;
+ const_reorder_classifier_right_ = NULL;
+
+ srl_reorder_classifier_left_ = NULL;
+ srl_reorder_classifier_right_ = NULL;
+
+ if (b_order_feature_) {
+ InitializeClassifier((terms[0] + string(".left")).c_str(),
+ &const_reorder_classifier_left_);
+ InitializeClassifier((terms[0] + string(".right")).c_str(),
+ &const_reorder_classifier_right_);
+ }
+
+ if (b_srl_order_feature_) {
+ InitializeClassifier((terms[3] + string(".left")).c_str(),
+ &srl_reorder_classifier_left_);
+ InitializeClassifier((terms[3] + string(".right")).c_str(),
+ &srl_reorder_classifier_right_);
+ }
+
+ parsed_tree_ = NULL;
+ focused_consts_ = NULL;
+
+ srl_sentence_ = NULL;
+ focused_srl_ = NULL;
+
+ map_left_ = NULL;
+ map_right_ = NULL;
+
+ map_srl_left_ = NULL;
+ map_srl_right_ = NULL;
+
+ dict_block_status_ = new Dict();
+ dict_block_status_->Convert("Unaligned", false);
+ dict_block_status_->Convert("Discon't", false);
+ dict_block_status_->Convert("Con't", false);
+ }
+
+ ~ConstReorderFeatureImpl() {
+ if (const_reorder_classifier_left_) delete const_reorder_classifier_left_;
+ if (const_reorder_classifier_right_) delete const_reorder_classifier_right_;
+ if (srl_reorder_classifier_left_) delete srl_reorder_classifier_left_;
+ if (srl_reorder_classifier_right_) delete srl_reorder_classifier_right_;
+ FreeSentenceVariables();
+
+ delete dict_block_status_;
+ }
+
+ static int ReserveStateSize() { return 1 * sizeof(TargetTranslation*); }
+
+ void InitializeInputSentence(const std::string& parse_file,
+ const std::string& srl_file) {
+ FreeSentenceVariables();
+ if (b_srl_block_feature_ || b_srl_order_feature_) {
+ assert(srl_file != "");
+ srl_sentence_ = ReadSRLSentence(srl_file);
+ parsed_tree_ = srl_sentence_->m_pTree;
+ } else {
+ assert(parse_file != "");
+ srl_sentence_ = NULL;
+ parsed_tree_ = ReadParseTree(parse_file);
+ }
+
+ if (b_block_feature_ || b_order_feature_) {
+ focused_consts_ = new FocusedConstituent(parsed_tree_);
+
+ if (b_order_feature_) {
+ // we can do the classifier "off-line"
+ map_left_ = new MapClassifier();
+ map_right_ = new MapClassifier();
+ InitializeConstReorderClassifierOutput();
+ }
+ }
+
+ if (b_srl_block_feature_ || b_srl_order_feature_) {
+ focused_srl_ = new FocusedSRL(srl_sentence_);
+
+ if (b_srl_order_feature_) {
+ map_srl_left_ = new MapClassifier();
+ map_srl_right_ = new MapClassifier();
+ InitializeSRLReorderClassifierOutput();
+ }
+ }
+
+ if (parsed_tree_ != NULL) {
+ size_t i = parsed_tree_->m_vecTerminals.size();
+ vec_target_tran_.reserve(20 * i * i * i);
+ } else
+ vec_target_tran_.reserve(1000000);
+ }
+
+ void SetConstReorderFeature(const Hypergraph::Edge& edge,
+ SparseVector<double>* features,
+ const vector<const void*>& ant_states,
+ void* state) {
+ if (parsed_tree_ == NULL) return;
+
+ short int begin = edge.i_, end = edge.j_ - 1;
+
+ typedef TargetTranslation* PtrTargetTranslation;
+ PtrTargetTranslation* remnant =
+ reinterpret_cast<PtrTargetTranslation*>(state);
+
+ vector<const TargetTranslation*> vec_node;
+ vec_node.reserve(edge.tail_nodes_.size());
+ for (size_t i = 0; i < edge.tail_nodes_.size(); i++) {
+ const PtrTargetTranslation* astate =
+ reinterpret_cast<const PtrTargetTranslation*>(ant_states[i]);
+ vec_node.push_back(astate[0]);
+ }
+
+ int e_num_word = edge.rule_->e_.size();
+ for (size_t i = 0; i < vec_node.size(); i++) {
+ e_num_word += vec_node[i]->e_num_words_;
+ e_num_word--;
+ }
+
+ remnant[0] = new TargetTranslation(begin, end, e_num_word);
+ vec_target_tran_.push_back(remnant[0]);
+
+ // reset the alignment
+ // for the source side, we know its position in source sentence
+ // for the target side, we always assume its starting position is 0
+ unsigned vc = 0;
+ const TRulePtr rule = edge.rule_;
+ std::vector<int> f_index(rule->f_.size());
+ int index = edge.i_;
+ for (unsigned i = 0; i < rule->f_.size(); i++) {
+ f_index[i] = index;
+ const WordID& c = rule->f_[i];
+ if (c < 1)
+ index = vec_node[vc++]->end_pos_ + 1;
+ else
+ index++;
+ }
+ assert(vc == vec_node.size());
+ assert(index == edge.j_);
+
+ std::vector<int> e_index(rule->e_.size());
+ index = 0;
+ vc = 0;
+ for (unsigned i = 0; i < rule->e_.size(); i++) {
+ e_index[i] = index;
+ const WordID& c = rule->e_[i];
+ if (c < 1) {
+ index += vec_node[-c]->e_num_words_;
+ vc++;
+ } else
+ index++;
+ }
+ assert(vc == vec_node.size());
+
+ size_t nt_pos = 0;
+ for (size_t i = 0; i < edge.rule_->f_.size(); i++) {
+ if (edge.rule_->f_[i] > 0) continue;
+
+ // it's an NT
+ size_t j;
+ for (j = 0; j < edge.rule_->e_.size(); j++)
+ if (edge.rule_->e_[j] * -1 == nt_pos) break;
+ assert(j != edge.rule_->e_.size());
+ nt_pos++;
+
+ // i aligns j
+ int eindex = e_index[j];
+ const vector<AlignmentPoint>& align =
+ vec_node[-1 * edge.rule_->e_[j]]->align_;
+ for (size_t k = 0; k < align.size(); k++) {
+ remnant[0]->InsertAlignmentPoint(align[k].s_, eindex + align[k].t_);
+ }
+ }
+ for (size_t i = 0; i < edge.rule_->a_.size(); i++) {
+ int findex = f_index[edge.rule_->a_[i].s_];
+ int eindex = e_index[edge.rule_->a_[i].t_];
+ remnant[0]->InsertAlignmentPoint(findex, eindex);
+ }
+
+ // till now, we finished setting state values
+ // next, use the state values to calculate constituent reorder feature
+ SetConstReorderFeature(begin, end, features, remnant[0],
+ vec_node, f_index);
+ }
+
+ void SetConstReorderFeature(short int begin, short int end,
+ SparseVector<double>* features,
+ const TargetTranslation* target_translation,
+ const vector<const TargetTranslation*>& vec_node,
+ std::vector<int>& /*findex*/) {
+ if (b_srl_block_feature_ || b_srl_order_feature_) {
+ double logprob_srl_reorder_left = 0.0, logprob_srl_reorder_right = 0.0;
+ for (size_t i = 0; i < focused_srl_->focus_predicates_.size(); i++) {
+ const FocusedPredicate* pred = focused_srl_->focus_predicates_[i];
+ if (!is_overlap(begin, end, pred->begin_, pred->end_))
+ continue; // have no overlap between this predicate (with its
+ // argument) and the current edge
+
+ size_t j;
+ for (j = 0; j < vec_node.size(); j++) {
+ if (is_inside(pred->begin_, pred->end_, vec_node[j]->begin_pos_,
+ vec_node[j]->end_pos_))
+ break;
+ }
+ if (j < vec_node.size()) continue;
+
+ vector<int> vecBlockStatus;
+ vecBlockStatus.reserve(pred->vec_items_.size());
+ for (j = 0; j < pred->vec_items_.size(); j++) {
+ const STreeItem* con1 = pred->vec_items_[j]->tree_item_;
+ if (con1->m_iBegin < begin || con1->m_iEnd > end) {
+ vecBlockStatus.push_back(0);
+ continue;
+ } // the node is partially outside the current edge
+
+ string type = target_translation->IsTargetConstinousSpan2(
+ con1->m_iBegin, con1->m_iEnd);
+ vecBlockStatus.push_back(dict_block_status_->Convert(type, false));
+
+ if (!b_srl_block_feature_) continue;
+ // see if the node is covered by an NT
+ size_t k;
+ for (k = 0; k < vec_node.size(); k++) {
+ if (is_inside(con1->m_iBegin, con1->m_iEnd, vec_node[k]->begin_pos_,
+ vec_node[k]->end_pos_))
+ break;
+ }
+ if (k < vec_node.size()) continue;
+ int f_id = FD::Convert(string(pred->vec_items_[j]->role_) + type);
+ if (f_id) features->add_value(f_id, 1);
+ }
+
+ if (!b_srl_order_feature_) continue;
+
+ vector<int> vecPosition, vecRelativePosition;
+ vector<int> vecRightPosition, vecRelativeRightPosition;
+ vecPosition.reserve(pred->vec_items_.size());
+ vecRelativePosition.reserve(pred->vec_items_.size());
+ vecRightPosition.reserve(pred->vec_items_.size());
+ vecRelativeRightPosition.reserve(pred->vec_items_.size());
+ for (j = 0; j < pred->vec_items_.size(); j++) {
+ const STreeItem* con1 = pred->vec_items_[j]->tree_item_;
+ if (con1->m_iBegin < begin || con1->m_iEnd > end) {
+ vecPosition.push_back(-1);
+ vecRightPosition.push_back(-1);
+ continue;
+ } // the node is partially outside the current edge
+ int left1 = -1, right1 = -1;
+ target_translation->FindLeftRightMostTargetSpan(
+ con1->m_iBegin, con1->m_iEnd, left1, right1);
+ vecPosition.push_back(left1);
+ vecRightPosition.push_back(right1);
+ }
+ fnGetRelativePosition(vecPosition, vecRelativePosition);
+ fnGetRelativePosition(vecRightPosition, vecRelativeRightPosition);
+
+ for (j = 1; j < pred->vec_items_.size(); j++) {
+ const STreeItem* con1 = pred->vec_items_[j - 1]->tree_item_;
+ const STreeItem* con2 = pred->vec_items_[j]->tree_item_;
+
+ if (con1->m_iBegin < begin || con2->m_iEnd > end)
+ continue; // one of the two nodes is partially outside the current
+ // edge
+
+ // both con1 and con2 are covered, need to check if they are covered
+ // by the same NT
+ size_t k;
+ for (k = 0; k < vec_node.size(); k++) {
+ if (is_inside(con1->m_iBegin, con2->m_iEnd, vec_node[k]->begin_pos_,
+ vec_node[k]->end_pos_))
+ break;
+ }
+ if (k < vec_node.size()) continue;
+
+ // they are not covered bye the same NT
+ string outcome;
+ string key;
+ GenerateKey(pred->vec_items_[j - 1]->tree_item_,
+ pred->vec_items_[j]->tree_item_, vecBlockStatus[j - 1],
+ vecBlockStatus[j], key);
+
+ fnGetOutcome(vecRelativePosition[j - 1], vecRelativePosition[j],
+ outcome);
+ double prob = CalculateConstReorderProb(srl_reorder_classifier_left_,
+ map_srl_left_, key, outcome);
+ // printf("%s %s %f\n", ostr.str().c_str(), outcome.c_str(), prob);
+ logprob_srl_reorder_left += log10(prob);
+
+ fnGetOutcome(vecRelativeRightPosition[j - 1],
+ vecRelativeRightPosition[j], outcome);
+ prob = CalculateConstReorderProb(srl_reorder_classifier_right_,
+ map_srl_right_, key, outcome);
+ logprob_srl_reorder_right += log10(prob);
+ }
+ }
+
+ if (b_srl_order_feature_) {
+ int f_id = FD::Convert("SRLReorderFeatureLeft");
+ if (f_id && logprob_srl_reorder_left != 0.0)
+ features->set_value(f_id, logprob_srl_reorder_left);
+ f_id = FD::Convert("SRLReorderFeatureRight");
+ if (f_id && logprob_srl_reorder_right != 0.0)
+ features->set_value(f_id, logprob_srl_reorder_right);
+ }
+ }
+
+ if (b_block_feature_ || b_order_feature_) {
+ double logprob_const_reorder_left = 0.0,
+ logprob_const_reorder_right = 0.0;
+
+ for (size_t i = 0; i < focused_consts_->focus_parents_.size(); i++) {
+ STreeItem* parent = focused_consts_->focus_parents_[i];
+ if (!is_overlap(begin, end, parent->m_iBegin,
+ parent->m_iEnd))
+ continue; // have no overlap between this parent node and the current
+ // edge
+
+ size_t j;
+ for (j = 0; j < vec_node.size(); j++) {
+ if (is_inside(parent->m_iBegin, parent->m_iEnd,
+ vec_node[j]->begin_pos_, vec_node[j]->end_pos_))
+ break;
+ }
+ if (j < vec_node.size()) continue;
+
+ if (b_block_feature_) {
+ if (parent->m_iBegin >= begin &&
+ parent->m_iEnd <= end) {
+ string type = target_translation->IsTargetConstinousSpan2(
+ parent->m_iBegin, parent->m_iEnd);
+ int f_id = FD::Convert(string(parent->m_pszTerm) + type);
+ if (f_id) features->add_value(f_id, 1);
+ }
+ }
+
+ if (parent->m_vecChildren.size() == 1 || !b_order_feature_) continue;
+
+ vector<int> vecChunkBlock;
+ vecChunkBlock.reserve(parent->m_vecChildren.size());
+
+ for (j = 0; j < parent->m_vecChildren.size(); j++) {
+ STreeItem* con1 = parent->m_vecChildren[j];
+ if (con1->m_iBegin < begin || con1->m_iEnd > end) {
+ vecChunkBlock.push_back(0);
+ continue;
+ } // the node is partially outside the current edge
+
+ string type = target_translation->IsTargetConstinousSpan2(
+ con1->m_iBegin, con1->m_iEnd);
+ vecChunkBlock.push_back(dict_block_status_->Convert(type, false));
+
+ /*if (!b_block_feature_) continue;
+ //see if the node is covered by an NT
+ size_t k;
+ for (k = 0; k < vec_node.size(); k++) {
+ if (is_inside(con1->m_iBegin, con1->m_iEnd,
+ vec_node[k]->begin_pos_, vec_node[k]->end_pos_))
+ break;
+ }
+ if (k < vec_node.size()) continue;
+ int f_id = FD::Convert(string(con1->m_pszTerm) + type);
+ if (f_id)
+ features->add_value(f_id, 1);*/
+ }
+
+ if (!b_order_feature_) continue;
+
+ vector<int> vecPosition, vecRelativePosition;
+ vector<int> vecRightPosition, vecRelativeRightPosition;
+ vecPosition.reserve(parent->m_vecChildren.size());
+ vecRelativePosition.reserve(parent->m_vecChildren.size());
+ vecRightPosition.reserve(parent->m_vecChildren.size());
+ vecRelativeRightPosition.reserve(parent->m_vecChildren.size());
+ for (j = 0; j < parent->m_vecChildren.size(); j++) {
+ STreeItem* con1 = parent->m_vecChildren[j];
+ if (con1->m_iBegin < begin || con1->m_iEnd > end) {
+ vecPosition.push_back(-1);
+ vecRightPosition.push_back(-1);
+ continue;
+ } // the node is partially outside the current edge
+ int left1 = -1, right1 = -1;
+ target_translation->FindLeftRightMostTargetSpan(
+ con1->m_iBegin, con1->m_iEnd, left1, right1);
+ vecPosition.push_back(left1);
+ vecRightPosition.push_back(right1);
+ }
+ fnGetRelativePosition(vecPosition, vecRelativePosition);
+ fnGetRelativePosition(vecRightPosition, vecRelativeRightPosition);
+
+ for (j = 1; j < parent->m_vecChildren.size(); j++) {
+ STreeItem* con1 = parent->m_vecChildren[j - 1];
+ STreeItem* con2 = parent->m_vecChildren[j];
+
+ if (con1->m_iBegin < begin || con2->m_iEnd > end)
+ continue; // one of the two nodes is partially outside the current
+ // edge
+
+ // both con1 and con2 are covered, need to check if they are covered
+ // by the same NT
+ size_t k;
+ for (k = 0; k < vec_node.size(); k++) {
+ if (is_inside(con1->m_iBegin, con2->m_iEnd, vec_node[k]->begin_pos_,
+ vec_node[k]->end_pos_))
+ break;
+ }
+ if (k < vec_node.size()) continue;
+
+ // they are not covered bye the same NT
+ string outcome;
+ string key;
+ GenerateKey(parent->m_vecChildren[j - 1], parent->m_vecChildren[j],
+ vecChunkBlock[j - 1], vecChunkBlock[j], key);
+
+ fnGetOutcome(vecRelativePosition[j - 1], vecRelativePosition[j],
+ outcome);
+ double prob = CalculateConstReorderProb(
+ const_reorder_classifier_left_, map_left_, key, outcome);
+ // printf("%s %s %f\n", ostr.str().c_str(), outcome.c_str(), prob);
+ logprob_const_reorder_left += log10(prob);
+
+ fnGetOutcome(vecRelativeRightPosition[j - 1],
+ vecRelativeRightPosition[j], outcome);
+ prob = CalculateConstReorderProb(const_reorder_classifier_right_,
+ map_right_, key, outcome);
+ logprob_const_reorder_right += log10(prob);
+ }
+ }
+
+ if (b_order_feature_) {
+ int f_id = FD::Convert("ConstReorderFeatureLeft");
+ if (f_id && logprob_const_reorder_left != 0.0)
+ features->set_value(f_id, logprob_const_reorder_left);
+ f_id = FD::Convert("ConstReorderFeatureRight");
+ if (f_id && logprob_const_reorder_right != 0.0)
+ features->set_value(f_id, logprob_const_reorder_right);
+ }
+ }
+ }
+
+ private:
+ void Byte_to_Char(unsigned char* str, int n) {
+ str[0] = (n & 255);
+ str[1] = n / 256;
+ }
+ void GenerateKey(const STreeItem* pCon1, const STreeItem* pCon2,
+ int iBlockStatus1, int iBlockStatus2, string& key) {
+ assert(iBlockStatus1 != 0);
+ assert(iBlockStatus2 != 0);
+ unsigned char szTerm[1001];
+ Byte_to_Char(szTerm, pCon1->m_iBegin);
+ Byte_to_Char(szTerm + 2, pCon2->m_iEnd);
+ szTerm[4] = (char)iBlockStatus1;
+ szTerm[5] = (char)iBlockStatus2;
+ szTerm[6] = '\0';
+ // sprintf(szTerm, "%d|%d|%d|%d|%s|%s", pCon1->m_iBegin, pCon1->m_iEnd,
+ // pCon2->m_iBegin, pCon2->m_iEnd, strBlockStatus1.c_str(),
+ // strBlockStatus2.c_str());
+ key = string(szTerm, szTerm + 6);
+ }
+ void InitializeConstReorderClassifierOutput() {
+ if (!b_order_feature_) return;
+ int size_block_status = dict_block_status_->max();
+
+ for (size_t i = 0; i < focused_consts_->focus_parents_.size(); i++) {
+ STreeItem* parent = focused_consts_->focus_parents_[i];
+
+ for (size_t j = 1; j < parent->m_vecChildren.size(); j++) {
+ for (size_t k = 1; k <= size_block_status; k++) {
+ for (size_t l = 1; l <= size_block_status; l++) {
+ ostringstream ostr;
+ GenerateFeature(parsed_tree_, parent, j,
+ dict_block_status_->Convert(k),
+ dict_block_status_->Convert(l), ostr);
+
+ string strKey;
+ GenerateKey(parent->m_vecChildren[j - 1], parent->m_vecChildren[j],
+ k, l, strKey);
+
+ vector<double> vecOutput;
+ const_reorder_classifier_left_->fnEval(ostr.str().c_str(),
+ vecOutput);
+ (*map_left_)[strKey] = vecOutput;
+
+ const_reorder_classifier_right_->fnEval(ostr.str().c_str(),
+ vecOutput);
+ (*map_right_)[strKey] = vecOutput;
+ }
+ }
+ }
+ }
+ }
+
+ void InitializeSRLReorderClassifierOutput() {
+ if (!b_srl_order_feature_) return;
+ int size_block_status = dict_block_status_->max();
+
+ for (size_t i = 0; i < focused_srl_->focus_predicates_.size(); i++) {
+ const FocusedPredicate* pred = focused_srl_->focus_predicates_[i];
+
+ for (size_t j = 1; j < pred->vec_items_.size(); j++) {
+ for (size_t k = 1; k <= size_block_status; k++) {
+ for (size_t l = 1; l <= size_block_status; l++) {
+ ostringstream ostr;
+
+ SArgumentReorderModel::fnGenerateFeature(
+ parsed_tree_, pred->pred_, pred, j,
+ dict_block_status_->Convert(k), dict_block_status_->Convert(l),
+ ostr);
+
+ string strKey;
+ GenerateKey(pred->vec_items_[j - 1]->tree_item_,
+ pred->vec_items_[j]->tree_item_, k, l, strKey);
+
+ vector<double> vecOutput;
+ srl_reorder_classifier_left_->fnEval(ostr.str().c_str(), vecOutput);
+ (*map_srl_left_)[strKey] = vecOutput;
+
+ srl_reorder_classifier_right_->fnEval(ostr.str().c_str(),
+ vecOutput);
+ (*map_srl_right_)[strKey] = vecOutput;
+ }
+ }
+ }
+ }
+ }
+
+ double CalculateConstReorderProb(
+ const Tsuruoka_Maxent* const_reorder_classifier, const MapClassifier* map,
+ const string& key, const string& outcome) {
+ MapClassifier::const_iterator iter = (*map).find(key);
+ assert(iter != map->end());
+ int id = const_reorder_classifier->fnGetClassId(outcome);
+ return iter->second[id];
+ }
+
+ void FreeSentenceVariables() {
+ if (srl_sentence_ != NULL) {
+ delete srl_sentence_;
+ srl_sentence_ = NULL;
+ } else {
+ if (parsed_tree_ != NULL) delete parsed_tree_;
+ parsed_tree_ = NULL;
+ }
+
+ if (focused_consts_ != NULL) delete focused_consts_;
+ focused_consts_ = NULL;
+
+ for (size_t i = 0; i < vec_target_tran_.size(); i++)
+ delete vec_target_tran_[i];
+ vec_target_tran_.clear();
+
+ if (map_left_ != NULL) delete map_left_;
+ map_left_ = NULL;
+ if (map_right_ != NULL) delete map_right_;
+ map_right_ = NULL;
+
+ if (map_srl_left_ != NULL) delete map_srl_left_;
+ map_srl_left_ = NULL;
+ if (map_srl_right_ != NULL) delete map_srl_right_;
+ map_srl_right_ = NULL;
+ }
+
+ void InitializeClassifier(const char* pszFname,
+ Tsuruoka_Maxent** ppClassifier) {
+ (*ppClassifier) = new Tsuruoka_Maxent(pszFname);
+ }
+
+ void GenerateOutcome(const vector<int>& vecPos, vector<string>& vecOutcome) {
+ vecOutcome.clear();
+
+ for (size_t i = 1; i < vecPos.size(); i++) {
+ if (vecPos[i] == -1 || vecPos[i] == vecPos[i - 1]) {
+ vecOutcome.push_back("M"); // monotone
+ continue;
+ }
+
+ if (vecPos[i - 1] == -1) {
+ // vecPos[i] is not -1
+ size_t j = i - 2;
+ while (j > -1 && vecPos[j] == -1) j--;
+
+ size_t k;
+ for (k = 0; k < j; k++) {
+ if (vecPos[k] > vecPos[j] || vecPos[k] <= vecPos[i]) break;
+ }
+ if (k < j) {
+ vecOutcome.push_back("DM");
+ continue;
+ }
+
+ for (k = i + 1; k < vecPos.size(); k++)
+ if (vecPos[k] < vecPos[i] && (j == -1 && vecPos[k] >= vecPos[j]))
+ break;
+ if (k < vecPos.size()) {
+ vecOutcome.push_back("DM");
+ continue;
+ }
+ vecOutcome.push_back("M");
+ } else {
+ // neither of vecPos[i-1] and vecPos[i] is -1
+ if (vecPos[i - 1] < vecPos[i]) {
+ // monotone or discon't monotone
+ size_t j;
+ for (j = 0; j < i - 1; j++)
+ if (vecPos[j] > vecPos[i - 1] && vecPos[j] <= vecPos[i]) break;
+ if (j < i - 1) {
+ vecOutcome.push_back("DM");
+ continue;
+ }
+ for (j = i + 1; j < vecPos.size(); j++)
+ if (vecPos[j] >= vecPos[i - 1] && vecPos[j] < vecPos[i]) break;
+ if (j < vecPos.size()) {
+ vecOutcome.push_back("DM");
+ continue;
+ }
+ vecOutcome.push_back("M");
+ } else {
+ // swap or discon't swap
+ size_t j;
+ for (j = 0; j < i - 1; j++)
+ if (vecPos[j] > vecPos[i] && vecPos[j] <= vecPos[i - 1]) break;
+ if (j < i - 1) {
+ vecOutcome.push_back("DS");
+ continue;
+ }
+ for (j = i + 1; j < vecPos.size(); j++)
+ if (vecPos[j] >= vecPos[i] && vecPos[j] < vecPos[i - 1]) break;
+ if (j < vecPos.size()) {
+ vecOutcome.push_back("DS");
+ continue;
+ }
+ vecOutcome.push_back("S");
+ }
+ }
+ }
+
+ assert(vecOutcome.size() == vecPos.size() - 1);
+ }
+
+ void fnGetRelativePosition(const vector<int>& vecLeft,
+ vector<int>& vecPosition) {
+ vecPosition.clear();
+
+ vector<float> vec;
+ vec.reserve(vecLeft.size());
+ for (size_t i = 0; i < vecLeft.size(); i++) {
+ if (vecLeft[i] == -1) {
+ if (i == 0)
+ vec.push_back(-1);
+ else
+ vec.push_back(vecLeft[i - 1] + 0.1);
+ } else
+ vec.push_back(vecLeft[i]);
+ }
+
+ for (size_t i = 0; i < vecLeft.size(); i++) {
+ int count = 0;
+
+ for (size_t j = 0; j < vecLeft.size(); j++) {
+ if (j == i) continue;
+ if (vec[j] < vec[i]) {
+ count++;
+ } else if (vec[j] == vec[i] && j < i) {
+ count++;
+ }
+ }
+ vecPosition.push_back(count);
+ }
+
+ for (size_t i = 1; i < vecPosition.size(); i++) {
+ if (vecPosition[i - 1] == vecPosition[i]) {
+ for (size_t j = 0; j < vecLeft.size(); j++) cout << vecLeft[j] << " ";
+ cout << "\n";
+ assert(false);
+ }
+ }
+ }
+
+ inline void fnGetOutcome(int i1, int i2, string& strOutcome) {
+ assert(i1 != i2);
+ if (i1 < i2) {
+ if (i2 > i1 + 1)
+ strOutcome = string("DM");
+ else
+ strOutcome = string("M");
+ } else {
+ if (i1 > i2 + 1)
+ strOutcome = string("DS");
+ else
+ strOutcome = string("S");
+ }
+ }
+
+ // features in constituent_reorder_model.cc
+ void GenerateFeature(const SParsedTree* pTree, const STreeItem* pParent,
+ int iPos, const string& strBlockStatus1,
+ const string& strBlockStatus2, ostringstream& ostr) {
+ STreeItem* pCon1, *pCon2;
+ pCon1 = pParent->m_vecChildren[iPos - 1];
+ pCon2 = pParent->m_vecChildren[iPos];
+
+ string left_label = string(pCon1->m_pszTerm);
+ string right_label = string(pCon2->m_pszTerm);
+ string parent_label = string(pParent->m_pszTerm);
+
+ vector<string> vec_other_right_sibling;
+ for (int i = iPos + 1; i < pParent->m_vecChildren.size(); i++)
+ vec_other_right_sibling.push_back(
+ string(pParent->m_vecChildren[i]->m_pszTerm));
+ if (vec_other_right_sibling.size() == 0)
+ vec_other_right_sibling.push_back(string("NULL"));
+ vector<string> vec_other_left_sibling;
+ for (int i = 0; i < iPos - 1; i++)
+ vec_other_left_sibling.push_back(
+ string(pParent->m_vecChildren[i]->m_pszTerm));
+ if (vec_other_left_sibling.size() == 0)
+ vec_other_left_sibling.push_back(string("NULL"));
+
+ // generate features
+ // f1
+ ostr << "f1=" << left_label << "_" << right_label << "_" << parent_label;
+ // f2
+ for (int i = 0; i < vec_other_right_sibling.size(); i++)
+ ostr << " f2=" << left_label << "_" << right_label << "_" << parent_label
+ << "_" << vec_other_right_sibling[i];
+ // f3
+ for (int i = 0; i < vec_other_left_sibling.size(); i++)
+ ostr << " f3=" << left_label << "_" << right_label << "_" << parent_label
+ << "_" << vec_other_left_sibling[i];
+ // f4
+ ostr << " f4=" << left_label << "_" << right_label << "_"
+ << pTree->m_vecTerminals[pCon1->m_iHeadWord]->m_ptParent->m_pszTerm;
+ // f5
+ ostr << " f5=" << left_label << "_" << right_label << "_"
+ << pTree->m_vecTerminals[pCon1->m_iHeadWord]->m_pszTerm;
+ // f6
+ ostr << " f6=" << left_label << "_" << right_label << "_"
+ << pTree->m_vecTerminals[pCon2->m_iHeadWord]->m_ptParent->m_pszTerm;
+ // f7
+ ostr << " f7=" << left_label << "_" << right_label << "_"
+ << pTree->m_vecTerminals[pCon2->m_iHeadWord]->m_pszTerm;
+ // f8
+ ostr << " f8=" << left_label << "_" << right_label << "_"
+ << strBlockStatus1;
+ // f9
+ ostr << " f9=" << left_label << "_" << right_label << "_"
+ << strBlockStatus2;
+
+ // f10
+ ostr << " f10=" << left_label << "_" << parent_label;
+ // f11
+ ostr << " f11=" << right_label << "_" << parent_label;
+ }
+
+ SParsedTree* ReadParseTree(const std::string& parse_file) {
+ SParseReader* reader = new SParseReader(parse_file.c_str(), false);
+ SParsedTree* tree = reader->fnReadNextParseTree();
+ // assert(tree != NULL);
+ delete reader;
+ return tree;
+ }
+
+ SSrlSentence* ReadSRLSentence(const std::string& srl_file) {
+ SSrlSentenceReader* reader = new SSrlSentenceReader(srl_file.c_str());
+ SSrlSentence* srl = reader->fnReadNextSrlSentence();
+ // assert(srl != NULL);
+ delete reader;
+ return srl;
+ }
+
+ private:
+ Tsuruoka_Maxent* const_reorder_classifier_left_;
+ Tsuruoka_Maxent* const_reorder_classifier_right_;
+
+ Tsuruoka_Maxent* srl_reorder_classifier_left_;
+ Tsuruoka_Maxent* srl_reorder_classifier_right_;
+
+ MapClassifier* map_left_;
+ MapClassifier* map_right_;
+
+ MapClassifier* map_srl_left_;
+ MapClassifier* map_srl_right_;
+
+ SParsedTree* parsed_tree_;
+ FocusedConstituent* focused_consts_;
+ vector<TargetTranslation*> vec_target_tran_;
+
+ bool b_order_feature_;
+ bool b_block_feature_;
+
+ bool b_srl_block_feature_;
+ bool b_srl_order_feature_;
+ SSrlSentence* srl_sentence_;
+ FocusedSRL* focused_srl_;
+
+ Dict* dict_block_status_;
+};
+
+ConstReorderFeature::ConstReorderFeature(const std::string& param) {
+ pimpl_ = new ConstReorderFeatureImpl(param);
+ SetStateSize(ConstReorderFeatureImpl::ReserveStateSize());
+ SetIgnoredStateSize(ConstReorderFeatureImpl::ReserveStateSize());
+ name_ = "ConstReorderFeature";
+}
+
+ConstReorderFeature::~ConstReorderFeature() { // TODO
+ delete pimpl_;
+}
+
+void ConstReorderFeature::PrepareForInput(const SentenceMetadata& smeta) {
+ string parse_file = smeta.GetSGMLValue("parse");
+ if (parse_file.empty()) {
+ parse_file = smeta.GetSGMLValue("src_tree");
+ }
+ string srl_file = smeta.GetSGMLValue("srl");
+ assert(!(parse_file == "" && srl_file == ""));
+
+ pimpl_->InitializeInputSentence(parse_file, srl_file);
+}
+
+void ConstReorderFeature::TraversalFeaturesImpl(
+ const SentenceMetadata& /* smeta */, const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states, SparseVector<double>* features,
+ SparseVector<double>* /*estimated_features*/, void* state) const {
+ pimpl_->SetConstReorderFeature(edge, features, ant_states, state);
+}
+
+string ConstReorderFeature::usage(bool show_params, bool show_details) {
+ ostringstream out;
+ out << "ConstReorderFeature";
+ if (show_params) {
+ out << " model_file_prefix [const_block=1 const_order=1] [srl_block=0 "
+ "srl_order=0]"
+ << "\nParameters:\n"
+ << " const_{block,order}: enable/disable constituency constraints.\n"
+ << " src_{block,order}: enable/disable semantic role labeling "
+ "constraints.\n";
+ }
+ if (show_details) {
+ out << "\n"
+ << "Soft reordering constraint features from "
+ "http://www.aclweb.org/anthology/P14-1106. To train the classifers, "
+ "use utils/const_reorder_model_trainer for constituency reordering "
+ "constraints and utils/argument_reorder_model_trainer for semantic "
+ "role labeling reordering constraints.\n"
+ << "Input segments should provide path to parse tree (resp. SRL parse) "
+ "as \"parse\" (resp. \"srl\") properties.\n";
+ }
+ return out.str();
+}
+
+boost::shared_ptr<FeatureFunction> CreateConstReorderModel(
+ const std::string& param) {
+ ConstReorderFeature* ret = new ConstReorderFeature(param);
+ return boost::shared_ptr<FeatureFunction>(ret);
+}
diff --git a/decoder/ff_const_reorder.h b/decoder/ff_const_reorder.h
new file mode 100644
index 00000000..a5be02d0
--- /dev/null
+++ b/decoder/ff_const_reorder.h
@@ -0,0 +1,43 @@
+/*
+ * ff_const_reorder.h
+ *
+ * Created on: Jul 11, 2013
+ * Author: junhuili
+ */
+
+#ifndef FF_CONST_REORDER_H_
+#define FF_CONST_REORDER_H_
+
+#include "ff_factory.h"
+#include "ff.h"
+
+struct ConstReorderFeatureImpl;
+
+// Soft reordering constraint features from
+// http://www.aclweb.org/anthology/P14-1106. To train the classifers,
+// use utils/const_reorder_model_trainer for constituency reordering
+// constraints and utils/argument_reorder_model_trainer for SRL
+// reordering constraints.
+//
+// Input segments should provide path to parse tree (resp. SRL parse)
+// as "parse" (resp. "srl") properties.
+class ConstReorderFeature : public FeatureFunction {
+ public:
+ ConstReorderFeature(const std::string& param);
+ ~ConstReorderFeature();
+ static std::string usage(bool param, bool verbose);
+
+ protected:
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+
+ virtual void TraversalFeaturesImpl(
+ const SentenceMetadata& smeta, const HG::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features, SparseVector<double>* estimated_features,
+ void* out_context) const;
+
+ private:
+ ConstReorderFeatureImpl* pimpl_;
+};
+
+#endif /* FF_CONST_REORDER_H_ */
diff --git a/decoder/ff_const_reorder_common.h b/decoder/ff_const_reorder_common.h
new file mode 100644
index 00000000..755fd948
--- /dev/null
+++ b/decoder/ff_const_reorder_common.h
@@ -0,0 +1,1348 @@
+#ifndef _FF_CONST_REORDER_COMMON_H
+#define _FF_CONST_REORDER_COMMON_H
+
+#include <string>
+#include <assert.h>
+#include <stdio.h>
+#include <string.h>
+#include <string>
+#include <sstream>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "maxent.h"
+#include "stringlib.h"
+
+namespace const_reorder {
+
+struct STreeItem {
+ STreeItem(const char *pszTerm) {
+ m_pszTerm = new char[strlen(pszTerm) + 1];
+ strcpy(m_pszTerm, pszTerm);
+
+ m_ptParent = NULL;
+ m_iBegin = -1;
+ m_iEnd = -1;
+ m_iHeadChild = -1;
+ m_iHeadWord = -1;
+ m_iBrotherIndex = -1;
+ }
+ ~STreeItem() {
+ delete[] m_pszTerm;
+ for (size_t i = 0; i < m_vecChildren.size(); i++) delete m_vecChildren[i];
+ }
+ int fnAppend(STreeItem *ptChild) {
+ m_vecChildren.push_back(ptChild);
+ ptChild->m_iBrotherIndex = m_vecChildren.size() - 1;
+ ptChild->m_ptParent = this;
+ return m_vecChildren.size() - 1;
+ }
+ int fnGetChildrenNum() { return m_vecChildren.size(); }
+
+ bool fnIsPreTerminal(void) {
+ int I;
+ if (this == NULL || m_vecChildren.size() == 0) return false;
+
+ for (I = 0; I < m_vecChildren.size(); I++)
+ if (m_vecChildren[I]->m_vecChildren.size() > 0) return false;
+
+ return true;
+ }
+
+ public:
+ char *m_pszTerm;
+
+ std::vector<STreeItem *> m_vecChildren; // children items
+ STreeItem *m_ptParent; // the parent item
+
+ int m_iBegin;
+ int m_iEnd; // the node span words[m_iBegin, m_iEnd]
+ int m_iHeadChild; // the index of its head child
+ int m_iHeadWord; // the index of its head word
+ int m_iBrotherIndex; // the index in his brothers
+};
+
+struct SGetHeadWord {
+ typedef std::vector<std::string> CVectorStr;
+ SGetHeadWord() {}
+ ~SGetHeadWord() {}
+ int fnGetHeadWord(char *pszCFGLeft, CVectorStr vectRight) {
+ // 0 indicating from right to left while 1 indicating from left to right
+ char szaHeadLists[201] = "0";
+
+ /* //head rules for Egnlish
+ if( strcmp( pszCFGLeft, "ADJP" ) == 0 )
+ strcpy( szaHeadLists, "0NNS 0QP 0NN 0$ 0ADVP 0JJ 0VBN 0VBG 0ADJP
+ 0JJR 0NP 0JJS 0DT 0FW 0RBR 0RBS 0SBAR 0RB 0" );
+ else if( strcmp( pszCFGLeft, "ADVP" ) == 0 )
+ strcpy( szaHeadLists, "1RB 1RBR 1RBS 1FW 1ADVP 1TO 1CD 1JJR 1JJ 1IN
+ 1NP 1JJS 1NN 1" );
+ else if( strcmp( pszCFGLeft, "CONJP" ) == 0 )
+ strcpy( szaHeadLists, "1CC 1RB 1IN 1" );
+ else if( strcmp( pszCFGLeft, "FRAG" ) == 0 )
+ strcpy( szaHeadLists, "1" );
+ else if( strcmp( pszCFGLeft, "INTJ" ) == 0 )
+ strcpy( szaHeadLists, "0" );
+ else if( strcmp( pszCFGLeft, "LST" ) == 0 )
+ strcpy( szaHeadLists, "1LS 1: 1CLN 1" );
+ else if( strcmp( pszCFGLeft, "NAC" ) == 0 )
+ strcpy( szaHeadLists, "0NN 0NNS 0NNP 0NNPS 0NP 0NAC 0EX 0$ 0CD 0QP
+ 0PRP 0VBG 0JJ 0JJS 0JJR 0ADJP 0FW 0" );
+ else if( strcmp( pszCFGLeft, "PP" ) == 0 )
+ strcpy( szaHeadLists, "1IN 1TO 1VBG 1VBN 1RP 1FW 1" );
+ else if( strcmp( pszCFGLeft, "PRN" ) == 0 )
+ strcpy( szaHeadLists, "1" );
+ else if( strcmp( pszCFGLeft, "PRT" ) == 0 )
+ strcpy( szaHeadLists, "1RP 1" );
+ else if( strcmp( pszCFGLeft, "QP" ) == 0 )
+ strcpy( szaHeadLists, "0$ 0IN 0NNS 0NN 0JJ 0RB 0DT 0CD 0NCD 0QP 0JJR
+ 0JJS 0" );
+ else if( strcmp( pszCFGLeft, "RRC" ) == 0 )
+ strcpy( szaHeadLists, "1VP 1NP 1ADVP 1ADJP 1PP 1" );
+ else if( strcmp( pszCFGLeft, "S" ) == 0 )
+ strcpy( szaHeadLists, "0TO 0IN 0VP 0S 0SBAR 0ADJP 0UCP 0NP 0" );
+ else if( strcmp( pszCFGLeft, "SBAR" ) == 0 )
+ strcpy( szaHeadLists, "0WHNP 0WHPP 0WHADVP 0WHADJP 0IN 0DT 0S 0SQ
+ 0SINV 0SBAR 0FRAG 0" );
+ else if( strcmp( pszCFGLeft, "SBARQ" ) == 0 )
+ strcpy( szaHeadLists, "0SQ 0S 0SINV 0SBARQ 0FRAG 0" );
+ else if( strcmp( pszCFGLeft, "SINV" ) == 0 )
+ strcpy( szaHeadLists, "0VBZ 0VBD 0VBP 0VB 0MD 0VP 0S 0SINV 0ADJP 0NP
+ 0" );
+ else if( strcmp( pszCFGLeft, "SQ" ) == 0 )
+ strcpy( szaHeadLists, "0VBZ 0VBD 0VBP 0VB 0MD 0VP 0SQ 0" );
+ else if( strcmp( pszCFGLeft, "UCP" ) == 0 )
+ strcpy( szaHeadLists, "1" );
+ else if( strcmp( pszCFGLeft, "VP" ) == 0 )
+ strcpy( szaHeadLists, "0TO 0VBD 0VBN 0MD 0VBZ 0VB 0VBG 0VBP 0VP
+ 0ADJP 0NN 0NNS 0NP 0" );
+ else if( strcmp( pszCFGLeft, "WHADJP" ) == 0 )
+ strcpy( szaHeadLists, "0CC 0WRB 0JJ 0ADJP 0" );
+ else if( strcmp( pszCFGLeft, "WHADVP" ) == 0 )
+ strcpy( szaHeadLists, "1CC 1WRB 1" );
+ else if( strcmp( pszCFGLeft, "WHNP" ) == 0 )
+ strcpy( szaHeadLists, "0WDT 0WP 0WP$ 0WHADJP 0WHPP 0WHNP 0" );
+ else if( strcmp( pszCFGLeft, "WHPP" ) == 0 )
+ strcpy( szaHeadLists, "1IN 1TO FW 1" );
+ else if( strcmp( pszCFGLeft, "NP" ) == 0 )
+ strcpy( szaHeadLists, "0NN NNP NNS NNPS NX POS JJR 0NP 0$ ADJP PRN
+ 0CD 0JJ JJS RB QP 0" );
+ */
+
+ if (strcmp(pszCFGLeft, "ADJP") == 0)
+ strcpy(szaHeadLists, "0ADJP JJ 0AD NN CS 0");
+ else if (strcmp(pszCFGLeft, "ADVP") == 0)
+ strcpy(szaHeadLists, "0ADVP AD 0");
+ else if (strcmp(pszCFGLeft, "CLP") == 0)
+ strcpy(szaHeadLists, "0CLP M 0");
+ else if (strcmp(pszCFGLeft, "CP") == 0)
+ strcpy(szaHeadLists, "0DEC SP 1ADVP CS 0CP IP 0");
+ else if (strcmp(pszCFGLeft, "DNP") == 0)
+ strcpy(szaHeadLists, "0DNP DEG 0DEC 0");
+ else if (strcmp(pszCFGLeft, "DVP") == 0)
+ strcpy(szaHeadLists, "0DVP DEV 0");
+ else if (strcmp(pszCFGLeft, "DP") == 0)
+ strcpy(szaHeadLists, "1DP DT 1");
+ else if (strcmp(pszCFGLeft, "FRAG") == 0)
+ strcpy(szaHeadLists, "0VV NR NN 0");
+ else if (strcmp(pszCFGLeft, "INTJ") == 0)
+ strcpy(szaHeadLists, "0INTJ IJ 0");
+ else if (strcmp(pszCFGLeft, "LST") == 0)
+ strcpy(szaHeadLists, "1LST CD OD 1");
+ else if (strcmp(pszCFGLeft, "IP") == 0)
+ strcpy(szaHeadLists, "0IP VP 0VV 0");
+ // strcpy( szaHeadLists, "0VP 0VV 1IP 0" );
+ else if (strcmp(pszCFGLeft, "LCP") == 0)
+ strcpy(szaHeadLists, "0LCP LC 0");
+ else if (strcmp(pszCFGLeft, "NP") == 0)
+ strcpy(szaHeadLists, "0NP NN NT NR QP 0");
+ else if (strcmp(pszCFGLeft, "PP") == 0)
+ strcpy(szaHeadLists, "1PP P 1");
+ else if (strcmp(pszCFGLeft, "PRN") == 0)
+ strcpy(szaHeadLists, "0 NP IP VP NT NR NN 0");
+ else if (strcmp(pszCFGLeft, "QP") == 0)
+ strcpy(szaHeadLists, "0QP CLP CD OD 0");
+ else if (strcmp(pszCFGLeft, "VP") == 0)
+ strcpy(szaHeadLists, "1VP VA VC VE VV BA LB VCD VSB VRD VNV VCP 1");
+ else if (strcmp(pszCFGLeft, "VCD") == 0)
+ strcpy(szaHeadLists, "0VCD VV VA VC VE 0");
+ if (strcmp(pszCFGLeft, "VRD") == 0)
+ strcpy(szaHeadLists, "0VRD VV VA VC VE 0");
+ else if (strcmp(pszCFGLeft, "VSB") == 0)
+ strcpy(szaHeadLists, "0VSB VV VA VC VE 0");
+ else if (strcmp(pszCFGLeft, "VCP") == 0)
+ strcpy(szaHeadLists, "0VCP VV VA VC VE 0");
+ else if (strcmp(pszCFGLeft, "VNV") == 0)
+ strcpy(szaHeadLists, "0VNV VV VA VC VE 0");
+ else if (strcmp(pszCFGLeft, "VPT") == 0)
+ strcpy(szaHeadLists, "0VNV VV VA VC VE 0");
+ else if (strcmp(pszCFGLeft, "UCP") == 0)
+ strcpy(szaHeadLists, "0");
+ else if (strcmp(pszCFGLeft, "WHNP") == 0)
+ strcpy(szaHeadLists, "0WHNP NP NN NT NR QP 0");
+ else if (strcmp(pszCFGLeft, "WHPP") == 0)
+ strcpy(szaHeadLists, "1WHPP PP P 1");
+
+ /* //head rules for GENIA corpus
+ if( strcmp( pszCFGLeft, "ADJP" ) == 0 )
+ strcpy( szaHeadLists, "0NNS 0QP 0NN 0$ 0ADVP 0JJ 0VBN 0VBG 0ADJP
+ 0JJR 0NP 0JJS 0DT 0FW 0RBR 0RBS 0SBAR 0RB 0" );
+ else if( strcmp( pszCFGLeft, "ADVP" ) == 0 )
+ strcpy( szaHeadLists, "1RB 1RBR 1RBS 1FW 1ADVP 1TO 1CD 1JJR 1JJ 1IN
+ 1NP 1JJS 1NN 1" );
+ else if( strcmp( pszCFGLeft, "CONJP" ) == 0 )
+ strcpy( szaHeadLists, "1CC 1RB 1IN 1" );
+ else if( strcmp( pszCFGLeft, "FRAG" ) == 0 )
+ strcpy( szaHeadLists, "1" );
+ else if( strcmp( pszCFGLeft, "INTJ" ) == 0 )
+ strcpy( szaHeadLists, "0" );
+ else if( strcmp( pszCFGLeft, "LST" ) == 0 )
+ strcpy( szaHeadLists, "1LS 1: 1CLN 1" );
+ else if( strcmp( pszCFGLeft, "NAC" ) == 0 )
+ strcpy( szaHeadLists, "0NN 0NNS 0NNP 0NNPS 0NP 0NAC 0EX 0$ 0CD 0QP
+ 0PRP 0VBG 0JJ 0JJS 0JJR 0ADJP 0FW 0" );
+ else if( strcmp( pszCFGLeft, "PP" ) == 0 )
+ strcpy( szaHeadLists, "1IN 1TO 1VBG 1VBN 1RP 1FW 1" );
+ else if( strcmp( pszCFGLeft, "PRN" ) == 0 )
+ strcpy( szaHeadLists, "1" );
+ else if( strcmp( pszCFGLeft, "PRT" ) == 0 )
+ strcpy( szaHeadLists, "1RP 1" );
+ else if( strcmp( pszCFGLeft, "QP" ) == 0 )
+ strcpy( szaHeadLists, "0$ 0IN 0NNS 0NN 0JJ 0RB 0DT 0CD 0NCD 0QP 0JJR
+ 0JJS 0" );
+ else if( strcmp( pszCFGLeft, "RRC" ) == 0 )
+ strcpy( szaHeadLists, "1VP 1NP 1ADVP 1ADJP 1PP 1" );
+ else if( strcmp( pszCFGLeft, "S" ) == 0 )
+ strcpy( szaHeadLists, "0TO 0IN 0VP 0S 0SBAR 0ADJP 0UCP 0NP 0" );
+ else if( strcmp( pszCFGLeft, "SBAR" ) == 0 )
+ strcpy( szaHeadLists, "0WHNP 0WHPP 0WHADVP 0WHADJP 0IN 0DT 0S 0SQ
+ 0SINV 0SBAR 0FRAG 0" );
+ else if( strcmp( pszCFGLeft, "SBARQ" ) == 0 )
+ strcpy( szaHeadLists, "0SQ 0S 0SINV 0SBARQ 0FRAG 0" );
+ else if( strcmp( pszCFGLeft, "SINV" ) == 0 )
+ strcpy( szaHeadLists, "0VBZ 0VBD 0VBP 0VB 0MD 0VP 0S 0SINV 0ADJP 0NP
+ 0" );
+ else if( strcmp( pszCFGLeft, "SQ" ) == 0 )
+ strcpy( szaHeadLists, "0VBZ 0VBD 0VBP 0VB 0MD 0VP 0SQ 0" );
+ else if( strcmp( pszCFGLeft, "UCP" ) == 0 )
+ strcpy( szaHeadLists, "1" );
+ else if( strcmp( pszCFGLeft, "VP" ) == 0 )
+ strcpy( szaHeadLists, "0TO 0VBD 0VBN 0MD 0VBZ 0VB 0VBG 0VBP 0VP
+ 0ADJP 0NN 0NNS 0NP 0" );
+ else if( strcmp( pszCFGLeft, "WHADJP" ) == 0 )
+ strcpy( szaHeadLists, "0CC 0WRB 0JJ 0ADJP 0" );
+ else if( strcmp( pszCFGLeft, "WHADVP" ) == 0 )
+ strcpy( szaHeadLists, "1CC 1WRB 1" );
+ else if( strcmp( pszCFGLeft, "WHNP" ) == 0 )
+ strcpy( szaHeadLists, "0WDT 0WP 0WP$ 0WHADJP 0WHPP 0WHNP 0" );
+ else if( strcmp( pszCFGLeft, "WHPP" ) == 0 )
+ strcpy( szaHeadLists, "1IN 1TO FW 1" );
+ else if( strcmp( pszCFGLeft, "NP" ) == 0 )
+ strcpy( szaHeadLists, "0NN NNP NNS NNPS NX POS JJR 0NP 0$ ADJP PRN
+ 0CD 0JJ JJS RB QP 0" );
+ */
+
+ return fnMyOwnHeadWordRule(szaHeadLists, vectRight);
+ }
+
+ private:
+ int fnMyOwnHeadWordRule(char *pszaHeadLists, CVectorStr vectRight) {
+ char szHeadList[201], *p;
+ char szTerm[101];
+ int J;
+
+ p = pszaHeadLists;
+
+ int iCountRight;
+
+ iCountRight = vectRight.size();
+
+ szHeadList[0] = '\0';
+ while (1) {
+ szTerm[0] = '\0';
+ sscanf(p, "%s", szTerm);
+ if (strlen(szHeadList) == 0) {
+ if (strcmp(szTerm, "0") == 0) {
+ return iCountRight - 1;
+ }
+ if (strcmp(szTerm, "1") == 0) {
+ return 0;
+ }
+
+ sprintf(szHeadList, "%c %s ", szTerm[0], szTerm + 1);
+ p = strstr(p, szTerm);
+ p += strlen(szTerm);
+ } else {
+ if ((szTerm[0] == '0') || (szTerm[0] == '1')) {
+ if (szHeadList[0] == '0') {
+ for (J = iCountRight - 1; J >= 0; J--) {
+ sprintf(szTerm, " %s ", vectRight.at(J).c_str());
+ if (strstr(szHeadList, szTerm) != NULL) return J;
+ }
+ } else {
+ for (J = 0; J < iCountRight; J++) {
+ sprintf(szTerm, " %s ", vectRight.at(J).c_str());
+ if (strstr(szHeadList, szTerm) != NULL) return J;
+ }
+ }
+
+ szHeadList[0] = '\0';
+ } else {
+ strcat(szHeadList, szTerm);
+ strcat(szHeadList, " ");
+
+ p = strstr(p, szTerm);
+ p += strlen(szTerm);
+ }
+ }
+ }
+
+ return 0;
+ }
+};
+
+struct SParsedTree {
+ SParsedTree() { m_ptRoot = NULL; }
+ ~SParsedTree() {
+ if (m_ptRoot != NULL) delete m_ptRoot;
+ }
+ static SParsedTree *fnConvertFromString(const char *pszStr) {
+ if (strcmp(pszStr, "(())") == 0) return NULL;
+ SParsedTree *pTree = new SParsedTree();
+
+ std::vector<std::string> vecSyn;
+ fnReadSyntactic(pszStr, vecSyn);
+
+ int iLeft = 1, iRight = 1; //# left/right parenthesis
+
+ STreeItem *pcurrent;
+
+ pTree->m_ptRoot = new STreeItem(vecSyn[1].c_str());
+
+ pcurrent = pTree->m_ptRoot;
+
+ for (size_t i = 2; i < vecSyn.size() - 1; i++) {
+ if (strcmp(vecSyn[i].c_str(), "(") == 0)
+ iLeft++;
+ else if (strcmp(vecSyn[i].c_str(), ")") == 0) {
+ iRight++;
+ if (pcurrent == NULL) {
+ // error
+ fprintf(stderr, "ERROR in ConvertFromString\n");
+ fprintf(stderr, "%s\n", pszStr);
+ return NULL;
+ }
+ pcurrent = pcurrent->m_ptParent;
+ } else {
+ STreeItem *ptNewItem = new STreeItem(vecSyn[i].c_str());
+ pcurrent->fnAppend(ptNewItem);
+ pcurrent = ptNewItem;
+
+ if (strcmp(vecSyn[i - 1].c_str(), "(") != 0 &&
+ strcmp(vecSyn[i - 1].c_str(), ")") != 0) {
+ pTree->m_vecTerminals.push_back(ptNewItem);
+ pcurrent = pcurrent->m_ptParent;
+ }
+ }
+ }
+
+ if (iLeft != iRight) {
+ // error
+ fprintf(stderr, "the left and right parentheses are not matched!");
+ fprintf(stderr, "ERROR in ConvertFromString\n");
+ fprintf(stderr, "%s\n", pszStr);
+ return NULL;
+ }
+
+ return pTree;
+ }
+
+ int fnGetNumWord() { return m_vecTerminals.size(); }
+
+ void fnSetSpanInfo() {
+ int iNextNum = 0;
+ fnSuffixTraverseSetSpanInfo(m_ptRoot, iNextNum);
+ }
+
+ void fnSetHeadWord() {
+ for (size_t i = 0; i < m_vecTerminals.size(); i++)
+ m_vecTerminals[i]->m_iHeadWord = i;
+ SGetHeadWord *pGetHeadWord = new SGetHeadWord();
+ fnSuffixTraverseSetHeadWord(m_ptRoot, pGetHeadWord);
+ delete pGetHeadWord;
+ }
+
+ STreeItem *fnFindNodeForSpan(int iLeft, int iRight, bool bLowest) {
+ STreeItem *pTreeItem = m_vecTerminals[iLeft];
+
+ while (pTreeItem->m_iEnd < iRight) {
+ pTreeItem = pTreeItem->m_ptParent;
+ if (pTreeItem == NULL) break;
+ }
+ if (pTreeItem == NULL) return NULL;
+ if (pTreeItem->m_iEnd > iRight) return NULL;
+
+ assert(pTreeItem->m_iEnd == iRight);
+ if (bLowest) return pTreeItem;
+
+ while (pTreeItem->m_ptParent != NULL &&
+ pTreeItem->m_ptParent->fnGetChildrenNum() == 1)
+ pTreeItem = pTreeItem->m_ptParent;
+
+ return pTreeItem;
+ }
+
+ private:
+ void fnSuffixTraverseSetSpanInfo(STreeItem *ptItem, int &iNextNum) {
+ int I;
+ int iNumChildren = ptItem->fnGetChildrenNum();
+ for (I = 0; I < iNumChildren; I++)
+ fnSuffixTraverseSetSpanInfo(ptItem->m_vecChildren[I], iNextNum);
+
+ if (I == 0) {
+ ptItem->m_iBegin = iNextNum;
+ ptItem->m_iEnd = iNextNum++;
+ } else {
+ ptItem->m_iBegin = ptItem->m_vecChildren[0]->m_iBegin;
+ ptItem->m_iEnd = ptItem->m_vecChildren[I - 1]->m_iEnd;
+ }
+ }
+
+ void fnSuffixTraverseSetHeadWord(STreeItem *ptItem,
+ SGetHeadWord *pGetHeadWord) {
+ int I, iHeadchild;
+
+ if (ptItem->m_vecChildren.size() == 0) return;
+
+ for (I = 0; I < ptItem->m_vecChildren.size(); I++)
+ fnSuffixTraverseSetHeadWord(ptItem->m_vecChildren[I], pGetHeadWord);
+
+ std::vector<std::string> vecRight;
+
+ if (ptItem->m_vecChildren.size() == 1)
+ iHeadchild = 0;
+ else {
+ for (I = 0; I < ptItem->m_vecChildren.size(); I++)
+ vecRight.push_back(std::string(ptItem->m_vecChildren[I]->m_pszTerm));
+
+ iHeadchild = pGetHeadWord->fnGetHeadWord(ptItem->m_pszTerm, vecRight);
+ }
+
+ ptItem->m_iHeadChild = iHeadchild;
+ ptItem->m_iHeadWord = ptItem->m_vecChildren[iHeadchild]->m_iHeadWord;
+ }
+
+ static void fnReadSyntactic(const char *pszSyn,
+ std::vector<std::string> &vec) {
+ char *p;
+ int I;
+
+ int iLeftNum, iRightNum;
+ char *pszTmp, *pszTerm;
+ pszTmp = new char[strlen(pszSyn)];
+ pszTerm = new char[strlen(pszSyn)];
+ pszTmp[0] = pszTerm[0] = '\0';
+
+ vec.clear();
+
+ char *pszLine;
+ pszLine = new char[strlen(pszSyn) + 1];
+ strcpy(pszLine, pszSyn);
+
+ char *pszLine2;
+
+ while (1) {
+ while ((strlen(pszLine) > 0) && (pszLine[strlen(pszLine) - 1] > 0) &&
+ (pszLine[strlen(pszLine) - 1] <= ' '))
+ pszLine[strlen(pszLine) - 1] = '\0';
+
+ if (strlen(pszLine) == 0) break;
+
+ // printf( "%s\n", pszLine );
+ pszLine2 = pszLine;
+ while (pszLine2[0] <= ' ') pszLine2++;
+ if (pszLine2[0] == '<') continue;
+
+ sscanf(pszLine2 + 1, "%s", pszTmp);
+
+ if (pszLine2[0] == '(') {
+ iLeftNum = 0;
+ iRightNum = 0;
+ }
+
+ p = pszLine2;
+ while (1) {
+ pszTerm[0] = '\0';
+ sscanf(p, "%s", pszTerm);
+
+ if (strlen(pszTerm) == 0) break;
+ p = strstr(p, pszTerm);
+ p += strlen(pszTerm);
+
+ if ((pszTerm[0] == '(') || (pszTerm[strlen(pszTerm) - 1] == ')')) {
+ if (pszTerm[0] == '(') {
+ vec.push_back(std::string("("));
+ iLeftNum++;
+
+ I = 1;
+ while (pszTerm[I] == '(' && pszTerm[I] != '\0') {
+ vec.push_back(std::string("("));
+ iLeftNum++;
+
+ I++;
+ }
+
+ if (strlen(pszTerm) > 1) vec.push_back(std::string(pszTerm + I));
+ } else {
+ char *pTmp;
+ pTmp = pszTerm + strlen(pszTerm) - 1;
+ while ((pTmp[0] == ')') && (pTmp >= pszTerm)) pTmp--;
+ pTmp[1] = '\0';
+
+ if (strlen(pszTerm) > 0) vec.push_back(std::string(pszTerm));
+ pTmp += 2;
+
+ for (I = 0; I <= (int)strlen(pTmp); I++) {
+ vec.push_back(std::string(")"));
+ iRightNum++;
+ }
+ }
+ } else {
+ char *q;
+ q = strchr(pszTerm, ')');
+ if (q != NULL) {
+ q[0] = '\0';
+ if (pszTerm[0] != '\0') vec.push_back(std::string(pszTerm));
+ vec.push_back(std::string(")"));
+ iRightNum++;
+
+ q++;
+ while (q[0] == ')') {
+ vec.push_back(std::string(")"));
+ q++;
+ iRightNum++;
+ }
+
+ while (q[0] == '(') {
+ vec.push_back(std::string("("));
+ q++;
+ iLeftNum++;
+ }
+
+ if (q[0] != '\0') vec.push_back(std::string(q));
+ } else
+ vec.push_back(std::string(pszTerm));
+ }
+ }
+
+ if (iLeftNum != iRightNum) {
+ fprintf(stderr, "%s\n", pszSyn);
+ assert(iLeftNum == iRightNum);
+ }
+ /*if ( iLeftNum != iRightNum ) {
+ printf( "ERROR: left( and right ) is not matched, %d ( and %d
+ )\n", iLeftNum, iRightNum );
+ return;
+ }*/
+
+ if (vec.size() >= 2 && strcmp(vec[1].c_str(), "(") == 0) {
+ //( (IP..) )
+ std::vector<std::string>::iterator it;
+ it = vec.begin();
+ it++;
+ vec.insert(it, std::string("ROOT"));
+ }
+
+ break;
+ }
+
+ delete[] pszLine;
+ delete[] pszTmp;
+ delete[] pszTerm;
+ }
+
+ public:
+ STreeItem *m_ptRoot;
+ std::vector<STreeItem *> m_vecTerminals; // the leaf nodes
+};
+
+struct SParseReader {
+ SParseReader(const char *pszParse_Fname, bool bFlattened = false)
+ : m_bFlattened(bFlattened) {
+ m_fpIn = fopen(pszParse_Fname, "r");
+ assert(m_fpIn != NULL);
+ }
+ ~SParseReader() {
+ if (m_fpIn != NULL) fclose(m_fpIn);
+ }
+
+ SParsedTree *fnReadNextParseTree() {
+ SParsedTree *pTree = NULL;
+ char *pszLine = new char[100001];
+ int iLen;
+
+ while (fnReadNextSentence(pszLine, &iLen) == true) {
+ if (iLen == 0) continue;
+
+ pTree = SParsedTree::fnConvertFromString(pszLine);
+ if (pTree == NULL) break;
+ if (m_bFlattened)
+ fnPostProcessingFlattenedParse(pTree);
+ else {
+ pTree->fnSetSpanInfo();
+ pTree->fnSetHeadWord();
+ }
+ break;
+ }
+
+ delete[] pszLine;
+ return pTree;
+ }
+
+ SParsedTree *fnReadNextParseTreeWithProb(double *pProb) {
+ SParsedTree *pTree = NULL;
+ char *pszLine = new char[100001];
+ int iLen;
+
+ while (fnReadNextSentence(pszLine, &iLen) == true) {
+ if (iLen == 0) continue;
+
+ char *p = strchr(pszLine, ' ');
+ assert(p != NULL);
+ p[0] = '\0';
+ p++;
+ if (pProb) (*pProb) = atof(pszLine);
+
+ pTree = SParsedTree::fnConvertFromString(p);
+ if (m_bFlattened)
+ fnPostProcessingFlattenedParse(pTree);
+ else {
+ pTree->fnSetSpanInfo();
+ pTree->fnSetHeadWord();
+ }
+ break;
+ }
+
+ delete[] pszLine;
+ return pTree;
+ }
+
+ private:
+ /*
+ * since to the parse tree is a flattened tree, use the head mark to identify
+ * head info.
+ * the head node will be marked as "*XP*"
+ */
+ void fnSetParseTreeHeadInfo(SParsedTree *pTree) {
+ for (size_t i = 0; i < pTree->m_vecTerminals.size(); i++)
+ pTree->m_vecTerminals[i]->m_iHeadWord = i;
+ fnSuffixTraverseSetHeadWord(pTree->m_ptRoot);
+ }
+
+ void fnSuffixTraverseSetHeadWord(STreeItem *pTreeItem) {
+ if (pTreeItem->m_vecChildren.size() == 0) return;
+
+ for (size_t i = 0; i < pTreeItem->m_vecChildren.size(); i++)
+ fnSuffixTraverseSetHeadWord(pTreeItem->m_vecChildren[i]);
+
+ std::vector<std::string> vecRight;
+
+ int iHeadchild;
+
+ if (pTreeItem->fnIsPreTerminal()) {
+ iHeadchild = 0;
+ } else {
+ size_t i;
+ for (i = 0; i < pTreeItem->m_vecChildren.size(); i++) {
+ char *p = pTreeItem->m_vecChildren[i]->m_pszTerm;
+ if (p[0] == '*' && p[strlen(p) - 1] == '*') {
+ iHeadchild = i;
+ p[strlen(p) - 1] = '\0';
+ std::string str = p + 1;
+ strcpy(p, str.c_str()); // erase the "*..*"
+ break;
+ }
+ }
+ assert(i < pTreeItem->m_vecChildren.size());
+ }
+
+ pTreeItem->m_iHeadChild = iHeadchild;
+ pTreeItem->m_iHeadWord = pTreeItem->m_vecChildren[iHeadchild]->m_iHeadWord;
+ }
+ void fnPostProcessingFlattenedParse(SParsedTree *pTree) {
+ pTree->fnSetSpanInfo();
+ fnSetParseTreeHeadInfo(pTree);
+ }
+ bool fnReadNextSentence(char *pszLine, int *piLength) {
+ if (feof(m_fpIn) == true) return false;
+
+ int iLen;
+
+ pszLine[0] = '\0';
+
+ fgets(pszLine, 10001, m_fpIn);
+ iLen = strlen(pszLine);
+ while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) {
+ pszLine[iLen - 1] = '\0';
+ iLen--;
+ }
+
+ if (piLength != NULL) (*piLength) = iLen;
+
+ return true;
+ }
+
+ private:
+ FILE *m_fpIn;
+ const bool m_bFlattened;
+};
+
+/*
+ * Note:
+ * m_vec_s_align.size() may not be equal to the length of source side
+ *sentence
+ * due to the last words may not be aligned
+ *
+ */
+struct SAlignment {
+ typedef std::vector<int> SingleAlign;
+ SAlignment(const char* pszAlign) { fnInitializeAlignment(pszAlign); }
+ ~SAlignment() {}
+
+ bool fnIsAligned(int i, bool s) const {
+ const std::vector<SingleAlign>* palign;
+ if (s == true)
+ palign = &m_vec_s_align;
+ else
+ palign = &m_vec_t_align;
+ if ((*palign)[i].size() == 0) return false;
+ return true;
+ }
+
+ /*
+ * return true if [b, e] is aligned phrases on source side (if s==true) or on
+ * the target side (if s==false);
+ * return false, otherwise.
+ */
+ bool fnIsAlignedPhrase(int b, int e, bool s, int* pob, int* poe) const {
+ int ob, oe; //[b, e] on the other side
+ if (s == true)
+ fnGetLeftRightMost(b, e, m_vec_s_align, ob, oe);
+ else
+ fnGetLeftRightMost(b, e, m_vec_t_align, ob, oe);
+
+ if (ob == -1) {
+ if (pob != NULL) (*pob) = -1;
+ if (poe != NULL) (*poe) = -1;
+ return false; // no aligned word among [b, e]
+ }
+ if (pob != NULL) (*pob) = ob;
+ if (poe != NULL) (*poe) = oe;
+
+ int bb, be; //[b, e] back given [ob, oe] on the other side
+ if (s == true)
+ fnGetLeftRightMost(ob, oe, m_vec_t_align, bb, be);
+ else
+ fnGetLeftRightMost(ob, oe, m_vec_s_align, bb, be);
+
+ if (bb < b || be > e) return false;
+ return true;
+ }
+
+ bool fnIsAlignedTightPhrase(int b, int e, bool s, int* pob, int* poe) const {
+ const std::vector<SingleAlign>* palign;
+ if (s == true)
+ palign = &m_vec_s_align;
+ else
+ palign = &m_vec_t_align;
+
+ if ((*palign).size() <= e || (*palign)[b].size() == 0 ||
+ (*palign)[e].size() == 0)
+ return false;
+
+ return fnIsAlignedPhrase(b, e, s, pob, poe);
+ }
+
+ void fnGetLeftRightMost(int b, int e, bool s, int& ob, int& oe) const {
+ if (s == true)
+ fnGetLeftRightMost(b, e, m_vec_s_align, ob, oe);
+ else
+ fnGetLeftRightMost(b, e, m_vec_t_align, ob, oe);
+ }
+
+ /*
+ * look the translation of source[b, e] is continuous or not
+ * 1) return "Unaligned": if the source[b, e] is translated silently;
+ * 2) return "Con't": if none of target words in target[.., ..] is exclusively
+ * aligned to any word outside source[b, e]
+ * 3) return "Discon't": otherwise;
+ */
+ std::string fnIsContinuous(int b, int e) const {
+ int ob, oe;
+ fnGetLeftRightMost(b, e, true, ob, oe);
+ if (ob == -1) return "Unaligned";
+
+ for (int i = ob; i <= oe; i++) {
+ if (!fnIsAligned(i, false)) continue;
+ const SingleAlign& a = m_vec_t_align[i];
+ int j;
+ for (j = 0; j < a.size(); j++)
+ if (a[j] >= b && a[j] <= e) break;
+ if (j == a.size()) return "Discon't";
+ }
+ return "Con't";
+ }
+
+ const SingleAlign* fnGetSingleWordAlign(int i, bool s) const {
+ if (s == true) {
+ if (i >= m_vec_s_align.size()) return NULL;
+ return &(m_vec_s_align[i]);
+ } else {
+ if (i >= m_vec_t_align.size()) return NULL;
+ return &(m_vec_t_align[i]);
+ }
+ }
+
+ private:
+ void fnGetLeftRightMost(int b, int e, const std::vector<SingleAlign>& align,
+ int& ob, int& oe) const {
+ ob = oe = -1;
+ for (int i = b; i <= e && i < align.size(); i++) {
+ if (align[i].size() > 0) {
+ if (align[i][0] < ob || ob == -1) ob = align[i][0];
+ if (oe < align[i][align[i].size() - 1])
+ oe = align[i][align[i].size() - 1];
+ }
+ }
+ }
+ void fnInitializeAlignment(const char* pszAlign) {
+ m_vec_s_align.clear();
+ m_vec_t_align.clear();
+
+ std::vector<std::string> terms = SplitOnWhitespace(std::string(pszAlign));
+ int si, ti;
+ for (size_t i = 0; i < terms.size(); i++) {
+ sscanf(terms[i].c_str(), "%d-%d", &si, &ti);
+
+ while (m_vec_s_align.size() <= si) {
+ SingleAlign sa;
+ m_vec_s_align.push_back(sa);
+ }
+ while (m_vec_t_align.size() <= ti) {
+ SingleAlign sa;
+ m_vec_t_align.push_back(sa);
+ }
+
+ m_vec_s_align[si].push_back(ti);
+ m_vec_t_align[ti].push_back(si);
+ }
+
+ // sort
+ for (size_t i = 0; i < m_vec_s_align.size(); i++) {
+ std::sort(m_vec_s_align[i].begin(), m_vec_s_align[i].end());
+ }
+ for (size_t i = 0; i < m_vec_t_align.size(); i++) {
+ std::sort(m_vec_t_align[i].begin(), m_vec_t_align[i].end());
+ }
+ }
+
+ private:
+ std::vector<SingleAlign> m_vec_s_align; // source side words' alignment
+ std::vector<SingleAlign> m_vec_t_align; // target side words' alignment
+};
+
+struct SAlignmentReader {
+ SAlignmentReader(const char* pszFname) {
+ m_fpIn = fopen(pszFname, "r");
+ assert(m_fpIn != NULL);
+ }
+ ~SAlignmentReader() {
+ if (m_fpIn != NULL) fclose(m_fpIn);
+ }
+ SAlignment* fnReadNextAlignment() {
+ if (feof(m_fpIn) == true) return NULL;
+ char* pszLine = new char[100001];
+ pszLine[0] = '\0';
+ fgets(pszLine, 10001, m_fpIn);
+ int iLen = strlen(pszLine);
+ if (iLen == 0) return NULL;
+ while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) {
+ pszLine[iLen - 1] = '\0';
+ iLen--;
+ }
+ SAlignment* pAlign = new SAlignment(pszLine);
+ delete[] pszLine;
+ return pAlign;
+ }
+
+ private:
+ FILE* m_fpIn;
+};
+
+struct SArgument {
+ SArgument(const char* pszRole, int iBegin, int iEnd, float fProb) {
+ m_pszRole = new char[strlen(pszRole) + 1];
+ strcpy(m_pszRole, pszRole);
+ m_iBegin = iBegin;
+ m_iEnd = iEnd;
+ m_fProb = fProb;
+ m_pTreeItem = NULL;
+ }
+ ~SArgument() { delete[] m_pszRole; }
+
+ void fnSetTreeItem(STreeItem* pTreeItem) {
+ m_pTreeItem = pTreeItem;
+ if (m_pTreeItem != NULL && m_pTreeItem->m_iBegin != -1) {
+ assert(m_pTreeItem->m_iBegin == m_iBegin);
+ assert(m_pTreeItem->m_iEnd == m_iEnd);
+ }
+ }
+
+ char* m_pszRole; // argument rule, e.g., ARG0, ARGM-TMP
+ int m_iBegin;
+ int m_iEnd; // the span of the argument, [m_iBegin, m_iEnd]
+ float m_fProb; // the probability of this role,
+ STreeItem* m_pTreeItem;
+};
+
+struct SPredicate {
+ SPredicate(const char* pszLemma, int iPosition) {
+ if (pszLemma != NULL) {
+ m_pszLemma = new char[strlen(pszLemma) + 1];
+ strcpy(m_pszLemma, pszLemma);
+ } else
+ m_pszLemma = NULL;
+ m_iPosition = iPosition;
+ }
+ ~SPredicate() {
+ if (m_pszLemma != NULL) delete[] m_pszLemma;
+ for (size_t i = 0; i < m_vecArgt.size(); i++) delete m_vecArgt[i];
+ }
+ int fnAppend(const char* pszRole, int iBegin, int iEnd) {
+ SArgument* pArgt = new SArgument(pszRole, iBegin, iEnd, 1.0);
+ return fnAppend(pArgt);
+ }
+ int fnAppend(SArgument* pArgt) {
+ m_vecArgt.push_back(pArgt);
+ int iPosition = m_vecArgt.size() - 1;
+ return iPosition;
+ }
+
+ char* m_pszLemma; // lemma of the predicate, for Chinese, it's always as same
+ // as the predicate itself
+ int m_iPosition; // the position in sentence
+ std::vector<SArgument*> m_vecArgt; // arguments associated to the predicate
+};
+
+struct SSrlSentence {
+ SSrlSentence() { m_pTree = NULL; }
+ ~SSrlSentence() {
+ if (m_pTree != NULL) delete m_pTree;
+
+ for (size_t i = 0; i < m_vecPred.size(); i++) delete m_vecPred[i];
+ }
+ int fnAppend(const char* pszLemma, int iPosition) {
+ SPredicate* pPred = new SPredicate(pszLemma, iPosition);
+ return fnAppend(pPred);
+ }
+ int fnAppend(SPredicate* pPred) {
+ m_vecPred.push_back(pPred);
+ int iPosition = m_vecPred.size() - 1;
+ return iPosition;
+ }
+ int GetPredicateNum() { return m_vecPred.size(); }
+
+ SParsedTree* m_pTree;
+ std::vector<SPredicate*> m_vecPred;
+};
+
+struct SSrlSentenceReader {
+ SSrlSentenceReader(const char* pszSrlFname) {
+ m_fpIn = fopen(pszSrlFname, "r");
+ assert(m_fpIn != NULL);
+ }
+ ~SSrlSentenceReader() {
+ if (m_fpIn != NULL) fclose(m_fpIn);
+ }
+
+ inline void fnReplaceAll(std::string& str, const std::string& from,
+ const std::string& to) {
+ size_t start_pos = 0;
+ while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
+ str.replace(start_pos, from.length(), to);
+ start_pos += to.length(); // In case 'to' contains 'from', like replacing
+ // 'x' with 'yx'
+ }
+ }
+
+ // TODO: here only considers flat predicate-argument structure
+ // i.e., no overlap among them
+ SSrlSentence* fnReadNextSrlSentence() {
+ std::vector<std::vector<std::string> > vecContent;
+ if (fnReadNextContent(vecContent) == false) return NULL;
+
+ SSrlSentence* pSrlSentence = new SSrlSentence();
+ int iSize = vecContent.size();
+ // put together syntactic text
+ std::ostringstream ostr;
+ for (int i = 0; i < iSize; i++) {
+ std::string strSynSeg =
+ vecContent[i][5]; // the 5th column is the syntactic segment
+ size_t iPosition = strSynSeg.find_first_of('*');
+ assert(iPosition != std::string::npos);
+ std::ostringstream ostrTmp;
+ ostrTmp << "(" << vecContent[i][2] << " " << vecContent[i][0]
+ << ")"; // the 2th column is POS-tag, and the 0th column is word
+ strSynSeg.replace(iPosition, 1, ostrTmp.str());
+ fnReplaceAll(strSynSeg, "(", " (");
+ ostr << strSynSeg;
+ }
+ std::string strSyn = ostr.str();
+ pSrlSentence->m_pTree = SParsedTree::fnConvertFromString(strSyn.c_str());
+ pSrlSentence->m_pTree->fnSetHeadWord();
+ pSrlSentence->m_pTree->fnSetSpanInfo();
+
+ // read predicate-argument structure
+ int iNumPred = vecContent[0].size() - 8;
+ for (int i = 0; i < iNumPred; i++) {
+ std::vector<std::string> vecRole;
+ std::vector<int> vecBegin;
+ std::vector<int> vecEnd;
+ int iPred = -1;
+ for (int j = 0; j < iSize; j++) {
+ const char* p = vecContent[j][i + 8].c_str();
+ const char* q;
+ if (p[0] == '(') {
+ // starting position of an argument(or predicate)
+ vecBegin.push_back(j);
+ q = strchr(p, '*');
+ assert(q != NULL);
+ vecRole.push_back(vecContent[j][i + 8].substr(1, q - p - 1));
+ if (vecRole.back().compare("V") == 0) {
+ assert(iPred == -1);
+ iPred = vecRole.size() - 1;
+ }
+ }
+ if (p[strlen(p) - 1] == ')') {
+ // end position of an argument(or predicate)
+ vecEnd.push_back(j);
+ assert(vecBegin.size() == vecEnd.size());
+ }
+ }
+ assert(iPred != -1);
+ SPredicate* pPred = new SPredicate(
+ pSrlSentence->m_pTree->m_vecTerminals[vecBegin[iPred]]->m_pszTerm,
+ vecBegin[iPred]);
+ pSrlSentence->fnAppend(pPred);
+ for (size_t j = 0; j < vecBegin.size(); j++) {
+ if (j == iPred) continue;
+ pPred->fnAppend(vecRole[j].c_str(), vecBegin[j], vecEnd[j]);
+ pPred->m_vecArgt.back()->fnSetTreeItem(
+ pSrlSentence->m_pTree->fnFindNodeForSpan(vecBegin[j], vecEnd[j],
+ false));
+ }
+ }
+ return pSrlSentence;
+ }
+
+ private:
+ bool fnReadNextContent(std::vector<std::vector<std::string> >& vecContent) {
+ vecContent.clear();
+ if (feof(m_fpIn) == true) return false;
+ char* pszLine;
+ pszLine = new char[100001];
+ pszLine[0] = '\0';
+ int iLen;
+ while (!feof(m_fpIn)) {
+ fgets(pszLine, 10001, m_fpIn);
+ iLen = strlen(pszLine);
+ while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) {
+ pszLine[iLen - 1] = '\0';
+ iLen--;
+ }
+ if (iLen == 0) break; // end of this sentence
+
+ std::vector<std::string> terms = SplitOnWhitespace(std::string(pszLine));
+ assert(terms.size() > 7);
+ vecContent.push_back(terms);
+ }
+ delete[] pszLine;
+ return true;
+ }
+
+ private:
+ FILE* m_fpIn;
+};
+
+typedef std::unordered_map<std::string, int> Map;
+typedef std::unordered_map<std::string, int>::iterator Iterator;
+
+struct Tsuruoka_Maxent {
+ Tsuruoka_Maxent(const char* pszModelFName) {
+ if (pszModelFName != NULL) {
+ m_pModel = new maxent::ME_Model();
+ m_pModel->load_from_file(pszModelFName);
+ } else
+ m_pModel = NULL;
+ }
+
+ ~Tsuruoka_Maxent() {
+ if (m_pModel != NULL) delete m_pModel;
+ }
+
+ void fnEval(const char* pszContext, std::vector<double>& vecOutput) const {
+ std::vector<std::string> vecContext;
+ maxent::ME_Sample* pmes = new maxent::ME_Sample();
+ SplitOnWhitespace(std::string(pszContext), &vecContext);
+
+ vecOutput.clear();
+
+ for (size_t i = 0; i < vecContext.size(); i++)
+ pmes->add_feature(vecContext[i]);
+ std::vector<double> vecProb = m_pModel->classify(*pmes);
+
+ for (size_t i = 0; i < vecProb.size(); i++) {
+ std::string label = m_pModel->get_class_label(i);
+ vecOutput.push_back(vecProb[i]);
+ }
+ delete pmes;
+ }
+ int fnGetClassId(const std::string& strLabel) const {
+ return m_pModel->get_class_id(strLabel);
+ }
+
+ private:
+ maxent::ME_Model* m_pModel;
+};
+
+// an argument item or a predicate item (the verb itself)
+struct SSRLItem {
+ SSRLItem(const STreeItem *tree_item, std::string role)
+ : tree_item_(tree_item), role_(role) {}
+ ~SSRLItem() {}
+ const STreeItem *tree_item_;
+ const std::string role_;
+};
+
+struct SPredicateItem {
+ SPredicateItem(const SParsedTree *tree, const SPredicate *pred)
+ : pred_(pred) {
+ vec_items_.reserve(pred->m_vecArgt.size() + 1);
+ for (int i = 0; i < pred->m_vecArgt.size(); i++) {
+ vec_items_.push_back(
+ new SSRLItem(pred->m_vecArgt[i]->m_pTreeItem,
+ std::string(pred->m_vecArgt[i]->m_pszRole)));
+ }
+ vec_items_.push_back(
+ new SSRLItem(tree->m_vecTerminals[pred->m_iPosition]->m_ptParent,
+ std::string("Pred")));
+ sort(vec_items_.begin(), vec_items_.end(), SortFunction);
+
+ begin_ = vec_items_[0]->tree_item_->m_iBegin;
+ end_ = vec_items_[vec_items_.size() - 1]->tree_item_->m_iEnd;
+ }
+
+ ~SPredicateItem() { vec_items_.clear(); }
+
+ static bool SortFunction(SSRLItem *i, SSRLItem *j) {
+ return (i->tree_item_->m_iBegin < j->tree_item_->m_iBegin);
+ }
+
+ std::vector<SSRLItem *> vec_items_;
+ int begin_;
+ int end_;
+ const SPredicate *pred_;
+};
+
+struct SArgumentReorderModel {
+ public:
+ static std::string fnGetBlockOutcome(int iBegin, int iEnd,
+ SAlignment *pAlign) {
+ return pAlign->fnIsContinuous(iBegin, iEnd);
+ }
+ static void fnGetReorderType(SPredicateItem *pPredItem, SAlignment *pAlign,
+ std::vector<std::string> &vecStrLeftReorder,
+ std::vector<std::string> &vecStrRightReorder) {
+ std::vector<int> vecLeft, vecRight;
+ for (int i = 0; i < pPredItem->vec_items_.size(); i++) {
+ const STreeItem *pCon1 = pPredItem->vec_items_[i]->tree_item_;
+ int iLeft1, iRight1;
+ pAlign->fnGetLeftRightMost(pCon1->m_iBegin, pCon1->m_iEnd, true, iLeft1,
+ iRight1);
+ vecLeft.push_back(iLeft1);
+ vecRight.push_back(iRight1);
+ }
+ std::vector<int> vecLeftPosition;
+ fnGetRelativePosition(vecLeft, vecLeftPosition);
+ std::vector<int> vecRightPosition;
+ fnGetRelativePosition(vecRight, vecRightPosition);
+
+ vecStrLeftReorder.clear();
+ vecStrRightReorder.clear();
+ for (int i = 1; i < vecLeftPosition.size(); i++) {
+ std::string strOutcome;
+ fnGetOutcome(vecLeftPosition[i - 1], vecLeftPosition[i], strOutcome);
+ vecStrLeftReorder.push_back(strOutcome);
+ fnGetOutcome(vecRightPosition[i - 1], vecRightPosition[i], strOutcome);
+ vecStrRightReorder.push_back(strOutcome);
+ }
+ }
+
+ /*
+ * features:
+ * f1: (left_label, right_label, parent_label)
+ * f2: (left_label, right_label, parent_label, other_right_sibling_label)
+ * f3: (left_label, right_label, parent_label, other_left_sibling_label)
+ * f4: (left_label, right_label, left_head_pos)
+ * f5: (left_label, right_label, left_head_word)
+ * f6: (left_label, right_label, right_head_pos)
+ * f7: (left_label, right_label, right_head_word)
+ * f8: (left_label, right_label, left_chunk_status)
+ * f9: (left_label, right_label, right_chunk_status)
+ * f10: (left_label, parent_label)
+ * f11: (right_label, parent_label)
+ *
+ * f1: (left_role, right_role, predicate_term)
+ * f2: (left_role, right_role, predicate_term, other_right_role)
+ * f3: (left_role, right_role, predicate_term, other_left_role)
+ * f4: (left_role, right_role, left_head_pos)
+ * f5: (left_role, right_role, left_head_word)
+ * f6: (left_role, right_role, left_syntactic_label)
+ * f7: (left_role, right_role, right_head_pos)
+ * f8: (left_role, right_role, right_head_word)
+ * f8: (left_role, right_role, right_syntactic_label)
+ * f8: (left_role, right_role, left_chunk_status)
+ * f9: (left_role, right_role, right_chunk_status)
+ * f10: (left_role, right_role, left_chunk_status)
+ * f11: (left_role, right_role, right_chunk_status)
+ * f12: (left_label, parent_label)
+ * f13: (right_label, parent_label)
+ */
+ static void fnGenerateFeature(const SParsedTree *pTree,
+ const SPredicate *pPred,
+ const SPredicateItem *pPredItem, int iPos,
+ const std::string &strBlock1,
+ const std::string &strBlock2,
+ std::ostringstream &ostr) {
+ SSRLItem *pSRLItem1 = pPredItem->vec_items_[iPos - 1];
+ SSRLItem *pSRLItem2 = pPredItem->vec_items_[iPos];
+ const STreeItem *pCon1 = pSRLItem1->tree_item_;
+ const STreeItem *pCon2 = pSRLItem2->tree_item_;
+
+ std::string left_role = pSRLItem1->role_;
+ std::string right_role = pSRLItem2->role_;
+
+ std::string predicate_term =
+ pTree->m_vecTerminals[pPred->m_iPosition]->m_pszTerm;
+
+ std::vector<std::string> vec_other_right_sibling;
+ for (int i = iPos + 1; i < pPredItem->vec_items_.size(); i++)
+ vec_other_right_sibling.push_back(
+ std::string(pPredItem->vec_items_[i]->role_));
+ if (vec_other_right_sibling.size() == 0)
+ vec_other_right_sibling.push_back(std::string("NULL"));
+
+ std::vector<std::string> vec_other_left_sibling;
+ for (int i = 0; i < iPos - 1; i++)
+ vec_other_right_sibling.push_back(
+ std::string(pPredItem->vec_items_[i]->role_));
+ if (vec_other_left_sibling.size() == 0)
+ vec_other_left_sibling.push_back(std::string("NULL"));
+
+ // generate features
+ // f1
+ ostr << "f1=" << left_role << "_" << right_role << "_" << predicate_term;
+ ostr << "f1=" << left_role << "_" << right_role;
+
+ // f2
+ for (int i = 0; i < vec_other_right_sibling.size(); i++) {
+ ostr << " f2=" << left_role << "_" << right_role << "_" << predicate_term
+ << "_" << vec_other_right_sibling[i];
+ ostr << " f2=" << left_role << "_" << right_role << "_"
+ << vec_other_right_sibling[i];
+ }
+ // f3
+ for (int i = 0; i < vec_other_left_sibling.size(); i++) {
+ ostr << " f3=" << left_role << "_" << right_role << "_" << predicate_term
+ << "_" << vec_other_left_sibling[i];
+ ostr << " f3=" << left_role << "_" << right_role << "_"
+ << vec_other_left_sibling[i];
+ }
+ // f4
+ ostr << " f4=" << left_role << "_" << right_role << "_"
+ << pTree->m_vecTerminals[pCon1->m_iHeadWord]->m_ptParent->m_pszTerm;
+ // f5
+ ostr << " f5=" << left_role << "_" << right_role << "_"
+ << pTree->m_vecTerminals[pCon1->m_iHeadWord]->m_pszTerm;
+ // f6
+ ostr << " f6=" << left_role << "_" << right_role << "_" << pCon2->m_pszTerm;
+ // f7
+ ostr << " f7=" << left_role << "_" << right_role << "_"
+ << pTree->m_vecTerminals[pCon2->m_iHeadWord]->m_ptParent->m_pszTerm;
+ // f8
+ ostr << " f8=" << left_role << "_" << right_role << "_"
+ << pTree->m_vecTerminals[pCon2->m_iHeadWord]->m_pszTerm;
+ // f9
+ ostr << " f9=" << left_role << "_" << right_role << "_" << pCon2->m_pszTerm;
+ // f10
+ ostr << " f10=" << left_role << "_" << right_role << "_" << strBlock1;
+ // f11
+ ostr << " f11=" << left_role << "_" << right_role << "_" << strBlock2;
+ // f12
+ ostr << " f12=" << left_role << "_" << predicate_term;
+ ostr << " f12=" << left_role;
+ // f13
+ ostr << " f13=" << right_role << "_" << predicate_term;
+ ostr << " f13=" << right_role;
+ }
+
+ private:
+ static void fnGetOutcome(int i1, int i2, std::string &strOutcome) {
+ assert(i1 != i2);
+ if (i1 < i2) {
+ if (i2 > i1 + 1)
+ strOutcome = std::string("DM");
+ else
+ strOutcome = std::string("M");
+ } else {
+ if (i1 > i2 + 1)
+ strOutcome = std::string("DS");
+ else
+ strOutcome = std::string("S");
+ }
+ }
+
+ static void fnGetRelativePosition(const std::vector<int> &vecLeft,
+ std::vector<int> &vecPosition) {
+ vecPosition.clear();
+
+ std::vector<float> vec;
+ for (int i = 0; i < vecLeft.size(); i++) {
+ if (vecLeft[i] == -1) {
+ if (i == 0)
+ vec.push_back(-1);
+ else
+ vec.push_back(vecLeft[i - 1] + 0.1);
+ } else
+ vec.push_back(vecLeft[i]);
+ }
+
+ for (int i = 0; i < vecLeft.size(); i++) {
+ int count = 0;
+
+ for (int j = 0; j < vecLeft.size(); j++) {
+ if (j == i) continue;
+ if (vec[j] < vec[i]) {
+ count++;
+ } else if (vec[j] == vec[i] && j < i) {
+ count++;
+ }
+ }
+ vecPosition.push_back(count);
+ }
+ }
+};
+} // namespace const_reorder
+
+#endif // _FF_CONST_REORDER_COMMON_H
diff --git a/decoder/ff_soft_syn.cc b/decoder/ff_soft_syn.cc
new file mode 100644
index 00000000..2d4369d1
--- /dev/null
+++ b/decoder/ff_soft_syn.cc
@@ -0,0 +1,265 @@
+/*
+ * ff_soft_syn.cc
+ *
+ */
+#include "ff_soft_syn.h"
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "hg.h"
+#include "sentence_metadata.h"
+#include "ff_const_reorder_common.h"
+
+#include <string>
+#include <vector>
+#include <stdio.h>
+
+using namespace std;
+using namespace const_reorder;
+
+typedef HASH_MAP<std::string, vector<string> > MapFeatures;
+
+struct SoftSynFeatureImpl {
+ SoftSynFeatureImpl(const string& /*params*/) {
+ parsed_tree_ = NULL;
+ map_features_ = NULL;
+ }
+
+ ~SoftSynFeatureImpl() { FreeSentenceVariables(); }
+
+ void InitializeInputSentence(const std::string& parse_file) {
+ FreeSentenceVariables();
+ parsed_tree_ = ReadParseTree(parse_file);
+
+ // we can do the features "off-line"
+ map_features_ = new MapFeatures();
+ InitializeFeatures(map_features_);
+ }
+
+ /*
+ * ff_const_reorder.cc::ConstReorderFeatureImpl also defines this function
+ */
+ void FindConsts(const SParsedTree* tree, int begin, int end,
+ vector<STreeItem*>& consts) {
+ STreeItem* item;
+ item = tree->m_vecTerminals[begin]->m_ptParent;
+ while (true) {
+ while (item->m_ptParent != NULL &&
+ item->m_ptParent->m_iBegin == item->m_iBegin &&
+ item->m_ptParent->m_iEnd <= end)
+ item = item->m_ptParent;
+
+ if (item->m_ptParent == NULL && item->m_vecChildren.size() == 1 &&
+ strcmp(item->m_pszTerm, "ROOT") == 0)
+ item = item->m_vecChildren[0]; // we automatically add a "ROOT" node at
+ // the top, skip it if necessary.
+
+ consts.push_back(item);
+ if (item->m_iEnd < end)
+ item = tree->m_vecTerminals[item->m_iEnd + 1]->m_ptParent;
+ else
+ break;
+ }
+ }
+
+ /*
+ * according to Marton & Resnik (2008)
+ * a span cann't have both X+ style and X= style features
+ * a constituent XP is crossed only if the span not only covers parts of XP's
+ *content, but also covers one or more words outside XP
+ * a span may have X+, Y+
+ *
+ * (note, we refer X* features to X= features in Marton & Resnik (2008))
+ */
+ void GenerateSoftFeature(int begin, int end, const SParsedTree* tree,
+ vector<string>& vecFeature) {
+ vector<STreeItem*> vecNode;
+ FindConsts(tree, begin, end, vecNode);
+
+ if (vecNode.size() == 1) {
+ // match to one constituent
+ string feature_name = string(vecNode[0]->m_pszTerm) + string("*");
+ vecFeature.push_back(feature_name);
+ } else {
+ // match to multiple constituents, find the lowest common parent (lcp)
+ STreeItem* lcp = vecNode[0];
+ while (lcp->m_iEnd < end) lcp = lcp->m_ptParent;
+
+ for (size_t i = 0; i < vecNode.size(); i++) {
+ STreeItem* item = vecNode[i];
+
+ while (item != lcp) {
+ if (item->m_iBegin < begin || item->m_iEnd > end) {
+ // item is crossed
+ string feature_name = string(item->m_pszTerm) + string("+");
+ vecFeature.push_back(feature_name);
+ }
+ if (item->m_iBrotherIndex > 0 &&
+ item->m_ptParent->m_vecChildren[item->m_iBrotherIndex - 1]
+ ->m_iBegin >= begin &&
+ item->m_ptParent->m_vecChildren[item->m_iBrotherIndex - 1]
+ ->m_iEnd <= end)
+ break; // we don't want to collect crossed constituents twice
+ item = item->m_ptParent;
+ }
+ }
+ }
+ }
+
+ void GenerateSoftFeatureFromFlattenedTree(int begin, int end,
+ const SParsedTree* tree,
+ vector<string>& vecFeature) {
+ vector<STreeItem*> vecNode;
+ FindConsts(tree, begin, end, vecNode);
+
+ if (vecNode.size() == 1) {
+ // match to one constituent
+ string feature_name = string(vecNode[0]->m_pszTerm) + string("*");
+ vecFeature.push_back(feature_name);
+ } else {
+ // match to multiple constituents, see if they have a common parent
+ size_t i = 0;
+ for (i = 1; i < vecNode.size(); i++) {
+ if (vecNode[i]->m_ptParent != vecNode[0]->m_ptParent) break;
+ }
+ if (i == vecNode.size()) {
+ // they share a common parent
+ string feature_name =
+ string(vecNode[0]->m_ptParent->m_pszTerm) + string("&");
+ vecFeature.push_back(feature_name);
+ } else {
+ // they don't share a common parent, find the lowest common parent (lcp)
+ STreeItem* lcp = vecNode[0];
+ while (lcp->m_iEnd < end) lcp = lcp->m_ptParent;
+
+ for (size_t i = 0; i < vecNode.size(); i++) {
+ STreeItem* item = vecNode[i];
+
+ while (item != lcp) {
+ if (item->m_iBegin < begin || item->m_iEnd > end) {
+ // item is crossed
+ string feature_name = string(item->m_pszTerm) + string("+");
+ vecFeature.push_back(feature_name);
+ }
+ if (item->m_iBrotherIndex > 0 &&
+ item->m_ptParent->m_vecChildren[item->m_iBrotherIndex - 1]
+ ->m_iBegin >= begin &&
+ item->m_ptParent->m_vecChildren[item->m_iBrotherIndex - 1]
+ ->m_iEnd <= end)
+ break; // we don't want to collect crossed constituents twice
+ item = item->m_ptParent;
+ }
+ }
+ }
+ }
+ }
+
+ void SetSoftSynFeature(const Hypergraph::Edge& edge,
+ SparseVector<double>* features) {
+ if (parsed_tree_ == NULL) return;
+
+ // soft feature for the whole span
+ const vector<string> vecFeature =
+ GenerateSoftFeature(edge.i_, edge.j_ - 1, map_features_);
+ for (size_t i = 0; i < vecFeature.size(); i++) {
+ int f_id = FD::Convert(vecFeature[i]);
+ if (f_id) features->set_value(f_id, 1);
+ }
+ }
+
+ private:
+ const vector<string>& GenerateSoftFeature(int begin, int end,
+ MapFeatures* map_features) {
+ string key;
+ GenerateKey(begin, end, key);
+ MapFeatures::const_iterator iter = (*map_features).find(key);
+ assert(iter != map_features->end());
+ return iter->second;
+ }
+
+ void Byte_to_Char(unsigned char* str, int n) {
+ str[0] = (n & 255);
+ str[1] = n / 256;
+ }
+
+ void GenerateKey(int begin, int end, string& key) {
+ unsigned char szTerm[1001];
+ Byte_to_Char(szTerm, begin);
+ Byte_to_Char(szTerm + 2, end);
+ szTerm[4] = '\0';
+ key = string(szTerm, szTerm + 4);
+ }
+
+ void InitializeFeatures(MapFeatures* map_features) {
+ if (parsed_tree_ == NULL) return;
+
+ for (size_t i = 0; i < parsed_tree_->m_vecTerminals.size(); i++)
+ for (size_t j = i; j < parsed_tree_->m_vecTerminals.size(); j++) {
+ vector<string> vecFeature;
+ GenerateSoftFeature(i, j, parsed_tree_, vecFeature);
+ string key;
+ GenerateKey(i, j, key);
+ (*map_features)[key] = vecFeature;
+ }
+ }
+
+ void FreeSentenceVariables() {
+ if (parsed_tree_ != NULL) delete parsed_tree_;
+ if (map_features_ != NULL) delete map_features_;
+ }
+
+ SParsedTree* ReadParseTree(const std::string& parse_file) {
+ SParseReader* reader = new SParseReader(parse_file.c_str(), false);
+ SParsedTree* tree = reader->fnReadNextParseTree();
+ // assert(tree != NULL);
+ delete reader;
+ return tree;
+ }
+
+ private:
+ SParsedTree* parsed_tree_;
+
+ MapFeatures* map_features_;
+};
+
+SoftSynFeature::SoftSynFeature(std::string param) {
+ pimpl_ = new SoftSynFeatureImpl(param);
+ name_ = "SoftSynFeature";
+}
+
+SoftSynFeature::~SoftSynFeature() { delete pimpl_; }
+
+void SoftSynFeature::PrepareForInput(const SentenceMetadata& smeta) {
+ string parse_file = smeta.GetSGMLValue("parse");
+ if (parse_file.empty()) {
+ parse_file = smeta.GetSGMLValue("src_tree");
+ }
+ assert(parse_file.size());
+ pimpl_->InitializeInputSentence(parse_file);
+}
+
+void SoftSynFeature::TraversalFeaturesImpl(
+ const SentenceMetadata& /*smeta*/, const Hypergraph::Edge& edge,
+ const vector<const void*>& /*ant_states*/, SparseVector<double>* features,
+ SparseVector<double>* /*estimated_features*/, void* /*state*/) const {
+ pimpl_->SetSoftSynFeature(edge, features);
+}
+
+string SoftSynFeature::usage(bool /*param*/, bool /*verbose*/) {
+ return "SoftSynFeature";
+}
+
+boost::shared_ptr<FeatureFunction> CreateSoftSynFeatureModel(
+ std::string param) {
+ SoftSynFeature* ret = new SoftSynFeature(param);
+ return boost::shared_ptr<FeatureFunction>(ret);
+}
+
+boost::shared_ptr<FeatureFunction> SoftSynFeatureFactory::Create(
+ std::string param) const {
+ return CreateSoftSynFeatureModel(param);
+}
+
+std::string SoftSynFeatureFactory::usage(bool params, bool verbose) const {
+ return SoftSynFeature::usage(params, verbose);
+}
diff --git a/decoder/ff_soft_syn.h b/decoder/ff_soft_syn.h
new file mode 100644
index 00000000..df9a6cc8
--- /dev/null
+++ b/decoder/ff_soft_syn.h
@@ -0,0 +1,38 @@
+/*
+ * ff_soft_syn.h
+ *
+ */
+
+#ifndef FF_SOFT_SYN_H_
+#define FF_SOFT_SYN_H_
+
+#include "ff_factory.h"
+#include "ff.h"
+
+struct SoftSynFeatureImpl;
+
+class SoftSynFeature : public FeatureFunction {
+ public:
+ SoftSynFeature(std::string param);
+ ~SoftSynFeature();
+ static std::string usage(bool param, bool verbose);
+
+ protected:
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+
+ virtual void TraversalFeaturesImpl(
+ const SentenceMetadata& smeta, const HG::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features, SparseVector<double>* estimated_features,
+ void* out_context) const;
+
+ private:
+ SoftSynFeatureImpl* pimpl_;
+};
+
+struct SoftSynFeatureFactory : public FactoryBase<FeatureFunction> {
+ FP Create(std::string param) const;
+ std::string usage(bool params, bool verbose) const;
+};
+
+#endif /* FF_SOFT_SYN_H_ */
diff --git a/decoder/ffset.cc b/decoder/ffset.cc
index 5820f421..8ba70389 100644
--- a/decoder/ffset.cc
+++ b/decoder/ffset.cc
@@ -14,6 +14,11 @@ ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*>
for (int i = 0; i < models_.size(); ++i) {
model_state_pos_[i] = state_size_;
state_size_ += models_[i]->StateSize();
+ int num_ignored_bytes = models_[i]->IgnoredStateSize();
+ if (num_ignored_bytes > 0) {
+ ranges_to_erase_.push_back(
+ {state_size_ - num_ignored_bytes, state_size_});
+ }
}
}
@@ -70,3 +75,13 @@ void ModelSet::AddFinalFeatures(const FFState& state, HG::Edge* edge,SentenceMet
edge->edge_prob_.logeq(edge->feature_values_.dot(weights_));
}
+bool ModelSet::NeedsStateErasure() const { return !ranges_to_erase_.empty(); }
+
+void ModelSet::EraseIgnoredBytes(FFState* state) const {
+ // TODO: can we memset?
+ for (const auto& range : ranges_to_erase_) {
+ for (int i = range.first; i < range.second; ++i) {
+ (*state)[i] = 0;
+ }
+ }
+}
diff --git a/decoder/ffset.h b/decoder/ffset.h
index b7322ee2..84f9fdb9 100644
--- a/decoder/ffset.h
+++ b/decoder/ffset.h
@@ -1,6 +1,7 @@
#ifndef FFSET_H_
#define FFSET_H_
+#include <utility>
#include <vector>
#include "value_array.h"
#include "prob.h"
@@ -47,11 +48,18 @@ class ModelSet {
bool stateless() const { return !state_size_; }
+ // Part of a feature state may be used for storing some side data for
+ // calculating feature values but not necessary for splitting hypernodes. Such
+ // bytes needs to be erased for hypernode splitting.
+ bool NeedsStateErasure() const;
+ void EraseIgnoredBytes(FFState* state) const;
+
private:
std::vector<const FeatureFunction*> models_;
const std::vector<double>& weights_;
int state_size_;
std::vector<int> model_state_pos_;
+ std::vector<std::pair<int, int> > ranges_to_erase_;
};
#endif
diff --git a/decoder/forest_writer.cc b/decoder/forest_writer.cc
index 6e4cccb3..cc9094d7 100644
--- a/decoder/forest_writer.cc
+++ b/decoder/forest_writer.cc
@@ -11,13 +11,13 @@
using namespace std;
ForestWriter::ForestWriter(const std::string& path, int num) :
- fname_(path + '/' + boost::lexical_cast<string>(num) + ".json.gz"), used_(false) {}
+ fname_(path + '/' + boost::lexical_cast<string>(num) + ".bin.gz"), used_(false) {}
-bool ForestWriter::Write(const Hypergraph& forest, bool minimal_rules) {
+bool ForestWriter::Write(const Hypergraph& forest) {
assert(!used_);
used_ = true;
cerr << " Writing forest to " << fname_ << endl;
WriteFile wf(fname_);
- return HypergraphIO::WriteToJSON(forest, minimal_rules, wf.stream());
+ return HypergraphIO::WriteToBinary(forest, wf.stream());
}
diff --git a/decoder/forest_writer.h b/decoder/forest_writer.h
index 4d28de77..54e83470 100644
--- a/decoder/forest_writer.h
+++ b/decoder/forest_writer.h
@@ -7,7 +7,7 @@ class Hypergraph;
struct ForestWriter {
ForestWriter(const std::string& path, int num);
- bool Write(const Hypergraph& forest, bool minimal_rules);
+ bool Write(const Hypergraph& forest);
const std::string fname_;
bool used_;
diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc
index 50e6adcc..fe28f4c6 100644
--- a/decoder/fst_translator.cc
+++ b/decoder/fst_translator.cc
@@ -27,11 +27,15 @@ struct FSTTranslatorImpl {
const vector<double>& weights,
Hypergraph* forest) {
bool composed = false;
- if (input.find("{\"rules\"") == 0) {
+ if (input.find("::forest::") == 0) {
istringstream is(input);
+ string header, fname;
+ is >> header >> fname;
+ ReadFile rf(fname);
+ if (!rf) { cerr << "Failed to open " << fname << endl; abort(); }
Hypergraph src_cfg_hg;
- if (!HypergraphIO::ReadFromJSON(&is, &src_cfg_hg)) {
- cerr << "Failed to read HG from JSON.\n";
+ if (!HypergraphIO::ReadFromBinary(rf.stream(), &src_cfg_hg)) {
+ cerr << "Failed to read HG.\n";
abort();
}
if (add_pass_through_rules) {
diff --git a/decoder/hg.h b/decoder/hg.h
index 256f650f..c756012e 100644
--- a/decoder/hg.h
+++ b/decoder/hg.h
@@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include <boost/shared_ptr.hpp>
+#include <boost/serialization/vector.hpp>
#include "feature_vector.h"
#include "small_vector.h"
@@ -69,6 +70,18 @@ namespace HG {
short int j_;
short int prev_i_;
short int prev_j_;
+ template<class Archive>
+ void serialize(Archive & ar, const unsigned int /*version*/) {
+ ar & head_node_;
+ ar & tail_nodes_;
+ ar & rule_;
+ ar & feature_values_;
+ ar & i_;
+ ar & j_;
+ ar & prev_i_;
+ ar & prev_j_;
+ ar & id_;
+ }
void show(std::ostream &o,unsigned mask=SPAN|RULE) const {
o<<'{';
if (mask&CATEGORY)
@@ -149,6 +162,24 @@ namespace HG {
WordID NT() const { return -cat_; }
EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_
EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_
+ template<class Archive>
+ void save(Archive & ar, const unsigned int /*version*/) const {
+ ar & node_hash;
+ ar & id_;
+ ar & TD::Convert(-cat_);
+ ar & in_edges_;
+ ar & out_edges_;
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int /*version*/) {
+ ar & node_hash;
+ ar & id_;
+ std::string cat; ar & cat;
+ cat_ = -TD::Convert(cat);
+ ar & in_edges_;
+ ar & out_edges_;
+ }
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting
node_hash = o.node_hash;
cat_=o.cat_;
@@ -492,6 +523,27 @@ public:
void set_ids(); // resync edge,node .id_
void check_ids() const; // assert that .id_ have been kept in sync
+ template<class Archive>
+ void save(Archive & ar, const unsigned int /*version*/) const {
+ unsigned ns = nodes_.size(); ar & ns;
+ unsigned es = edges_.size(); ar & es;
+ for (auto& n : nodes_) ar & n;
+ for (auto& e : edges_) ar & e;
+ int x;
+ x = edges_topo_; ar & x;
+ x = is_linear_chain_; ar & x;
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int /*version*/) {
+ unsigned ns; ar & ns; nodes_.resize(ns);
+ unsigned es; ar & es; edges_.resize(es);
+ for (auto& n : nodes_) ar & n;
+ for (auto& e : edges_) ar & e;
+ int x;
+ ar & x; edges_topo_ = x;
+ ar & x; is_linear_chain_ = x;
+ }
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
private:
Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {}
};
diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc
index eb0be3d4..626b2954 100644
--- a/decoder/hg_io.cc
+++ b/decoder/hg_io.cc
@@ -6,362 +6,27 @@
#include <sstream>
#include <iostream>
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/archive/binary_oarchive.hpp>
+#include <boost/serialization/shared_ptr.hpp>
+
#include "fast_lexical_cast.hpp"
#include "tdict.h"
-#include "json_parse.h"
#include "hg.h"
using namespace std;
-struct HGReader : public JSONParser {
- HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; }
-
- void CreateNode(const string& cat, const string& shash, const vector<int>& in_edges) {
- WordID c = TD::Convert("X") * -1;
- if (!cat.empty()) c = TD::Convert(cat) * -1;
- Hypergraph::Node* node = hg.AddNode(c);
- char* dend;
- if (shash.size())
- node->node_hash = strtoull(shash.c_str(), &dend, 16);
- else
- node->node_hash = 0;
- for (int i = 0; i < in_edges.size(); ++i) {
- if (in_edges[i] >= hg.edges_.size()) {
- cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i]
- << ", but hg only has " << hg.edges_.size() << " edges!\n";
- abort();
- }
- hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node);
- }
- }
- void CreateEdge(const TRulePtr& rule, SparseVector<double>* feats, const SmallVectorUnsigned& tail) {
- Hypergraph::Edge* edge = hg.AddEdge(rule, tail);
- feats->swap(edge->feature_values_);
- edge->i_ = spans[0];
- edge->j_ = spans[1];
- edge->prev_i_ = spans[2];
- edge->prev_j_ = spans[3];
- }
-
- bool HandleJSONEvent(int type, const JSON_value* value) {
- switch(state) {
- case -1:
- assert(type == JSON_T_OBJECT_BEGIN);
- state = 0;
- break;
- case 0:
- if (type == JSON_T_OBJECT_END) {
- //cerr << "HG created\n"; // TODO, signal some kind of callback
- } else if (type == JSON_T_KEY) {
- string val = value->vu.str.value;
- if (val == "features") { assert(fdict.empty()); state = 1; }
- else if (val == "is_sorted") { state = 3; }
- else if (val == "rules") { assert(rules.empty()); state = 4; }
- else if (val == "node") { state = 8; }
- else if (val == "edges") { state = 13; }
- else { cerr << "Unexpected key: " << val << endl; return false; }
- }
- break;
-
- // features
- case 1:
- if(type == JSON_T_NULL) { state = 0; break; }
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 2;
- break;
- case 2:
- if(type == JSON_T_ARRAY_END) { state = 0; break; }
- assert(type == JSON_T_STRING);
- fdict.push_back(FD::Convert(value->vu.str.value));
- assert(fdict.back() > 0);
- break;
-
- // is_sorted
- case 3:
- assert(type == JSON_T_TRUE || type == JSON_T_FALSE);
- is_sorted = (type == JSON_T_TRUE);
- if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; }
- state = 0;
- break;
-
- // rules
- case 4:
- if(type == JSON_T_NULL) { state = 0; break; }
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 5;
- break;
- case 5:
- if(type == JSON_T_ARRAY_END) { state = 0; break; }
- assert(type == JSON_T_INTEGER);
- state = 6;
- rule_id = value->vu.integer_value;
- break;
- case 6:
- assert(type == JSON_T_STRING);
- rules[rule_id] = TRulePtr(new TRule(value->vu.str.value));
- state = 5;
- break;
-
- // Nodes
- case 8:
- assert(type == JSON_T_OBJECT_BEGIN);
- ++nodes;
- in_edges.clear();
- cat.clear();
- shash.clear();
- state = 9; break;
- case 9:
- if (type == JSON_T_OBJECT_END) {
- //cerr << "Creating NODE\n";
- CreateNode(cat, shash, in_edges);
- state = 0; break;
- }
- assert(type == JSON_T_KEY);
- cur_key = value->vu.str.value;
- if (cur_key == "cat") { assert(cat.empty()); state = 10; break; }
- if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; }
- if (cur_key == "node_hash") { assert(shash.empty()); state = 24; break; }
- cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n";
- return false;
- case 10:
- assert(type == JSON_T_STRING || type == JSON_T_NULL);
- cat = value->vu.str.value;
- state = 9; break;
- case 11:
- if (type == JSON_T_NULL) { state = 9; break; }
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 12; break;
- case 12:
- if (type == JSON_T_ARRAY_END) { state = 9; break; }
- assert(type == JSON_T_INTEGER);
- //cerr << "in_edges: " << value->vu.integer_value << endl;
- in_edges.push_back(value->vu.integer_value);
- break;
-
- // "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12},
- // { "tail": null, "feats" : [0,0.87,1,0.02], "spans":[1,2,3,4], "rule": 17},
- // { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}]
- case 13:
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 14;
- break;
- case 14:
- if (type == JSON_T_ARRAY_END) { state = 0; break; }
- assert(type == JSON_T_OBJECT_BEGIN);
- //cerr << "New edge\n";
- ++edges;
- cur_rule.reset(); feats.clear(); tail.clear();
- state = 15; break;
- case 15:
- if (type == JSON_T_OBJECT_END) {
- CreateEdge(cur_rule, &feats, tail);
- state = 14; break;
- }
- assert(type == JSON_T_KEY);
- cur_key = value->vu.str.value;
- //cerr << "edge key " << cur_key << endl;
- if (cur_key == "rule") { assert(!cur_rule); state = 16; break; }
- if (cur_key == "spans") { assert(!cur_rule); state = 22; break; }
- if (cur_key == "feats") { assert(feats.empty()); state = 17; break; }
- if (cur_key == "tail") { assert(tail.empty()); state = 20; break; }
- cerr << "Unexpected key " << cur_key << " in edge specification\n";
- return false;
- case 16: // edge.rule
- if (type == JSON_T_INTEGER) {
- int rule_id = value->vu.integer_value;
- if (rules.find(rule_id) == rules.end()) {
- // rules list must come before the edge definitions!
- cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n";
- return false;
- }
- cur_rule = rules[rule_id];
- } else if (type == JSON_T_STRING) {
- cur_rule.reset(new TRule(value->vu.str.value));
- } else {
- cerr << "Rule must be either a rule id or a rule string" << endl;
- return false;
- }
- // cerr << "Edge: rule=" << cur_rule->AsString() << endl;
- state = 15;
- break;
- case 17: // edge.feats
- if (type == JSON_T_NULL) { state = 15; break; }
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 18; break;
- case 18:
- if (type == JSON_T_ARRAY_END) { state = 15; break; }
- if (type != JSON_T_INTEGER && type != JSON_T_STRING) {
- cerr << "Unexpected feature id type\n"; return false;
- }
- if (type == JSON_T_INTEGER) {
- fid = value->vu.integer_value;
- assert(fid < fdict.size());
- fid = fdict[fid];
- } else if (JSON_T_STRING) {
- fid = FD::Convert(value->vu.str.value);
- } else { abort(); }
- state = 19;
- break;
- case 19:
- {
- assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT);
- double val = (type == JSON_T_INTEGER ? static_cast<double>(value->vu.integer_value) :
- strtod(value->vu.str.value, NULL));
- feats.set_value(fid, val);
- state = 18;
- break;
- }
- case 20: // edge.tail
- if (type == JSON_T_NULL) { state = 15; break; }
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 21; break;
- case 21:
- if (type == JSON_T_ARRAY_END) { state = 15; break; }
- assert(type == JSON_T_INTEGER);
- tail.push_back(value->vu.integer_value);
- break;
- case 22: // edge.spans
- assert(type == JSON_T_ARRAY_BEGIN);
- state = 23;
- spans[0] = spans[1] = spans[2] = spans[3] = -1;
- spanc = 0;
- break;
- case 23:
- if (type == JSON_T_ARRAY_END) { state = 15; break; }
- assert(type == JSON_T_INTEGER);
- assert(spanc < 4);
- spans[spanc] = value->vu.integer_value;
- ++spanc;
- break;
- case 24: // read node hash
- assert(type == JSON_T_STRING);
- shash = value->vu.str.value;
- state = 9;
- break;
- }
- return true;
- }
- string rp;
- string cat;
- SmallVectorUnsigned tail;
- vector<int> in_edges;
- string shash;
- TRulePtr cur_rule;
- map<int, TRulePtr> rules;
- vector<int> fdict;
- SparseVector<double> feats;
- int state;
- int fid;
- int nodes;
- int edges;
- int spans[4];
- int spanc;
- string cur_key;
- Hypergraph& hg;
- int rule_id;
- bool nodes_needed;
- bool edges_needed;
- bool is_sorted;
-};
-
-bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) {
+bool HypergraphIO::ReadFromBinary(istream* in, Hypergraph* hg) {
+ boost::archive::binary_iarchive oa(*in);
hg->clear();
- HGReader reader(hg);
- return reader.Parse(in);
-}
-
-static void WriteRule(const TRule& r, ostream* out) {
- if (!r.lhs_) { (*out) << "[X] ||| "; }
- JSONParser::WriteEscapedString(r.AsString(), out);
+ oa >> *hg;
+ return true;
}
-bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) {
- if (hg.empty()) { *out << "{}\n"; return true; }
- map<const TRule*, int> rid;
- ostream& o = *out;
- rid[NULL] = 0;
- o << '{';
- if (!remove_rules) {
- o << "\"rules\":[";
- for (int i = 0; i < hg.edges_.size(); ++i) {
- const TRule* r = hg.edges_[i].rule_.get();
- int &id = rid[r];
- if (!id) {
- id=rid.size() - 1;
- if (id > 1) o << ',';
- o << id << ',';
- WriteRule(*r, &o);
- };
- }
- o << "],";
- }
- const bool use_fdict = FD::NumFeats() < 1000;
- if (use_fdict) {
- o << "\"features\":[";
- for (int i = 1; i < FD::NumFeats(); ++i) {
- o << (i==1 ? "":",");
- JSONParser::WriteEscapedString(FD::Convert(i), &o);
- }
- o << "],";
- }
- vector<int> edgemap(hg.edges_.size(), -1); // edges may be in non-topo order
- int edge_count = 0;
- for (int i = 0; i < hg.nodes_.size(); ++i) {
- const Hypergraph::Node& node = hg.nodes_[i];
- if (i > 0) { o << ","; }
- o << "\"edges\":[";
- for (int j = 0; j < node.in_edges_.size(); ++j) {
- const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]];
- edgemap[edge.id_] = edge_count;
- ++edge_count;
- o << (j == 0 ? "" : ",") << "{";
-
- o << "\"tail\":[";
- for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
- o << (k > 0 ? "," : "") << edge.tail_nodes_[k];
- }
- o << "],";
-
- o << "\"spans\":[" << edge.i_ << "," << edge.j_ << "," << edge.prev_i_ << "," << edge.prev_j_ << "],";
-
- o << "\"feats\":[";
- bool first = true;
- for (SparseVector<double>::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) {
- if (!it->second) continue; // don't write features that have a zero value
- if (!it->first) continue; // if the feature set was frozen this might happen
- if (!first) o << ',';
- if (use_fdict)
- o << (it->first - 1);
- else {
- JSONParser::WriteEscapedString(FD::Convert(it->first), &o);
- }
- o << ',' << it->second;
- first = false;
- }
- o << "]";
- if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; }
- o << "}";
- }
- o << "],";
-
- o << "\"node\":{\"in_edges\":[";
- for (int j = 0; j < node.in_edges_.size(); ++j) {
- int mapped_edge = edgemap[node.in_edges_[j]];
- assert(mapped_edge >= 0);
- o << (j == 0 ? "" : ",") << mapped_edge;
- }
- o << "]";
- if (node.cat_ < 0) {
- o << ",\"cat\":";
- JSONParser::WriteEscapedString(TD::Convert(node.cat_ * -1), &o);
- }
- char buf[48];
- sprintf(buf, "%016lX", node.node_hash);
- o << ",\"node_hash\":\"" << buf << "\"";
- o << "}";
- }
- o << "}\n";
+bool HypergraphIO::WriteToBinary(const Hypergraph& hg, ostream* out) {
+ boost::archive::binary_oarchive oa(*out);
+ oa << hg;
return true;
}
diff --git a/decoder/hg_io.h b/decoder/hg_io.h
index 5a2bd808..93a9e280 100644
--- a/decoder/hg_io.h
+++ b/decoder/hg_io.h
@@ -9,19 +9,11 @@ class Hypergraph;
struct HypergraphIO {
- // the format is basically a list of nodes and edges in topological order
- // any edge you read, you must have already read its tail nodes
- // any node you read, you must have already read its incoming edges
- // this may make writing a bit more challenging if your forest is not
- // topologically sorted (but that probably doesn't happen very often),
- // but it makes reading much more memory efficient.
- // see test_data/small.json.gz for an email encoding
- static bool ReadFromJSON(std::istream* in, Hypergraph* out);
+ static bool ReadFromBinary(std::istream* in, Hypergraph* out);
+ static bool WriteToBinary(const Hypergraph& hg, std::ostream* out);
// if remove_rules is used, the hypergraph is serialized without rule information
// (so it only contains structure and feature information)
- static bool WriteToJSON(const Hypergraph& hg, bool remove_rules, std::ostream* out);
-
static void WriteAsCFG(const Hypergraph& hg);
// Write only the target size information in bottom-up order.
diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc
index 5cb8626a..366b269d 100644
--- a/decoder/hg_test.cc
+++ b/decoder/hg_test.cc
@@ -1,10 +1,14 @@
#define BOOST_TEST_MODULE hg_test
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/serialization/shared_ptr.hpp>
+#include <boost/serialization/vector.hpp>
+#include <sstream>
#include <iostream>
#include "tdict.h"
-#include "json_parse.h"
#include "hg_intersect.h"
#include "hg_union.h"
#include "viterbi.h"
@@ -394,16 +398,6 @@ BOOST_AUTO_TEST_CASE(Small) {
BOOST_CHECK_CLOSE(2.1431036, log(c2), 1e-4);
}
-BOOST_AUTO_TEST_CASE(JSONTest) {
- std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
- ostringstream os;
- JSONParser::WriteEscapedString("\"I don't know\", she said.", &os);
- BOOST_CHECK_EQUAL("\"\\\"I don't know\\\", she said.\"", os.str());
- ostringstream os2;
- JSONParser::WriteEscapedString("yes", &os2);
- BOOST_CHECK_EQUAL("\"yes\"", os2.str());
-}
-
BOOST_AUTO_TEST_CASE(TestGenericKBest) {
std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
Hypergraph hg;
@@ -427,19 +421,29 @@ BOOST_AUTO_TEST_CASE(TestGenericKBest) {
}
}
-BOOST_AUTO_TEST_CASE(TestReadWriteHG) {
+BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) {
std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
- Hypergraph hg,hg2;
- CreateHG(path, &hg);
- hg.edges_.front().j_ = 23;
- hg.edges_.back().prev_i_ = 99;
- ostringstream os;
- HypergraphIO::WriteToJSON(hg, false, &os);
- istringstream is(os.str());
- HypergraphIO::ReadFromJSON(&is, &hg2);
- BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths());
- BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23);
- BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99);
+ Hypergraph hg;
+ Hypergraph hg2;
+ std::string out;
+ {
+ CreateHG(path, &hg);
+ hg.edges_.front().j_ = 23;
+ hg.edges_.back().prev_i_ = 99;
+ ostringstream os;
+ boost::archive::text_oarchive oa(os);
+ oa << hg;
+ out = os.str();
+ }
+ {
+ cerr << out << endl;
+ istringstream is(out);
+ boost::archive::text_iarchive ia(is);
+ ia >> hg2;
+ BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths());
+ BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23);
+ BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99);
+ }
}
BOOST_AUTO_TEST_SUITE_END()
diff --git a/decoder/hg_test.h b/decoder/hg_test.h
index b7bab3c2..575b9c54 100644
--- a/decoder/hg_test.h
+++ b/decoder/hg_test.h
@@ -12,12 +12,8 @@ namespace {
typedef char const* Name;
-Name urdu_json="urdu.json.gz";
-Name urdu_wts="Arity_0 1.70741473606976 Arity_1 1.12426238048012 Arity_2 1.14986187839554 Glue -0.04589037041388 LanguageModel 1.09051 PassThrough -3.66226367902928 PhraseModel_0 -1.94633451863252 PhraseModel_1 -0.1475347695476 PhraseModel_2 -1.614818994946 WordPenalty -3.0 WordPenaltyFsa -0.56028442964748 ShorterThanPrev -10 LongerThanPrev -10";
-Name small_json="small.json.gz";
+Name small_json="small.bin.gz";
Name small_wts="Model_0 -2 Model_1 -.5 Model_2 -1.1 Model_3 -1 Model_4 -1 Model_5 .5 Model_6 .2 Model_7 -.3";
-Name perro_json="perro.json.gz";
-Name perro_wts="SameFirstLetter 1 LongerThanPrev 1 ShorterThanPrev 1 GlueTop 0.0 Glue -1.0 EgivenF -0.5 FgivenE -0.5 LexEgivenF -0.5 LexFgivenE -0.5 LM 1";
}
@@ -32,7 +28,7 @@ struct HGSetup {
static void JsonFile(Hypergraph *hg,std::string f) {
ReadFile rf(f);
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
}
static void JsonTestFile(Hypergraph *hg,std::string path,std::string n) {
JsonFile(hg,path + "/"+n);
@@ -48,35 +44,35 @@ void AddNullEdge(Hypergraph* hg) {
}
void HGSetup::CreateTinyLatticeHG(const std::string& path,Hypergraph* hg) {
- ReadFile rf(path + "/hg_test.tiny_lattice");
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ ReadFile rf(path + "/hg_test.tiny_lattice.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
AddNullEdge(hg);
}
void HGSetup::CreateLatticeHG(const std::string& path,Hypergraph* hg) {
- ReadFile rf(path + "/hg_test.lattice");
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ ReadFile rf(path + "/hg_test.lattice.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
AddNullEdge(hg);
}
void HGSetup::CreateHG_tiny(const std::string& path, Hypergraph* hg) {
- ReadFile rf(path + "/hg_test.tiny");
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ ReadFile rf(path + "/hg_test.tiny.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
}
void HGSetup::CreateHG_int(const std::string& path,Hypergraph* hg) {
- ReadFile rf(path + "/hg_test.hg_int");
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ ReadFile rf(path + "/hg_test.hg_int.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
}
void HGSetup::CreateHG(const std::string& path,Hypergraph* hg) {
- ReadFile rf(path + "/hg_test.hg");
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ ReadFile rf(path + "/hg_test.hg.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
}
void HGSetup::CreateHGBalanced(const std::string& path,Hypergraph* hg) {
- ReadFile rf(path + "/hg_test.hg_balanced");
- HypergraphIO::ReadFromJSON(rf.stream(), hg);
+ ReadFile rf(path + "/hg_test.hg_balanced.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), hg);
}
#endif
diff --git a/decoder/json_parse.cc b/decoder/json_parse.cc
deleted file mode 100644
index f6fdfea8..00000000
--- a/decoder/json_parse.cc
+++ /dev/null
@@ -1,50 +0,0 @@
-#include "json_parse.h"
-
-#include <string>
-#include <iostream>
-
-using namespace std;
-
-static const char *json_hex_chars = "0123456789abcdef";
-
-void JSONParser::WriteEscapedString(const string& in, ostream* out) {
- int pos = 0;
- int start_offset = 0;
- unsigned char c = 0;
- (*out) << '"';
- while(pos < in.size()) {
- c = in[pos];
- switch(c) {
- case '\b':
- case '\n':
- case '\r':
- case '\t':
- case '"':
- case '\\':
- case '/':
- if(pos - start_offset > 0)
- (*out) << in.substr(start_offset, pos - start_offset);
- if(c == '\b') (*out) << "\\b";
- else if(c == '\n') (*out) << "\\n";
- else if(c == '\r') (*out) << "\\r";
- else if(c == '\t') (*out) << "\\t";
- else if(c == '"') (*out) << "\\\"";
- else if(c == '\\') (*out) << "\\\\";
- else if(c == '/') (*out) << "\\/";
- start_offset = ++pos;
- break;
- default:
- if(c < ' ') {
- cerr << "Warning, bad character (" << static_cast<int>(c) << ") in string\n";
- if(pos - start_offset > 0)
- (*out) << in.substr(start_offset, pos - start_offset);
- (*out) << "\\u00" << json_hex_chars[c >> 4] << json_hex_chars[c & 0xf];
- start_offset = ++pos;
- } else pos++;
- }
- }
- if(pos - start_offset > 0)
- (*out) << in.substr(start_offset, pos - start_offset);
- (*out) << '"';
-}
-
diff --git a/decoder/json_parse.h b/decoder/json_parse.h
deleted file mode 100644
index 85e2eff1..00000000
--- a/decoder/json_parse.h
+++ /dev/null
@@ -1,58 +0,0 @@
-#ifndef JSON_WRAPPER_H_
-#define JSON_WRAPPER_H_
-
-#include <iostream>
-#include <cassert>
-#include "JSON_parser.h"
-
-class JSONParser {
- public:
- JSONParser() {
- init_JSON_config(&config);
- hack.mf = &JSONParser::Callback;
- config.depth = 10;
- config.callback_ctx = reinterpret_cast<void*>(this);
- config.callback = hack.cb;
- config.allow_comments = 1;
- config.handle_floats_manually = 1;
- jc = new_JSON_parser(&config);
- }
- virtual ~JSONParser() {
- delete_JSON_parser(jc);
- }
- bool Parse(std::istream* in) {
- int count = 0;
- int lc = 1;
- for (; in ; ++count) {
- int next_char = in->get();
- if (!in->good()) break;
- if (lc == '\n') { ++lc; }
- if (!JSON_parser_char(jc, next_char)) {
- std::cerr << "JSON_parser_char: syntax error, line " << lc << " (byte " << count << ")" << std::endl;
- return false;
- }
- }
- if (!JSON_parser_done(jc)) {
- std::cerr << "JSON_parser_done: syntax error\n";
- return false;
- }
- return true;
- }
- static void WriteEscapedString(const std::string& in, std::ostream* out);
- protected:
- virtual bool HandleJSONEvent(int type, const JSON_value* value) = 0;
- private:
- int Callback(int type, const JSON_value* value) {
- if (HandleJSONEvent(type, value)) return 1;
- return 0;
- }
- JSON_parser_struct* jc;
- JSON_config config;
- typedef int (JSONParser::* MF)(int type, const struct JSON_value_struct* value);
- union CBHack {
- JSON_parser_callback cb;
- MF mf;
- } hack;
-};
-
-#endif
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index d2c4715c..cd587833 100644
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -21,6 +21,7 @@
#include "kbest.h"
#include "timing_stats.h"
#include "sentences.h"
+#include "b64featvector.h"
//TODO: put function impls into .cc
//TODO: move Translation into its own .h and use in cdec
@@ -252,19 +253,31 @@ struct OracleBleu {
}
bool show_derivation;
+ int show_derivation_mask;
+
template <class Filter>
- void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) {
+ void kbest(int sent_id, Hypergraph const& forest, int k, bool mr_mira_compat,
+ int src_len, std::ostream& kbest_out = std::cout,
+ std::ostream& deriv_out = std::cerr) {
using namespace std;
using namespace boost;
typedef KBest::KBestDerivations<Sentence, ESentenceTraversal,Filter> K;
K kbest(forest,k);
//add length (f side) src length of this sentence to the psuedo-doc src length count
float curr_src_length = doc_src_length + tmp_src_length;
- for (int i = 0; i < k; ++i) {
+ if (mr_mira_compat) kbest_out << k << "\n";
+ int i = 0;
+ for (; i < k; ++i) {
typename K::Derivation *d = kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
- kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- << d->feature_values << " ||| " << log(d->score);
+ kbest_out << sent_id << " ||| ";
+ if (mr_mira_compat) kbest_out << src_len << " ||| ";
+ kbest_out << TD::GetString(d->yield) << " ||| ";
+ if (mr_mira_compat)
+ kbest_out << EncodeFeatureVector(d->feature_values);
+ else
+ kbest_out << d->feature_values;
+ kbest_out << " ||| " << log(d->score);
if (!refs.empty()) {
ScoreP sentscore = GetScore(d->yield,sent_id);
sentscore->PlusEquals(*doc_score,float(1));
@@ -275,14 +288,21 @@ struct OracleBleu {
if (show_derivation) {
deriv_out<<"\nsent_id="<<sent_id<<"."<<i<<" ||| "; //where i is candidate #/k
deriv_out<<log(d->score)<<"\n";
- deriv_out<<kbest.derivation_tree(*d,true);
+ deriv_out<<kbest.derivation_tree(*d,true, show_derivation_mask);
deriv_out<<"\n"<<flush;
}
}
+ if (mr_mira_compat) {
+ for (; i < k; ++i) kbest_out << "\n";
+ kbest_out << flush;
+ }
}
// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
- void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_, std::string const &deriv_out_filename_) {
+ void DumpKBest(const int sent_id, const Hypergraph& forest, const int k,
+ const bool unique, const bool mr_mira_compat,
+ const int src_len, std::string const& kbest_out_filename_,
+ std::string const& deriv_out_filename_) {
WriteFile ko(kbest_out_filename_);
std::cerr << "Output kbest to " << kbest_out_filename_ <<std::endl;
@@ -295,9 +315,11 @@ struct OracleBleu {
WriteFile oderiv(sderiv.str());
if (!unique)
- kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),oderiv.get());
+ kbest<KBest::NoFilter<std::vector<WordID> > >(
+ sent_id, forest, k, mr_mira_compat, src_len, ko.get(), oderiv.get());
else {
- kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),oderiv.get());
+ kbest<KBest::FilterUnique>(sent_id, forest, k, mr_mira_compat, src_len,
+ ko.get(), oderiv.get());
}
}
@@ -305,7 +327,8 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo
{
std::ostringstream kbest_string_stream;
kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id;
- DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), "-");
+ DumpKBest(sent_id, forest, k, unique, false, -1, kbest_string_stream.str(),
+ "-");
}
};
diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc
index 18c83c56..2c5fa9c4 100644
--- a/decoder/rescore_translator.cc
+++ b/decoder/rescore_translator.cc
@@ -3,6 +3,7 @@
#include <sstream>
#include <boost/shared_ptr.hpp>
+#include "filelib.h"
#include "sentence_metadata.h"
#include "hg.h"
#include "hg_io.h"
@@ -20,16 +21,18 @@ struct RescoreTranslatorImpl {
bool Translate(const string& input,
const vector<double>& weights,
Hypergraph* forest) {
- if (input == "{}") return false;
- if (input.find("{\"rules\"") == 0) {
- istringstream is(input);
- Hypergraph src_cfg_hg;
- if (!HypergraphIO::ReadFromJSON(&is, forest)) {
- cerr << "Parse error while reading HG from JSON.\n";
- abort();
- }
- } else {
- cerr << "Can only read HG input from JSON: use training/grammar_convert\n";
+ istringstream is(input);
+ string header, fname;
+ is >> header >> fname;
+ if (header != "::forest::") {
+ cerr << "RescoreTranslator: expected input lines of form ::forest:: filename.gz\n";
+ abort();
+ }
+ ReadFile rf(fname);
+ if (!rf) { cerr << "Can't read " << fname << endl; abort(); }
+ Hypergraph src_cfg_hg;
+ if (!HypergraphIO::ReadFromBinary(rf.stream(), forest)) {
+ cerr << "Parse error while reading HG.\n";
abort();
}
Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll
index d4a8d86b..8b48ab7b 100644
--- a/decoder/rule_lexer.ll
+++ b/decoder/rule_lexer.ll
@@ -356,6 +356,7 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const
void RuleLexer::ReadRule(const std::string& srule, RuleCallback func, bool mono, void* extra) {
init_default_feature_names();
+ scfglex_fname = srule;
lex_mono_rules = mono;
lex_line = 1;
rule_callback_extra = extra;
diff --git a/decoder/test_data/hg_test.hg.bin.gz b/decoder/test_data/hg_test.hg.bin.gz
new file mode 100644
index 00000000..c07fbe8c
--- /dev/null
+++ b/decoder/test_data/hg_test.hg.bin.gz
Binary files differ
diff --git a/decoder/test_data/hg_test.hg_balanced.bin.gz b/decoder/test_data/hg_test.hg_balanced.bin.gz
new file mode 100644
index 00000000..896d3d60
--- /dev/null
+++ b/decoder/test_data/hg_test.hg_balanced.bin.gz
Binary files differ
diff --git a/decoder/test_data/hg_test.hg_int.bin.gz b/decoder/test_data/hg_test.hg_int.bin.gz
new file mode 100644
index 00000000..e0bd6187
--- /dev/null
+++ b/decoder/test_data/hg_test.hg_int.bin.gz
Binary files differ
diff --git a/decoder/test_data/hg_test.lattice.bin.gz b/decoder/test_data/hg_test.lattice.bin.gz
new file mode 100644
index 00000000..8a8c05f4
--- /dev/null
+++ b/decoder/test_data/hg_test.lattice.bin.gz
Binary files differ
diff --git a/decoder/test_data/hg_test.tiny.bin.gz b/decoder/test_data/hg_test.tiny.bin.gz
new file mode 100644
index 00000000..0e68eb40
--- /dev/null
+++ b/decoder/test_data/hg_test.tiny.bin.gz
Binary files differ
diff --git a/decoder/test_data/hg_test.tiny_lattice.bin.gz b/decoder/test_data/hg_test.tiny_lattice.bin.gz
new file mode 100644
index 00000000..97e8dc05
--- /dev/null
+++ b/decoder/test_data/hg_test.tiny_lattice.bin.gz
Binary files differ
diff --git a/decoder/test_data/perro.json.gz b/decoder/test_data/perro.json.gz
deleted file mode 100644
index 41de5758..00000000
--- a/decoder/test_data/perro.json.gz
+++ /dev/null
Binary files differ
diff --git a/decoder/test_data/small.bin.gz b/decoder/test_data/small.bin.gz
new file mode 100644
index 00000000..1c5a1631
--- /dev/null
+++ b/decoder/test_data/small.bin.gz
Binary files differ
diff --git a/decoder/test_data/small.json.gz b/decoder/test_data/small.json.gz
deleted file mode 100644
index f6f37293..00000000
--- a/decoder/test_data/small.json.gz
+++ /dev/null
Binary files differ
diff --git a/decoder/test_data/urdu.json.gz b/decoder/test_data/urdu.json.gz
deleted file mode 100644
index 84535402..00000000
--- a/decoder/test_data/urdu.json.gz
+++ /dev/null
Binary files differ
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index 08dae64c..cdd83ffc 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -267,6 +267,7 @@ struct Tree2StringTranslatorImpl {
for (auto sym : rule_src)
cur = &cur->next[sym];
TRulePtr rule(new TRule(rhse, rhsf, lhs));
+ rule->a_.push_back(AlignmentPoint(0, 0));
rule->ComputeArity();
rule->scores_.set_value(ntfid, 1.0);
rule->scores_.set_value(kFID, 1.0);
diff --git a/decoder/trule.h b/decoder/trule.h
index adef7cc7..7af46747 100644
--- a/decoder/trule.h
+++ b/decoder/trule.h
@@ -11,6 +11,7 @@
#include "sparse_vector.h"
#include "wordid.h"
+#include "tdict.h"
class TRule;
typedef boost::shared_ptr<TRule> TRulePtr;
@@ -26,6 +27,7 @@ struct AlignmentPoint {
short s_;
short t_;
};
+
inline std::ostream& operator<<(std::ostream& os, const AlignmentPoint& p) {
return os << static_cast<int>(p.s_) << '-' << static_cast<int>(p.t_);
}
@@ -163,6 +165,66 @@ class TRule {
// optional, shows internal structure of TSG rules
boost::shared_ptr<cdec::TreeFragment> tree_structure;
+ friend class boost::serialization::access;
+ template<class Archive>
+ void save(Archive & ar, const unsigned int /*version*/) const {
+ ar & TD::Convert(-lhs_);
+ unsigned f_size = f_.size();
+ ar & f_size;
+ assert(f_size <= (sizeof(size_t) * 8));
+ size_t f_nt_mask = 0;
+ for (int i = f_.size() - 1; i >= 0; --i) {
+ f_nt_mask <<= 1;
+ f_nt_mask |= (f_[i] <= 0 ? 1 : 0);
+ }
+ ar & f_nt_mask;
+ for (unsigned i = 0; i < f_.size(); ++i)
+ ar & TD::Convert(f_[i] < 0 ? -f_[i] : f_[i]);
+ unsigned e_size = e_.size();
+ ar & e_size;
+ size_t e_nt_mask = 0;
+ assert(e_size <= (sizeof(size_t) * 8));
+ for (int i = e_.size() - 1; i >= 0; --i) {
+ e_nt_mask <<= 1;
+ e_nt_mask |= (e_[i] <= 0 ? 1 : 0);
+ }
+ ar & e_nt_mask;
+ for (unsigned i = 0; i < e_.size(); ++i)
+ if (e_[i] <= 0) ar & e_[i]; else ar & TD::Convert(e_[i]);
+ ar & arity_;
+ ar & scores_;
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int /*version*/) {
+ std::string lhs; ar & lhs; lhs_ = -TD::Convert(lhs);
+ unsigned f_size; ar & f_size;
+ f_.resize(f_size);
+ size_t f_nt_mask; ar & f_nt_mask;
+ std::string sym;
+ for (unsigned i = 0; i < f_size; ++i) {
+ bool mask = (f_nt_mask & 1);
+ ar & sym;
+ f_[i] = TD::Convert(sym) * (mask ? -1 : 1);
+ f_nt_mask >>= 1;
+ }
+ unsigned e_size; ar & e_size;
+ e_.resize(e_size);
+ size_t e_nt_mask; ar & e_nt_mask;
+ for (unsigned i = 0; i < e_size; ++i) {
+ bool mask = (e_nt_mask & 1);
+ if (mask) {
+ ar & e_[i];
+ } else {
+ ar & sym;
+ e_[i] = TD::Convert(sym);
+ }
+ e_nt_mask >>= 1;
+ }
+ ar & arity_;
+ ar & scores_;
+ }
+
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
private:
TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {}
};
diff --git a/decoder/trule_test.cc b/decoder/trule_test.cc
index 0cb7e2e8..d75c2016 100644
--- a/decoder/trule_test.cc
+++ b/decoder/trule_test.cc
@@ -4,6 +4,10 @@
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
#include <iostream>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/serialization/shared_ptr.hpp>
+#include <sstream>
#include "tdict.h"
using namespace std;
@@ -53,3 +57,35 @@ BOOST_AUTO_TEST_CASE(TestRuleR) {
BOOST_CHECK_EQUAL(t6.e_[3], 0);
}
+BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) {
+ string str;
+ string t7str;
+ {
+ TRule t7;
+ t7.ReadFromString("[X] ||| den [X,1] sah [X,2] . ||| [2] saw the [1] . ||| Feature1=0.12321 Foo=0.23232 Bar=0.121");
+ cerr << t7.AsString() << endl;
+ ostringstream os;
+ TRulePtr tp1(new TRule("[X] ||| a b c ||| x z y ||| A=1 B=2"));
+ TRulePtr tp2 = tp1;
+ boost::archive::text_oarchive oa(os);
+ oa << t7;
+ oa << tp1;
+ oa << tp2;
+ str = os.str();
+ t7str = t7.AsString();
+ }
+ {
+ istringstream is(str);
+ boost::archive::text_iarchive ia(is);
+ TRule t8;
+ ia >> t8;
+ TRulePtr tp3, tp4;
+ ia >> tp3;
+ ia >> tp4;
+ cerr << t8.AsString() << endl;
+ BOOST_CHECK_EQUAL(t7str, t8.AsString());
+ cerr << tp3->AsString() << endl;
+ cerr << tp4->AsString() << endl;
+ }
+}
+
diff --git a/python/cdec/hypergraph.pxd b/python/cdec/hypergraph.pxd
index 1e150bbc..9780cf8b 100644
--- a/python/cdec/hypergraph.pxd
+++ b/python/cdec/hypergraph.pxd
@@ -63,7 +63,8 @@ cdef extern from "decoder/viterbi.h":
cdef extern from "decoder/hg_io.h" namespace "HypergraphIO":
# Hypergraph JSON I/O
bint ReadFromJSON(istream* inp, Hypergraph* out)
- bint WriteToJSON(Hypergraph& hg, bint remove_rules, ostream* out)
+ bint ReadFromBinary(istream* inp, Hypergraph* out)
+ bint WriteToBinary(Hypergraph& hg, ostream* out)
# Hypergraph PLF I/O
void ReadFromPLF(string& inp, Hypergraph* out)
string AsPLF(Hypergraph& hg, bint include_global_parentheses)
diff --git a/python/cdec/sa/compile.py b/python/cdec/sa/compile.py
index a5bd0699..78ab729d 100644
--- a/python/cdec/sa/compile.py
+++ b/python/cdec/sa/compile.py
@@ -119,7 +119,7 @@ def main():
a = cdec.sa.Alignment(from_text=args.alignment)
a.write_binary(a_bin)
stop_time = monitor_cpu()
- logger.info('Compiling alignment took %f seonds', stop_time - start_time)
+ logger.info('Compiling alignment took %f seconds', stop_time - start_time)
start_time = monitor_cpu()
logger.info('Compiling bilexical dictionary')
diff --git a/tests/system_tests/cfg_rescore/input.txt b/tests/system_tests/cfg_rescore/input.txt
index 2999a5fb..99624d85 100644
--- a/tests/system_tests/cfg_rescore/input.txt
+++ b/tests/system_tests/cfg_rescore/input.txt
@@ -1 +1 @@
-{"rules":[1,"[S] ||| [NP1] [VP] ||| [1] [2] ||| Active=1",2,"[S] ||| [NP2] [VPSV] by [NP1] ||| [1] [2] by [3] ||| Passive=1",3,"[VP] ||| [V] [NP2] ||| [1] [2]",4,"[V] ||| ate ||| ate",5,"[VPSV] ||| was eaten ||| was eaten",6,"[NP1] ||| John ||| John",7,"[NP2] ||| broccoli ||| broccoli",8,"[NP2] ||| the broccoli ||| the broccoli ||| Definite=1",9,"[Goal] ||| [X] ||| [1]"],"features":["PhraseModel_0","PhraseModel_1","PhraseModel_2","PhraseModel_3","PhraseModel_4","PhraseModel_5","PhraseModel_6","PhraseModel_7","PhraseModel_8","PhraseModel_9","PhraseModel_10","PhraseModel_11","PhraseModel_12","PhraseModel_13","PhraseModel_14","PhraseModel_15","PhraseModel_16","PhraseModel_17","PhraseModel_18","PhraseModel_19","PhraseModel_20","PhraseModel_21","PhraseModel_22","PhraseModel_23","PhraseModel_24","PhraseModel_25","PhraseModel_26","PhraseModel_27","PhraseModel_28","PhraseModel_29","PhraseModel_30","PhraseModel_31","PhraseModel_32","PhraseModel_33","PhraseModel_34","PhraseModel_35","PhraseModel_36","PhraseModel_37","PhraseModel_38","PhraseModel_39","PhraseModel_40","PhraseModel_41","PhraseModel_42","PhraseModel_43","PhraseModel_44","PhraseModel_45","PhraseModel_46","PhraseModel_47","PhraseModel_48","PhraseModel_49","PhraseModel_50","PhraseModel_51","PhraseModel_52","PhraseModel_53","PhraseModel_54","PhraseModel_55","PhraseModel_56","PhraseModel_57","PhraseModel_58","PhraseModel_59","PhraseModel_60","PhraseModel_61","PhraseModel_62","PhraseModel_63","PhraseModel_64","PhraseModel_65","PhraseModel_66","PhraseModel_67","PhraseModel_68","PhraseModel_69","PhraseModel_70","PhraseModel_71","PhraseModel_72","PhraseModel_73","PhraseModel_74","PhraseModel_75","PhraseModel_76","PhraseModel_77","PhraseModel_78","PhraseModel_79","PhraseModel_80","PhraseModel_81","PhraseModel_82","PhraseModel_83","PhraseModel_84","PhraseModel_85","PhraseModel_86","PhraseModel_87","PhraseModel_88","PhraseModel_89","PhraseModel_90","PhraseModel_91","PhraseModel_92","PhraseModel_93","PhraseModel_94","PhraseModel_95","PhraseModel_96","PhraseModel_97","PhraseModel_98","PhraseModel_99","Active","Passive","Definite"],"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":6}],"node":{"in_edges":[0],"cat":"NP1","node_hash":"0000000000000006"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":4}],"node":{"in_edges":[1],"cat":"V","node_hash":"0000000000000004"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":7},{"tail":[],"spans":[-1,-1,-1,-1],"feats":[102,1],"rule":8}],"node":{"in_edges":[2,3],"cat":"NP2","node_hash":"0000000000000008"},"edges":[{"tail":[1,2],"spans":[-1,-1,-1,-1],"feats":[],"rule":3}],"node":{"in_edges":[4],"cat":"VP","node_hash":"0000000000000003"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":5}],"node":{"in_edges":[5],"cat":"VPSV","node_hash":"0000000000000005"},"edges":[{"tail":[0,3],"spans":[-1,-1,-1,-1],"feats":[100,1],"rule":1},{"tail":[2,4,0],"spans":[-1,-1,-1,-1],"feats":[101,1],"rule":2}],"node":{"in_edges":[6,7],"cat":"S","node_hash":"0000000000000002"},"edges":[{"tail":[5],"spans":[-1,-1,-1,-1],"feats":[],"rule":9}],"node":{"in_edges":[8],"cat":"Goal","node_hash":"000000000000003D"}}
+::forest:: input0.hg.bin.gz
diff --git a/tests/system_tests/cfg_rescore/input0.hg.bin.gz b/tests/system_tests/cfg_rescore/input0.hg.bin.gz
new file mode 100644
index 00000000..051e1e32
--- /dev/null
+++ b/tests/system_tests/cfg_rescore/input0.hg.bin.gz
Binary files differ
diff --git a/tests/system_tests/conll/README b/tests/system_tests/conll/README
new file mode 100644
index 00000000..261e6a05
--- /dev/null
+++ b/tests/system_tests/conll/README
@@ -0,0 +1,8 @@
+To generate the input file, run:
+
+ ~/cdec/corpus/conll2cdec.pl input.conll > input.txt
+
+This will create a training corpus (i.e., an input is present as well as
+gold standard output is present) in input.txt.
+
+See cdec.ini for examples of how to include features in the model.
diff --git a/tests/system_tests/conll/cdec.ini b/tests/system_tests/conll/cdec.ini
new file mode 100644
index 00000000..f214857a
--- /dev/null
+++ b/tests/system_tests/conll/cdec.ini
@@ -0,0 +1,13 @@
+formalism=tagger
+tagger_tagset=tagset.txt
+
+# grab the second feature column from the conll input (-w 2) and
+# create a feature of i-1,i-2 conjoined with y_i
+feature_function=CoNLLFeatures -w 2 -t xxy:%x[-1]_%x[0]:%y[0]
+
+# grab the second feature column from the conll input (-w 2) and
+# create a feature of i-1,i-2 conjoined with y_i
+feature_function=CoNLLFeatures -w 1 -t xy:%x[0]:%y[0]
+
+intersection_strategy=full
+
diff --git a/tests/system_tests/conll/gold.statistics b/tests/system_tests/conll/gold.statistics
new file mode 100644
index 00000000..17366689
--- /dev/null
+++ b/tests/system_tests/conll/gold.statistics
@@ -0,0 +1,20 @@
+-lm_nodes 12
+-lm_edges 24
+-lm_paths 729
++lm_nodes 12
++lm_edges 24
++lm_paths 729
++lm_trans O O O B I O
+constr_nodes 12
+constr_edges 12
+constr_paths 1
+-lm_nodes 10
+-lm_edges 20
+-lm_paths 243
++lm_nodes 10
++lm_edges 20
++lm_paths 243
++lm_trans O B I I O
+constr_nodes 10
+constr_edges 10
+constr_paths 1
diff --git a/tests/system_tests/conll/gold.stdout b/tests/system_tests/conll/gold.stdout
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/tests/system_tests/conll/gold.stdout
diff --git a/tests/system_tests/conll/input.conll b/tests/system_tests/conll/input.conll
new file mode 100644
index 00000000..507510ca
--- /dev/null
+++ b/tests/system_tests/conll/input.conll
@@ -0,0 +1,13 @@
+the the DT O
+angry angry JJ O
+dog dog NN O
+bit bite VBD B
+me I PRN I
+. . . O
+
+i i PRN O
+ate eat VBD B
+some some DT I
+pie pie NN I
+. . . O
+
diff --git a/tests/system_tests/conll/input.txt b/tests/system_tests/conll/input.txt
new file mode 100644
index 00000000..6a1a0230
--- /dev/null
+++ b/tests/system_tests/conll/input.txt
@@ -0,0 +1,2 @@
+<seg id="0" feat1="the angry dog bite I ." feat2="DT JJ NN VBD PRN ."> the angry dog bit me . ||| O O O B I O </seg>
+<seg id="1" feat1="i eat some pie ." feat2="PRN VBD DT NN ."> i ate some pie . ||| O B I I O </seg>
diff --git a/tests/system_tests/conll/tagset.txt b/tests/system_tests/conll/tagset.txt
new file mode 100644
index 00000000..bd0e6b60
--- /dev/null
+++ b/tests/system_tests/conll/tagset.txt
@@ -0,0 +1 @@
+B I O
diff --git a/tests/system_tests/conll/weights b/tests/system_tests/conll/weights
new file mode 100644
index 00000000..de130cb6
--- /dev/null
+++ b/tests/system_tests/conll/weights
@@ -0,0 +1,64 @@
+# Objective = 7.63544 (eval count=5)
+xxy:<s>_DT:B -0.19295226006843877
+xy:the:B -0.19295226006843877
+xxy:<s>_DT:I -0.19295226006843877
+xy:the:I -0.19295226006843877
+xxy:<s>_DT:O 0.38590452013687793
+xy:the:O 0.38590452013687793
+xxy:DT_JJ:B -0.19295226006843877
+xy:angry:B -0.19295226006843877
+xxy:DT_JJ:I -0.19295226006843877
+xy:angry:I -0.19295226006843877
+xxy:DT_JJ:O 0.38590452013687793
+xy:angry:O 0.38590452013687793
+xxy:JJ_NN:B -0.19295226006843885
+xy:dog:B -0.19295226006843885
+xxy:JJ_NN:I -0.19295226006843885
+xy:dog:I -0.19295226006843885
+xxy:JJ_NN:O 0.38590452013687765
+xy:dog:O 0.38590452013687765
+xxy:NN_VBD:B 0.38590452013687765
+xy:bite:B 0.38590452013687765
+xxy:NN_VBD:I -0.19295226006843885
+xy:bite:I -0.19295226006843885
+xxy:NN_VBD:O -0.19295226006843885
+xy:bite:O -0.19295226006843885
+xxy:VBD_PRN:B -0.19295226006843885
+xy:I:B -0.19295226006843885
+xxy:VBD_PRN:I 0.38590452013687765
+xy:I:I 0.38590452013687765
+xxy:VBD_PRN:O -0.19295226006843885
+xy:I:O -0.19295226006843885
+xxy:PRN_.:B -0.16038191506717553
+xy:.:B -0.32076383013435106
+xxy:PRN_.:I -0.16038191506717553
+xy:.:I -0.32076383013435106
+xxy:PRN_.:O 0.32076383013435134
+xy:.:O 0.64152766026870267
+xxy:<s>_PRN:B -0.19295226006843871
+xy:i:B -0.19295226006843871
+xxy:<s>_PRN:I -0.19295226006843871
+xy:i:I -0.19295226006843871
+xxy:<s>_PRN:O 0.38590452013687804
+xy:i:O 0.38590452013687804
+xxy:PRN_VBD:B 0.38590452013687804
+xy:eat:B 0.38590452013687804
+xxy:PRN_VBD:I -0.19295226006843871
+xy:eat:I -0.19295226006843871
+xxy:PRN_VBD:O -0.19295226006843871
+xy:eat:O -0.19295226006843871
+xxy:VBD_DT:B -0.19295226006843877
+xy:some:B -0.19295226006843877
+xxy:VBD_DT:I 0.38590452013687798
+xy:some:I 0.38590452013687798
+xxy:VBD_DT:O -0.19295226006843877
+xy:some:O -0.19295226006843877
+xxy:DT_NN:B -0.19295226006843877
+xy:pie:B -0.19295226006843877
+xxy:DT_NN:I 0.38590452013687798
+xy:pie:I 0.38590452013687798
+xxy:DT_NN:O -0.19295226006843877
+xy:pie:O -0.19295226006843877
+xxy:NN_.:B -0.16038191506717553
+xxy:NN_.:I -0.16038191506717553
+xxy:NN_.:O 0.32076383013435134
diff --git a/tests/system_tests/ftrans/input.txt b/tests/system_tests/ftrans/input.txt
index aa37b2e7..99624d85 100644
--- a/tests/system_tests/ftrans/input.txt
+++ b/tests/system_tests/ftrans/input.txt
@@ -1 +1 @@
-{"rules":[1,"[B] ||| b ||| b",2,"[C] ||| c ||| c",3,"[A] ||| [B,1] [C,2] ||| [1] [2] ||| Mono=1",4,"[A] ||| [C,1] [B,2] ||| [1] [2] ||| Inv=1",5,"[S] ||| [A,1] ||| [1]"],"features":["Mono","Inv"],"edges":[{"tail":[],"feats":[],"rule":1}],"node":{"in_edges":[0],"cat":"B"},"edges":[{"tail":[],"feats":[],"rule":2}],"node":{"in_edges":[1],"cat":"C"},"edges":[{"tail":[0,1],"feats":[0,1],"rule":3},{"tail":[1,0],"feats":[1,1],"rule":4}],"node":{"in_edges":[2,3],"cat":"A"},"edges":[{"tail":[2],"feats":[],"rule":5}],"node":{"in_edges":[4],"cat":"S"}}
+::forest:: input0.hg.bin.gz
diff --git a/tests/system_tests/ftrans/input0.hg.bin.gz b/tests/system_tests/ftrans/input0.hg.bin.gz
new file mode 100644
index 00000000..210f4a44
--- /dev/null
+++ b/tests/system_tests/ftrans/input0.hg.bin.gz
Binary files differ
diff --git a/training/Makefile.am b/training/Makefile.am
index 8ef3c939..2812a9be 100644
--- a/training/Makefile.am
+++ b/training/Makefile.am
@@ -8,5 +8,5 @@ SUBDIRS = \
dtrain \
latent_svm \
mira \
- rampion
-
+ rampion \
+ const_reorder
diff --git a/training/const_reorder/Makefile.am b/training/const_reorder/Makefile.am
new file mode 100644
index 00000000..2c681398
--- /dev/null
+++ b/training/const_reorder/Makefile.am
@@ -0,0 +1,8 @@
+bin_PROGRAMS = const_reorder_model_trainer argument_reorder_model_trainer
+
+AM_CPPFLAGS = -I$(top_srcdir) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder
+
+const_reorder_model_trainer_SOURCES = constituent_reorder_model.cc trainer.h trainer.cc
+const_reorder_model_trainer_LDADD = ../../utils/libutils.a
+argument_reorder_model_trainer_SOURCES = argument_reorder_model.cc trainer.h trainer.cc
+argument_reorder_model_trainer_LDADD = ../../utils/libutils.a
diff --git a/training/const_reorder/argument_reorder_model.cc b/training/const_reorder/argument_reorder_model.cc
new file mode 100644
index 00000000..87f2ce2f
--- /dev/null
+++ b/training/const_reorder/argument_reorder_model.cc
@@ -0,0 +1,307 @@
+/*
+ * argument_reorder_model.cc
+ *
+ * Created on: Dec 15, 2013
+ * Author: lijunhui
+ */
+
+#include <boost/program_options.hpp>
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "utils/filelib.h"
+
+#include "trainer.h"
+
+using namespace std;
+using namespace const_reorder;
+
+inline void fnPreparingTrainingdata(const char* pszFName, int iCutoff,
+ const char* pszNewFName) {
+ Map hashPredicate;
+ {
+ ReadFile in(pszFName);
+ string line;
+ while (getline(*in.stream(), line)) {
+ if (!line.size()) continue;
+ vector<string> terms;
+ SplitOnWhitespace(line, &terms);
+ for (const auto& i : terms) {
+ ++hashPredicate[i];
+ }
+ }
+ }
+
+ {
+ ReadFile in(pszFName);
+ WriteFile out(pszNewFName);
+ string line;
+ while (getline(*in.stream(), line)) {
+ if (!line.size()) continue;
+ vector<string> terms;
+ SplitOnWhitespace(line, &terms);
+ bool written = false;
+ for (const auto& i : terms) {
+ if (hashPredicate[i] >= iCutoff) {
+ (*out.stream()) << i << " ";
+ written = true;
+ }
+ }
+ if (written) {
+ (*out.stream()) << "\n";
+ }
+ }
+ }
+}
+
+struct SArgumentReorderTrainer {
+ SArgumentReorderTrainer(
+ const char* pszSRLFname, // source-side srl tree file name
+ const char* pszAlignFname, // alignment filename
+ const char* pszSourceFname, // source file name
+ const char* pszTargetFname, // target file name
+ const char* pszTopPredicateFname, // target file name
+ const char* pszInstanceFname, // training instance file name
+ const char* pszModelFname, // classifier model file name
+ int iCutoff) {
+ fnGenerateInstanceFiles(pszSRLFname, pszAlignFname, pszSourceFname,
+ pszTargetFname, pszTopPredicateFname,
+ pszInstanceFname);
+
+ string strInstanceFname, strModelFname;
+ strInstanceFname = string(pszInstanceFname) + string(".left");
+ strModelFname = string(pszModelFname) + string(".left");
+ fnTraining(strInstanceFname.c_str(), strModelFname.c_str(), iCutoff);
+ strInstanceFname = string(pszInstanceFname) + string(".right");
+ strModelFname = string(pszModelFname) + string(".right");
+ fnTraining(strInstanceFname.c_str(), strModelFname.c_str(), iCutoff);
+ }
+
+ ~SArgumentReorderTrainer() {}
+
+ private:
+ void fnTraining(const char* pszInstanceFname, const char* pszModelFname,
+ int iCutoff) {
+ char* pszNewInstanceFName = new char[strlen(pszInstanceFname) + 50];
+ if (iCutoff > 0) {
+ sprintf(pszNewInstanceFName, "%s.tmp", pszInstanceFname);
+ fnPreparingTrainingdata(pszInstanceFname, iCutoff, pszNewInstanceFName);
+ } else {
+ strcpy(pszNewInstanceFName, pszInstanceFname);
+ }
+
+ Tsuruoka_Maxent_Trainer* pMaxent = new Tsuruoka_Maxent_Trainer;
+ pMaxent->fnTrain(pszNewInstanceFName, "l1", pszModelFname);
+ delete pMaxent;
+
+ if (strcmp(pszNewInstanceFName, pszInstanceFname) != 0) {
+ sprintf(pszNewInstanceFName, "rm %s.tmp", pszInstanceFname);
+ system(pszNewInstanceFName);
+ }
+ delete[] pszNewInstanceFName;
+ }
+
+ void fnGenerateInstanceFiles(
+ const char* pszSRLFname, // source-side flattened parse tree file name
+ const char* pszAlignFname, // alignment filename
+ const char* pszSourceFname, // source file name
+ const char* pszTargetFname, // target file name
+ const char* pszTopPredicateFname, // top predicate file name (we only
+ // consider predicates with 100+
+ // occurrences
+ const char* pszInstanceFname // training instance file name
+ ) {
+ SAlignmentReader* pAlignReader = new SAlignmentReader(pszAlignFname);
+ SSrlSentenceReader* pSRLReader = new SSrlSentenceReader(pszSRLFname);
+ ReadFile source_file(pszSourceFname);
+ ReadFile target_file(pszTargetFname);
+
+ Map* pMapPredicate;
+ if (pszTopPredicateFname != NULL)
+ pMapPredicate = fnLoadTopPredicates(pszTopPredicateFname);
+ else
+ pMapPredicate = NULL;
+
+ string line;
+
+ WriteFile left_file(pszInstanceFname + string(".left"));
+ WriteFile right_file(pszInstanceFname + string(".right"));
+
+ // read sentence by sentence
+ SAlignment* pAlign;
+ SSrlSentence* pSRL;
+ SParsedTree* pTree;
+ int iSentNum = 0;
+ while ((pAlign = pAlignReader->fnReadNextAlignment()) != NULL) {
+ pSRL = pSRLReader->fnReadNextSrlSentence();
+ assert(pSRL != NULL);
+ pTree = pSRL->m_pTree;
+ assert(getline(*source_file.stream(), line));
+ vector<string> vecSTerms;
+ SplitOnWhitespace(line, &vecSTerms);
+ assert(getline(*target_file.stream(), line));
+ vector<string> vecTTerms;
+ SplitOnWhitespace(line, &vecTTerms);
+ // vecTPOSTerms.size() == 0, given the case when an english sentence fails
+ // parsing
+
+ if (pTree != NULL) {
+ for (size_t i = 0; i < pSRL->m_vecPred.size(); i++) {
+ SPredicate* pPred = pSRL->m_vecPred[i];
+ if (strcmp(pTree->m_vecTerminals[pPred->m_iPosition]
+ ->m_ptParent->m_pszTerm,
+ "VA") == 0)
+ continue;
+ string strPred =
+ string(pTree->m_vecTerminals[pPred->m_iPosition]->m_pszTerm);
+ if (pMapPredicate != NULL) {
+ Map::iterator iter_map = pMapPredicate->find(strPred);
+ if (pMapPredicate != NULL && iter_map == pMapPredicate->end())
+ continue;
+ }
+
+ SPredicateItem* pPredItem = new SPredicateItem(pTree, pPred);
+
+ vector<string> vecStrBlock;
+ for (size_t j = 0; j < pPredItem->vec_items_.size(); j++) {
+ SSRLItem* pItem1 = pPredItem->vec_items_[j];
+ vecStrBlock.push_back(SArgumentReorderModel::fnGetBlockOutcome(
+ pItem1->tree_item_->m_iBegin, pItem1->tree_item_->m_iEnd,
+ pAlign));
+ }
+
+ vector<string> vecStrLeftReorderType;
+ vector<string> vecStrRightReorderType;
+ SArgumentReorderModel::fnGetReorderType(
+ pPredItem, pAlign, vecStrLeftReorderType, vecStrRightReorderType);
+ for (int j = 1; j < pPredItem->vec_items_.size(); j++) {
+ string strLeftOutcome, strRightOutcome;
+ strLeftOutcome = vecStrLeftReorderType[j - 1];
+ strRightOutcome = vecStrRightReorderType[j - 1];
+ ostringstream ostr;
+ SArgumentReorderModel::fnGenerateFeature(pTree, pPred, pPredItem, j,
+ vecStrBlock[j - 1],
+ vecStrBlock[j], ostr);
+
+ // fprintf(stderr, "%s %s\n", ostr.str().c_str(),
+ // strOutcome.c_str());
+ // fprintf(fpOut, "sentid=%d %s %s\n", iSentNum, ostr.str().c_str(),
+ // strOutcome.c_str());
+ (*left_file.stream()) << ostr.str() << " " << strLeftOutcome
+ << "\n";
+ (*right_file.stream()) << ostr.str() << " " << strRightOutcome
+ << "\n";
+ }
+ }
+ }
+ delete pSRL;
+
+ delete pAlign;
+ iSentNum++;
+
+ if (iSentNum % 100000 == 0) fprintf(stderr, "#%d\n", iSentNum);
+ }
+
+ delete pAlignReader;
+ delete pSRLReader;
+ }
+
+ Map* fnLoadTopPredicates(const char* pszTopPredicateFname) {
+ if (pszTopPredicateFname == NULL) return NULL;
+
+ Map* pMapPredicate = new Map();
+ // STxtFileReader* pReader = new STxtFileReader(pszTopPredicateFname);
+ ReadFile in(pszTopPredicateFname);
+ // char* pszLine = new char[50001];
+ string line;
+ int iNumCount = 0;
+ while (getline(*in.stream(), line)) {
+ if (line.size() && line[0] == '#') continue;
+ auto p = line.find(' ');
+ assert(p != string::npos);
+ int iCount = atoi(line.substr(p + 1).c_str());
+ if (iCount < 100) break;
+ (*pMapPredicate)[line] = iNumCount++;
+ }
+ return pMapPredicate;
+ }
+};
+
+namespace po = boost::program_options;
+
+inline void print_options(std::ostream& out,
+ po::options_description const& opts) {
+ typedef std::vector<boost::shared_ptr<po::option_description> > Ds;
+ Ds const& ds = opts.options();
+ out << '"';
+ for (unsigned i = 0; i < ds.size(); ++i) {
+ if (i) out << ' ';
+ out << "--" << ds[i]->long_name();
+ }
+ out << '\n';
+}
+inline string str(char const* name, po::variables_map const& conf) {
+ return conf[name].as<string>();
+}
+
+//--srl_file /scratch0/mt_exp/gale-align/gale-align.nw.srl.cn --align_file
+/// scratch0/mt_exp/gale-align/gale-align.nw.al --source_file
+/// scratch0/mt_exp/gale-align/gale-align.nw.cn --target_file
+/// scratch0/mt_exp/gale-align/gale-align.nw.en --instance_file
+/// scratch0/mt_exp/gale-align/gale-align.nw.argreorder.instance --model_prefix
+/// scratch0/mt_exp/gale-align/gale-align.nw.argreorder.model --feature_cutoff 2
+//--srl_file /scratch0/mt_exp/gale-ctb/gale-ctb.srl.cn --align_file
+/// scratch0/mt_exp/gale-ctb/gale-ctb.align --source_file
+/// scratch0/mt_exp/gale-ctb/gale-ctb.cn --target_file
+/// scratch0/mt_exp/gale-ctb/gale-ctb.en0 --instance_file
+/// scratch0/mt_exp/gale-ctb/gale-ctb.argreorder.instance --model_prefix
+/// scratch0/mt_exp/gale-ctb/gale-ctb.argreorder.model --feature_cutoff 2
+int main(int argc, char** argv) {
+
+ po::options_description opts("Configuration options");
+ opts.add_options()("srl_file", po::value<string>(), "srl file path (input)")(
+ "align_file", po::value<string>(), "Alignment file path (input)")(
+ "source_file", po::value<string>(), "Source text file path (input)")(
+ "target_file", po::value<string>(), "Target text file path (input)")(
+ "instance_file", po::value<string>(), "Instance file path (output)")(
+ "model_prefix", po::value<string>(),
+ "Model file path prefix (output): three files will be generated")(
+ "feature_cutoff", po::value<int>()->default_value(100),
+ "Feature cutoff threshold")("help", "produce help message");
+
+ po::variables_map vm;
+ if (argc) {
+ po::store(po::parse_command_line(argc, argv, opts), vm);
+ po::notify(vm);
+ }
+
+ if (vm.count("help")) {
+ print_options(cout, opts);
+ return 1;
+ }
+
+ if (!vm.count("srl_file") || !vm.count("align_file") ||
+ !vm.count("source_file") || !vm.count("target_file") ||
+ !vm.count("instance_file") || !vm.count("model_prefix")) {
+ print_options(cout, opts);
+ if (!vm.count("parse_file")) cout << "--parse_file NOT FOUND\n";
+ if (!vm.count("align_file")) cout << "--align_file NOT FOUND\n";
+ if (!vm.count("source_file")) cout << "--source_file NOT FOUND\n";
+ if (!vm.count("target_file")) cout << "--target_file NOT FOUND\n";
+ if (!vm.count("instance_file")) cout << "--instance_file NOT FOUND\n";
+ if (!vm.count("model_prefix")) cout << "--model_prefix NOT FOUND\n";
+ exit(0);
+ }
+
+ SArgumentReorderTrainer* pTrainer = new SArgumentReorderTrainer(
+ str("srl_file", vm).c_str(), str("align_file", vm).c_str(),
+ str("source_file", vm).c_str(), str("target_file", vm).c_str(), NULL,
+ str("instance_file", vm).c_str(), str("model_prefix", vm).c_str(),
+ vm["feature_cutoff"].as<int>());
+ delete pTrainer;
+
+ return 1;
+}
diff --git a/training/const_reorder/constituent_reorder_model.cc b/training/const_reorder/constituent_reorder_model.cc
new file mode 100644
index 00000000..d3ad0f2b
--- /dev/null
+++ b/training/const_reorder/constituent_reorder_model.cc
@@ -0,0 +1,636 @@
+/*
+ * constituent_reorder_model.cc
+ *
+ * Created on: Jul 10, 2013
+ * Author: junhuili
+ */
+
+#include <string>
+#include <unordered_map>
+
+#include <boost/program_options.hpp>
+
+#include "utils/filelib.h"
+
+#include "trainer.h"
+
+using namespace std;
+using namespace const_reorder;
+
+typedef std::unordered_map<std::string, int> Map;
+typedef std::unordered_map<std::string, int>::iterator Iterator;
+
+namespace po = boost::program_options;
+
+inline void fnPreparingTrainingdata(const char* pszFName, int iCutoff,
+ const char* pszNewFName) {
+ Map hashPredicate;
+ {
+ ReadFile f(pszFName);
+ string line;
+ while (getline(*f.stream(), line)) {
+ if (!line.size()) continue;
+ vector<string> terms;
+ SplitOnWhitespace(line, &terms);
+ for (const auto& i : terms) {
+ ++hashPredicate[i];
+ }
+ }
+ }
+
+ {
+ ReadFile in(pszFName);
+ WriteFile out(pszNewFName);
+ string line;
+ while (getline(*in.stream(), line)) {
+ if (!line.size()) continue;
+ vector<string> terms;
+ SplitOnWhitespace(line, &terms);
+ bool written = false;
+ for (const auto& i : terms) {
+ if (hashPredicate[i] >= iCutoff) {
+ (*out.stream()) << i << " ";
+ written = true;
+ }
+ }
+ if (written) {
+ (*out.stream()) << "\n";
+ }
+ }
+ }
+}
+
+struct SConstReorderTrainer {
+ SConstReorderTrainer(
+ const char* pszSynFname, // source-side flattened parse tree file name
+ const char* pszAlignFname, // alignment filename
+ const char* pszSourceFname, // source file name
+ const char* pszTargetFname, // target file name
+ const char* pszInstanceFname, // training instance file name
+ const char* pszModelPrefix, // classifier model file name prefix
+ int iCutoff, // feature count threshold
+ const char* /*pszOption*/ // other classifier parameters (for svmlight)
+ ) {
+ fnGenerateInstanceFile(pszSynFname, pszAlignFname, pszSourceFname,
+ pszTargetFname, pszInstanceFname);
+
+ string strInstanceLeftFname = string(pszInstanceFname) + string(".left");
+ string strInstanceRightFname = string(pszInstanceFname) + string(".right");
+
+ string strModelLeftFname = string(pszModelPrefix) + string(".left");
+ string strModelRightFname = string(pszModelPrefix) + string(".right");
+
+ fprintf(stdout, "...Training the left ordering model\n");
+ fnTraining(strInstanceLeftFname.c_str(), strModelLeftFname.c_str(),
+ iCutoff);
+ fprintf(stdout, "...Training the right ordering model\n");
+ fnTraining(strInstanceRightFname.c_str(), strModelRightFname.c_str(),
+ iCutoff);
+ }
+ ~SConstReorderTrainer() {}
+
+ private:
+ void fnTraining(const char* pszInstanceFname, const char* pszModelFname,
+ int iCutoff) {
+ char* pszNewInstanceFName = new char[strlen(pszInstanceFname) + 50];
+ if (iCutoff > 0) {
+ sprintf(pszNewInstanceFName, "%s.tmp", pszInstanceFname);
+ fnPreparingTrainingdata(pszInstanceFname, iCutoff, pszNewInstanceFName);
+ } else {
+ strcpy(pszNewInstanceFName, pszInstanceFname);
+ }
+
+ /*Zhangle_Maxent *pZhangleMaxent = new Zhangle_Maxent(NULL);
+pZhangleMaxent->fnTrain(pszInstanceFname, "lbfgs", pszModelFname, 100, 2.0);
+delete pZhangleMaxent;*/
+
+ Tsuruoka_Maxent_Trainer* pMaxent = new Tsuruoka_Maxent_Trainer;
+ pMaxent->fnTrain(pszNewInstanceFName, "l1", pszModelFname);
+ delete pMaxent;
+
+ if (strcmp(pszNewInstanceFName, pszInstanceFname) != 0) {
+ sprintf(pszNewInstanceFName, "rm %s.tmp", pszInstanceFname);
+ system(pszNewInstanceFName);
+ }
+ delete[] pszNewInstanceFName;
+ }
+
+ inline bool fnIsVerbPOS(const char* pszTerm) {
+ if (strcmp(pszTerm, "VV") == 0 || strcmp(pszTerm, "VA") == 0 ||
+ strcmp(pszTerm, "VC") == 0 || strcmp(pszTerm, "VE") == 0)
+ return true;
+ return false;
+ }
+
+ inline void fnGetOutcome(int iL1, int iR1, int iL2, int iR2,
+ const SAlignment* /*pAlign*/, string& strOutcome) {
+ if (iL1 == -1 && iL2 == -1)
+ strOutcome = "BU"; // 1. both are untranslated
+ else if (iL1 == -1)
+ strOutcome = "1U"; // 2. XP1 is untranslated
+ else if (iL2 == -1)
+ strOutcome = "2U"; // 3. XP2 is untranslated
+ else if (iL1 == iL2 && iR1 == iR2)
+ strOutcome = "SS"; // 4. Have same scope
+ else if (iL1 <= iL2 && iR1 >= iR2)
+ strOutcome = "1C2"; // 5. XP1's translation covers XP2's
+ else if (iL1 >= iL2 && iR1 <= iR2)
+ strOutcome = "2C1"; // 6. XP2's translation covers XP1's
+ else if (iR1 < iL2) {
+ int i = iR1 + 1;
+ /*while (i < iL2) {
+ if (pAlign->fnIsAligned(i, false))
+ break;
+ i++;
+ }*/
+ if (i == iL2)
+ strOutcome = "M"; // 7. Monotone
+ else
+ strOutcome = "DM"; // 8. Discontinuous monotone
+ } else if (iL1 < iL2 && iL2 <= iR1 && iR1 < iR2)
+ strOutcome = "OM"; // 9. Overlap monotone
+ else if (iR2 < iL1) {
+ int i = iR2 + 1;
+ /*while (i < iL1) {
+ if (pAlign->fnIsAligned(i, false))
+ break;
+ i++;
+ }*/
+ if (i == iL1)
+ strOutcome = "S"; // 10. Swap
+ else
+ strOutcome = "DS"; // 11. Discontinuous swap
+ } else if (iL2 < iL1 && iL1 <= iR2 && iR2 < iR1)
+ strOutcome = "OS"; // 12. Overlap swap
+ else
+ assert(false);
+ }
+
+ inline void fnGetOutcome(int i1, int i2, string& strOutcome) {
+ assert(i1 != i2);
+ if (i1 < i2) {
+ if (i2 > i1 + 1)
+ strOutcome = string("DM");
+ else
+ strOutcome = string("M");
+ } else {
+ if (i1 > i2 + 1)
+ strOutcome = string("DS");
+ else
+ strOutcome = string("S");
+ }
+ }
+
+ inline void fnGetRelativePosition(const vector<int>& vecLeft,
+ vector<int>& vecPosition) {
+ vecPosition.clear();
+
+ vector<float> vec;
+ for (size_t i = 0; i < vecLeft.size(); i++) {
+ if (vecLeft[i] == -1) {
+ if (i == 0)
+ vec.push_back(-1);
+ else
+ vec.push_back(vecLeft[i - 1] + 0.1);
+ } else
+ vec.push_back(vecLeft[i]);
+ }
+
+ for (size_t i = 0; i < vecLeft.size(); i++) {
+ int count = 0;
+
+ for (size_t j = 0; j < vecLeft.size(); j++) {
+ if (j == i) continue;
+ if (vec[j] < vec[i]) {
+ count++;
+ } else if (vec[j] == vec[i] && j < i) {
+ count++;
+ }
+ }
+ vecPosition.push_back(count);
+ }
+ }
+
+ /*
+ * features:
+ * f1: (left_label, right_label, parent_label)
+ * f2: (left_label, right_label, parent_label, other_right_sibling_label)
+ * f3: (left_label, right_label, parent_label, other_left_sibling_label)
+ * f4: (left_label, right_label, left_head_pos)
+ * f5: (left_label, right_label, left_head_word)
+ * f6: (left_label, right_label, right_head_pos)
+ * f7: (left_label, right_label, right_head_word)
+ * f8: (left_label, right_label, left_chunk_status)
+ * f9: (left_label, right_label, right_chunk_status)
+ * f10: (left_label, parent_label)
+ * f11: (right_label, parent_label)
+ */
+ void fnGenerateInstance(const SParsedTree* pTree, const STreeItem* pParent,
+ int iPos, const vector<string>& vecChunkStatus,
+ const vector<int>& vecPosition,
+ const vector<string>& vecSTerms,
+ const vector<string>& /*vecTTerms*/, string& strOutcome,
+ ostringstream& ostr) {
+ STreeItem* pCon1, *pCon2;
+ pCon1 = pParent->m_vecChildren[iPos - 1];
+ pCon2 = pParent->m_vecChildren[iPos];
+
+ fnGetOutcome(vecPosition[iPos - 1], vecPosition[iPos], strOutcome);
+
+ string left_label = string(pCon1->m_pszTerm);
+ string right_label = string(pCon2->m_pszTerm);
+ string parent_label = string(pParent->m_pszTerm);
+
+ vector<string> vec_other_right_sibling;
+ for (int i = iPos + 1; i < pParent->m_vecChildren.size(); i++)
+ vec_other_right_sibling.push_back(
+ string(pParent->m_vecChildren[i]->m_pszTerm));
+ if (vec_other_right_sibling.size() == 0)
+ vec_other_right_sibling.push_back(string("NULL"));
+ vector<string> vec_other_left_sibling;
+ for (int i = 0; i < iPos - 1; i++)
+ vec_other_left_sibling.push_back(
+ string(pParent->m_vecChildren[i]->m_pszTerm));
+ if (vec_other_left_sibling.size() == 0)
+ vec_other_left_sibling.push_back(string("NULL"));
+
+ // generate features
+ // f1
+ ostr << "f1=" << left_label << "_" << right_label << "_" << parent_label;
+ // f2
+ for (int i = 0; i < vec_other_right_sibling.size(); i++)
+ ostr << " f2=" << left_label << "_" << right_label << "_" << parent_label
+ << "_" << vec_other_right_sibling[i];
+ // f3
+ for (int i = 0; i < vec_other_left_sibling.size(); i++)
+ ostr << " f3=" << left_label << "_" << right_label << "_" << parent_label
+ << "_" << vec_other_left_sibling[i];
+ // f4
+ ostr << " f4=" << left_label << "_" << right_label << "_"
+ << pTree->m_vecTerminals[pCon1->m_iHeadWord]->m_ptParent->m_pszTerm;
+ // f5
+ ostr << " f5=" << left_label << "_" << right_label << "_"
+ << vecSTerms[pCon1->m_iHeadWord];
+ // f6
+ ostr << " f6=" << left_label << "_" << right_label << "_"
+ << pTree->m_vecTerminals[pCon2->m_iHeadWord]->m_ptParent->m_pszTerm;
+ // f7
+ ostr << " f7=" << left_label << "_" << right_label << "_"
+ << vecSTerms[pCon2->m_iHeadWord];
+ // f8
+ ostr << " f8=" << left_label << "_" << right_label << "_"
+ << vecChunkStatus[iPos - 1];
+ // f9
+ ostr << " f9=" << left_label << "_" << right_label << "_"
+ << vecChunkStatus[iPos];
+ // f10
+ ostr << " f10=" << left_label << "_" << parent_label;
+ // f11
+ ostr << " f11=" << right_label << "_" << parent_label;
+ }
+
+ /*
+ * Source side (11 features):
+ * f1: the categories of XP1 and XP2 (f1_1, f1_2)
+ * f2: the head words of XP1 and XP2 (f2_1, f2_2)
+ * f3: the first and last word of XP1 (f3_f, f3_l)
+ * f4: the first and last word of XP2 (f4_f, f4_l)
+ * f5: is XP1 or XP2 the head node (f5_1, f5_2)
+ * f6: the category of the common parent
+ * Target side (6 features):
+ * f7: the first and the last word of XP1's translation (f7_f, f7_l)
+ * f8: the first and the last word of XP2's translation (f8_f, f8_l)
+ * f9: the translation of XP1's and XP2's head word (f9_1, f9_2)
+ */
+ void fnGenerateInstance(const SParsedTree* /*pTree*/, const STreeItem* pParent,
+ const STreeItem* pCon1, const STreeItem* pCon2,
+ const SAlignment* pAlign,
+ const vector<string>& vecSTerms,
+ const vector<string>& /*vecTTerms*/, string& strOutcome,
+ ostringstream& ostr) {
+
+ int iLeft1, iRight1, iLeft2, iRight2;
+ pAlign->fnGetLeftRightMost(pCon1->m_iBegin, pCon1->m_iEnd, true, iLeft1,
+ iRight1);
+ pAlign->fnGetLeftRightMost(pCon2->m_iBegin, pCon2->m_iEnd, true, iLeft2,
+ iRight2);
+
+ fnGetOutcome(iLeft1, iRight1, iLeft2, iRight2, pAlign, strOutcome);
+
+ // generate features
+ // f1
+ ostr << "f1_1=" << pCon1->m_pszTerm << " f1_2=" << pCon2->m_pszTerm;
+ // f2
+ ostr << " f2_1=" << vecSTerms[pCon1->m_iHeadWord] << " f2_2"
+ << vecSTerms[pCon2->m_iHeadWord];
+ // f3
+ ostr << " f3_f=" << vecSTerms[pCon1->m_iBegin]
+ << " f3_l=" << vecSTerms[pCon1->m_iEnd];
+ // f4
+ ostr << " f4_f=" << vecSTerms[pCon2->m_iBegin]
+ << " f4_l=" << vecSTerms[pCon2->m_iEnd];
+ // f5
+ if (pParent->m_iHeadChild == pCon1->m_iBrotherIndex)
+ ostr << " f5_1=1";
+ else
+ ostr << " f5_1=0";
+ if (pParent->m_iHeadChild == pCon2->m_iBrotherIndex)
+ ostr << " f5_2=1";
+ else
+ ostr << " f5_2=0";
+ // f6
+ ostr << " f6=" << pParent->m_pszTerm;
+
+ /*//f7
+ if (iLeft1 != -1) {
+ ostr << " f7_f=" << vecTTerms[iLeft1] << " f7_l=" <<
+ vecTTerms[iRight1];
+ }
+ if (iLeft2 != -1) {
+ ostr << " f8_f=" << vecTTerms[iLeft2] << " f8_l=" <<
+ vecTTerms[iRight2];
+ }
+
+ const vector<int>* pvecTarget =
+ pAlign->fnGetSingleWordAlign(pCon1->m_iHeadWord, true);
+ string str = "";
+ for (size_t i = 0; pvecTarget != NULL && i < pvecTarget->size(); i++) {
+ str += vecTTerms[(*pvecTarget)[i]] + "_";
+ }
+ if (str.length() > 0) {
+ ostr << " f9_1=" << str.substr(0, str.size()-1);
+ }
+ pvecTarget = pAlign->fnGetSingleWordAlign(pCon2->m_iHeadWord, true);
+ str = "";
+ for (size_t i = 0; pvecTarget != NULL && i < pvecTarget->size(); i++) {
+ str += vecTTerms[(*pvecTarget)[i]] + "_";
+ }
+ if (str.length() > 0) {
+ ostr << " f9_2=" << str.substr(0, str.size()-1);
+ } */
+ }
+
+ void fnGetFocusedParentNodes(const SParsedTree* pTree,
+ vector<STreeItem*>& vecFocused) {
+ for (size_t i = 0; i < pTree->m_vecTerminals.size(); i++) {
+ STreeItem* pParent = pTree->m_vecTerminals[i]->m_ptParent;
+
+ while (pParent != NULL) {
+ // if (pParent->m_vecChildren.size() > 1 && pParent->m_iEnd -
+ // pParent->m_iBegin > 5) {
+ if (pParent->m_vecChildren.size() > 1) {
+ // do constituent reordering for all children of pParent
+ vecFocused.push_back(pParent);
+ }
+ if (pParent->m_iBrotherIndex != 0) break;
+ pParent = pParent->m_ptParent;
+ }
+ }
+ }
+
+ void fnGenerateInstanceFile(
+ const char* pszSynFname, // source-side flattened parse tree file name
+ const char* pszAlignFname, // alignment filename
+ const char* pszSourceFname, // source file name
+ const char* pszTargetFname, // target file name
+ const char* pszInstanceFname // training instance file name
+ ) {
+ SAlignmentReader* pAlignReader = new SAlignmentReader(pszAlignFname);
+ SParseReader* pParseReader = new SParseReader(pszSynFname, false);
+
+ ReadFile source_file(pszSourceFname);
+ ReadFile target_file(pszTargetFname);
+ string strInstanceLeftFname = string(pszInstanceFname) + string(".left");
+ string strInstanceRightFname = string(pszInstanceFname) + string(".right");
+ WriteFile left_file(strInstanceLeftFname);
+ WriteFile right_file(strInstanceRightFname);
+
+ // read sentence by sentence
+ SAlignment* pAlign;
+ SParsedTree* pTree;
+ string line;
+ int iSentNum = 0;
+ while ((pAlign = pAlignReader->fnReadNextAlignment()) != NULL) {
+ pTree = pParseReader->fnReadNextParseTree();
+
+ assert(getline(*source_file.stream(), line));
+ vector<string> vecSTerms;
+ SplitOnWhitespace(line, &vecSTerms);
+
+ assert(getline(*target_file.stream(), line));
+ vector<string> vecTTerms;
+ SplitOnWhitespace(line, &vecTTerms);
+
+ if (pTree != NULL) {
+
+ vector<STreeItem*> vecFocused;
+ fnGetFocusedParentNodes(pTree, vecFocused);
+
+ for (size_t i = 0; i < vecFocused.size(); i++) {
+
+ STreeItem* pParent = vecFocused[i];
+
+ vector<int> vecLeft, vecRight;
+ for (size_t j = 0; j < pParent->m_vecChildren.size(); j++) {
+ STreeItem* pCon1 = pParent->m_vecChildren[j];
+ int iLeft1, iRight1;
+ pAlign->fnGetLeftRightMost(pCon1->m_iBegin, pCon1->m_iEnd, true,
+ iLeft1, iRight1);
+ vecLeft.push_back(iLeft1);
+ vecRight.push_back(iRight1);
+ }
+ vector<int> vecLeftPosition;
+ fnGetRelativePosition(vecLeft, vecLeftPosition);
+ vector<int> vecRightPosition;
+ fnGetRelativePosition(vecRight, vecRightPosition);
+
+ vector<string> vecChunkStatus;
+ for (size_t j = 0; j < pParent->m_vecChildren.size(); j++) {
+ string strOutcome =
+ pAlign->fnIsContinuous(pParent->m_vecChildren[j]->m_iBegin,
+ pParent->m_vecChildren[j]->m_iEnd);
+ vecChunkStatus.push_back(strOutcome);
+ }
+
+ for (size_t j = 1; j < pParent->m_vecChildren.size(); j++) {
+ // children[j-1] vs. children[j] reordering
+
+ string strLeftOutcome;
+ ostringstream ostr;
+
+ fnGenerateInstance(pTree, pParent, j, vecChunkStatus,
+ vecLeftPosition, vecSTerms, vecTTerms,
+ strLeftOutcome, ostr);
+
+ string ostr_str = ostr.str();
+
+ // fprintf(stderr, "%s %s\n", ostr.str().c_str(),
+ // strLeftOutcome.c_str());
+ (*left_file.stream()) << ostr_str << " " << strLeftOutcome << "\n";
+
+ string strRightOutcome;
+ fnGetOutcome(vecRightPosition[j - 1], vecRightPosition[j],
+ strRightOutcome);
+ (*right_file.stream()) << ostr_str
+ << " LeftOrder=" << strLeftOutcome << " "
+ << strRightOutcome << "\n";
+ }
+ }
+ delete pTree;
+ }
+
+ delete pAlign;
+ iSentNum++;
+
+ if (iSentNum % 100000 == 0) fprintf(stderr, "#%d\n", iSentNum);
+ }
+
+ delete pAlignReader;
+ delete pParseReader;
+ }
+
+ void fnGenerateInstanceFile2(
+ const char* pszSynFname, // source-side flattened parse tree file name
+ const char* pszAlignFname, // alignment filename
+ const char* pszSourceFname, // source file name
+ const char* pszTargetFname, // target file name
+ const char* pszInstanceFname // training instance file name
+ ) {
+ SAlignmentReader* pAlignReader = new SAlignmentReader(pszAlignFname);
+ SParseReader* pParseReader = new SParseReader(pszSynFname, false);
+
+ ReadFile source_file(pszSourceFname);
+ ReadFile target_file(pszTargetFname);
+
+ WriteFile output_file(pszInstanceFname);
+
+ // read sentence by sentence
+ SAlignment* pAlign;
+ SParsedTree* pTree;
+ string line;
+ int iSentNum = 0;
+ while ((pAlign = pAlignReader->fnReadNextAlignment()) != NULL) {
+ pTree = pParseReader->fnReadNextParseTree();
+ assert(getline(*source_file.stream(), line));
+ vector<string> vecSTerms;
+ SplitOnWhitespace(line, &vecSTerms);
+
+ assert(getline(*target_file.stream(), line));
+ vector<string> vecTTerms;
+ SplitOnWhitespace(line, &vecTTerms);
+
+ if (pTree != NULL) {
+
+ vector<STreeItem*> vecFocused;
+ fnGetFocusedParentNodes(pTree, vecFocused);
+
+ for (size_t i = 0;
+ i < vecFocused.size() && pTree->m_vecTerminals.size() > 10; i++) {
+
+ STreeItem* pParent = vecFocused[i];
+
+ for (size_t j = 1; j < pParent->m_vecChildren.size(); j++) {
+ // children[j-1] vs. children[j] reordering
+
+ string strOutcome;
+ ostringstream ostr;
+
+ fnGenerateInstance(pTree, pParent, pParent->m_vecChildren[j - 1],
+ pParent->m_vecChildren[j], pAlign, vecSTerms,
+ vecTTerms, strOutcome, ostr);
+
+ // fprintf(stderr, "%s %s\n", ostr.str().c_str(),
+ // strOutcome.c_str());
+ (*output_file.stream()) << ostr.str() << " " << strOutcome << "\n";
+ }
+ }
+ delete pTree;
+ }
+
+ delete pAlign;
+ iSentNum++;
+
+ if (iSentNum % 100000 == 0) fprintf(stderr, "#%d\n", iSentNum);
+ }
+
+ delete pAlignReader;
+ delete pParseReader;
+ }
+};
+
+inline void print_options(std::ostream& out,
+ po::options_description const& opts) {
+ typedef std::vector<boost::shared_ptr<po::option_description> > Ds;
+ Ds const& ds = opts.options();
+ out << '"';
+ for (unsigned i = 0; i < ds.size(); ++i) {
+ if (i) out << ' ';
+ out << "--" << ds[i]->long_name();
+ }
+ out << '\n';
+}
+inline string str(char const* name, po::variables_map const& conf) {
+ return conf[name].as<string>();
+}
+
+//--parse_file /scratch0/mt_exp/gq-ctb/data/train.srl.cn --align_file
+/// scratch0/mt_exp/gq-ctb/data/aligned.grow-diag-final-and --source_file
+/// scratch0/mt_exp/gq-ctb/data/train.cn --target_file
+/// scratch0/mt_exp/gq-ctb/data/train.en --instance_file
+/// scratch0/mt_exp/gq-ctb/data/srl-instance --model_prefix
+/// scratch0/mt_exp/gq-ctb/data/srl-instance --feature_cutoff 10
+int main(int argc, char** argv) {
+
+ po::options_description opts("Configuration options");
+ opts.add_options()("parse_file", po::value<string>(),
+ "parse file path (input)")(
+ "align_file", po::value<string>(), "Alignment file path (input)")(
+ "source_file", po::value<string>(), "Source text file path (input)")(
+ "target_file", po::value<string>(), "Target text file path (input)")(
+ "instance_file", po::value<string>(), "Instance file path (output)")(
+ "model_prefix", po::value<string>(),
+ "Model file path prefix (output): three files will be generated")(
+ "feature_cutoff", po::value<int>()->default_value(100),
+ "Feature cutoff threshold")("svm_option", po::value<string>(),
+ "Parameters for SVMLight classifier")(
+ "help", "produce help message");
+
+ po::variables_map vm;
+ if (argc) {
+ po::store(po::parse_command_line(argc, argv, opts), vm);
+ po::notify(vm);
+ }
+
+ if (vm.count("help")) {
+ print_options(cout, opts);
+ return 1;
+ }
+
+ if (!vm.count("parse_file") || !vm.count("align_file") ||
+ !vm.count("source_file") || !vm.count("target_file") ||
+ !vm.count("instance_file") || !vm.count("model_prefix")) {
+ print_options(cout, opts);
+ if (!vm.count("parse_file")) cout << "--parse_file NOT FOUND\n";
+ if (!vm.count("align_file")) cout << "--align_file NOT FOUND\n";
+ if (!vm.count("source_file")) cout << "--source_file NOT FOUND\n";
+ if (!vm.count("target_file")) cout << "--target_file NOT FOUND\n";
+ if (!vm.count("instance_file")) cout << "--instance_file NOT FOUND\n";
+ if (!vm.count("model_prefix")) cout << "--model_prefix NOT FOUND\n";
+ exit(0);
+ }
+
+ const char* pOption;
+ if (vm.count("svm_option"))
+ pOption = str("svm_option", vm).c_str();
+ else
+ pOption = NULL;
+
+ SConstReorderTrainer* pTrainer = new SConstReorderTrainer(
+ str("parse_file", vm).c_str(), str("align_file", vm).c_str(),
+ str("source_file", vm).c_str(), str("target_file", vm).c_str(),
+ str("instance_file", vm).c_str(), str("model_prefix", vm).c_str(),
+ vm["feature_cutoff"].as<int>(), pOption);
+ delete pTrainer;
+
+ return 0;
+}
diff --git a/training/const_reorder/trainer.cc b/training/const_reorder/trainer.cc
new file mode 100644
index 00000000..1d388eec
--- /dev/null
+++ b/training/const_reorder/trainer.cc
@@ -0,0 +1,69 @@
+#include "trainer.h"
+
+#include "utils/maxent.h"
+
+Tsuruoka_Maxent_Trainer::Tsuruoka_Maxent_Trainer()
+ : const_reorder::Tsuruoka_Maxent(NULL) {}
+
+void Tsuruoka_Maxent_Trainer::fnTrain(const char* pszInstanceFName,
+ const char* pszAlgorithm,
+ const char* pszModelFName) {
+ assert(strcmp(pszAlgorithm, "l1") == 0 || strcmp(pszAlgorithm, "l2") == 0 ||
+ strcmp(pszAlgorithm, "sgd") == 0 || strcmp(pszAlgorithm, "SGD") == 0);
+ FILE* fpIn = fopen(pszInstanceFName, "r");
+
+ maxent::ME_Model* pModel = new maxent::ME_Model();
+
+ char* pszLine = new char[100001];
+ int iNumInstances = 0;
+ int iLen;
+ while (!feof(fpIn)) {
+ pszLine[0] = '\0';
+ fgets(pszLine, 20000, fpIn);
+ if (strlen(pszLine) == 0) {
+ continue;
+ }
+
+ iLen = strlen(pszLine);
+ while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) {
+ pszLine[iLen - 1] = '\0';
+ iLen--;
+ }
+
+ iNumInstances++;
+
+ maxent::ME_Sample* pmes = new maxent::ME_Sample();
+
+ char* p = strrchr(pszLine, ' ');
+ assert(p != NULL);
+ p[0] = '\0';
+ p++;
+ std::vector<std::string> vecContext;
+ SplitOnWhitespace(std::string(pszLine), &vecContext);
+
+ pmes->label = std::string(p);
+ for (size_t i = 0; i < vecContext.size(); i++)
+ pmes->add_feature(vecContext[i]);
+ pModel->add_training_sample((*pmes));
+ if (iNumInstances % 100000 == 0)
+ fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances);
+ delete pmes;
+ }
+ fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances);
+ fclose(fpIn);
+
+ if (strcmp(pszAlgorithm, "l1") == 0)
+ pModel->use_l1_regularizer(1.0);
+ else if (strcmp(pszAlgorithm, "l2") == 0)
+ pModel->use_l2_regularizer(1.0);
+ else
+ pModel->use_SGD();
+
+ pModel->train();
+ pModel->save_to_file(pszModelFName);
+
+ delete pModel;
+ fprintf(stdout, "......Finished Training\n");
+ fprintf(stdout, "......Model saved as %s\n", pszModelFName);
+ delete[] pszLine;
+}
diff --git a/training/const_reorder/trainer.h b/training/const_reorder/trainer.h
new file mode 100644
index 00000000..e574a536
--- /dev/null
+++ b/training/const_reorder/trainer.h
@@ -0,0 +1,12 @@
+#ifndef TRAINING_CONST_REORDER_TRAINER_H_
+#define TRAINING_CONST_REORDER_TRAINER_H_
+
+#include "decoder/ff_const_reorder_common.h"
+
+struct Tsuruoka_Maxent_Trainer : const_reorder::Tsuruoka_Maxent {
+ Tsuruoka_Maxent_Trainer();
+ void fnTrain(const char* pszInstanceFName, const char* pszAlgorithm,
+ const char* pszModelFName);
+};
+
+#endif // TRAINING_CONST_REORDER_TRAINER_H_
diff --git a/training/dpmert/lo_test.cc b/training/dpmert/lo_test.cc
index b8776169..69e5aa3f 100644
--- a/training/dpmert/lo_test.cc
+++ b/training/dpmert/lo_test.cc
@@ -56,10 +56,11 @@ BOOST_AUTO_TEST_CASE(TestConvexHull) {
}
BOOST_AUTO_TEST_CASE(TestConvexHullInside) {
- const string json = "{\"rules\":[1,\"[X] ||| a ||| a\",2,\"[X] ||| A [X] ||| A [1]\",3,\"[X] ||| c ||| c\",4,\"[X] ||| C [X] ||| C [1]\",5,\"[X] ||| [X] B [X] ||| [1] B [2]\",6,\"[X] ||| [X] b [X] ||| [1] b [2]\",7,\"[X] ||| X [X] ||| X [1]\",8,\"[X] ||| Z [X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}";
+ std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
Hypergraph hg;
- istringstream instr(json);
- HypergraphIO::ReadFromJSON(&instr, &hg);
+ ReadFile rf(path + "/test-ch-inside.bin.gz");
+ assert(rf);
+ HypergraphIO::ReadFromBinary(rf.stream(), &hg);
SparseVector<double> wts;
wts.set_value(FD::Convert("f1"), 0.4);
wts.set_value(FD::Convert("f2"), 1.0);
@@ -121,13 +122,13 @@ BOOST_AUTO_TEST_CASE( TestS1) {
std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
Hypergraph hg;
- ReadFile rf(path + "/0.json.gz");
- HypergraphIO::ReadFromJSON(rf.stream(), &hg);
+ ReadFile rf(path + "/0.bin.gz");
+ HypergraphIO::ReadFromBinary(rf.stream(), &hg);
hg.Reweight(wts);
Hypergraph hg2;
- ReadFile rf2(path + "/1.json.gz");
- HypergraphIO::ReadFromJSON(rf2.stream(), &hg2);
+ ReadFile rf2(path + "/1.bin.gz");
+ HypergraphIO::ReadFromBinary(rf2.stream(), &hg2);
hg2.Reweight(wts);
vector<vector<WordID> > refs1(4);
@@ -193,10 +194,11 @@ BOOST_AUTO_TEST_CASE( TestS1) {
}
BOOST_AUTO_TEST_CASE(TestZeroOrigin) {
- const string json = "{\"rules\":[1,\"[X7] ||| blA ||| without ||| LHSProb=3.92173 LexE2F=2.90799 LexF2E=1.85003 GenerativeProb=10.5381 RulePenalty=1 XFE=2.77259 XEF=0.441833 LabelledEF=2.63906 LabelledFE=4.96981 LogRuleCount=0.693147\",2,\"[X7] ||| blA ||| except ||| LHSProb=4.92173 LexE2F=3.90799 LexF2E=1.85003 GenerativeProb=11.5381 RulePenalty=1 XFE=2.77259 XEF=1.44183 LabelledEF=2.63906 LabelledFE=4.96981 LogRuleCount=1.69315\",3,\"[S] ||| [X7,1] ||| [1] ||| GlueTop=1\",4,\"[X28] ||| EnwAn ||| title ||| LHSProb=3.96802 LexE2F=2.22462 LexF2E=1.83258 GenerativeProb=10.0863 RulePenalty=1 XFE=0 XEF=1.20397 LabelledEF=1.20397 LabelledFE=-1.98341e-08 LogRuleCount=1.09861\",5,\"[X0] ||| EnwAn ||| funny ||| LHSProb=3.98479 LexE2F=1.79176 LexF2E=3.21888 GenerativeProb=11.1681 RulePenalty=1 XFE=0 XEF=2.30259 LabelledEF=2.30259 LabelledFE=0 LogRuleCount=0 SingletonRule=1\",6,\"[X8] ||| [X7,1] EnwAn ||| entitled [1] ||| LHSProb=3.82533 LexE2F=3.21888 LexF2E=2.52573 GenerativeProb=11.3276 RulePenalty=1 XFE=1.20397 XEF=1.20397 LabelledEF=2.30259 LabelledFE=2.30259 LogRuleCount=0 SingletonRule=1\",7,\"[S] ||| [S,1] [X28,2] ||| [1] [2] ||| Glue=1\",8,\"[S] ||| [S,1] [X0,2] ||| [1] [2] ||| Glue=1\",9,\"[S] ||| [X8,1] ||| [1] ||| GlueTop=1\",10,\"[Goal] ||| [S,1] ||| [1]\"],\"features\":[\"PassThrough\",\"Glue\",\"GlueTop\",\"LanguageModel\",\"WordPenalty\",\"LHSProb\",\"LexE2F\",\"LexF2E\",\"GenerativeProb\",\"RulePenalty\",\"XFE\",\"XEF\",\"LabelledEF\",\"LabelledFE\",\"LogRuleCount\",\"SingletonRule\"],\"edges\":[{\"tail\":[],\"spans\":[0,1,-1,-1],\"feats\":[5,3.92173,6,2.90799,7,1.85003,8,10.5381,9,1,10,2.77259,11,0.441833,12,2.63906,13,4.96981,14,0.693147],\"rule\":1},{\"tail\":[],\"spans\":[0,1,-1,-1],\"feats\":[5,4.92173,6,3.90799,7,1.85003,8,11.5381,9,1,10,2.77259,11,1.44183,12,2.63906,13,4.96981,14,1.69315],\"rule\":2}],\"node\":{\"in_edges\":[0,1],\"cat\":\"X7\"},\"edges\":[{\"tail\":[0],\"spans\":[0,1,-1,-1],\"feats\":[2,1],\"rule\":3}],\"node\":{\"in_edges\":[2],\"cat\":\"S\"},\"edges\":[{\"tail\":[],\"spans\":[1,2,-1,-1],\"feats\":[5,3.96802,6,2.22462,7,1.83258,8,10.0863,9,1,11,1.20397,12,1.20397,13,-1.98341e-08,14,1.09861],\"rule\":4}],\"node\":{\"in_edges\":[3],\"cat\":\"X28\"},\"edges\":[{\"tail\":[],\"spans\":[1,2,-1,-1],\"feats\":[5,3.98479,6,1.79176,7,3.21888,8,11.1681,9,1,11,2.30259,12,2.30259,15,1],\"rule\":5}],\"node\":{\"in_edges\":[4],\"cat\":\"X0\"},\"edges\":[{\"tail\":[0],\"spans\":[0,2,-1,-1],\"feats\":[5,3.82533,6,3.21888,7,2.52573,8,11.3276,9,1,10,1.20397,11,1.20397,12,2.30259,13,2.30259,15,1],\"rule\":6}],\"node\":{\"in_edges\":[5],\"cat\":\"X8\"},\"edges\":[{\"tail\":[1,2],\"spans\":[0,2,-1,-1],\"feats\":[1,1],\"rule\":7},{\"tail\":[1,3],\"spans\":[0,2,-1,-1],\"feats\":[1,1],\"rule\":8},{\"tail\":[4],\"spans\":[0,2,-1,-1],\"feats\":[2,1],\"rule\":9}],\"node\":{\"in_edges\":[6,7,8],\"cat\":\"S\"},\"edges\":[{\"tail\":[5],\"spans\":[0,2,-1,-1],\"feats\":[],\"rule\":10}],\"node\":{\"in_edges\":[9],\"cat\":\"Goal\"}}";
+ std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
+ ReadFile rf(path + "/test-zero-origin.bin.gz");
+ assert(rf);
Hypergraph hg;
- istringstream instr(json);
- HypergraphIO::ReadFromJSON(&instr, &hg);
+ HypergraphIO::ReadFromBinary(rf.stream(), &hg);
SparseVector<double> wts;
wts.set_value(FD::Convert("PassThrough"), -0.929201533002898);
hg.Reweight(wts);
diff --git a/training/dpmert/mr_dpmert_generate_mapper_input.cc b/training/dpmert/mr_dpmert_generate_mapper_input.cc
index 199cd23a..3fa2f476 100644
--- a/training/dpmert/mr_dpmert_generate_mapper_input.cc
+++ b/training/dpmert/mr_dpmert_generate_mapper_input.cc
@@ -70,7 +70,7 @@ int main(int argc, char** argv) {
unsigned dev_set_size = conf["dev_set_size"].as<unsigned>();
for (unsigned i = 0; i < dev_set_size; ++i) {
for (unsigned j = 0; j < directions.size(); ++j) {
- cout << forest_repository << '/' << i << ".json.gz " << i << ' ';
+ cout << forest_repository << '/' << i << ".bin.gz " << i << ' ';
print(cout, origin, "=", ";");
cout << ' ';
print(cout, directions[j], "=", ";");
diff --git a/training/dpmert/mr_dpmert_map.cc b/training/dpmert/mr_dpmert_map.cc
index d1efcf96..2bf3f8fc 100644
--- a/training/dpmert/mr_dpmert_map.cc
+++ b/training/dpmert/mr_dpmert_map.cc
@@ -83,7 +83,7 @@ int main(int argc, char** argv) {
istringstream is(line);
int sent_id;
string file, s_origin, s_direction;
- // path-to-file (JSON) sent_ed starting-point search-direction
+ // path-to-file sent_ed starting-point search-direction
is >> file >> sent_id >> s_origin >> s_direction;
SparseVector<double> origin;
ReadSparseVectorString(s_origin, &origin);
@@ -93,7 +93,7 @@ int main(int argc, char** argv) {
if (last_file != file) {
last_file = file;
ReadFile rf(file);
- HypergraphIO::ReadFromJSON(rf.stream(), &hg);
+ HypergraphIO::ReadFromBinary(rf.stream(), &hg);
}
const ConvexHullWeightFunction wf(origin, direction);
const ConvexHull hull = Inside<ConvexHull, ConvexHullWeightFunction>(hg, NULL, wf);
diff --git a/training/dpmert/test_data/0.bin.gz b/training/dpmert/test_data/0.bin.gz
new file mode 100644
index 00000000..388298e9
--- /dev/null
+++ b/training/dpmert/test_data/0.bin.gz
Binary files differ
diff --git a/training/dpmert/test_data/0.json.gz b/training/dpmert/test_data/0.json.gz
deleted file mode 100644
index 30f8dd77..00000000
--- a/training/dpmert/test_data/0.json.gz
+++ /dev/null
Binary files differ
diff --git a/training/dpmert/test_data/1.bin.gz b/training/dpmert/test_data/1.bin.gz
new file mode 100644
index 00000000..44f9e0ff
--- /dev/null
+++ b/training/dpmert/test_data/1.bin.gz
Binary files differ
diff --git a/training/dpmert/test_data/1.json.gz b/training/dpmert/test_data/1.json.gz
deleted file mode 100644
index c82cc179..00000000
--- a/training/dpmert/test_data/1.json.gz
+++ /dev/null
Binary files differ
diff --git a/training/dpmert/test_data/test-ch-inside.bin.gz b/training/dpmert/test_data/test-ch-inside.bin.gz
new file mode 100644
index 00000000..392f08c6
--- /dev/null
+++ b/training/dpmert/test_data/test-ch-inside.bin.gz
Binary files differ
diff --git a/training/dpmert/test_data/test-zero-origin.bin.gz b/training/dpmert/test_data/test-zero-origin.bin.gz
new file mode 100644
index 00000000..c641faaf
--- /dev/null
+++ b/training/dpmert/test_data/test-zero-origin.bin.gz
Binary files differ
diff --git a/training/minrisk/minrisk_optimize.cc b/training/minrisk/minrisk_optimize.cc
index da8b5260..a2938fb0 100644
--- a/training/minrisk/minrisk_optimize.cc
+++ b/training/minrisk/minrisk_optimize.cc
@@ -178,7 +178,7 @@ int main(int argc, char** argv) {
ReadFile rf(file);
if (kis.size() % 5 == 0) { cerr << '.'; }
if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; }
- HypergraphIO::ReadFromJSON(rf.stream(), &hg);
+ HypergraphIO::ReadFromBinary(rf.stream(), &hg);
hg.Reweight(weights);
curkbest.AddKBestCandidates(hg, kbest_size, ds[sent_id]);
if (kbest_file.size())
diff --git a/training/pro/mr_pro_map.cc b/training/pro/mr_pro_map.cc
index da58cd24..b142fd05 100644
--- a/training/pro/mr_pro_map.cc
+++ b/training/pro/mr_pro_map.cc
@@ -203,7 +203,7 @@ int main(int argc, char** argv) {
const string kbest_file = os.str();
if (FileExists(kbest_file))
J_i.ReadFromFile(kbest_file);
- HypergraphIO::ReadFromJSON(rf.stream(), &hg);
+ HypergraphIO::ReadFromBinary(rf.stream(), &hg);
hg.Reweight(weights);
J_i.AddKBestCandidates(hg, kbest_size, ds[sent_id]);
J_i.WriteToFile(kbest_file);
diff --git a/training/rampion/rampion_cccp.cc b/training/rampion/rampion_cccp.cc
index 1e36dc51..1c45bac5 100644
--- a/training/rampion/rampion_cccp.cc
+++ b/training/rampion/rampion_cccp.cc
@@ -136,7 +136,7 @@ int main(int argc, char** argv) {
ReadFile rf(file);
if (kis.size() % 5 == 0) { cerr << '.'; }
if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; }
- HypergraphIO::ReadFromJSON(rf.stream(), &hg);
+ HypergraphIO::ReadFromBinary(rf.stream(), &hg);
hg.Reweight(weights);
curkbest.AddKBestCandidates(hg, kbest_size, ds[sent_id]);
if (kbest_file.size())
diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc
index 5c1b4d4a..04f1eb77 100644
--- a/training/utils/grammar_convert.cc
+++ b/training/utils/grammar_convert.cc
@@ -43,7 +43,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::notify(*conf);
if (conf->count("help") || conf->count("input") == 0) {
- cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into JSON hypergraph.\n";
+ cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into serialized hypergraph.\n";
cerr << dcmdline_options << endl;
exit(1);
}
@@ -254,7 +254,8 @@ void ProcessHypergraph(const vector<double>& w, const po::variables_map& conf, c
if (w.size() > 0) { hg->Reweight(w); }
if (conf.count("collapse_weights")) CollapseWeights(hg);
if (conf["output"].as<string>() == "json") {
- HypergraphIO::WriteToJSON(*hg, false, &cout);
+ cerr << "NOT IMPLEMENTED ... talk to cdyer if you need this functionality\n";
+ // HypergraphIO::WriteToBinary(*hg, &cout);
if (!ref.empty()) { cerr << "REF: " << ref << endl; }
} else {
vector<WordID> onebest;
@@ -315,11 +316,11 @@ int main(int argc, char **argv) {
line = line.substr(0, pos + 2);
}
istringstream is(line);
- if (HypergraphIO::ReadFromJSON(&is, &hg)) {
+ if (HypergraphIO::ReadFromBinary(&is, &hg)) {
ProcessHypergraph(w, conf, ref, &hg);
hg.clear();
} else {
- cerr << "Error reading grammar from JSON: line " << lc << endl;
+ cerr << "Error reading grammar line " << lc << endl;
exit(1);
}
} else {
diff --git a/utils/Makefile.am b/utils/Makefile.am
index 727fa8a5..dd74ddc0 100644
--- a/utils/Makefile.am
+++ b/utils/Makefile.am
@@ -22,6 +22,7 @@ libutils_a_SOURCES = \
alias_sampler.h \
alignment_io.h \
array2d.h \
+ b64featvector.h \
b64tools.h \
batched_append.h \
city.h \
@@ -40,6 +41,8 @@ libutils_a_SOURCES = \
kernel_string_subseq.h \
logval.h \
m.h \
+ maxent.h \
+ maxent.cpp \
murmur_hash3.h \
murmur_hash3.cc \
named_enum.h \
@@ -70,6 +73,7 @@ libutils_a_SOURCES = \
fast_lexical_cast.hpp \
intrusive_refcount.hpp \
alignment_io.cc \
+ b64featvector.cc \
b64tools.cc \
corpus_tools.cc \
dict.cc \
@@ -117,4 +121,3 @@ stringlib_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_U
# do NOT NOT NOT add any other -I includes NO NO NO NO NO ######
AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -I. -I$(top_srcdir) -DTEST_DATA=\"$(top_srcdir)/utils/test_data\"
################################################################
-
diff --git a/utils/b64featvector.cc b/utils/b64featvector.cc
new file mode 100644
index 00000000..c7d08b29
--- /dev/null
+++ b/utils/b64featvector.cc
@@ -0,0 +1,55 @@
+#include "b64featvector.h"
+
+#include <sstream>
+#include <boost/scoped_array.hpp>
+#include "b64tools.h"
+#include "fdict.h"
+
+using namespace std;
+
+static inline void EncodeFeatureWeight(const string &featname, weight_t weight,
+ ostream *output) {
+ output->write(featname.data(), featname.size() + 1);
+ output->write(reinterpret_cast<char *>(&weight), sizeof(weight_t));
+}
+
+string EncodeFeatureVector(const SparseVector<weight_t> &vec) {
+ string b64;
+ {
+ ostringstream base64_strm;
+ {
+ ostringstream strm;
+ for (SparseVector<weight_t>::const_iterator it = vec.begin();
+ it != vec.end(); ++it)
+ if (it->second != 0)
+ EncodeFeatureWeight(FD::Convert(it->first), it->second, &strm);
+ string data(strm.str());
+ B64::b64encode(data.data(), data.size(), &base64_strm);
+ }
+ b64 = base64_strm.str();
+ }
+ return b64;
+}
+
+void DecodeFeatureVector(const string &data, SparseVector<weight_t> *vec) {
+ vec->clear();
+ if (data.empty()) return;
+ // Decode data
+ size_t b64_len = data.size(), len = b64_len / 4 * 3;
+ boost::scoped_array<char> buf(new char[len]);
+ bool res =
+ B64::b64decode(reinterpret_cast<const unsigned char *>(data.data()),
+ b64_len, buf.get(), len);
+ assert(res);
+ // Apply updates
+ size_t cur = 0;
+ while (cur < len) {
+ string feat_name(buf.get() + cur);
+ if (feat_name.empty()) break; // Encountered trailing \0
+ int feat_id = FD::Convert(feat_name);
+ weight_t feat_delta =
+ *reinterpret_cast<weight_t *>(buf.get() + cur + feat_name.size() + 1);
+ (*vec)[feat_id] = feat_delta;
+ cur += feat_name.size() + 1 + sizeof(weight_t);
+ }
+}
diff --git a/utils/b64featvector.h b/utils/b64featvector.h
new file mode 100644
index 00000000..6ac04d44
--- /dev/null
+++ b/utils/b64featvector.h
@@ -0,0 +1,12 @@
+#ifndef _B64FEATVECTOR_H_
+#define _B64FEATVECTOR_H_
+
+#include <string>
+
+#include "sparse_vector.h"
+#include "weights.h"
+
+std::string EncodeFeatureVector(const SparseVector<weight_t> &);
+void DecodeFeatureVector(const std::string &, SparseVector<weight_t> *);
+
+#endif // _B64FEATVECTOR_H_
diff --git a/utils/maxent.cpp b/utils/maxent.cpp
new file mode 100644
index 00000000..fd772e08
--- /dev/null
+++ b/utils/maxent.cpp
@@ -0,0 +1,1127 @@
+/*
+ * $Id: maxent.cpp,v 1.1.1.1 2007/05/15 08:30:35 kyoshida Exp $
+ */
+
+#include "maxent.h"
+
+#include <vector>
+#include <iostream>
+#include <cmath>
+#include <cstdio>
+
+using namespace std;
+
+namespace maxent {
+double ME_Model::FunctionGradient(const vector<double>& x,
+ vector<double>& grad) {
+ assert((int)_fb.Size() == x.size());
+ for (size_t i = 0; i < x.size(); i++) {
+ _vl[i] = x[i];
+ }
+
+ double score = update_model_expectation();
+
+ if (_l2reg == 0) {
+ for (size_t i = 0; i < x.size(); i++) {
+ grad[i] = -(_vee[i] - _vme[i]);
+ }
+ } else {
+ const double c = _l2reg * 2;
+ for (size_t i = 0; i < x.size(); i++) {
+ grad[i] = -(_vee[i] - _vme[i] - c * _vl[i]);
+ }
+ }
+
+ return -score;
+}
+
+int ME_Model::perform_GIS(int C) {
+ cerr << "C = " << C << endl;
+ C = 1;
+ cerr << "performing AGIS" << endl;
+ vector<double> pre_v;
+ double pre_logl = -999999;
+ for (int iter = 0; iter < 200; iter++) {
+
+ double logl = update_model_expectation();
+ fprintf(stderr, "iter = %2d C = %d f = %10.7f train_err = %7.5f", iter,
+ C, logl, _train_error);
+ if (_heldout.size() > 0) {
+ double hlogl = heldout_likelihood();
+ fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", hlogl,
+ _heldout_error);
+ }
+ cerr << endl;
+
+ if (logl < pre_logl) {
+ C += 1;
+ _vl = pre_v;
+ iter--;
+ continue;
+ }
+ if (C > 1 && iter % 10 == 0) C--;
+
+ pre_logl = logl;
+ pre_v = _vl;
+ for (int i = 0; i < _fb.Size(); i++) {
+ double coef = _vee[i] / _vme[i];
+ _vl[i] += log(coef) / C;
+ }
+ }
+ cerr << endl;
+
+ return 0;
+}
+
+int ME_Model::perform_QUASI_NEWTON() {
+ const int dim = _fb.Size();
+ vector<double> x0(dim);
+
+ for (int i = 0; i < dim; i++) {
+ x0[i] = _vl[i];
+ }
+
+ vector<double> x;
+ if (_l1reg > 0) {
+ cerr << "performing OWLQN" << endl;
+ x = perform_OWLQN(x0, _l1reg);
+ } else {
+ cerr << "performing LBFGS" << endl;
+ x = perform_LBFGS(x0);
+ }
+
+ for (int i = 0; i < dim; i++) {
+ _vl[i] = x[i];
+ }
+
+ return 0;
+}
+
+int ME_Model::conditional_probability(const Sample& s,
+ std::vector<double>& membp) const {
+ // int num_classes = membp.size();
+ double sum = 0;
+ int max_label = 0;
+ // double maxp = 0;
+
+ vector<double> powv(_num_classes, 0.0);
+ for (vector<int>::const_iterator j = s.positive_features.begin();
+ j != s.positive_features.end(); j++) {
+ for (vector<int>::const_iterator k = _feature2mef[*j].begin();
+ k != _feature2mef[*j].end(); k++) {
+ powv[_fb.Feature(*k).label()] += _vl[*k];
+ }
+ }
+ for (vector<pair<int, double> >::const_iterator j = s.rvfeatures.begin();
+ j != s.rvfeatures.end(); j++) {
+ for (vector<int>::const_iterator k = _feature2mef[j->first].begin();
+ k != _feature2mef[j->first].end(); k++) {
+ powv[_fb.Feature(*k).label()] += _vl[*k] * j->second;
+ }
+ }
+
+ std::vector<double>::const_iterator pmax =
+ max_element(powv.begin(), powv.end());
+ double offset = max(0.0, *pmax - 700); // to avoid overflow
+ for (int label = 0; label < _num_classes; label++) {
+ double pow = powv[label] - offset;
+ double prod = exp(pow);
+ // cout << pow << " " << prod << ", ";
+ // if (_ref_modelp != NULL) prod *= _train_refpd[n][label];
+ if (_ref_modelp != NULL) prod *= s.ref_pd[label];
+ assert(prod != 0);
+ membp[label] = prod;
+ sum += prod;
+ }
+ for (int label = 0; label < _num_classes; label++) {
+ membp[label] /= sum;
+ if (membp[label] > membp[max_label]) max_label = label;
+ }
+ return max_label;
+}
+
+int ME_Model::make_feature_bag(const int cutoff) {
+ int max_num_features = 0;
+
+// count the occurrences of features
+#ifdef USE_HASH_MAP
+ typedef std::unordered_map<unsigned int, int> map_type;
+#else
+ typedef std::map<unsigned int, int> map_type;
+#endif
+ map_type count;
+ if (cutoff > 0) {
+ for (std::vector<Sample>::const_iterator i = _vs.begin(); i != _vs.end();
+ i++) {
+ for (std::vector<int>::const_iterator j = i->positive_features.begin();
+ j != i->positive_features.end(); j++) {
+ count[ME_Feature(i->label, *j).body()]++;
+ }
+ for (std::vector<pair<int, double> >::const_iterator j =
+ i->rvfeatures.begin();
+ j != i->rvfeatures.end(); j++) {
+ count[ME_Feature(i->label, j->first).body()]++;
+ }
+ }
+ }
+
+ int n = 0;
+ for (std::vector<Sample>::const_iterator i = _vs.begin(); i != _vs.end();
+ i++, n++) {
+ max_num_features =
+ max(max_num_features, (int)(i->positive_features.size()));
+ for (std::vector<int>::const_iterator j = i->positive_features.begin();
+ j != i->positive_features.end(); j++) {
+ const ME_Feature feature(i->label, *j);
+ // if (cutoff > 0 && count[feature.body()] < cutoff) continue;
+ if (cutoff > 0 && count[feature.body()] <= cutoff) continue;
+ _fb.Put(feature);
+ // cout << i->label << "\t" << *j << "\t" << id << endl;
+ // feature2sample[id].push_back(n);
+ }
+ for (std::vector<pair<int, double> >::const_iterator j =
+ i->rvfeatures.begin();
+ j != i->rvfeatures.end(); j++) {
+ const ME_Feature feature(i->label, j->first);
+ // if (cutoff > 0 && count[feature.body()] < cutoff) continue;
+ if (cutoff > 0 && count[feature.body()] <= cutoff) continue;
+ _fb.Put(feature);
+ }
+ }
+ count.clear();
+
+ // cerr << "num_classes = " << _num_classes << endl;
+ // cerr << "max_num_features = " << max_num_features << endl;
+
+ init_feature2mef();
+
+ return max_num_features;
+}
+
+double ME_Model::heldout_likelihood() {
+ double logl = 0;
+ int ncorrect = 0;
+ for (std::vector<Sample>::const_iterator i = _heldout.begin();
+ i != _heldout.end(); i++) {
+ vector<double> membp(_num_classes);
+ int l = classify(*i, membp);
+ logl += log(membp[i->label]);
+ if (l == i->label) ncorrect++;
+ }
+ _heldout_error = 1 - (double)ncorrect / _heldout.size();
+
+ return logl /= _heldout.size();
+}
+
+double ME_Model::update_model_expectation() {
+ double logl = 0;
+ int ncorrect = 0;
+
+ _vme.resize(_fb.Size());
+ for (int i = 0; i < _fb.Size(); i++) _vme[i] = 0;
+
+ int n = 0;
+ for (vector<Sample>::const_iterator i = _vs.begin(); i != _vs.end();
+ i++, n++) {
+ vector<double> membp(_num_classes);
+ int max_label = conditional_probability(*i, membp);
+
+ logl += log(membp[i->label]);
+ // cout << membp[*i] << " " << logl << " ";
+ if (max_label == i->label) ncorrect++;
+
+ // model_expectation
+ for (vector<int>::const_iterator j = i->positive_features.begin();
+ j != i->positive_features.end(); j++) {
+ for (vector<int>::const_iterator k = _feature2mef[*j].begin();
+ k != _feature2mef[*j].end(); k++) {
+ _vme[*k] += membp[_fb.Feature(*k).label()];
+ }
+ }
+ for (vector<pair<int, double> >::const_iterator j = i->rvfeatures.begin();
+ j != i->rvfeatures.end(); j++) {
+ for (vector<int>::const_iterator k = _feature2mef[j->first].begin();
+ k != _feature2mef[j->first].end(); k++) {
+ _vme[*k] += membp[_fb.Feature(*k).label()] * j->second;
+ }
+ }
+ }
+
+ for (int i = 0; i < _fb.Size(); i++) {
+ _vme[i] /= _vs.size();
+ }
+
+ _train_error = 1 - (double)ncorrect / _vs.size();
+
+ logl /= _vs.size();
+
+ if (_l2reg > 0) {
+ const double c = _l2reg;
+ for (int i = 0; i < _fb.Size(); i++) {
+ logl -= _vl[i] * _vl[i] * c;
+ }
+ }
+
+ // logl /= _vs.size();
+
+ // fprintf(stderr, "iter =%3d logl = %10.7f train_acc = %7.5f\n", iter,
+ // logl, (double)ncorrect/train.size());
+ // fprintf(stderr, "logl = %10.7f train_acc = %7.5f\n", logl,
+ // (double)ncorrect/_train.size());
+
+ return logl;
+}
+
+int ME_Model::train(const vector<ME_Sample>& vms) {
+ _vs.clear();
+ for (vector<ME_Sample>::const_iterator i = vms.begin(); i != vms.end(); i++) {
+ add_training_sample(*i);
+ }
+
+ return train();
+}
+
+void ME_Model::add_training_sample(const ME_Sample& mes) {
+ Sample s;
+ s.label = _label_bag.Put(mes.label);
+ if (s.label > ME_Feature::MAX_LABEL_TYPES) {
+ cerr << "error: too many types of labels." << endl;
+ exit(1);
+ }
+ for (vector<string>::const_iterator j = mes.features.begin();
+ j != mes.features.end(); j++) {
+ s.positive_features.push_back(_featurename_bag.Put(*j));
+ }
+ for (vector<pair<string, double> >::const_iterator j = mes.rvfeatures.begin();
+ j != mes.rvfeatures.end(); j++) {
+ s.rvfeatures.push_back(
+ pair<int, double>(_featurename_bag.Put(j->first), j->second));
+ }
+ if (_ref_modelp != NULL) {
+ ME_Sample tmp = mes;
+ ;
+ s.ref_pd = _ref_modelp->classify(tmp);
+ }
+ // cout << s.label << "\t";
+ // for (vector<int>::const_iterator j = s.positive_features.begin(); j !=
+ // s.positive_features.end(); j++){
+ // cout << *j << " ";
+ // }
+ // cout << endl;
+
+ _vs.push_back(s);
+}
+
+int ME_Model::train() {
+ if (_l1reg > 0 && _l2reg > 0) {
+ cerr << "error: L1 and L2 regularizers cannot be used simultaneously."
+ << endl;
+ return 0;
+ }
+ if (_vs.size() == 0) {
+ cerr << "error: no training data." << endl;
+ return 0;
+ }
+ if (_nheldout >= (int)_vs.size()) {
+ cerr << "error: too much heldout data. no training data is available."
+ << endl;
+ return 0;
+ }
+ // if (_nheldout > 0) random_shuffle(_vs.begin(), _vs.end());
+
+ int max_label = 0;
+ for (std::vector<Sample>::const_iterator i = _vs.begin(); i != _vs.end();
+ i++) {
+ max_label = max(max_label, i->label);
+ }
+ _num_classes = max_label + 1;
+ if (_num_classes != _label_bag.Size()) {
+ cerr << "warning: _num_class != _label_bag.Size()" << endl;
+ }
+
+ if (_ref_modelp != NULL) {
+ cerr << "setting reference distribution...";
+ for (int i = 0; i < _ref_modelp->num_classes(); i++) {
+ _label_bag.Put(_ref_modelp->get_class_label(i));
+ }
+ _num_classes = _label_bag.Size();
+ for (vector<Sample>::iterator i = _vs.begin(); i != _vs.end(); i++) {
+ set_ref_dist(*i);
+ }
+ cerr << "done" << endl;
+ }
+
+ for (int i = 0; i < _nheldout; i++) {
+ _heldout.push_back(_vs.back());
+ _vs.pop_back();
+ }
+
+ sort(_vs.begin(), _vs.end());
+
+ int cutoff = 0;
+ if (cutoff > 0) cerr << "cutoff threshold = " << cutoff << endl;
+ if (_l1reg > 0) cerr << "L1 regularizer = " << _l1reg << endl;
+ if (_l2reg > 0) cerr << "L2 regularizer = " << _l2reg << endl;
+
+ // normalize
+ _l1reg /= _vs.size();
+ _l2reg /= _vs.size();
+
+ cerr << "preparing for estimation...";
+ make_feature_bag(cutoff);
+ // _vs.clear();
+ cerr << "done" << endl;
+ cerr << "number of samples = " << _vs.size() << endl;
+ cerr << "number of features = " << _fb.Size() << endl;
+
+ cerr << "calculating empirical expectation...";
+ _vee.resize(_fb.Size());
+ for (int i = 0; i < _fb.Size(); i++) {
+ _vee[i] = 0;
+ }
+ for (int n = 0; n < (int)_vs.size(); n++) {
+ const Sample* i = &_vs[n];
+ for (vector<int>::const_iterator j = i->positive_features.begin();
+ j != i->positive_features.end(); j++) {
+ for (vector<int>::const_iterator k = _feature2mef[*j].begin();
+ k != _feature2mef[*j].end(); k++) {
+ if (_fb.Feature(*k).label() == i->label) _vee[*k] += 1.0;
+ }
+ }
+
+ for (vector<pair<int, double> >::const_iterator j = i->rvfeatures.begin();
+ j != i->rvfeatures.end(); j++) {
+ for (vector<int>::const_iterator k = _feature2mef[j->first].begin();
+ k != _feature2mef[j->first].end(); k++) {
+ if (_fb.Feature(*k).label() == i->label) _vee[*k] += j->second;
+ }
+ }
+ }
+ for (int i = 0; i < _fb.Size(); i++) {
+ _vee[i] /= _vs.size();
+ }
+ cerr << "done" << endl;
+
+ _vl.resize(_fb.Size());
+ for (int i = 0; i < _fb.Size(); i++) _vl[i] = 0.0;
+
+ if (_optimization_method == SGD) {
+ perform_SGD();
+ } else {
+ perform_QUASI_NEWTON();
+ }
+
+ int num_active = 0;
+ for (int i = 0; i < _fb.Size(); i++) {
+ if (_vl[i] != 0) num_active++;
+ }
+ cerr << "number of active features = " << num_active << endl;
+
+ return 0;
+}
+
+void ME_Model::get_features(list<pair<pair<string, string>, double> >& fl) {
+ fl.clear();
+ // for (int i = 0; i < _fb.Size(); i++) {
+ // ME_Feature f = _fb.Feature(i);
+ // fl.push_back( make_pair(make_pair(_label_bag.Str(f.label()),
+ // _featurename_bag.Str(f.feature())), _vl[i]));
+ // }
+ for (MiniStringBag::map_type::const_iterator i = _featurename_bag.begin();
+ i != _featurename_bag.end(); i++) {
+ for (int j = 0; j < _label_bag.Size(); j++) {
+ string label = _label_bag.Str(j);
+ string history = i->first;
+ int id = _fb.Id(ME_Feature(j, i->second));
+ if (id < 0) continue;
+ fl.push_back(make_pair(make_pair(label, history), _vl[id]));
+ }
+ }
+}
+
+void ME_Model::clear() {
+ _vl.clear();
+ _label_bag.Clear();
+ _featurename_bag.Clear();
+ _fb.Clear();
+ _feature2mef.clear();
+ _vee.clear();
+ _vme.clear();
+ _vs.clear();
+ _heldout.clear();
+}
+
+bool ME_Model::load_from_file(const string& filename) {
+ FILE* fp = fopen(filename.c_str(), "r");
+ if (!fp) {
+ cerr << "error: cannot open " << filename << "!" << endl;
+ return false;
+ }
+
+ _vl.clear();
+ _label_bag.Clear();
+ _featurename_bag.Clear();
+ _fb.Clear();
+ char buf[1024];
+ while (fgets(buf, 1024, fp)) {
+ string line(buf);
+ string::size_type t1 = line.find_first_of('\t');
+ string::size_type t2 = line.find_last_of('\t');
+ string classname = line.substr(0, t1);
+ string featurename = line.substr(t1 + 1, t2 - (t1 + 1));
+ float lambda;
+ string w = line.substr(t2 + 1);
+ sscanf(w.c_str(), "%f", &lambda);
+
+ int label = _label_bag.Put(classname);
+ int feature = _featurename_bag.Put(featurename);
+ _fb.Put(ME_Feature(label, feature));
+ _vl.push_back(lambda);
+ }
+
+ _num_classes = _label_bag.Size();
+
+ init_feature2mef();
+
+ fclose(fp);
+
+ return true;
+}
+
+void ME_Model::init_feature2mef() {
+ _feature2mef.clear();
+ for (int i = 0; i < _featurename_bag.Size(); i++) {
+ vector<int> vi;
+ for (int k = 0; k < _num_classes; k++) {
+ int id = _fb.Id(ME_Feature(k, i));
+ if (id >= 0) vi.push_back(id);
+ }
+ _feature2mef.push_back(vi);
+ }
+}
+
+bool ME_Model::load_from_array(const ME_Model_Data data[]) {
+ _vl.clear();
+ for (int i = 0;; i++) {
+ if (string(data[i].label) == "///") break;
+ int label = _label_bag.Put(data[i].label);
+ int feature = _featurename_bag.Put(data[i].feature);
+ _fb.Put(ME_Feature(label, feature));
+ _vl.push_back(data[i].weight);
+ }
+ _num_classes = _label_bag.Size();
+
+ init_feature2mef();
+
+ return true;
+}
+
+bool ME_Model::save_to_file(const string& filename, const double th) const {
+ FILE* fp = fopen(filename.c_str(), "w");
+ if (!fp) {
+ cerr << "error: cannot open " << filename << "!" << endl;
+ return false;
+ }
+
+ // for (int i = 0; i < _fb.Size(); i++) {
+ // if (_vl[i] == 0) continue; // ignore zero-weight features
+ // ME_Feature f = _fb.Feature(i);
+ // fprintf(fp, "%s\t%s\t%f\n", _label_bag.Str(f.label()).c_str(),
+ // _featurename_bag.Str(f.feature()).c_str(), _vl[i]);
+ // }
+ for (MiniStringBag::map_type::const_iterator i = _featurename_bag.begin();
+ i != _featurename_bag.end(); i++) {
+ for (int j = 0; j < _label_bag.Size(); j++) {
+ string label = _label_bag.Str(j);
+ string history = i->first;
+ int id = _fb.Id(ME_Feature(j, i->second));
+ if (id < 0) continue;
+ if (_vl[id] == 0) continue; // ignore zero-weight features
+ if (fabs(_vl[id]) < th) continue; // cut off low-weight features
+ fprintf(fp, "%s\t%s\t%f\n", label.c_str(), history.c_str(), _vl[id]);
+ }
+ }
+
+ fclose(fp);
+
+ return true;
+}
+
+void ME_Model::set_ref_dist(Sample& s) const {
+ vector<double> v0 = s.ref_pd;
+ vector<double> v(_num_classes);
+ for (unsigned int i = 0; i < v.size(); i++) {
+ v[i] = 0;
+ string label = get_class_label(i);
+ int id_ref = _ref_modelp->get_class_id(label);
+ if (id_ref != -1) {
+ v[i] = v0[id_ref];
+ }
+ if (v[i] == 0) v[i] = 0.001; // to avoid -inf logl
+ }
+ s.ref_pd = v;
+}
+
+int ME_Model::classify(const Sample& nbs, vector<double>& membp) const {
+ // vector<double> membp(_num_classes);
+ assert(_num_classes == (int)membp.size());
+ conditional_probability(nbs, membp);
+ int max_label = 0;
+ double max = 0.0;
+ for (int i = 0; i < (int)membp.size(); i++) {
+ // cout << membp[i] << " ";
+ if (membp[i] > max) {
+ max_label = i;
+ max = membp[i];
+ }
+ }
+ // cout << endl;
+ return max_label;
+}
+
+vector<double> ME_Model::classify(ME_Sample& mes) const {
+ Sample s;
+ for (vector<string>::const_iterator j = mes.features.begin();
+ j != mes.features.end(); j++) {
+ int id = _featurename_bag.Id(*j);
+ if (id >= 0) s.positive_features.push_back(id);
+ }
+ for (vector<pair<string, double> >::const_iterator j = mes.rvfeatures.begin();
+ j != mes.rvfeatures.end(); j++) {
+ int id = _featurename_bag.Id(j->first);
+ if (id >= 0) {
+ s.rvfeatures.push_back(pair<int, double>(id, j->second));
+ }
+ }
+ if (_ref_modelp != NULL) {
+ s.ref_pd = _ref_modelp->classify(mes);
+ set_ref_dist(s);
+ }
+
+ vector<double> vp(_num_classes);
+ int label = classify(s, vp);
+ mes.label = get_class_label(label);
+ return vp;
+}
+
+// template<class FuncGrad>
+// std::vector<double>
+// perform_LBFGS(FuncGrad func_grad, const std::vector<double> & x0);
+
+std::vector<double> perform_LBFGS(
+ double (*func_grad)(const std::vector<double> &, std::vector<double> &),
+ const std::vector<double> &x0);
+
+std::vector<double> perform_OWLQN(
+ double (*func_grad)(const std::vector<double> &, std::vector<double> &),
+ const std::vector<double> &x0, const double C);
+
+const int LBFGS_M = 10;
+
+const static int M = LBFGS_M;
+const static double LINE_SEARCH_ALPHA = 0.1;
+const static double LINE_SEARCH_BETA = 0.5;
+
+// stopping criteria
+int LBFGS_MAX_ITER = 300;
+const static double MIN_GRAD_NORM = 0.0001;
+
+// LBFGS
+
+double ME_Model::backtracking_line_search(const Vec& x0, const Vec& grad0,
+ const double f0, const Vec& dx,
+ Vec& x, Vec& grad1) {
+ double t = 1.0 / LINE_SEARCH_BETA;
+
+ double f;
+ do {
+ t *= LINE_SEARCH_BETA;
+ x = x0 + t * dx;
+ f = FunctionGradient(x.STLVec(), grad1.STLVec());
+ // cout << "*";
+ } while (f > f0 + LINE_SEARCH_ALPHA * t * dot_product(dx, grad0));
+
+ return f;
+}
+
+//
+// Jorge Nocedal, "Updating Quasi-Newton Matrices With Limited Storage",
+// Mathematics of Computation, Vol. 35, No. 151, pp. 773-782, 1980.
+//
+Vec approximate_Hg(const int iter, const Vec& grad, const Vec s[],
+ const Vec y[], const double z[]) {
+ int offset, bound;
+ if (iter <= M) {
+ offset = 0;
+ bound = iter;
+ } else {
+ offset = iter - M;
+ bound = M;
+ }
+
+ Vec q = grad;
+ double alpha[M], beta[M];
+ for (int i = bound - 1; i >= 0; i--) {
+ const int j = (i + offset) % M;
+ alpha[i] = z[j] * dot_product(s[j], q);
+ q += -alpha[i] * y[j];
+ }
+ if (iter > 0) {
+ const int j = (iter - 1) % M;
+ const double gamma = ((1.0 / z[j]) / dot_product(y[j], y[j]));
+ // static double gamma;
+ // if (gamma == 0) gamma = ((1.0 / z[j]) / dot_product(y[j], y[j]));
+ q *= gamma;
+ }
+ for (int i = 0; i <= bound - 1; i++) {
+ const int j = (i + offset) % M;
+ beta[i] = z[j] * dot_product(y[j], q);
+ q += s[j] * (alpha[i] - beta[i]);
+ }
+
+ return q;
+}
+
+vector<double> ME_Model::perform_LBFGS(const vector<double>& x0) {
+ const size_t dim = x0.size();
+ Vec x = x0;
+
+ Vec grad(dim), dx(dim);
+ double f = FunctionGradient(x.STLVec(), grad.STLVec());
+
+ Vec s[M], y[M];
+ double z[M]; // rho
+
+ for (int iter = 0; iter < LBFGS_MAX_ITER; iter++) {
+
+ fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, -f, _train_error);
+ if (_nheldout > 0) {
+ const double heldout_logl = heldout_likelihood();
+ fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl,
+ _heldout_error);
+ }
+ fprintf(stderr, "\n");
+
+ if (sqrt(dot_product(grad, grad)) < MIN_GRAD_NORM) break;
+
+ dx = -1 * approximate_Hg(iter, grad, s, y, z);
+
+ Vec x1(dim), grad1(dim);
+ f = backtracking_line_search(x, grad, f, dx, x1, grad1);
+
+ s[iter % M] = x1 - x;
+ y[iter % M] = grad1 - grad;
+ z[iter % M] = 1.0 / dot_product(y[iter % M], s[iter % M]);
+ x = x1;
+ grad = grad1;
+ }
+
+ return x.STLVec();
+}
+
+// OWLQN
+
+// stopping criteria
+int OWLQN_MAX_ITER = 300;
+
+Vec approximate_Hg(const int iter, const Vec& grad, const Vec s[],
+ const Vec y[], const double z[]);
+
+inline int sign(double x) {
+ if (x > 0) return 1;
+ if (x < 0) return -1;
+ return 0;
+};
+
+static Vec pseudo_gradient(const Vec& x, const Vec& grad0, const double C) {
+ Vec grad = grad0;
+ for (size_t i = 0; i < x.Size(); i++) {
+ if (x[i] != 0) {
+ grad[i] += C * sign(x[i]);
+ continue;
+ }
+ const double gm = grad0[i] - C;
+ if (gm > 0) {
+ grad[i] = gm;
+ continue;
+ }
+ const double gp = grad0[i] + C;
+ if (gp < 0) {
+ grad[i] = gp;
+ continue;
+ }
+ grad[i] = 0;
+ }
+
+ return grad;
+}
+
+double ME_Model::regularized_func_grad(const double C, const Vec& x,
+ Vec& grad) {
+ double f = FunctionGradient(x.STLVec(), grad.STLVec());
+ for (size_t i = 0; i < x.Size(); i++) {
+ f += C * fabs(x[i]);
+ }
+
+ return f;
+}
+
+double ME_Model::constrained_line_search(double C, const Vec& x0,
+ const Vec& grad0, const double f0,
+ const Vec& dx, Vec& x, Vec& grad1) {
+ // compute the orthant to explore
+ Vec orthant = x0;
+ for (size_t i = 0; i < orthant.Size(); i++) {
+ if (orthant[i] == 0) orthant[i] = -grad0[i];
+ }
+
+ double t = 1.0 / LINE_SEARCH_BETA;
+
+ double f;
+ do {
+ t *= LINE_SEARCH_BETA;
+ x = x0 + t * dx;
+ x.Project(orthant);
+ // for (size_t i = 0; i < x.Size(); i++) {
+ // if (x0[i] != 0 && sign(x[i]) != sign(x0[i])) x[i] = 0;
+ // }
+
+ f = regularized_func_grad(C, x, grad1);
+ // cout << "*";
+ } while (f > f0 + LINE_SEARCH_ALPHA * dot_product(x - x0, grad0));
+
+ return f;
+}
+
+vector<double> ME_Model::perform_OWLQN(const vector<double>& x0,
+ const double C) {
+ const size_t dim = x0.size();
+ Vec x = x0;
+
+ Vec grad(dim), dx(dim);
+ double f = regularized_func_grad(C, x, grad);
+
+ Vec s[M], y[M];
+ double z[M]; // rho
+
+ for (int iter = 0; iter < OWLQN_MAX_ITER; iter++) {
+ Vec pg = pseudo_gradient(x, grad, C);
+
+ fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, -f, _train_error);
+ if (_nheldout > 0) {
+ const double heldout_logl = heldout_likelihood();
+ fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl,
+ _heldout_error);
+ }
+ fprintf(stderr, "\n");
+
+ if (sqrt(dot_product(pg, pg)) < MIN_GRAD_NORM) break;
+
+ dx = -1 * approximate_Hg(iter, pg, s, y, z);
+ if (dot_product(dx, pg) >= 0) dx.Project(-1 * pg);
+
+ Vec x1(dim), grad1(dim);
+ f = constrained_line_search(C, x, pg, f, dx, x1, grad1);
+
+ s[iter % M] = x1 - x;
+ y[iter % M] = grad1 - grad;
+ z[iter % M] = 1.0 / dot_product(y[iter % M], s[iter % M]);
+
+ x = x1;
+ grad = grad1;
+ }
+
+ return x.STLVec();
+}
+
+// SGD
+
+// const double SGD_ETA0 = 1;
+// const double SGD_ITER = 30;
+// const double SGD_ALPHA = 0.85;
+
+//#define FOLOS_NAIVE
+//#define FOLOS_LAZY
+#define SGD_CP
+
+inline void apply_l1_penalty(const int i, const double u, vector<double>& _vl,
+ vector<double>& q) {
+ double& w = _vl[i];
+ const double z = w;
+ double& qi = q[i];
+ if (w > 0) {
+ w = max(0.0, w - (u + qi));
+ } else if (w < 0) {
+ w = min(0.0, w + (u - qi));
+ }
+ qi += w - z;
+}
+
+static double l1norm(const vector<double>& v) {
+ double sum = 0;
+ for (size_t i = 0; i < v.size(); i++) sum += abs(v[i]);
+ return sum;
+}
+
+inline void update_folos_lazy(const int iter_sample, const int k,
+ vector<double>& _vl,
+ const vector<double>& sum_eta,
+ vector<int>& last_updated) {
+ const double penalty = sum_eta[iter_sample] - sum_eta[last_updated[k]];
+ double& x = _vl[k];
+ if (x > 0)
+ x = max(0.0, x - penalty);
+ else
+ x = min(0.0, x + penalty);
+ last_updated[k] = iter_sample;
+}
+
+int ME_Model::perform_SGD() {
+ if (_l2reg > 0) {
+ cerr << "error: L2 regularization is currently not supported in SGD mode."
+ << endl;
+ exit(1);
+ }
+
+ cerr << "performing SGD" << endl;
+
+ const double l1param = _l1reg;
+
+ const int d = _fb.Size();
+
+ vector<int> ri(_vs.size());
+ for (size_t i = 0; i < ri.size(); i++) ri[i] = i;
+
+ vector<double> grad(d);
+ int iter_sample = 0;
+ const double eta0 = SGD_ETA0;
+
+ // cerr << "l1param = " << l1param << endl;
+ cerr << "eta0 = " << eta0 << " alpha = " << SGD_ALPHA << endl;
+
+ double u = 0;
+ vector<double> q(d, 0);
+ vector<int> last_updated(d, 0);
+ vector<double> sum_eta;
+ sum_eta.push_back(0);
+
+ for (int iter = 0; iter < SGD_ITER; iter++) {
+
+ random_shuffle(ri.begin(), ri.end());
+
+ double logl = 0;
+ int ncorrect = 0, ntotal = 0;
+ for (size_t i = 0; i < _vs.size(); i++, ntotal++, iter_sample++) {
+ const Sample& s = _vs[ri[i]];
+
+#ifdef FOLOS_LAZY
+ for (vector<int>::const_iterator j = s.positive_features.begin();
+ j != s.positive_features.end(); j++) {
+ for (vector<int>::const_iterator k = _feature2mef[*j].begin();
+ k != _feature2mef[*j].end(); k++) {
+ update_folos_lazy(iter_sample, *k, _vl, sum_eta, last_updated);
+ }
+ }
+#endif
+
+ vector<double> membp(_num_classes);
+ const int max_label = conditional_probability(s, membp);
+
+ const double eta =
+ eta0 * pow(SGD_ALPHA,
+ (double)iter_sample / _vs.size()); // exponential decay
+ // const double eta = eta0 / (1.0 + (double)iter_sample /
+ // _vs.size());
+
+ // if (iter_sample % _vs.size() == 0) cerr << "eta = " << eta <<
+ // endl;
+ u += eta * l1param;
+
+ sum_eta.push_back(sum_eta.back() + eta * l1param);
+
+ logl += log(membp[s.label]);
+ if (max_label == s.label) ncorrect++;
+
+ // binary features
+ for (vector<int>::const_iterator j = s.positive_features.begin();
+ j != s.positive_features.end(); j++) {
+ for (vector<int>::const_iterator k = _feature2mef[*j].begin();
+ k != _feature2mef[*j].end(); k++) {
+ const double me = membp[_fb.Feature(*k).label()];
+ const double ee = (_fb.Feature(*k).label() == s.label ? 1.0 : 0);
+ const double grad = (me - ee);
+ _vl[*k] -= eta * grad;
+#ifdef SGD_CP
+ apply_l1_penalty(*k, u, _vl, q);
+#endif
+ }
+ }
+ // real-valued features
+ for (vector<pair<int, double> >::const_iterator j = s.rvfeatures.begin();
+ j != s.rvfeatures.end(); j++) {
+ for (vector<int>::const_iterator k = _feature2mef[j->first].begin();
+ k != _feature2mef[j->first].end(); k++) {
+ const double me = membp[_fb.Feature(*k).label()];
+ const double ee = (_fb.Feature(*k).label() == s.label ? 1.0 : 0);
+ const double grad = (me - ee) * j->second;
+ _vl[*k] -= eta * grad;
+#ifdef SGD_CP
+ apply_l1_penalty(*k, u, _vl, q);
+#endif
+ }
+ }
+
+#ifdef FOLOS_NAIVE
+ for (size_t j = 0; j < d; j++) {
+ double& x = _vl[j];
+ if (x > 0)
+ x = max(0.0, x - eta * l1param);
+ else
+ x = min(0.0, x + eta * l1param);
+ }
+#endif
+ }
+ logl /= _vs.size();
+// fprintf(stderr, "%4d logl = %8.3f acc = %6.4f ", iter, logl,
+// (double)ncorrect / ntotal);
+
+#ifdef FOLOS_LAZY
+ if (l1param > 0) {
+ for (size_t j = 0; j < d; j++)
+ update_folos_lazy(iter_sample, j, _vl, sum_eta, last_updated);
+ }
+#endif
+
+ double f = logl;
+ if (l1param > 0) {
+ const double l1 =
+ l1norm(_vl); // this is not accurate when lazy update is used
+ // cerr << "f0 = " << update_model_expectation() - l1param * l1 << "
+ // ";
+ f -= l1param * l1;
+ int nonzero = 0;
+ for (int j = 0; j < d; j++)
+ if (_vl[j] != 0) nonzero++;
+ // cerr << " f = " << f << " l1 = " << l1 << " nonzero_features = "
+ // << nonzero << endl;
+ }
+ // fprintf(stderr, "%4d obj = %7.3f acc = %6.4f", iter+1, f,
+ // (double)ncorrect/ntotal);
+ // fprintf(stderr, "%4d obj = %f", iter+1, f);
+ fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, f,
+ 1 - (double)ncorrect / ntotal);
+
+ if (_nheldout > 0) {
+ double heldout_logl = heldout_likelihood();
+ // fprintf(stderr, " heldout_logl = %f acc = %6.4f\n",
+ // heldout_logl, 1 - _heldout_error);
+ fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl,
+ _heldout_error);
+ }
+ fprintf(stderr, "\n");
+ }
+
+ return 0;
+}
+
+} // namespace maxent
+
+/*
+ * $Log: maxent.cpp,v $
+ * Revision 1.1.1.1 2007/05/15 08:30:35 kyoshida
+ * stepp tagger, by Okanohara and Tsuruoka
+ *
+ * Revision 1.28 2006/08/21 17:30:38 tsuruoka
+ * use MAX_LABEL_TYPES
+ *
+ * Revision 1.27 2006/07/25 13:19:53 tsuruoka
+ * sort _vs[]
+ *
+ * Revision 1.26 2006/07/18 11:13:15 tsuruoka
+ * modify comments
+ *
+ * Revision 1.25 2006/07/18 10:02:15 tsuruoka
+ * remove sample2feature[]
+ * speed up conditional_probability()
+ *
+ * Revision 1.24 2006/07/18 05:10:51 tsuruoka
+ * add ref_dist
+ *
+ * Revision 1.23 2005/12/24 07:05:32 tsuruoka
+ * modify conditional_probability() to avoid overflow
+ *
+ * Revision 1.22 2005/12/24 07:01:25 tsuruoka
+ * add cutoff for real-valued features
+ *
+ * Revision 1.21 2005/12/23 10:33:02 tsuruoka
+ * support real-valued features
+ *
+ * Revision 1.20 2005/12/23 09:15:29 tsuruoka
+ * modify _train to reduce memory consumption
+ *
+ * Revision 1.19 2005/10/28 13:10:14 tsuruoka
+ * fix for overflow (thanks to Ming Li)
+ *
+ * Revision 1.18 2005/10/28 13:03:07 tsuruoka
+ * add progress_bar
+ *
+ * Revision 1.17 2005/09/12 13:51:16 tsuruoka
+ * Sample: list -> vector
+ *
+ * Revision 1.16 2005/09/12 13:27:10 tsuruoka
+ * add add_training_sample()
+ *
+ * Revision 1.15 2005/04/27 11:22:27 tsuruoka
+ * bugfix
+ * ME_Sample: list -> vector
+ *
+ * Revision 1.14 2005/04/27 10:00:42 tsuruoka
+ * remove tmpfb
+ *
+ * Revision 1.13 2005/04/26 14:25:53 tsuruoka
+ * add MiniStringBag, USE_HASH_MAP
+ *
+ * Revision 1.12 2005/02/11 10:20:08 tsuruoka
+ * modify cutoff
+ *
+ * Revision 1.11 2004/10/04 05:50:25 tsuruoka
+ * add Clear()
+ *
+ * Revision 1.10 2004/08/26 16:52:26 tsuruoka
+ * fix load_from_file()
+ *
+ * Revision 1.9 2004/08/09 12:27:21 tsuruoka
+ * change messages
+ *
+ * Revision 1.8 2004/08/04 13:55:18 tsuruoka
+ * modify _sample2feature
+ *
+ * Revision 1.7 2004/07/28 13:42:58 tsuruoka
+ * add AGIS
+ *
+ * Revision 1.6 2004/07/28 05:54:13 tsuruoka
+ * get_class_name() -> get_class_label()
+ * ME_Feature: bugfix
+ *
+ * Revision 1.5 2004/07/27 16:58:47 tsuruoka
+ * modify the interface of classify()
+ *
+ * Revision 1.4 2004/07/26 17:23:46 tsuruoka
+ * _sample2feature: list -> vector
+ *
+ * Revision 1.3 2004/07/26 15:49:23 tsuruoka
+ * modify ME_Feature
+ *
+ * Revision 1.2 2004/07/26 13:52:18 tsuruoka
+ * modify cutoff
+ *
+ * Revision 1.1 2004/07/26 13:10:55 tsuruoka
+ * add files
+ *
+ * Revision 1.20 2004/07/22 08:34:45 tsuruoka
+ * modify _sample2feature[]
+ *
+ * Revision 1.19 2004/07/21 16:33:01 tsuruoka
+ * remove some comments
+ *
+ */
diff --git a/utils/maxent.h b/utils/maxent.h
new file mode 100644
index 00000000..74d13a6f
--- /dev/null
+++ b/utils/maxent.h
@@ -0,0 +1,477 @@
+/*
+ * $Id: maxent.h,v 1.1.1.1 2007/05/15 08:30:35 kyoshida Exp $
+ */
+
+#ifndef __MAXENT_H_
+#define __MAXENT_H_
+
+#include <algorithm>
+#include <iostream>
+#include <list>
+#include <map>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <cassert>
+
+namespace maxent {
+class Vec {
+ private:
+ std::vector<double> _v;
+
+ public:
+ Vec(const size_t n = 0, const double val = 0) { _v.resize(n, val); }
+ Vec(const std::vector<double>& v) : _v(v) {}
+ const std::vector<double>& STLVec() const { return _v; }
+ std::vector<double>& STLVec() { return _v; }
+ size_t Size() const { return _v.size(); }
+ double& operator[](int i) { return _v[i]; }
+ const double& operator[](int i) const { return _v[i]; }
+ Vec& operator+=(const Vec& b) {
+ assert(b.Size() == _v.size());
+ for (size_t i = 0; i < _v.size(); i++) {
+ _v[i] += b[i];
+ }
+ return *this;
+ }
+ Vec& operator*=(const double c) {
+ for (size_t i = 0; i < _v.size(); i++) {
+ _v[i] *= c;
+ }
+ return *this;
+ }
+ void Project(const Vec& y) {
+ for (size_t i = 0; i < _v.size(); i++) {
+ // if (sign(_v[i]) != sign(y[i])) _v[i] = 0;
+ if (_v[i] * y[i] <= 0) _v[i] = 0;
+ }
+ }
+};
+
+inline double dot_product(const Vec& a, const Vec& b) {
+ double sum = 0;
+ for (size_t i = 0; i < a.Size(); i++) {
+ sum += a[i] * b[i];
+ }
+ return sum;
+}
+
+inline std::ostream& operator<<(std::ostream& s, const Vec& a) {
+ s << "(";
+ for (size_t i = 0; i < a.Size(); i++) {
+ if (i != 0) s << ", ";
+ s << a[i];
+ }
+ s << ")";
+ return s;
+}
+
+inline const Vec operator+(const Vec& a, const Vec& b) {
+ Vec v(a.Size());
+ assert(a.Size() == b.Size());
+ for (size_t i = 0; i < a.Size(); i++) {
+ v[i] = a[i] + b[i];
+ }
+ return v;
+}
+
+inline const Vec operator-(const Vec& a, const Vec& b) {
+ Vec v(a.Size());
+ assert(a.Size() == b.Size());
+ for (size_t i = 0; i < a.Size(); i++) {
+ v[i] = a[i] - b[i];
+ }
+ return v;
+}
+
+inline const Vec operator*(const Vec& a, const double c) {
+ Vec v(a.Size());
+ for (size_t i = 0; i < a.Size(); i++) {
+ v[i] = a[i] * c;
+ }
+ return v;
+}
+
+inline const Vec operator*(const double c, const Vec& a) { return a * c; }
+
+//
+// data format for each sample for training/testing
+//
+struct ME_Sample {
+ public:
+ ME_Sample() : label("") {};
+ ME_Sample(const std::string& l) : label(l) {};
+ void set_label(const std::string& l) { label = l; }
+
+ // to add a binary feature
+ void add_feature(const std::string& f) { features.push_back(f); }
+
+ // to add a real-valued feature
+ void add_feature(const std::string& s, const double d) {
+ rvfeatures.push_back(std::pair<std::string, double>(s, d));
+ }
+
+ public:
+ std::string label;
+ std::vector<std::string> features;
+ std::vector<std::pair<std::string, double> > rvfeatures;
+
+ // obsolete
+ void add_feature(const std::pair<std::string, double>& f) {
+ rvfeatures.push_back(f); // real-valued features
+ }
+};
+
+//
+// for those who want to use load_from_array()
+//
+typedef struct ME_Model_Data {
+ char* label;
+ char* feature;
+ double weight;
+} ME_Model_Data;
+
+class ME_Model {
+ public:
+ void add_training_sample(const ME_Sample& s);
+ int train();
+ std::vector<double> classify(ME_Sample& s) const;
+ bool load_from_file(const std::string& filename);
+ bool save_to_file(const std::string& filename, const double th = 0) const;
+ int num_classes() const { return _num_classes; }
+ std::string get_class_label(int i) const { return _label_bag.Str(i); }
+ int get_class_id(const std::string& s) const { return _label_bag.Id(s); }
+ void get_features(
+ std::list<std::pair<std::pair<std::string, std::string>, double> >& fl);
+ void set_heldout(const int h, const int n = 0) {
+ _nheldout = h;
+ _early_stopping_n = n;
+ };
+ void use_l1_regularizer(const double v) { _l1reg = v; }
+ void use_l2_regularizer(const double v) { _l2reg = v; }
+ void use_SGD(int iter = 30, double eta0 = 1, double alpha = 0.85) {
+ _optimization_method = SGD;
+ SGD_ITER = iter;
+ SGD_ETA0 = eta0;
+ SGD_ALPHA = alpha;
+ }
+ bool load_from_array(const ME_Model_Data data[]);
+ void set_reference_model(const ME_Model& ref_model) {
+ _ref_modelp = &ref_model;
+ };
+ void clear();
+
+ ME_Model() {
+ _l1reg = _l2reg = 0;
+ _nheldout = 0;
+ _early_stopping_n = 0;
+ _ref_modelp = NULL;
+ _optimization_method = LBFGS;
+ }
+
+ public:
+ // obsolete. just for downward compatibility
+ int train(const std::vector<ME_Sample>& train);
+
+ private:
+ enum OPTIMIZATION_METHOD {
+ LBFGS,
+ OWLQN,
+ SGD
+ } _optimization_method;
+ // OWLQN and SGD are available only for L1-regularization
+
+ int SGD_ITER;
+ double SGD_ETA0;
+ double SGD_ALPHA;
+
+ double _l1reg, _l2reg;
+
+ struct Sample {
+ int label;
+ std::vector<int> positive_features;
+ std::vector<std::pair<int, double> > rvfeatures;
+ std::vector<double> ref_pd; // reference probability distribution
+ bool operator<(const Sample& x) const {
+ for (unsigned int i = 0; i < positive_features.size(); i++) {
+ if (i >= x.positive_features.size()) return false;
+ int v0 = positive_features[i];
+ int v1 = x.positive_features[i];
+ if (v0 < v1) return true;
+ if (v0 > v1) return false;
+ }
+ return false;
+ }
+ };
+
+ struct ME_Feature {
+ enum {
+ MAX_LABEL_TYPES = 255
+ };
+
+ // ME_Feature(const int l, const int f) : _body((l << 24) + f) {
+ // assert(l >= 0 && l < 256);
+ // assert(f >= 0 && f <= 0xffffff);
+ // };
+ // int label() const { return _body >> 24; }
+ // int feature() const { return _body & 0xffffff; }
+ ME_Feature(const int l, const int f) : _body((f << 8) + l) {
+ assert(l >= 0 && l <= MAX_LABEL_TYPES);
+ assert(f >= 0 && f <= 0xffffff);
+ };
+ int label() const { return _body & 0xff; }
+ int feature() const { return _body >> 8; }
+ unsigned int body() const { return _body; }
+
+ private:
+ unsigned int _body;
+ };
+
+ struct ME_FeatureBag {
+#ifdef USE_HASH_MAP
+ typedef std::unordered_map<unsigned int, int> map_type;
+#else
+ typedef std::map<unsigned int, int> map_type;
+#endif
+ map_type mef2id;
+ std::vector<ME_Feature> id2mef;
+ int Put(const ME_Feature& i) {
+ map_type::const_iterator j = mef2id.find(i.body());
+ if (j == mef2id.end()) {
+ int id = id2mef.size();
+ id2mef.push_back(i);
+ mef2id[i.body()] = id;
+ return id;
+ }
+ return j->second;
+ }
+ int Id(const ME_Feature& i) const {
+ map_type::const_iterator j = mef2id.find(i.body());
+ if (j == mef2id.end()) {
+ return -1;
+ }
+ return j->second;
+ }
+ ME_Feature Feature(int id) const {
+ assert(id >= 0 && id < (int)id2mef.size());
+ return id2mef[id];
+ }
+ int Size() const { return id2mef.size(); }
+ void Clear() {
+ mef2id.clear();
+ id2mef.clear();
+ }
+ };
+
+ struct hashfun_str {
+ size_t operator()(const std::string& s) const {
+ assert(sizeof(int) == 4 && sizeof(char) == 1);
+ const int* p = reinterpret_cast<const int*>(s.c_str());
+ size_t v = 0;
+ int n = s.size() / 4;
+ for (int i = 0; i < n; i++, p++) {
+ // v ^= *p;
+ v ^= *p << (4 * (i % 2)); // note) 0 <= char < 128
+ }
+ int m = s.size() % 4;
+ for (int i = 0; i < m; i++) {
+ v ^= s[4 * n + i] << (i * 8);
+ }
+ return v;
+ }
+ };
+
+ struct MiniStringBag {
+#ifdef USE_HASH_MAP
+ typedef std::unordered_map<std::string, int, hashfun_str> map_type;
+#else
+ typedef std::map<std::string, int> map_type;
+#endif
+ int _size;
+ map_type str2id;
+ MiniStringBag() : _size(0) {}
+ int Put(const std::string& i) {
+ map_type::const_iterator j = str2id.find(i);
+ if (j == str2id.end()) {
+ int id = _size;
+ _size++;
+ str2id[i] = id;
+ return id;
+ }
+ return j->second;
+ }
+ int Id(const std::string& i) const {
+ map_type::const_iterator j = str2id.find(i);
+ if (j == str2id.end()) return -1;
+ return j->second;
+ }
+ int Size() const { return _size; }
+ void Clear() {
+ str2id.clear();
+ _size = 0;
+ }
+ map_type::const_iterator begin() const { return str2id.begin(); }
+ map_type::const_iterator end() const { return str2id.end(); }
+ };
+
+ struct StringBag : public MiniStringBag {
+ std::vector<std::string> id2str;
+ int Put(const std::string& i) {
+ map_type::const_iterator j = str2id.find(i);
+ if (j == str2id.end()) {
+ int id = id2str.size();
+ id2str.push_back(i);
+ str2id[i] = id;
+ return id;
+ }
+ return j->second;
+ }
+ std::string Str(const int id) const {
+ assert(id >= 0 && id < (int)id2str.size());
+ return id2str[id];
+ }
+ int Size() const { return id2str.size(); }
+ void Clear() {
+ str2id.clear();
+ id2str.clear();
+ }
+ };
+
+ std::vector<Sample> _vs; // vector of training_samples
+ StringBag _label_bag;
+ MiniStringBag _featurename_bag;
+ std::vector<double> _vl; // vector of lambda
+ ME_FeatureBag _fb;
+ int _num_classes;
+ std::vector<double> _vee; // empirical expectation
+ std::vector<double> _vme; // empirical expectation
+ std::vector<std::vector<int> > _feature2mef;
+ std::vector<Sample> _heldout;
+ double _train_error; // current error rate on the training data
+ double _heldout_error; // current error rate on the heldout data
+ int _nheldout;
+ int _early_stopping_n;
+ std::vector<double> _vhlogl;
+ const ME_Model* _ref_modelp;
+
+ double heldout_likelihood();
+ int conditional_probability(const Sample& nbs,
+ std::vector<double>& membp) const;
+ int make_feature_bag(const int cutoff);
+ int classify(const Sample& nbs, std::vector<double>& membp) const;
+ double update_model_expectation();
+ int perform_QUASI_NEWTON();
+ int perform_SGD();
+ int perform_GIS(int C);
+ std::vector<double> perform_LBFGS(const std::vector<double>& x0);
+ std::vector<double> perform_OWLQN(const std::vector<double>& x0,
+ const double C);
+ double backtracking_line_search(const Vec& x0, const Vec& grad0,
+ const double f0, const Vec& dx, Vec& x,
+ Vec& grad1);
+ double regularized_func_grad(const double C, const Vec& x, Vec& grad);
+ double constrained_line_search(double C, const Vec& x0, const Vec& grad0,
+ const double f0, const Vec& dx, Vec& x,
+ Vec& grad1);
+
+ void set_ref_dist(Sample& s) const;
+ void init_feature2mef();
+
+ double FunctionGradient(const std::vector<double>& x,
+ std::vector<double>& grad);
+ static double FunctionGradientWrapper(const std::vector<double>& x,
+ std::vector<double>& grad);
+};
+} // namespace maxent
+
+#endif
+
+/*
+ * $Log: maxent.h,v $
+ * Revision 1.1.1.1 2007/05/15 08:30:35 kyoshida
+ * stepp tagger, by Okanohara and Tsuruoka
+ *
+ * Revision 1.24 2006/08/21 17:30:38 tsuruoka
+ * use MAX_LABEL_TYPES
+ *
+ * Revision 1.23 2006/07/25 13:19:53 tsuruoka
+ * sort _vs[]
+ *
+ * Revision 1.22 2006/07/18 11:13:15 tsuruoka
+ * modify comments
+ *
+ * Revision 1.21 2006/07/18 10:02:15 tsuruoka
+ * remove sample2feature[]
+ * speed up conditional_probability()
+ *
+ * Revision 1.20 2006/07/18 05:10:51 tsuruoka
+ * add ref_dist
+ *
+ * Revision 1.19 2005/12/23 10:33:02 tsuruoka
+ * support real-valued features
+ *
+ * Revision 1.18 2005/12/23 09:15:29 tsuruoka
+ * modify _train to reduce memory consumption
+ *
+ * Revision 1.17 2005/10/28 13:02:34 tsuruoka
+ * set_heldout(): add default value
+ * Feature()
+ *
+ * Revision 1.16 2005/09/12 13:51:16 tsuruoka
+ * Sample: list -> vector
+ *
+ * Revision 1.15 2005/09/12 13:27:10 tsuruoka
+ * add add_training_sample()
+ *
+ * Revision 1.14 2005/04/27 11:22:27 tsuruoka
+ * bugfix
+ * ME_Sample: list -> vector
+ *
+ * Revision 1.13 2005/04/27 10:20:19 tsuruoka
+ * MiniStringBag -> StringBag
+ *
+ * Revision 1.12 2005/04/27 10:00:42 tsuruoka
+ * remove tmpfb
+ *
+ * Revision 1.11 2005/04/26 14:25:53 tsuruoka
+ * add MiniStringBag, USE_HASH_MAP
+ *
+ * Revision 1.10 2004/10/04 05:50:25 tsuruoka
+ * add Clear()
+ *
+ * Revision 1.9 2004/08/09 12:27:21 tsuruoka
+ * change messages
+ *
+ * Revision 1.8 2004/08/04 13:55:19 tsuruoka
+ * modify _sample2feature
+ *
+ * Revision 1.7 2004/07/29 05:51:13 tsuruoka
+ * remove modeldata.h
+ *
+ * Revision 1.6 2004/07/28 13:42:58 tsuruoka
+ * add AGIS
+ *
+ * Revision 1.5 2004/07/28 05:54:14 tsuruoka
+ * get_class_name() -> get_class_label()
+ * ME_Feature: bugfix
+ *
+ * Revision 1.4 2004/07/27 16:58:47 tsuruoka
+ * modify the interface of classify()
+ *
+ * Revision 1.3 2004/07/26 17:23:46 tsuruoka
+ * _sample2feature: list -> vector
+ *
+ * Revision 1.2 2004/07/26 15:49:23 tsuruoka
+ * modify ME_Feature
+ *
+ * Revision 1.1 2004/07/26 13:10:55 tsuruoka
+ * add files
+ *
+ * Revision 1.18 2004/07/22 08:34:45 tsuruoka
+ * modify _sample2feature[]
+ *
+ * Revision 1.17 2004/07/21 16:33:01 tsuruoka
+ * remove some comments
+ *
+ */
diff --git a/utils/small_vector.h b/utils/small_vector.h
index c8cbcb2c..f16bc898 100644
--- a/utils/small_vector.h
+++ b/utils/small_vector.h
@@ -15,6 +15,7 @@
#include <new>
#include <stdint.h>
#include <boost/functional/hash.hpp>
+#include <boost/serialization/map.hpp>
//sizeof(T)/sizeof(T*)>1?sizeof(T)/sizeof(T*):1
@@ -297,6 +298,21 @@ public:
return hash_range(data_.ptr,data_.ptr+size_);
}
+ template<class Archive>
+ void save(Archive & ar, const unsigned int) const {
+ ar & size_;
+ for (unsigned i = 0; i < size_; ++i)
+ ar & (*this)[i];
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int) {
+ uint16_t s;
+ ar & s;
+ this->resize(s);
+ for (unsigned i = 0; i < size_; ++i)
+ ar & (*this)[i];
+ }
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
private:
union StorageType {
T vals[SV_MAX];
diff --git a/utils/small_vector_test.cc b/utils/small_vector_test.cc
index a4eb89ae..9e1a148d 100644
--- a/utils/small_vector_test.cc
+++ b/utils/small_vector_test.cc
@@ -3,6 +3,10 @@
#define BOOST_TEST_MODULE svTest
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <string>
+#include <sstream>
#include <iostream>
#include <vector>
@@ -128,3 +132,29 @@ BOOST_AUTO_TEST_CASE(Small) {
cerr << sizeof(SmallVectorInt) << endl;
cerr << sizeof(vector<int>) << endl;
}
+
+BOOST_AUTO_TEST_CASE(Serialize) {
+ std::string in;
+ {
+ SmallVectorInt v;
+ v.push_back(0);
+ v.push_back(1);
+ v.push_back(-2);
+ ostringstream os;
+ boost::archive::text_oarchive oa(os);
+ oa << v;
+ in = os.str();
+ cerr << in;
+ }
+ {
+ istringstream is(in);
+ boost::archive::text_iarchive ia(is);
+ SmallVectorInt v;
+ ia >> v;
+ BOOST_CHECK_EQUAL(v.size(), 3);
+ BOOST_CHECK_EQUAL(v[0], 0);
+ BOOST_CHECK_EQUAL(v[1], 1);
+ BOOST_CHECK_EQUAL(v[2], -2);
+ }
+}
+
diff --git a/utils/sv_test.cc b/utils/sv_test.cc
index 67df8c57..b006e66d 100644
--- a/utils/sv_test.cc
+++ b/utils/sv_test.cc
@@ -1,7 +1,12 @@
#define BOOST_TEST_MODULE WeightsTest
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <sstream>
+#include <string>
#include "sparse_vector.h"
+#include "fdict.h"
using namespace std;
@@ -33,3 +38,29 @@ BOOST_AUTO_TEST_CASE(Division) {
x /= -1;
BOOST_CHECK(x == y);
}
+
+BOOST_AUTO_TEST_CASE(Serialization) {
+ string arc;
+ FD::dict_.clear();
+ {
+ SparseVector<double> x;
+ x.set_value(FD::Convert("Feature1"), 1.0);
+ x.set_value(FD::Convert("Pi"), 3.14);
+ ostringstream os;
+ boost::archive::text_oarchive oa(os);
+ oa << x;
+ arc = os.str();
+ }
+ FD::dict_.clear();
+ FD::Convert("SomeNewString");
+ {
+ SparseVector<double> x;
+ istringstream is(arc);
+ boost::archive::text_iarchive ia(is);
+ ia >> x;
+ cerr << x << endl;
+ BOOST_CHECK_CLOSE(x.get(FD::Convert("Pi")), 3.14, 1e-9);
+ BOOST_CHECK_CLOSE(x.get(FD::Convert("Feature1")), 1.0, 1e-9);
+ }
+}
+