From 671c21451542e2dd20e45b4033d44d8e8735f87b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 3 Dec 2009 16:33:55 -0500 Subject: initial check in --- Makefile.am | 3 + README | 24 + configure.ac | 46 ++ src/JSON_parser.h | 152 ++++ src/Makefile.am | 67 ++ src/aligner.cc | 204 +++++ src/aligner.h | 23 + src/apply_models.cc | 344 ++++++++ src/apply_models.h | 20 + src/array2d.h | 171 ++++ src/bottom_up_parser.cc | 260 ++++++ src/bottom_up_parser.h | 27 + src/cdec.cc | 474 +++++++++++ src/cdec_ff.cc | 18 + src/collapse_weights.cc | 102 +++ src/dict.h | 40 + src/dict_test.cc | 30 + src/earley_composer.cc | 726 ++++++++++++++++ src/earley_composer.h | 29 + src/exp_semiring.h | 71 ++ src/fdict.cc | 4 + src/fdict.h | 21 + src/ff.cc | 93 +++ src/ff.h | 121 +++ src/ff_factory.cc | 35 + src/ff_factory.h | 39 + src/ff_itg_span.h | 7 + src/ff_test.cc | 134 +++ src/ff_wordalign.cc | 221 +++++ src/ff_wordalign.h | 133 +++ src/filelib.cc | 22 + src/filelib.h | 66 ++ src/forest_writer.cc | 23 + src/forest_writer.h | 16 + src/freqdict.cc | 23 + src/freqdict.h | 19 + src/fst_translator.cc | 91 ++ src/grammar.cc | 163 ++++ src/grammar.h | 83 ++ src/grammar_test.cc | 59 ++ src/gzstream.cc | 165 ++++ src/gzstream.h | 121 +++ src/hg.cc | 483 +++++++++++ src/hg.h | 225 +++++ src/hg_intersect.cc | 121 +++ src/hg_intersect.h | 13 + src/hg_io.cc | 585 +++++++++++++ src/hg_io.h | 37 + src/hg_test.cc | 441 ++++++++++ src/ibm_model1.cc | 4 + src/inside_outside.h | 111 +++ src/json_parse.cc | 50 ++ src/json_parse.h | 58 ++ src/kbest.h | 207 +++++ src/lattice.cc | 27 + src/lattice.h | 31 + src/lexcrf.cc | 112 +++ src/lexcrf.h | 18 + src/lm_ff.cc | 328 ++++++++ src/lm_ff.h | 32 + src/logval.h | 136 +++ src/maxtrans_blunsom.cc | 287 +++++++ src/parser_test.cc | 35 + src/phrasebased_translator.cc | 206 +++++ src/phrasebased_translator.h | 18 + src/phrasetable_fst.cc | 141 ++++ src/phrasetable_fst.h | 34 + src/prob.h | 8 + src/sampler.h | 136 +++ src/scfg_translator.cc | 66 ++ src/sentence_metadata.h | 42 + src/small_vector.h | 187 +++++ src/small_vector_test.cc | 129 +++ src/sparse_vector.cc | 98 +++ src/sparse_vector.h | 264 ++++++ src/stringlib.cc | 97 +++ src/stringlib.h | 91 ++ src/synparse.cc | 212 +++++ src/tdict.cc | 49 ++ src/tdict.h | 19 + src/timing_stats.cc | 24 + src/timing_stats.h | 25 + src/translator.h | 54 ++ src/trule.cc | 237 ++++++ src/trule.h | 122 +++ src/trule_test.cc | 65 ++ src/ttables.cc | 31 + src/ttables.h | 87 ++ src/viterbi.cc | 39 + src/viterbi.h | 130 +++ src/weights.cc | 73 ++ src/weights.h | 21 + src/weights_test.cc | 28 + src/wordid.h | 6 + training/Makefile.am | 39 + training/atools.cc | 207 +++++ training/cluster-em.pl | 110 +++ training/cluster-ptrain.pl | 144 ++++ training/grammar_convert.cc | 316 +++++++ training/lbfgs.h | 1459 +++++++++++++++++++++++++++++++++ training/lbfgs_test.cc | 112 +++ training/make-lexcrf-grammar.pl | 236 ++++++ training/model1.cc | 103 +++ training/mr_em_train.cc | 270 ++++++ training/mr_optimize_reduce.cc | 243 ++++++ training/optimize.cc | 114 +++ training/optimize.h | 104 +++ training/optimize_test.cc | 105 +++ training/plftools.cc | 93 +++ vest/Makefile.am | 32 + vest/comb_scorer.cc | 81 ++ vest/dist-vest.pl | 642 +++++++++++++++ vest/error_surface.cc | 46 ++ vest/fast_score.cc | 74 ++ vest/line_optimizer.cc | 101 +++ vest/lo_test.cc | 201 +++++ vest/mr_vest_generate_mapper_input.cc | 72 ++ vest/mr_vest_map.cc | 98 +++ vest/mr_vest_reduce.cc | 80 ++ vest/scorer.cc | 485 +++++++++++ vest/scorer_test.cc | 178 ++++ vest/ter.cc | 518 ++++++++++++ vest/test_data/0.json.gz | Bin 0 -> 13709 bytes vest/test_data/1.json.gz | Bin 0 -> 204803 bytes vest/test_data/c2e.txt.0 | 2 + vest/test_data/c2e.txt.1 | 2 + vest/test_data/c2e.txt.2 | 2 + vest/test_data/c2e.txt.3 | 2 + vest/test_data/re.txt.0 | 5 + vest/test_data/re.txt.1 | 5 + vest/test_data/re.txt.2 | 5 + vest/test_data/re.txt.3 | 5 + vest/union_forests.cc | 73 ++ vest/viterbi_envelope.cc | 167 ++++ 134 files changed, 17101 insertions(+) create mode 100644 Makefile.am create mode 100644 README create mode 100644 configure.ac create mode 100644 src/JSON_parser.h create mode 100644 src/Makefile.am create mode 100644 src/aligner.cc create mode 100644 src/aligner.h create mode 100644 src/apply_models.cc create mode 100644 src/apply_models.h create mode 100644 src/array2d.h create mode 100644 src/bottom_up_parser.cc create mode 100644 src/bottom_up_parser.h create mode 100644 src/cdec.cc create mode 100644 src/cdec_ff.cc create mode 100644 src/collapse_weights.cc create mode 100644 src/dict.h create mode 100644 src/dict_test.cc create mode 100644 src/earley_composer.cc create mode 100644 src/earley_composer.h create mode 100644 src/exp_semiring.h create mode 100644 src/fdict.cc create mode 100644 src/fdict.h create mode 100644 src/ff.cc create mode 100644 src/ff.h create mode 100644 src/ff_factory.cc create mode 100644 src/ff_factory.h create mode 100644 src/ff_itg_span.h create mode 100644 src/ff_test.cc create mode 100644 src/ff_wordalign.cc create mode 100644 src/ff_wordalign.h create mode 100644 src/filelib.cc create mode 100644 src/filelib.h create mode 100644 src/forest_writer.cc create mode 100644 src/forest_writer.h create mode 100644 src/freqdict.cc create mode 100644 src/freqdict.h create mode 100644 src/fst_translator.cc create mode 100644 src/grammar.cc create mode 100644 src/grammar.h create mode 100644 src/grammar_test.cc create mode 100644 src/gzstream.cc create mode 100644 src/gzstream.h create mode 100644 src/hg.cc create mode 100644 src/hg.h create mode 100644 src/hg_intersect.cc create mode 100644 src/hg_intersect.h create mode 100644 src/hg_io.cc create mode 100644 src/hg_io.h create mode 100644 src/hg_test.cc create mode 100644 src/ibm_model1.cc create mode 100644 src/inside_outside.h create mode 100644 src/json_parse.cc create mode 100644 src/json_parse.h create mode 100644 src/kbest.h create mode 100644 src/lattice.cc create mode 100644 src/lattice.h create mode 100644 src/lexcrf.cc create mode 100644 src/lexcrf.h create mode 100644 src/lm_ff.cc create mode 100644 src/lm_ff.h create mode 100644 src/logval.h create mode 100644 src/maxtrans_blunsom.cc create mode 100644 src/parser_test.cc create mode 100644 src/phrasebased_translator.cc create mode 100644 src/phrasebased_translator.h create mode 100644 src/phrasetable_fst.cc create mode 100644 src/phrasetable_fst.h create mode 100644 src/prob.h create mode 100644 src/sampler.h create mode 100644 src/scfg_translator.cc create mode 100644 src/sentence_metadata.h create mode 100644 src/small_vector.h create mode 100644 src/small_vector_test.cc create mode 100644 src/sparse_vector.cc create mode 100644 src/sparse_vector.h create mode 100644 src/stringlib.cc create mode 100644 src/stringlib.h create mode 100644 src/synparse.cc create mode 100644 src/tdict.cc create mode 100644 src/tdict.h create mode 100644 src/timing_stats.cc create mode 100644 src/timing_stats.h create mode 100644 src/translator.h create mode 100644 src/trule.cc create mode 100644 src/trule.h create mode 100644 src/trule_test.cc create mode 100644 src/ttables.cc create mode 100644 src/ttables.h create mode 100644 src/viterbi.cc create mode 100644 src/viterbi.h create mode 100644 src/weights.cc create mode 100644 src/weights.h create mode 100644 src/weights_test.cc create mode 100644 src/wordid.h create mode 100644 training/Makefile.am create mode 100644 training/atools.cc create mode 100755 training/cluster-em.pl create mode 100755 training/cluster-ptrain.pl create mode 100644 training/grammar_convert.cc create mode 100644 training/lbfgs.h create mode 100644 training/lbfgs_test.cc create mode 100755 training/make-lexcrf-grammar.pl create mode 100644 training/model1.cc create mode 100644 training/mr_em_train.cc create mode 100644 training/mr_optimize_reduce.cc create mode 100644 training/optimize.cc create mode 100644 training/optimize.h create mode 100644 training/optimize_test.cc create mode 100644 training/plftools.cc create mode 100644 vest/Makefile.am create mode 100644 vest/comb_scorer.cc create mode 100755 vest/dist-vest.pl create mode 100644 vest/error_surface.cc create mode 100644 vest/fast_score.cc create mode 100644 vest/line_optimizer.cc create mode 100644 vest/lo_test.cc create mode 100644 vest/mr_vest_generate_mapper_input.cc create mode 100644 vest/mr_vest_map.cc create mode 100644 vest/mr_vest_reduce.cc create mode 100644 vest/scorer.cc create mode 100644 vest/scorer_test.cc create mode 100644 vest/ter.cc create mode 100644 vest/test_data/0.json.gz create mode 100644 vest/test_data/1.json.gz create mode 100644 vest/test_data/c2e.txt.0 create mode 100644 vest/test_data/c2e.txt.1 create mode 100644 vest/test_data/c2e.txt.2 create mode 100644 vest/test_data/c2e.txt.3 create mode 100644 vest/test_data/re.txt.0 create mode 100644 vest/test_data/re.txt.1 create mode 100644 vest/test_data/re.txt.2 create mode 100644 vest/test_data/re.txt.3 create mode 100644 vest/union_forests.cc create mode 100644 vest/viterbi_envelope.cc diff --git a/Makefile.am b/Makefile.am new file mode 100644 index 00000000..38e9e59a --- /dev/null +++ b/Makefile.am @@ -0,0 +1,3 @@ +SUBDIRS = src training vest +AUTOMAKE_OPTIONS = foreign + diff --git a/README b/README new file mode 100644 index 00000000..63990ffb --- /dev/null +++ b/README @@ -0,0 +1,24 @@ +cdec is a fast decoder. + + .. more coming ... + +COPYRIGHT AND LICENSE +------------------------------------------------------------------------------ +Copyright (c) 2009 by Chris Dyer + +Licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain +a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +The LBFGS implementation contains code from the Computational +Crystallography Toolbox which is copyright (c) 2006 by The Regents of the +University of California, through Lawrence Berkeley National Laboratory. +For more information on their license, refer to http://cctbx.sourceforge.net/ diff --git a/configure.ac b/configure.ac new file mode 100644 index 00000000..76307998 --- /dev/null +++ b/configure.ac @@ -0,0 +1,46 @@ +AC_INIT +AM_INIT_AUTOMAKE(cdec,0.1) +AC_CONFIG_HEADERS(config.h) +AC_PROG_RANLIB +AC_PROG_CC +AC_PROG_CXX +AC_LANG_CPLUSPLUS +AX_BOOST_BASE +AX_BOOST_PROGRAM_OPTIONS +CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" +AC_CHECK_HEADER(boost/math/special_functions/digamma.hpp, + [AC_DEFINE([HAVE_BOOST_DIGAMMA], [], [flag for boost::math::digamma])]) + +GTEST_LIB_CHECK +AC_PROG_INSTALL + +AC_ARG_WITH(srilm, + [AC_HELP_STRING([--with-srilm=PATH], [(optional) path to SRI's LM toolkit])], + [with_srilm=$withval], + [with_srilm=no] + ) + +AM_CONDITIONAL([SRI_LM], false) + +if test "x$with_srilm" != 'xno' +then + SAVE_CPPFLAGS="$CPPFLAGS" + CPPFLAGS="$CPPFLAGS -I${with_srilm}/include" + + AC_CHECK_HEADER(Ngram.h, + [AC_DEFINE([HAVE_SRILM], [], [flag for SRILM])], + [AC_MSG_ERROR([Cannot find SRILM!])]) + + LIB_SRILM="-loolm -ldstruct -lmisc" + # ROOT/lib/i686-m64/liboolm.a + # ROOT/lib/i686-m64/libdstruct.a + # ROOT/lib/i686-m64/libmisc.a + MY_ARCH=`${with_srilm}/sbin/machine-type` + LDFLAGS="$LDFLAGS -L${with_srilm}/lib/${MY_ARCH}" + LIBS="$LIBS $LIB_SRILM" + FMTLIBS="$FMTLIBS liboolm.a libdstruct.a libmisc.a" + AM_CONDITIONAL([SRI_LM], true) +fi + +AC_OUTPUT(Makefile src/Makefile vest/Makefile) + diff --git a/src/JSON_parser.h b/src/JSON_parser.h new file mode 100644 index 00000000..ceb5b24b --- /dev/null +++ b/src/JSON_parser.h @@ -0,0 +1,152 @@ +#ifndef JSON_PARSER_H +#define JSON_PARSER_H + +/* JSON_parser.h */ + + +#include + +/* Windows DLL stuff */ +#ifdef _WIN32 +# ifdef JSON_PARSER_DLL_EXPORTS +# define JSON_PARSER_DLL_API __declspec(dllexport) +# else +# define JSON_PARSER_DLL_API __declspec(dllimport) +# endif +#else +# define JSON_PARSER_DLL_API +#endif + +/* Determine the integer type use to parse non-floating point numbers */ +#if __STDC_VERSION__ >= 199901L || HAVE_LONG_LONG == 1 +typedef long long JSON_int_t; +#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%lld" +#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%lld" +#else +typedef long JSON_int_t; +#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%ld" +#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%ld" +#endif + + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum +{ + JSON_T_NONE = 0, + JSON_T_ARRAY_BEGIN, // 1 + JSON_T_ARRAY_END, // 2 + JSON_T_OBJECT_BEGIN, // 3 + JSON_T_OBJECT_END, // 4 + JSON_T_INTEGER, // 5 + JSON_T_FLOAT, // 6 + JSON_T_NULL, // 7 + JSON_T_TRUE, // 8 + JSON_T_FALSE, // 9 + JSON_T_STRING, // 10 + JSON_T_KEY, // 11 + JSON_T_MAX // 12 +} JSON_type; + +typedef struct JSON_value_struct { + union { + JSON_int_t integer_value; + + long double float_value; + + struct { + const char* value; + size_t length; + } str; + } vu; +} JSON_value; + +typedef struct JSON_parser_struct* JSON_parser; + +/*! \brief JSON parser callback + + \param ctx The pointer passed to new_JSON_parser. + \param type An element of JSON_type but not JSON_T_NONE. + \param value A representation of the parsed value. This parameter is NULL for + JSON_T_ARRAY_BEGIN, JSON_T_ARRAY_END, JSON_T_OBJECT_BEGIN, JSON_T_OBJECT_END, + JSON_T_NULL, JSON_T_TRUE, and SON_T_FALSE. String values are always returned + as zero-terminated C strings. + + \return Non-zero if parsing should continue, else zero. +*/ +typedef int (*JSON_parser_callback)(void* ctx, int type, const struct JSON_value_struct* value); + + +/*! \brief The structure used to configure a JSON parser object + + \param depth If negative, the parser can parse arbitrary levels of JSON, otherwise + the depth is the limit + \param Pointer to a callback. This parameter may be NULL. In this case the input is merely checked for validity. + \param Callback context. This parameter may be NULL. + \param depth. Specifies the levels of nested JSON to allow. Negative numbers yield unlimited nesting. + \param allowComments. To allow C style comments in JSON, set to non-zero. + \param handleFloatsManually. To decode floating point numbers manually set this parameter to non-zero. + + \return The parser object. +*/ +typedef struct { + JSON_parser_callback callback; + void* callback_ctx; + int depth; + int allow_comments; + int handle_floats_manually; +} JSON_config; + + +/*! \brief Initializes the JSON parser configuration structure to default values. + + The default configuration is + - 127 levels of nested JSON (depends on JSON_PARSER_STACK_SIZE, see json_parser.c) + - no parsing, just checking for JSON syntax + - no comments + + \param config. Used to configure the parser. +*/ +JSON_PARSER_DLL_API void init_JSON_config(JSON_config* config); + +/*! \brief Create a JSON parser object + + \param config. Used to configure the parser. Set to NULL to use the default configuration. + See init_JSON_config + + \return The parser object. +*/ +JSON_PARSER_DLL_API extern JSON_parser new_JSON_parser(JSON_config* config); + +/*! \brief Destroy a previously created JSON parser object. */ +JSON_PARSER_DLL_API extern void delete_JSON_parser(JSON_parser jc); + +/*! \brief Parse a character. + + \return Non-zero, if all characters passed to this function are part of are valid JSON. +*/ +JSON_PARSER_DLL_API extern int JSON_parser_char(JSON_parser jc, int next_char); + +/*! \brief Finalize parsing. + + Call this method once after all input characters have been consumed. + + \return Non-zero, if all parsed characters are valid JSON, zero otherwise. +*/ +JSON_PARSER_DLL_API extern int JSON_parser_done(JSON_parser jc); + +/*! \brief Determine if a given string is valid JSON white space + + \return Non-zero if the string is valid, zero otherwise. +*/ +JSON_PARSER_DLL_API extern int JSON_parser_is_legal_white_space_string(const char* s); + + +#ifdef __cplusplus +} +#endif + + +#endif /* JSON_PARSER_H */ diff --git a/src/Makefile.am b/src/Makefile.am new file mode 100644 index 00000000..2af6cab7 --- /dev/null +++ b/src/Makefile.am @@ -0,0 +1,67 @@ +bin_PROGRAMS = \ + dict_test \ + weights_test \ + trule_test \ + hg_test \ + ff_test \ + parser_test \ + grammar_test \ + cdec \ + small_vector_test + +cdec_SOURCES = cdec.cc forest_writer.cc maxtrans_blunsom.cc cdec_ff.cc ff_factory.cc timing_stats.cc +small_vector_test_SOURCES = small_vector_test.cc +small_vector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +parser_test_SOURCES = parser_test.cc +parser_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +dict_test_SOURCES = dict_test.cc +dict_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +ff_test_SOURCES = ff_test.cc +ff_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +grammar_test_SOURCES = grammar_test.cc +grammar_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +hg_test_SOURCES = hg_test.cc +hg_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +trule_test_SOURCES = trule_test.cc +trule_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +weights_test_SOURCES = weights_test.cc +weights_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a + +LDADD = libhg.a + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) +AM_LDFLAGS = $(BOOST_LDFLAGS) $(BOOST_PROGRAM_OPTIONS_LIB) -lz + +noinst_LIBRARIES = libhg.a + +libhg_a_SOURCES = \ + fst_translator.cc \ + scfg_translator.cc \ + hg.cc \ + hg_io.cc \ + viterbi.cc \ + lattice.cc \ + logging.cc \ + aligner.cc \ + gzstream.cc \ + apply_models.cc \ + earley_composer.cc \ + phrasetable_fst.cc \ + sparse_vector.cc \ + trule.cc \ + filelib.cc \ + stringlib.cc \ + fdict.cc \ + tdict.cc \ + weights.cc \ + ttables.cc \ + ff.cc \ + lm_ff.cc \ + ff_wordalign.cc \ + hg_intersect.cc \ + lexcrf.cc \ + bottom_up_parser.cc \ + phrasebased_translator.cc \ + JSON_parser.c \ + json_parse.cc \ + grammar.cc diff --git a/src/aligner.cc b/src/aligner.cc new file mode 100644 index 00000000..d9d067e5 --- /dev/null +++ b/src/aligner.cc @@ -0,0 +1,204 @@ +#include "aligner.h" + +#include "array2d.h" +#include "hg.h" +#include "inside_outside.h" +#include + +using namespace std; + +struct EdgeCoverageInfo { + set src_indices; + set trg_indices; +}; + +static bool is_digit(char x) { return x >= '0' && x <= '9'; } + +boost::shared_ptr > AlignerTools::ReadPharaohAlignmentGrid(const string& al) { + int max_x = 0; + int max_y = 0; + int i = 0; + while (i < al.size()) { + int x = 0; + while(i < al.size() && is_digit(al[i])) { + x *= 10; + x += al[i] - '0'; + ++i; + } + if (x > max_x) max_x = x; + assert(i < al.size()); + assert(al[i] == '-'); + ++i; + int y = 0; + while(i < al.size() && is_digit(al[i])) { + y *= 10; + y += al[i] - '0'; + ++i; + } + if (y > max_y) max_y = y; + while(i < al.size() && al[i] == ' ') { ++i; } + } + + boost::shared_ptr > grid(new Array2D(max_x + 1, max_y + 1)); + i = 0; + while (i < al.size()) { + int x = 0; + while(i < al.size() && is_digit(al[i])) { + x *= 10; + x += al[i] - '0'; + ++i; + } + assert(i < al.size()); + assert(al[i] == '-'); + ++i; + int y = 0; + while(i < al.size() && is_digit(al[i])) { + y *= 10; + y += al[i] - '0'; + ++i; + } + (*grid)(x, y) = true; + while(i < al.size() && al[i] == ' ') { ++i; } + } + // cerr << *grid << endl; + return grid; +} + +void AlignerTools::SerializePharaohFormat(const Array2D& alignment, ostream* out) { + bool need_space = false; + for (int i = 0; i < alignment.width(); ++i) + for (int j = 0; j < alignment.height(); ++j) + if (alignment(i,j)) { + if (need_space) (*out) << ' '; else need_space = true; + (*out) << i << '-' << j; + } + (*out) << endl; +} + +// compute the coverage vectors of each edge +// prereq: all derivations yield the same string pair +void ComputeCoverages(const Hypergraph& g, + vector* pcovs) { + for (int i = 0; i < g.edges_.size(); ++i) { + const Hypergraph::Edge& edge = g.edges_[i]; + EdgeCoverageInfo& cov = (*pcovs)[i]; + // no words + if (edge.rule_->EWords() == 0 || edge.rule_->FWords() == 0) + continue; + // aligned to NULL (crf ibm variant only) + if (edge.prev_i_ == -1 || edge.i_ == -1) + continue; + assert(edge.j_ >= 0); + assert(edge.prev_j_ >= 0); + if (edge.Arity() == 0) { + for (int k = edge.i_; k < edge.j_; ++k) + cov.trg_indices.insert(k); + for (int k = edge.prev_i_; k < edge.prev_j_; ++k) + cov.src_indices.insert(k); + } else { + // note: this code, which handles mixed NT and terminal + // rules assumes that nodes uniquely define a src and trg + // span. + int k = edge.prev_i_; + int j = 0; + const vector& f = edge.rule_->e(); // rules are inverted + while (k < edge.prev_j_) { + if (f[j] > 0) { + cov.src_indices.insert(k); + // cerr << "src: " << k << endl; + ++k; + ++j; + } else { + const Hypergraph::Node& tailnode = g.nodes_[edge.tail_nodes_[-f[j]]]; + assert(tailnode.in_edges_.size() > 0); + // any edge will do: + const Hypergraph::Edge& rep_edge = g.edges_[tailnode.in_edges_.front()]; + //cerr << "skip " << (rep_edge.prev_j_ - rep_edge.prev_i_) << endl; // src span + k += (rep_edge.prev_j_ - rep_edge.prev_i_); // src span + ++j; + } + } + int tc = 0; + const vector& e = edge.rule_->f(); // rules are inverted + k = edge.i_; + j = 0; + // cerr << edge.rule_->AsString() << endl; + // cerr << "i=" << k << " j=" << edge.j_ << endl; + while (k < edge.j_) { + //cerr << " k=" << k << endl; + if (e[j] > 0) { + cov.trg_indices.insert(k); + // cerr << "trg: " << k << endl; + ++k; + ++j; + } else { + assert(tc < edge.tail_nodes_.size()); + const Hypergraph::Node& tailnode = g.nodes_[edge.tail_nodes_[tc]]; + assert(tailnode.in_edges_.size() > 0); + // any edge will do: + const Hypergraph::Edge& rep_edge = g.edges_[tailnode.in_edges_.front()]; + // cerr << "t skip " << (rep_edge.j_ - rep_edge.i_) << endl; // src span + k += (rep_edge.j_ - rep_edge.i_); // src span + ++j; + ++tc; + } + } + //abort(); + } + } +} + +void AlignerTools::WriteAlignment(const string& input, + const Lattice& ref, + const Hypergraph& g, + bool map_instead_of_viterbi) { + if (!map_instead_of_viterbi) { + assert(!"not implemented!"); + } + vector edge_posteriors(g.edges_.size()); + { + SparseVector posts; + InsideOutside, TransitionEventWeightFunction>(g, &posts); + for (int i = 0; i < edge_posteriors.size(); ++i) + edge_posteriors[i] = posts[i]; + } + vector edge2cov(g.edges_.size()); + ComputeCoverages(g, &edge2cov); + + Lattice src; + // currently only dealing with src text, even if the + // model supports lattice translation (which it probably does) + LatticeTools::ConvertTextToLattice(input, &src); + // TODO assert that src is a "real lattice" + + Array2D align(src.size(), ref.size(), prob_t::Zero()); + for (int c = 0; c < g.edges_.size(); ++c) { + const prob_t& p = edge_posteriors[c]; + const EdgeCoverageInfo& eci = edge2cov[c]; + for (set::const_iterator si = eci.src_indices.begin(); + si != eci.src_indices.end(); ++si) { + for (set::const_iterator ti = eci.trg_indices.begin(); + ti != eci.trg_indices.end(); ++ti) { + align(*si, *ti) += p; + } + } + } + prob_t threshold(0.9); + const bool use_soft_threshold = true; // TODO configure + + Array2D grid(src.size(), ref.size(), false); + for (int j = 0; j < ref.size(); ++j) { + if (use_soft_threshold) { + threshold = prob_t::Zero(); + for (int i = 0; i < src.size(); ++i) + if (align(i, j) > threshold) threshold = align(i, j); + //threshold *= prob_t(0.99); + } + for (int i = 0; i < src.size(); ++i) + grid(i, j) = align(i, j) >= threshold; + } + cerr << align << endl; + cerr << grid << endl; + SerializePharaohFormat(grid, &cout); +}; + diff --git a/src/aligner.h b/src/aligner.h new file mode 100644 index 00000000..970c72f2 --- /dev/null +++ b/src/aligner.h @@ -0,0 +1,23 @@ +#ifndef _ALIGNER_H_ + +#include +#include +#include +#include "array2d.h" +#include "lattice.h" + +class Hypergraph; + +struct AlignerTools { + static boost::shared_ptr > ReadPharaohAlignmentGrid(const std::string& al); + static void SerializePharaohFormat(const Array2D& alignment, std::ostream* out); + + // assumption: g contains derivations of input/ref and + // ONLY input/ref. + static void WriteAlignment(const std::string& input, + const Lattice& ref, + const Hypergraph& g, + bool map_instead_of_viterbi = true); +}; + +#endif diff --git a/src/apply_models.cc b/src/apply_models.cc new file mode 100644 index 00000000..8efb331b --- /dev/null +++ b/src/apply_models.cc @@ -0,0 +1,344 @@ +#include "apply_models.h" + +#include +#include +#include +#include + +#include + +#include "hg.h" +#include "ff.h" + +using namespace std; +using namespace std::tr1; + +struct Candidate; +typedef SmallVector JVector; +typedef vector CandidateHeap; +typedef vector CandidateList; + +// life cycle: candidates are created, placed on the heap +// and retrieved by their estimated cost, when they're +// retrieved, they're incorporated into the +LM hypergraph +// where they also know the head node index they are +// attached to. After they are added to the +LM hypergraph +// vit_prob_ and est_prob_ fields may be updated as better +// derivations are found (this happens since the successor's +// of derivation d may have a better score- they are +// explored lazily). However, the updates don't happen +// when a candidate is in the heap so maintaining the heap +// property is not an issue. +struct Candidate { + int node_index_; // -1 until incorporated + // into the +LM forest + const Hypergraph::Edge* in_edge_; // in -LM forest + Hypergraph::Edge out_edge_; + string state_; + const JVector j_; + prob_t vit_prob_; // these are fixed until the cand + // is popped, then they may be updated + prob_t est_prob_; + + Candidate(const Hypergraph::Edge& e, + const JVector& j, + const Hypergraph& out_hg, + const vector& D, + const SentenceMetadata& smeta, + const ModelSet& models, + bool is_goal) : + node_index_(-1), + in_edge_(&e), + j_(j) { + InitializeCandidate(out_hg, smeta, D, models, is_goal); + } + + // used to query uniqueness + Candidate(const Hypergraph::Edge& e, + const JVector& j) : in_edge_(&e), j_(j) {} + + bool IsIncorporatedIntoHypergraph() const { + return node_index_ >= 0; + } + + void InitializeCandidate(const Hypergraph& out_hg, + const SentenceMetadata& smeta, + const vector >& D, + const ModelSet& models, + const bool is_goal) { + const Hypergraph::Edge& in_edge = *in_edge_; + out_edge_.rule_ = in_edge.rule_; + out_edge_.feature_values_ = in_edge.feature_values_; + out_edge_.i_ = in_edge.i_; + out_edge_.j_ = in_edge.j_; + out_edge_.prev_i_ = in_edge.prev_i_; + out_edge_.prev_j_ = in_edge.prev_j_; + Hypergraph::TailNodeVector& tail = out_edge_.tail_nodes_; + tail.resize(j_.size()); + prob_t p = prob_t::One(); + // cerr << "\nEstimating application of " << in_edge.rule_->AsString() << endl; + for (int i = 0; i < tail.size(); ++i) { + const Candidate& ant = *D[in_edge.tail_nodes_[i]][j_[i]]; + assert(ant.IsIncorporatedIntoHypergraph()); + tail[i] = ant.node_index_; + p *= ant.vit_prob_; + } + prob_t edge_estimate = prob_t::One(); + if (is_goal) { + assert(tail.size() == 1); + const string& ant_state = out_hg.nodes_[tail.front()].state_; + models.AddFinalFeatures(ant_state, &out_edge_); + } else { + models.AddFeaturesToEdge(smeta, out_hg, &out_edge_, &state_, &edge_estimate); + } + vit_prob_ = out_edge_.edge_prob_ * p; + est_prob_ = vit_prob_ * edge_estimate; + } +}; + +ostream& operator<<(ostream& os, const Candidate& cand) { + os << "CAND["; + if (!cand.IsIncorporatedIntoHypergraph()) { os << "PENDING "; } + else { os << "+LM_node=" << cand.node_index_; } + os << " edge=" << cand.in_edge_->id_; + os << " j=<"; + for (int i = 0; i < cand.j_.size(); ++i) + os << (i==0 ? "" : " ") << cand.j_[i]; + os << "> vit=" << log(cand.vit_prob_); + os << " est=" << log(cand.est_prob_); + return os << ']'; +} + +struct HeapCandCompare { + bool operator()(const Candidate* l, const Candidate* r) const { + return l->est_prob_ < r->est_prob_; + } +}; + +struct EstProbSorter { + bool operator()(const Candidate* l, const Candidate* r) const { + return l->est_prob_ > r->est_prob_; + } +}; + +// the same candidate can be added multiple times if +// j is multidimensional (if you're going NW in Manhattan, you +// can first go north, then west, or you can go west then north) +// this is a hash function on the relevant variables from +// Candidate to enforce this. +struct CandidateUniquenessHash { + size_t operator()(const Candidate* c) const { + size_t x = 5381; + x = ((x << 5) + x) ^ c->in_edge_->id_; + for (int i = 0; i < c->j_.size(); ++i) + x = ((x << 5) + x) ^ c->j_[i]; + return x; + } +}; + +struct CandidateUniquenessEquals { + bool operator()(const Candidate* a, const Candidate* b) const { + return (a->in_edge_ == b->in_edge_) && (a->j_ == b->j_); + } +}; + +typedef unordered_set UniqueCandidateSet; +typedef unordered_map > State2Node; + +class CubePruningRescorer { + +public: + CubePruningRescorer(const ModelSet& m, + const SentenceMetadata& sm, + const Hypergraph& i, + int pop_limit, + Hypergraph* o) : + models(m), + smeta(sm), + in(i), + out(*o), + D(in.nodes_.size()), + pop_limit_(pop_limit) { + cerr << " Rescoring forest (cube pruning, pop_limit = " << pop_limit_ << ')' << endl; + } + + void Apply() { + int num_nodes = in.nodes_.size(); + int goal_id = num_nodes - 1; + int pregoal = goal_id - 1; + int every = 1; + if (num_nodes > 100) every = 10; + assert(in.nodes_[pregoal].out_edges_.size() == 1); + cerr << " "; + for (int i = 0; i < in.nodes_.size(); ++i) { + if (i % every == 0) cerr << '.'; + KBest(i, i == goal_id); + } + cerr << endl; + cerr << " Best path: " << log(D[goal_id].front()->vit_prob_) + << "\t" << log(D[goal_id].front()->est_prob_) << endl; + out.PruneUnreachable(D[goal_id].front()->node_index_); + FreeAll(); + } + + private: + void FreeAll() { + for (int i = 0; i < D.size(); ++i) { + CandidateList& D_i = D[i]; + for (int j = 0; j < D_i.size(); ++j) + delete D_i[j]; + } + D.clear(); + } + + void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) { + Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_.rule_, item->out_edge_.tail_nodes_); + new_edge->feature_values_ = item->out_edge_.feature_values_; + new_edge->edge_prob_ = item->out_edge_.edge_prob_; + new_edge->i_ = item->out_edge_.i_; + new_edge->j_ = item->out_edge_.j_; + new_edge->prev_i_ = item->out_edge_.prev_i_; + new_edge->prev_j_ = item->out_edge_.prev_j_; + Candidate*& o_item = (*s2n)[item->state_]; + if (!o_item) o_item = item; + + int& node_id = o_item->node_index_; + if (node_id < 0) { + Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_, item->state_); + node_id = new_node->id_; + } + Hypergraph::Node* node = &out.nodes_[node_id]; + out.ConnectEdgeToHeadNode(new_edge, node); + + // update candidate if we have a better derivation + // note: the difference between the vit score and the estimated + // score is the same for all items with a common residual DP + // state + if (item->vit_prob_ > o_item->vit_prob_) { + assert(o_item->state_ == item->state_); // sanity check! + o_item->est_prob_ = item->est_prob_; + o_item->vit_prob_ = item->vit_prob_; + } + if (item != o_item) freelist->push_back(item); + } + + void KBest(const int vert_index, const bool is_goal) { + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_cands; + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, smeta, models, is_goal)); + assert(unique_cands.insert(cand.back()).second); // these should all be unique! + } +// cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + PushSucc(*item, is_goal, &cand, &unique_cands); + IncorporateIntoPlusLMForest(item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i) + D_v[c++] = i->second; + sort(D_v.begin(), D_v.end(), EstProbSorter()); + // cerr << " expanded to " << D_v.size() << " nodes\n"; + + for (int i = 0; i < cand.size(); ++i) + delete cand[i]; + // freelist is necessary since even after an item merged, it still stays in + // the unique set so it can't be deleted til now + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; + } + + void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) { + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (cs->count(&query_unique) == 0) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + assert(cs->insert(new_cand).second); // insert into uniqueness set, sanity check + } + } + } + } + + const ModelSet& models; + const SentenceMetadata& smeta; + const Hypergraph& in; + Hypergraph& out; + + vector D; // maps nodes in in-HG to the + // equivalent nodes (many due to state + // splits) in the out-HG. + const int pop_limit_; +}; + +struct NoPruningRescorer { + NoPruningRescorer(const ModelSet& m, const Hypergraph& i, Hypergraph* o) : + models(m), + in(i), + out(*o) { + cerr << " Rescoring forest (full intersection)\n"; + } + + void RescoreNode(const int node_num, const bool is_goal) { + } + + void Apply() { + int num_nodes = in.nodes_.size(); + int goal_id = num_nodes - 1; + int pregoal = goal_id - 1; + int every = 1; + if (num_nodes > 100) every = 10; + assert(in.nodes_[pregoal].out_edges_.size() == 1); + cerr << " "; + for (int i = 0; i < in.nodes_.size(); ++i) { + if (i % every == 0) cerr << '.'; + RescoreNode(i, i == goal_id); + } + cerr << endl; + } + + private: + const ModelSet& models; + const Hypergraph& in; + Hypergraph& out; +}; + +// each node in the graph has one of these, it keeps track of +void ApplyModelSet(const Hypergraph& in, + const SentenceMetadata& smeta, + const ModelSet& models, + const PruningConfiguration& config, + Hypergraph* out) { + int pl = config.pop_limit; + if (pl > 100 && in.nodes_.size() > 80000) { + cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; + pl = 30; + } + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); +} + diff --git a/src/apply_models.h b/src/apply_models.h new file mode 100644 index 00000000..08fce037 --- /dev/null +++ b/src/apply_models.h @@ -0,0 +1,20 @@ +#ifndef _APPLY_MODELS_H_ +#define _APPLY_MODELS_H_ + +struct ModelSet; +struct Hypergraph; +struct SentenceMetadata; + +struct PruningConfiguration { + const int algorithm; // 0 = full intersection, 1 = cube pruning + const int pop_limit; // max number of pops off the heap at each node + explicit PruningConfiguration(int k) : algorithm(1), pop_limit(k) {} +}; + +void ApplyModelSet(const Hypergraph& in, + const SentenceMetadata& smeta, + const ModelSet& models, + const PruningConfiguration& config, + Hypergraph* out); + +#endif diff --git a/src/array2d.h b/src/array2d.h new file mode 100644 index 00000000..09d84d0b --- /dev/null +++ b/src/array2d.h @@ -0,0 +1,171 @@ +#ifndef ARRAY2D_H_ +#define ARRAY2D_H_ + +#include +#include +#include +#include +#include + +template +class Array2D { + public: + typedef typename std::vector::reference reference; + typedef typename std::vector::const_reference const_reference; + typedef typename std::vector::iterator iterator; + typedef typename std::vector::const_iterator const_iterator; + Array2D() : width_(0), height_(0) {} + Array2D(int w, int h, const T& d = T()) : + width_(w), height_(h), data_(w*h, d) {} + Array2D(const Array2D& rhs) : + width_(rhs.width_), height_(rhs.height_), data_(rhs.data_) {} + void resize(int w, int h, const T& d = T()) { + data_.resize(w * h, d); + width_ = w; + height_ = h; + } + const Array2D& operator=(const Array2D& rhs) { + data_ = rhs.data_; + width_ = rhs.width_; + height_ = rhs.height_; + return *this; + } + void fill(const T& v) { data_.assign(data_.size(), v); } + int width() const { return width_; } + int height() const { return height_; } + reference operator()(int i, int j) { + return data_[offset(i, j)]; + } + void clear() { data_.clear(); width_=0; height_=0; } + const_reference operator()(int i, int j) const { + return data_[offset(i, j)]; + } + iterator begin_col(int j) { + return data_.begin() + offset(0,j); + } + const_iterator begin_col(int j) const { + return data_.begin() + offset(0,j); + } + iterator end_col(int j) { + return data_.begin() + offset(0,j) + width_; + } + const_iterator end_col(int j) const { + return data_.begin() + offset(0,j) + width_; + } + iterator end() { return data_.end(); } + const_iterator end() const { return data_.end(); } + const Array2D& operator*=(const T& x) { + std::transform(data_.begin(), data_.end(), data_.begin(), + std::bind2nd(std::multiplies(), x)); + } + const Array2D& operator/=(const T& x) { + std::transform(data_.begin(), data_.end(), data_.begin(), + std::bind2nd(std::divides(), x)); + } + const Array2D& operator+=(const Array2D& m) { + std::transform(m.data_.begin(), m.data_.end(), data_.begin(), data_.begin(), std::plus()); + } + const Array2D& operator-=(const Array2D& m) { + std::transform(m.data_.begin(), m.data_.end(), data_.begin(), data_.begin(), std::minus()); + } + + private: + inline int offset(int i, int j) const { + assert(i data_; +}; + +template +Array2D operator*(const Array2D& l, const T& scalar) { + Array2D res(l); + res *= scalar; + return res; +} + +template +Array2D operator*(const T& scalar, const Array2D& l) { + Array2D res(l); + res *= scalar; + return res; +} + +template +Array2D operator/(const Array2D& l, const T& scalar) { + Array2D res(l); + res /= scalar; + return res; +} + +template +Array2D operator+(const Array2D& l, const Array2D& r) { + Array2D res(l); + res += r; + return res; +} + +template +Array2D operator-(const Array2D& l, const Array2D& r) { + Array2D res(l); + res -= r; + return res; +} + +template +inline std::ostream& operator<<(std::ostream& os, const Array2D& m) { + for (int i=0; i& m) { + os << ' '; + for (int j=0; j >& m) { + os << ' '; + for (int j=0; j& ar = m(i,j); + for (int k=0; k + +#include "hg.h" +#include "array2d.h" +#include "tdict.h" + +using namespace std; + +class ActiveChart; +class PassiveChart { + public: + PassiveChart(const string& goal, + const vector& grammars, + const Lattice& input, + Hypergraph* forest); + ~PassiveChart(); + + inline const vector& operator()(int i, int j) const { return chart_(i,j); } + bool Parse(); + inline int size() const { return chart_.width(); } + inline bool GoalFound() const { return goal_idx_ >= 0; } + inline int GetGoalIndex() const { return goal_idx_; } + + private: + void ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail); + void ApplyRule(const int i, const int j, TRulePtr r, const Hypergraph::TailNodeVector& ant_nodes); + void ApplyUnaryRules(const int i, const int j); + + const vector& grammars_; + const Lattice& input_; + Hypergraph* forest_; + Array2D > chart_; // chart_(i,j) is the list of nodes derived spanning i,j + typedef map Cat2NodeMap; + Array2D nodemap_; + vector act_chart_; + const WordID goal_cat_; // category that is being searched for at [0,n] + TRulePtr goal_rule_; + int goal_idx_; // index of goal node, if found + + static WordID kGOAL; // [Goal] +}; + +WordID PassiveChart::kGOAL = 0; + +class ActiveChart { + public: + ActiveChart(const Hypergraph* hg, const PassiveChart& psv_chart) : + hg_(hg), + act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {} + + struct ActiveItem { + ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, double lcost) : + gptr_(g), ant_nodes_(a), lattice_cost(lcost) {} + explicit ActiveItem(const GrammarIter* g) : + gptr_(g), ant_nodes_(), lattice_cost() {} + + void ExtendTerminal(int symbol, double src_cost, vector* out_cell) const { + const GrammarIter* ni = gptr_->Extend(symbol); + if (ni) out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost)); + } + void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector* out_cell) const { + int symbol = hg->nodes_[node_index].cat_; + const GrammarIter* ni = gptr_->Extend(symbol); + if (!ni) return; + Hypergraph::TailNodeVector na(ant_nodes_.size() + 1); + for (int i = 0; i < ant_nodes_.size(); ++i) + na[i] = ant_nodes_[i]; + na[ant_nodes_.size()] = node_index; + out_cell->push_back(ActiveItem(ni, na, lattice_cost)); + } + + const GrammarIter* gptr_; + Hypergraph::TailNodeVector ant_nodes_; + double lattice_cost; // TODO? use SparseVector + }; + + inline const vector& operator()(int i, int j) const { return act_chart_(i,j); } + void SeedActiveChart(const Grammar& g) { + int size = act_chart_.width(); + for (int i = 0; i < size; ++i) + if (g.HasRuleForSpan(i,i)) + act_chart_(i,i).push_back(ActiveItem(g.GetRoot())); + } + + void ExtendActiveItems(int i, int k, int j) { + //cerr << " LOOK(" << i << "," << k << ") for completed items in (" << k << "," << j << ")\n"; + vector& cell = act_chart_(i,j); + const vector& icell = act_chart_(i,k); + const vector& idxs = psv_chart_(k, j); + //if (!idxs.empty()) { cerr << "FOUND IN (" << k << "," << j << ")\n"; } + for (vector::const_iterator di = icell.begin(); di != icell.end(); ++di) { + for (vector::const_iterator ni = idxs.begin(); ni != idxs.end(); ++ni) { + di->ExtendNonTerminal(hg_, *ni, &cell); + } + } + } + + void AdvanceDotsForAllItemsInCell(int i, int j, const vector >& input) { + //cerr << "ADVANCE(" << i << "," << j << ")\n"; + for (int k=i+1; k < j; ++k) + ExtendActiveItems(i, k, j); + + const vector& out_arcs = input[j-1]; + for (vector::const_iterator ai = out_arcs.begin(); + ai != out_arcs.end(); ++ai) { + const WordID& f = ai->label; + const double& c = ai->cost; + const int& len = ai->dist2next; + //VLOG(1) << "F: " << TD::Convert(f) << endl; + const vector& ec = act_chart_(i, j-1); + for (vector::const_iterator di = ec.begin(); di != ec.end(); ++di) + di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1)); + } + } + + private: + const Hypergraph* hg_; + Array2D > act_chart_; + const PassiveChart& psv_chart_; +}; + +PassiveChart::PassiveChart(const string& goal, + const vector& grammars, + const Lattice& input, + Hypergraph* forest) : + grammars_(grammars), + input_(input), + forest_(forest), + chart_(input.size()+1, input.size()+1), + nodemap_(input.size()+1, input.size()+1), + goal_cat_(TD::Convert(goal) * -1), + goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")), + goal_idx_(-1) { + act_chart_.resize(grammars_.size()); + for (int i = 0; i < grammars_.size(); ++i) + act_chart_[i] = new ActiveChart(forest, *this); + if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; + cerr << " Goal category: [" << goal << ']' << endl; +} + +void PassiveChart::ApplyRule(const int i, const int j, TRulePtr r, const Hypergraph::TailNodeVector& ant_nodes) { + Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); + new_edge->prev_i_ = r->prev_i; + new_edge->prev_j_ = r->prev_j; + new_edge->i_ = i; + new_edge->j_ = j; + new_edge->feature_values_ = r->GetFeatureValues(); + Cat2NodeMap& c2n = nodemap_(i,j); + const bool is_goal = (r->GetLHS() == kGOAL); + const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); + Hypergraph::Node* node = NULL; + if (ni == c2n.end()) { + node = forest_->AddNode(r->GetLHS(), ""); + c2n[r->GetLHS()] = node->id_; + if (is_goal) { + assert(goal_idx_ == -1); + goal_idx_ = node->id_; + } else { + chart_(i,j).push_back(node->id_); + } + } else { + node = &forest_->nodes_[ni->second]; + } + forest_->ConnectEdgeToHeadNode(new_edge, node); +} + +void PassiveChart::ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail) { + const int n = rules->GetNumRules(); + for (int k = 0; k < n; ++k) + ApplyRule(i, j, rules->GetIthRule(k), tail); +} + +void PassiveChart::ApplyUnaryRules(const int i, const int j) { + const vector& nodes = chart_(i,j); // reference is important! + for (int gi = 0; gi < grammars_.size(); ++gi) { + if (!grammars_[gi]->HasRuleForSpan(i,j)) continue; + for (int di = 0; di < nodes.size(); ++di) { + const WordID& cat = forest_->nodes_[nodes[di]].cat_; + const vector& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat); + for (int ri = 0; ri < unaries.size(); ++ri) { + // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl; + const Hypergraph::TailNodeVector ant(1, nodes[di]); + ApplyRule(i, j, unaries[ri], ant); // may update nodes + } + } + } +} + +bool PassiveChart::Parse() { + forest_->nodes_.reserve(input_.size() * input_.size() * 2); + forest_->edges_.reserve(input_.size() * input_.size() * 1000); // TODO: reservation?? + goal_idx_ = -1; + for (int gi = 0; gi < grammars_.size(); ++gi) + act_chart_[gi]->SeedActiveChart(*grammars_[gi]); + + cerr << " "; + for (int l=1; lAdvanceDotsForAllItemsInCell(i, j, input_); + + const vector& cell = (*act_chart_[gi])(i,j); + for (vector::const_iterator ai = cell.begin(); + ai != cell.end(); ++ai) { + const RuleBin* rules = (ai->gptr_->GetRules()); + if (!rules) continue; + ApplyRules(i, j, rules, ai->ant_nodes_); + } + } + } + ApplyUnaryRules(i,j); + + for (int gi = 0; gi < grammars_.size(); ++gi) { + const Grammar& g = *grammars_[gi]; + // deal with non-terminals that were just proved + if (g.HasRuleForSpan(i, j)) + act_chart_[gi]->ExtendActiveItems(i, i, j); + } + } + const vector& dh = chart_(0, input_.size()); + for (int di = 0; di < dh.size(); ++di) { + const Hypergraph::Node& node = forest_->nodes_[dh[di]]; + if (node.cat_ == goal_cat_) { + Hypergraph::TailNodeVector ant(1, node.id_); + ApplyRule(0, input_.size(), goal_rule_, ant); + } + } + } + cerr << endl; + + if (GoalFound()) + forest_->PruneUnreachable(forest_->nodes_.size() - 1); + return GoalFound(); +} + +PassiveChart::~PassiveChart() { + for (int i = 0; i < act_chart_.size(); ++i) + delete act_chart_[i]; +} + +ExhaustiveBottomUpParser::ExhaustiveBottomUpParser( + const string& goal_sym, + const vector& grammars) : + goal_sym_(goal_sym), + grammars_(grammars) {} + +bool ExhaustiveBottomUpParser::Parse(const Lattice& input, + Hypergraph* forest) const { + PassiveChart chart(goal_sym_, grammars_, input, forest); + return chart.Parse(); +} diff --git a/src/bottom_up_parser.h b/src/bottom_up_parser.h new file mode 100644 index 00000000..546bfb54 --- /dev/null +++ b/src/bottom_up_parser.h @@ -0,0 +1,27 @@ +#ifndef _BOTTOM_UP_PARSER_H_ +#define _BOTTOM_UP_PARSER_H_ + +#include +#include + +#include "lattice.h" +#include "grammar.h" + +class Hypergraph; + +class ExhaustiveBottomUpParser { + public: + ExhaustiveBottomUpParser(const std::string& goal_sym, + const std::vector& grammars); + + // returns true if goal reached spanning the full input + // forest contains the full (i.e., unpruned) parse forest + bool Parse(const Lattice& input, + Hypergraph* forest) const; + + private: + const std::string goal_sym_; + const std::vector grammars_; +}; + +#endif diff --git a/src/cdec.cc b/src/cdec.cc new file mode 100644 index 00000000..c5780cef --- /dev/null +++ b/src/cdec.cc @@ -0,0 +1,474 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "timing_stats.h" +#include "translator.h" +#include "phrasebased_translator.h" +#include "aligner.h" +#include "stringlib.h" +#include "forest_writer.h" +#include "filelib.h" +#include "sampler.h" +#include "sparse_vector.h" +#include "lexcrf.h" +#include "weights.h" +#include "tdict.h" +#include "ff.h" +#include "ff_factory.h" +#include "hg_intersect.h" +#include "apply_models.h" +#include "viterbi.h" +#include "kbest.h" +#include "inside_outside.h" +#include "exp_semiring.h" +#include "sentence_metadata.h" + +using namespace std; +using namespace std::tr1; +using boost::shared_ptr; +namespace po = boost::program_options; + +// some globals ... +boost::shared_ptr > rng; + +namespace Hack { void MaxTrans(const Hypergraph& in, int beam_size); } + +void ShowBanner() { + cerr << "cdec v1.0 (c) 2009 by Chris Dyer\n"; +} + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("formalism,f",po::value()->default_value("scfg"),"Translation formalism; values include SCFG, FST, or PB. Specify LexicalCRF for experimental unsupervised CRF word alignment") + ("input,i",po::value()->default_value("-"),"Source file") + ("grammar,g",po::value >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") + ("weights,w",po::value(),"Feature weights file") + ("feature_function,F",po::value >()->composing(), "Additional feature function(s) (-L for list)") + ("list_feature_functions,L","List available feature functions") + ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") + ("k_best,k",po::value(),"Extract the k best derivations") + ("unique_k_best,r", "Unique k-best translation list") + ("aligner,a", "Run as a word/phrase aligner (src & ref required)") + ("cubepruning_pop_limit,K",po::value()->default_value(200), "Max number of pops from the candidate heap at each node") + ("goal",po::value()->default_value("S"),"Goal symbol (SCFG & FST)") + ("scfg_extra_glue_grammar", po::value(), "Extra glue grammar file (Glue grammars apply when i=0 but have no other span restrictions)") + ("scfg_no_hiero_glue_grammar,n", "No Hiero glue grammar (nb. by default the SCFG decoder adds Hiero glue rules)") + ("scfg_default_nt,d",po::value()->default_value("X"),"Default non-terminal symbol in SCFG") + ("scfg_max_span_limit,S",po::value()->default_value(10),"Maximum non-terminal span limit (except \"glue\" grammar)") + ("show_tree_structure,T", "Show the Viterbi derivation structure") + ("show_expected_length", "Show the expected translation length under the model") + ("show_partition,z", "Compute and show the partition (inside score)") + ("extract_rules", po::value(), "Extract the rules used in translation (de-duped) to this file") + ("graphviz","Show (constrained) translation forest in GraphViz format") + ("max_translation_beam,x", po::value(), "Beam approximation to get max translation from the chart") + ("max_translation_sample,X", po::value(), "Sample the max translation from the chart") + ("pb_max_distortion,D", po::value()->default_value(4), "Phrase-based decoder: maximum distortion") + ("gradient,G","Compute d log p(e|f) / d lambda_i and write to STDOUT (src & ref required)") + ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") + ("vector_format",po::value()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") + ("combine_size,C",po::value()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") + ("forest_output,O",po::value(),"Directory to write forests to") + ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc."); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + const string cfg = (*conf)["config"].as(); + cerr << "Configuration file: " << cfg << endl; + ifstream config(cfg.c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("list_feature_functions")) { + cerr << "Available feature functions (specify with -F):\n"; + global_ff_registry->DisplayList(); + cerr << endl; + exit(1); + } + + if (conf->count("help") || conf->count("grammar") == 0) { + cerr << dcmdline_options << endl; + exit(1); + } + + const string formalism = LowercaseString((*conf)["formalism"].as()); + if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', or 'lexcrf'\n"; + cerr << dcmdline_options << endl; + exit(1); + } +} + +// TODO move out of cdec into some sampling decoder file +void SampleRecurse(const Hypergraph& hg, const vector& ss, int n, vector* out) { + const SampleSet& s = ss[n]; + int i = rng->SelectSample(s); + const Hypergraph::Edge& edge = hg.edges_[hg.nodes_[n].in_edges_[i]]; + vector > ants(edge.tail_nodes_.size()); + for (int j = 0; j < ants.size(); ++j) + SampleRecurse(hg, ss, edge.tail_nodes_[j], &ants[j]); + + vector*> pants(ants.size()); + for (int j = 0; j < ants.size(); ++j) pants[j] = &ants[j]; + edge.rule_->ESubstitute(pants, out); +} + +struct SampleSort { + bool operator()(const pair& a, const pair& b) const { + return a.first > b.first; + } +}; + +// TODO move out of cdec into some sampling decoder file +void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) { + unordered_map > m; + hg->PushWeightsToGoal(); + const int num_nodes = hg->nodes_.size(); + vector ss(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + SampleSet& s = ss[i]; + const vector& in_edges = hg->nodes_[i].in_edges_; + for (int j = 0; j < in_edges.size(); ++j) { + s.add(hg->edges_[in_edges[j]].edge_prob_); + } + } + for (int i = 0; i < samples; ++i) { + vector yield; + SampleRecurse(*hg, ss, hg->nodes_.size() - 1, &yield); + const string trans = TD::GetString(yield); + ++m[trans]; + } + vector > dist; + for (unordered_map >::iterator i = m.begin(); + i != m.end(); ++i) { + dist.push_back(make_pair(i->second, i->first)); + } + sort(dist.begin(), dist.end(), SampleSort()); + if (k) { + for (int i = 0; i < k; ++i) + cout << dist[i].first << " ||| " << dist[i].second << endl; + } else { + cout << dist[0].second << endl; + } +} + +// TODO decoder output should probably be moved to another file +void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique) { + if (unique) { + KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + } else { + KBest::KBestDerivations, ESentenceTraversal> kbest(forest, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + } +} + +struct ELengthWeightFunction { + double operator()(const Hypergraph::Edge& e) const { + return e.rule_->ELength() - e.rule_->Arity(); + } +}; + + +struct TRPHash { + size_t operator()(const TRulePtr& o) const { return reinterpret_cast(o.get()); } +}; +static void ExtractRulesDedupe(const Hypergraph& hg, ostream* os) { + static unordered_set written; + for (int i = 0; i < hg.edges_.size(); ++i) { + const TRulePtr& rule = hg.edges_[i].rule_; + if (written.insert(rule).second) { + (*os) << rule->AsString() << endl; + } + } +} + +void register_feature_functions(); + +int main(int argc, char** argv) { + global_ff_registry.reset(new FFRegistry); + register_feature_functions(); + ShowBanner(); + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const bool write_gradient = conf.count("gradient"); + const bool feature_expectations = conf.count("feature_expectations"); + if (write_gradient && feature_expectations) { + cerr << "You can only specify --gradient or --feature_expectations, not both!\n"; + exit(1); + } + const bool output_training_vector = (write_gradient || feature_expectations); + + boost::shared_ptr translator; + const string formalism = LowercaseString(conf["formalism"].as()); + if (formalism == "scfg") + translator.reset(new SCFGTranslator(conf)); + else if (formalism == "fst") + translator.reset(new FSTTranslator(conf)); + else if (formalism == "pb") + translator.reset(new PhraseBasedTranslator(conf)); + else if (formalism == "lexcrf") + translator.reset(new LexicalCRF(conf)); + else + assert(!"error"); + + vector wv; + Weights w; + if (conf.count("weights")) { + w.InitFromFile(conf["weights"].as()); + wv.resize(FD::NumFeats()); + w.InitVector(&wv); + } + + // set up additional scoring features + vector > pffs; + vector late_ffs; + if (conf.count("feature_function") > 0) { + const vector& add_ffs = conf["feature_function"].as >(); + for (int i = 0; i < add_ffs.size(); ++i) { + string ff, param; + SplitCommandAndParam(add_ffs[i], &ff, ¶m); + if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; + else cerr << " (no config parameters)\n"; + shared_ptr pff = global_ff_registry->Create(ff, param); + if (!pff) { exit(1); } + // TODO check that multiple features aren't trying to set the same fid + pffs.push_back(pff); + late_ffs.push_back(pff.get()); + } + } + ModelSet late_models(wv, late_ffs); + + const int sample_max_trans = conf.count("max_translation_sample") ? + conf["max_translation_sample"].as() : 0; + if (sample_max_trans) + rng.reset(new RandomNumberGenerator); + const bool aligner_mode = conf.count("aligner"); + const bool minimal_forests = conf.count("minimal_forests"); + const bool graphviz = conf.count("graphviz"); + const bool encode_b64 = conf["vector_format"].as() == "b64"; + const bool kbest = conf.count("k_best"); + const bool unique_kbest = conf.count("unique_k_best"); + shared_ptr extract_file; + if (conf.count("extract_rules")) + extract_file.reset(new WriteFile(conf["extract_rules"].as())); + + int combine_size = conf["combine_size"].as(); + if (combine_size < 1) combine_size = 1; + const string input = conf["input"].as(); + cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl; + ReadFile in_read(input); + istream *in = in_read.stream(); + assert(*in); + + SparseVector acc_vec; // accumulate gradient + double acc_obj = 0; // accumulate objective + int g_count = 0; // number of gradient pieces computed + int sent_id = -1; // line counter + + while(*in) { + Timer::Summarize(); + ++sent_id; + string buf; + getline(*in, buf); + if (buf.empty()) continue; + map sgml; + ProcessAndStripSGML(&buf, &sgml); + if (sgml.find("id") != sgml.end()) + sent_id = atoi(sgml["id"].c_str()); + + cerr << "\nINPUT: "; + if (buf.size() < 100) + cerr << buf << endl; + else { + size_t x = buf.rfind(" ", 100); + if (x == string::npos) x = 100; + cerr << buf.substr(0, x) << " ..." << endl; + } + cerr << " id = " << sent_id << endl; + string to_translate; + Lattice ref; + ParseTranslatorInputLattice(buf, &to_translate, &ref); + const bool has_ref = ref.size() > 0; + SentenceMetadata smeta(sent_id, ref); + const bool hadoop_counters = (write_gradient); + Hypergraph forest; // -LM forest + Timer t("Translation"); + if (!translator->Translate(to_translate, &smeta, wv, &forest)) { + cerr << " NO PARSE FOUND.\n"; + if (hadoop_counters) + cerr << "reporter:counter:UserCounters,FParseFailed,1" << endl; + cout << endl << flush; + continue; + } + cerr << " -LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " -LM forest (paths): " << forest.NumberOfPaths() << endl; + if (conf.count("show_expected_length")) { + const PRPair res = + Inside, + PRWeightFunction >(forest); + cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl; + } + if (conf.count("show_partition")) { + const prob_t z = Inside(forest); + cerr << " -LM partition log(Z): " << log(z) << endl; + } + if (extract_file) + ExtractRulesDedupe(forest, extract_file->stream()); + vector trans; + const prob_t vs = ViterbiESentence(forest, &trans); + cerr << " -LM Viterbi: " << TD::GetString(trans) << endl; + if (conf.count("show_tree_structure")) + cerr << " -LM tree: " << ViterbiETree(forest) << endl;; + cerr << " -LM Viterbi: " << log(vs) << endl; + + bool has_late_models = !late_models.empty(); + if (has_late_models) { + forest.Reweight(wv); + forest.SortInEdgesByEdgeWeights(); + Hypergraph lm_forest; + int cubepruning_pop_limit = conf["cubepruning_pop_limit"].as(); + ApplyModelSet(forest, + smeta, + late_models, + PruningConfiguration(cubepruning_pop_limit), + &lm_forest); + forest.swap(lm_forest); + forest.Reweight(wv); + trans.clear(); + ViterbiESentence(forest, &trans); + cerr << " +LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " +LM forest (paths): " << forest.NumberOfPaths() << endl; + cerr << " +LM Viterbi: " << TD::GetString(trans) << endl; + } + if (conf.count("forest_output") && !has_ref) { + ForestWriter writer(conf["forest_output"].as(), sent_id); + assert(writer.Write(forest, minimal_forests)); + } + + if (sample_max_trans) { + MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as() : 0); + } else { + if (kbest) { + DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest); + } else { + if (!graphviz && !has_ref) { + cout << TD::GetString(trans) << endl << flush; + } + } + } + + const int max_trans_beam_size = conf.count("max_translation_beam") ? + conf["max_translation_beam"].as() : 0; + if (max_trans_beam_size) { + Hack::MaxTrans(forest, max_trans_beam_size); + continue; + } + + if (graphviz && !has_ref) forest.PrintGraphviz(); + + // the following are only used if write_gradient is true! + SparseVector full_exp, ref_exp, gradient; + double log_z = 0, log_ref_z = 0; + if (write_gradient) + log_z = log( + InsideOutside, EdgeFeaturesWeightFunction>(forest, &full_exp)); + + if (has_ref) { + if (HG::Intersect(ref, &forest)) { + cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl; + forest.Reweight(wv); + cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; + if (hadoop_counters) + cerr << "reporter:counter:UserCounters,SentencePairsParsed,1" << endl; + if (conf.count("show_partition")) { + const prob_t z = Inside(forest); + cerr << " Contst. partition log(Z): " << log(z) << endl; + } + //DumpKBest(sent_id, forest, 1000); + if (conf.count("forest_output")) { + ForestWriter writer(conf["forest_output"].as(), sent_id); + assert(writer.Write(forest, minimal_forests)); + } + if (aligner_mode && !output_training_vector) + AlignerTools::WriteAlignment(to_translate, ref, forest); + if (write_gradient) { + log_ref_z = log( + InsideOutside, EdgeFeaturesWeightFunction>(forest, &ref_exp)); + if (log_z < log_ref_z) { + cerr << "DIFF. ERR! log_z < log_ref_z: " << log_z << " " << log_ref_z << endl; + exit(1); + } + //cerr << "FULL: " << full_exp << endl; + //cerr << " REF: " << ref_exp << endl; + ref_exp -= full_exp; + acc_vec += ref_exp; + acc_obj += (log_z - log_ref_z); + } + if (feature_expectations) { + acc_obj += log( + InsideOutside, EdgeFeaturesWeightFunction>(forest, &ref_exp)); + acc_vec += ref_exp; + } + + if (output_training_vector) { + ++g_count; + if (g_count % combine_size == 0) { + if (encode_b64) { + cout << "0\t"; + B64::Encode(acc_obj, acc_vec, &cout); + cout << endl << flush; + } else { + cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush; + } + acc_vec.clear(); + acc_obj = 0; + } + } + if (conf.count("graphviz")) forest.PrintGraphviz(); + } else { + cerr << " REFERENCE UNREACHABLE.\n"; + if (write_gradient) { + if (hadoop_counters) + cerr << "reporter:counter:UserCounters,EFParseFailed,1" << endl; + cout << endl << flush; + } + } + } + } + if (output_training_vector && !acc_vec.empty()) { + if (encode_b64) { + cout << "0\t"; + B64::Encode(acc_obj, acc_vec, &cout); + cout << endl << flush; + } else { + cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush; + } + } +} + diff --git a/src/cdec_ff.cc b/src/cdec_ff.cc new file mode 100644 index 00000000..86abfb4a --- /dev/null +++ b/src/cdec_ff.cc @@ -0,0 +1,18 @@ +#include + +#include "ff.h" +#include "lm_ff.h" +#include "ff_factory.h" +#include "ff_wordalign.h" + +boost::shared_ptr global_ff_registry; + +void register_feature_functions() { + global_ff_registry->Register("LanguageModel", new FFFactory); + global_ff_registry->Register("WordPenalty", new FFFactory); + global_ff_registry->Register("RelativeSentencePosition", new FFFactory); + global_ff_registry->Register("MarkovJump", new FFFactory); + global_ff_registry->Register("BlunsomSynchronousParseHack", new FFFactory); + global_ff_registry->Register("AlignerResults", new FFFactory); +}; + diff --git a/src/collapse_weights.cc b/src/collapse_weights.cc new file mode 100644 index 00000000..5e0f3f72 --- /dev/null +++ b/src/collapse_weights.cc @@ -0,0 +1,102 @@ +#include +#include +#include + +#include +#include +#include + +#include "prob.h" +#include "filelib.h" +#include "trule.h" +#include "weights.h" + +namespace po = boost::program_options; +using namespace std; + +typedef std::tr1::unordered_map, prob_t, boost::hash > > MarginalMap; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("grammar,g", po::value(), "Grammar file") + ("weights,w", po::value(), "Weights file"); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + const string cfg = (*conf)["config"].as(); + cerr << "Configuration file: " << cfg << endl; + ifstream config(cfg.c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("grammar") || !conf->count("weights")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string wfile = conf["weights"].as(); + const string gfile = conf["grammar"].as(); + Weights wm; + wm.InitFromFile(wfile); + vector w; + wm.InitVector(&w); + MarginalMap e_tots; + MarginalMap f_tots; + prob_t tot; + { + ReadFile rf(gfile); + assert(*rf.stream()); + istream& in = *rf.stream(); + cerr << "Computing marginals...\n"; + int lc = 0; + while(in) { + string line; + getline(in, line); + ++lc; + if (line.empty()) continue; + TRule tr(line, true); + if (tr.GetFeatureValues().empty()) + cerr << "Line " << lc << ": empty features - may introduce bias\n"; + prob_t prob; + prob.logeq(tr.GetFeatureValues().dot(w)); + e_tots[tr.e_] += prob; + f_tots[tr.f_] += prob; + tot += prob; + } + } + bool normalized = (fabs(log(tot)) < 0.001); + cerr << "Total: " << tot << (normalized ? " [normalized]" : " [scaled]") << endl; + ReadFile rf(gfile); + istream&in = *rf.stream(); + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + TRule tr(line, true); + const double lp = tr.GetFeatureValues().dot(w); + if (isinf(lp)) { continue; } + tr.scores_.clear(); + + cout << tr.AsString() << " ||| F_and_E=" << lp - log(tot); + if (!normalized) { + cout << ";ZF_and_E=" << lp; + } + cout << ";F_given_E=" << lp - log(e_tots[tr.e_]) + << ";E_given_F=" << lp - log(f_tots[tr.f_]) << endl; + } + return 0; +} + diff --git a/src/dict.h b/src/dict.h new file mode 100644 index 00000000..bae9debe --- /dev/null +++ b/src/dict.h @@ -0,0 +1,40 @@ +#ifndef DICT_H_ +#define DICT_H_ + +#include +#include +#include +#include +#include + +#include + +#include "wordid.h" + +class Dict { + typedef std::tr1::unordered_map > Map; + public: + Dict() : b0_("") { words_.reserve(1000); } + inline int max() const { return words_.size(); } + inline WordID Convert(const std::string& word) { + Map::iterator i = d_.find(word); + if (i == d_.end()) { + words_.push_back(word); + d_[word] = words_.size(); + return words_.size(); + } else { + return i->second; + } + } + inline const std::string& Convert(const WordID& id) const { + if (id == 0) return b0_; + assert(id <= words_.size()); + return words_[id-1]; + } + private: + const std::string b0_; + std::vector words_; + Map d_; +}; + +#endif diff --git a/src/dict_test.cc b/src/dict_test.cc new file mode 100644 index 00000000..5c5d84f0 --- /dev/null +++ b/src/dict_test.cc @@ -0,0 +1,30 @@ +#include "dict.h" + +#include +#include + +class DTest : public testing::Test { + public: + DTest() {} + protected: + virtual void SetUp() { } + virtual void TearDown() { } +}; + +TEST_F(DTest, Convert) { + Dict d; + WordID a = d.Convert("foo"); + WordID b = d.Convert("bar"); + std::string x = "foo"; + WordID c = d.Convert(x); + EXPECT_NE(a, b); + EXPECT_EQ(a, c); + EXPECT_EQ(d.Convert(a), "foo"); + EXPECT_EQ(d.Convert(b), "bar"); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + diff --git a/src/earley_composer.cc b/src/earley_composer.cc new file mode 100644 index 00000000..a59686e0 --- /dev/null +++ b/src/earley_composer.cc @@ -0,0 +1,726 @@ +#include "earley_composer.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "phrasetable_fst.h" +#include "sparse_vector.h" +#include "tdict.h" +#include "hg.h" + +using boost::shared_ptr; +namespace po = boost::program_options; +using namespace std; +using namespace std::tr1; + +// Define the following macro if you want to see lots of debugging output +// when you run the chart parser +#undef DEBUG_CHART_PARSER + +// A few constants used by the chart parser /////////////// +static const int kMAX_NODES = 2000000; +static const string kPHRASE_STRING = "X"; +static bool constants_need_init = true; +static WordID kUNIQUE_START; +static WordID kPHRASE; +static TRulePtr kX1X2; +static TRulePtr kX1; +static WordID kEPS; +static TRulePtr kEPSRule; + +static void InitializeConstants() { + if (constants_need_init) { + kPHRASE = TD::Convert(kPHRASE_STRING) * -1; + kUNIQUE_START = TD::Convert("S") * -1; + kX1X2.reset(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]")); + kX1.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); + kEPSRule.reset(new TRule("[X] ||| ||| ")); + kEPS = TD::Convert(""); + constants_need_init = false; + } +} +//////////////////////////////////////////////////////////// + +class EGrammarNode { + friend bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest); + friend void AddGrammarRule(const string& r, map* g); + public: +#ifdef DEBUG_CHART_PARSER + string hint; +#endif + EGrammarNode() : is_some_rule_complete(false), is_root(false) {} + const map& GetTerminals() const { return tptr; } + const map& GetNonTerminals() const { return ntptr; } + bool HasNonTerminals() const { return (!ntptr.empty()); } + bool HasTerminals() const { return (!tptr.empty()); } + bool RuleCompletes() const { + return (is_some_rule_complete || (ntptr.empty() && tptr.empty())); + } + bool GrammarContinues() const { + return !(ntptr.empty() && tptr.empty()); + } + bool IsRoot() const { + return is_root; + } + // these are the features associated with the rule from the start + // node up to this point. If you use these features, you must + // not Extend() this rule. + const SparseVector& GetCFGProductionFeatures() const { + return input_features; + } + + const EGrammarNode* Extend(const WordID& t) const { + if (t < 0) { + map::const_iterator it = ntptr.find(t); + if (it == ntptr.end()) return NULL; + return &it->second; + } else { + map::const_iterator it = tptr.find(t); + if (it == tptr.end()) return NULL; + return &it->second; + } + } + + private: + map tptr; + map ntptr; + SparseVector input_features; + bool is_some_rule_complete; + bool is_root; +}; +typedef map EGrammar; // indexed by the rule LHS + +// edges are immutable once created +struct Edge { +#ifdef DEBUG_CHART_PARSER + static int id_count; + const int id; +#endif + const WordID cat; // lhs side of rule proved/being proved + const EGrammarNode* const dot; // dot position + const FSTNode* const q; // start of span + const FSTNode* const r; // end of span + const Edge* const active_parent; // back pointer, NULL for PREDICT items + const Edge* const passive_parent; // back pointer, NULL for SCAN and PREDICT items + const TargetPhraseSet* const tps; // translations + shared_ptr > features; // features from CFG rule + + bool IsPassive() const { + // when a rule is completed, this value will be set + return static_cast(features); + } + bool IsActive() const { return !IsPassive(); } + bool IsInitial() const { + return !(active_parent || passive_parent); + } + bool IsCreatedByScan() const { + return active_parent && !passive_parent && !dot->IsRoot(); + } + bool IsCreatedByPredict() const { + return dot->IsRoot(); + } + bool IsCreatedByComplete() const { + return active_parent && passive_parent; + } + + // constructor for PREDICT + Edge(WordID c, const EGrammarNode* d, const FSTNode* q_and_r) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(q_and_r), r(q_and_r), active_parent(NULL), passive_parent(NULL), tps(NULL) {} + Edge(WordID c, const EGrammarNode* d, const FSTNode* q_and_r, const Edge* act_parent) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(q_and_r), r(q_and_r), active_parent(act_parent), passive_parent(NULL), tps(NULL) {} + + // constructors for SCAN + Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j, + const Edge* act_par, const TargetPhraseSet* translations) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(NULL), tps(translations) {} + + Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j, + const Edge* act_par, const TargetPhraseSet* translations, + const SparseVector& feats) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(NULL), tps(translations), + features(new SparseVector(feats)) {} + + // constructors for COMPLETE + Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j, + const Edge* act_par, const Edge *pas_par) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(pas_par), tps(NULL) { + assert(pas_par->IsPassive()); + assert(act_par->IsActive()); + } + + Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j, + const Edge* act_par, const Edge *pas_par, const SparseVector& feats) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(pas_par), tps(NULL), + features(new SparseVector(feats)) { + assert(pas_par->IsPassive()); + assert(act_par->IsActive()); + } + + // constructor for COMPLETE query + Edge(const FSTNode* _r) : +#ifdef DEBUG_CHART_PARSER + id(0), +#endif + cat(0), dot(NULL), q(NULL), + r(_r), active_parent(NULL), passive_parent(NULL), tps(NULL) {} + // constructor for MERGE quere + Edge(const FSTNode* _q, int) : +#ifdef DEBUG_CHART_PARSER + id(0), +#endif + cat(0), dot(NULL), q(_q), + r(NULL), active_parent(NULL), passive_parent(NULL), tps(NULL) {} +}; +#ifdef DEBUG_CHART_PARSER +int Edge::id_count = 0; +#endif + +ostream& operator<<(ostream& os, const Edge& e) { + string type = "PREDICT"; + if (e.IsCreatedByScan()) + type = "SCAN"; + else if (e.IsCreatedByComplete()) + type = "COMPLETE"; + os << "[" +#ifdef DEBUG_CHART_PARSER + << '(' << e.id << ") " +#else + << '(' << &e << ") " +#endif + << "q=" << e.q << ", r=" << e.r + << ", cat="<< TD::Convert(e.cat*-1) << ", dot=" + << e.dot +#ifdef DEBUG_CHART_PARSER + << e.dot->hint +#endif + << (e.IsActive() ? ", Active" : ", Passive") + << ", " << type; +#ifdef DEBUG_CHART_PARSER + if (e.active_parent) { os << ", act.parent=(" << e.active_parent->id << ')'; } + if (e.passive_parent) { os << ", psv.parent=(" << e.passive_parent->id << ')'; } +#endif + if (e.tps) { os << ", tps=" << e.tps; } + return os << ']'; +} + +struct Traversal { + const Edge* const edge; // result from the active / passive combination + const Edge* const active; + const Edge* const passive; + Traversal(const Edge* me, const Edge* a, const Edge* p) : edge(me), active(a), passive(p) {} +}; + +struct UniqueTraversalHash { + size_t operator()(const Traversal* t) const { + size_t x = 5381; + x = ((x << 5) + x) ^ reinterpret_cast(t->active); + x = ((x << 5) + x) ^ reinterpret_cast(t->passive); + x = ((x << 5) + x) ^ t->edge->IsActive(); + return x; + } +}; + +struct UniqueTraversalEquals { + size_t operator()(const Traversal* a, const Traversal* b) const { + return (a->passive == b->passive && a->active == b->active && a->edge->IsActive() == b->edge->IsActive()); + } +}; + +struct UniqueEdgeHash { + size_t operator()(const Edge* e) const { + size_t x = 5381; + if (e->IsActive()) { + x = ((x << 5) + x) ^ reinterpret_cast(e->dot); + x = ((x << 5) + x) ^ reinterpret_cast(e->q); + x = ((x << 5) + x) ^ reinterpret_cast(e->r); + x = ((x << 5) + x) ^ static_cast(e->cat); + x += 13; + } else { // with passive edges, we don't care about the dot + x = ((x << 5) + x) ^ reinterpret_cast(e->q); + x = ((x << 5) + x) ^ reinterpret_cast(e->r); + x = ((x << 5) + x) ^ static_cast(e->cat); + } + return x; + } +}; + +struct UniqueEdgeEquals { + bool operator()(const Edge* a, const Edge* b) const { + if (a->IsActive() != b->IsActive()) return false; + if (a->IsActive()) { + return (a->cat == b->cat) && (a->dot == b->dot) && (a->q == b->q) && (a->r == b->r); + } else { + return (a->cat == b->cat) && (a->q == b->q) && (a->r == b->r); + } + } +}; + +struct REdgeHash { + size_t operator()(const Edge* e) const { + size_t x = 5381; + x = ((x << 5) + x) ^ reinterpret_cast(e->r); + return x; + } +}; + +struct REdgeEquals { + bool operator()(const Edge* a, const Edge* b) const { + return (a->r == b->r); + } +}; + +struct QEdgeHash { + size_t operator()(const Edge* e) const { + size_t x = 5381; + x = ((x << 5) + x) ^ reinterpret_cast(e->q); + return x; + } +}; + +struct QEdgeEquals { + bool operator()(const Edge* a, const Edge* b) const { + return (a->q == b->q); + } +}; + +struct EdgeQueue { + queue q; + EdgeQueue() {} + void clear() { while(!q.empty()) q.pop(); } + bool HasWork() const { return !q.empty(); } + const Edge* Next() { const Edge* res = q.front(); q.pop(); return res; } + void AddEdge(const Edge* s) { q.push(s); } +}; + +class EarleyComposerImpl { + public: + EarleyComposerImpl(WordID start_cat, const FSTNode& q_0) : start_cat_(start_cat), q_0_(&q_0) {} + + // returns false if the intersection is empty + bool Compose(const EGrammar& g, Hypergraph* forest) { + goal_node = NULL; + EGrammar::const_iterator sit = g.find(start_cat_); + forest->ReserveNodes(kMAX_NODES); + assert(sit != g.end()); + Edge* init = new Edge(start_cat_, &sit->second, q_0_); + assert(IncorporateNewEdge(init)); + while (exp_agenda.HasWork() || agenda.HasWork()) { + while(exp_agenda.HasWork()) { + const Edge* edge = exp_agenda.Next(); + FinishEdge(edge, forest); + } + if (agenda.HasWork()) { + const Edge* edge = agenda.Next(); +#ifdef DEBUG_CHART_PARSER + cerr << "processing (" << edge->id << ')' << endl; +#endif + if (edge->IsActive()) { + if (edge->dot->HasTerminals()) + DoScan(edge); + if (edge->dot->HasNonTerminals()) { + DoMergeWithPassives(edge); + DoPredict(edge, g); + } + } else { + DoComplete(edge); + } + } + } + if (goal_node) { + forest->PruneUnreachable(goal_node->id_); + forest->EpsilonRemove(kEPS); + } + FreeAll(); + return goal_node; + } + + void FreeAll() { + for (int i = 0; i < free_list_.size(); ++i) + delete free_list_[i]; + free_list_.clear(); + for (int i = 0; i < traversal_free_list_.size(); ++i) + delete traversal_free_list_[i]; + traversal_free_list_.clear(); + all_traversals.clear(); + exp_agenda.clear(); + agenda.clear(); + tps2node.clear(); + edge2node.clear(); + all_edges.clear(); + passive_edges.clear(); + active_edges.clear(); + } + + ~EarleyComposerImpl() { + FreeAll(); + } + + // returns the total number of edges created during composition + int EdgesCreated() const { + return free_list_.size(); + } + + private: + void DoScan(const Edge* edge) { + // here, we assume that the FST will potentially have many more outgoing + // edges than the grammar, which will be just a couple. If you want to + // efficiently handle the case where both are relatively large, this code + // will need to change how the intersection is done. The best general + // solution would probably be the Baeza-Yates double binary search. + + const EGrammarNode* dot = edge->dot; + const FSTNode* r = edge->r; + const map& terms = dot->GetTerminals(); + for (map::const_iterator git = terms.begin(); + git != terms.end(); ++git) { + const FSTNode* next_r = r->Extend(git->first); + if (!next_r) continue; + const EGrammarNode* next_dot = &git->second; + const bool grammar_continues = next_dot->GrammarContinues(); + const bool rule_completes = next_dot->RuleCompletes(); + assert(grammar_continues || rule_completes); + const SparseVector& input_features = next_dot->GetCFGProductionFeatures(); + // create up to 4 new edges! + if (next_r->HasOutgoingNonEpsilonEdges()) { // are there further symbols in the FST? + const TargetPhraseSet* translations = NULL; + if (rule_completes) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, next_r, edge, translations, input_features)); + if (grammar_continues) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, next_r, edge, translations)); + } + if (next_r->HasData()) { // indicates a loop back to q_0 in the FST + const TargetPhraseSet* translations = next_r->GetTranslations(); + if (rule_completes) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, q_0_, edge, translations, input_features)); + if (grammar_continues) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, q_0_, edge, translations)); + } + } + } + + void DoPredict(const Edge* edge, const EGrammar& g) { + const EGrammarNode* dot = edge->dot; + const map& non_terms = dot->GetNonTerminals(); + for (map::const_iterator git = non_terms.begin(); + git != non_terms.end(); ++git) { + const WordID nt_to_predict = git->first; + //cerr << edge->id << " -- " << TD::Convert(nt_to_predict*-1) << endl; + EGrammar::const_iterator egi = g.find(nt_to_predict); + if (egi == g.end()) { + cerr << "[ERROR] Can't find any grammar rules with a LHS of type " + << TD::Convert(-1*nt_to_predict) << '!' << endl; + continue; + } + assert(edge->IsActive()); + const EGrammarNode* new_dot = &egi->second; + Edge* new_edge = new Edge(nt_to_predict, new_dot, edge->r, edge); + IncorporateNewEdge(new_edge); + } + } + + void DoComplete(const Edge* passive) { +#ifdef DEBUG_CHART_PARSER + cerr << " complete: " << *passive << endl; +#endif + const WordID completed_nt = passive->cat; + const FSTNode* q = passive->q; + const FSTNode* next_r = passive->r; + const Edge query(q); + const pair::iterator, + unordered_multiset::iterator > p = + active_edges.equal_range(&query); + for (unordered_multiset::iterator it = p.first; + it != p.second; ++it) { + const Edge* active = *it; +#ifdef DEBUG_CHART_PARSER + cerr << " pos: " << *active << endl; +#endif + const EGrammarNode* next_dot = active->dot->Extend(completed_nt); + if (!next_dot) continue; + const SparseVector& input_features = next_dot->GetCFGProductionFeatures(); + // add up to 2 rules + if (next_dot->RuleCompletes()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive, input_features)); + if (next_dot->GrammarContinues()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive)); + } + } + + void DoMergeWithPassives(const Edge* active) { + // edge is active, has non-terminals, we need to find the passives that can extend it + assert(active->IsActive()); + assert(active->dot->HasNonTerminals()); +#ifdef DEBUG_CHART_PARSER + cerr << " merge active with passives: ACT=" << *active << endl; +#endif + const Edge query(active->r, 1); + const pair::iterator, + unordered_multiset::iterator > p = + passive_edges.equal_range(&query); + for (unordered_multiset::iterator it = p.first; + it != p.second; ++it) { + const Edge* passive = *it; + const EGrammarNode* next_dot = active->dot->Extend(passive->cat); + if (!next_dot) continue; + const FSTNode* next_r = passive->r; + const SparseVector& input_features = next_dot->GetCFGProductionFeatures(); + if (next_dot->RuleCompletes()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive, input_features)); + if (next_dot->GrammarContinues()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive)); + } + } + + // take ownership of edge memory, add to various indexes, etc + // returns true if this edge is new + bool IncorporateNewEdge(Edge* edge) { + free_list_.push_back(edge); + if (edge->passive_parent && edge->active_parent) { + Traversal* t = new Traversal(edge, edge->active_parent, edge->passive_parent); + traversal_free_list_.push_back(t); + if (all_traversals.find(t) != all_traversals.end()) { + return false; + } else { + all_traversals.insert(t); + } + } + exp_agenda.AddEdge(edge); + return true; + } + + bool FinishEdge(const Edge* edge, Hypergraph* hg) { + bool is_new = false; + if (all_edges.find(edge) == all_edges.end()) { +#ifdef DEBUG_CHART_PARSER + cerr << *edge << " is NEW\n"; +#endif + all_edges.insert(edge); + is_new = true; + if (edge->IsPassive()) passive_edges.insert(edge); + if (edge->IsActive()) active_edges.insert(edge); + agenda.AddEdge(edge); + } else { +#ifdef DEBUG_CHART_PARSER + cerr << *edge << " is NOT NEW.\n"; +#endif + } + AddEdgeToTranslationForest(edge, hg); + return is_new; + } + + // build the translation forest + void AddEdgeToTranslationForest(const Edge* edge, Hypergraph* hg) { + assert(hg->nodes_.size() < kMAX_NODES); + Hypergraph::Node* tps = NULL; + // first add any target language rules + if (edge->tps) { + Hypergraph::Node*& node = tps2node[(size_t)edge->tps]; + if (!node) { + // cerr << "Creating phrases for " << edge->tps << endl; + const vector& rules = edge->tps->GetRules(); + node = hg->AddNode(kPHRASE, ""); + for (int i = 0; i < rules.size(); ++i) { + Hypergraph::Edge* hg_edge = hg->AddEdge(rules[i], Hypergraph::TailNodeVector()); + hg_edge->feature_values_ += rules[i]->GetFeatureValues(); + hg->ConnectEdgeToHeadNode(hg_edge, node); + } + } + tps = node; + } + Hypergraph::Node*& head_node = edge2node[edge]; + if (!head_node) + head_node = hg->AddNode(kPHRASE, ""); + if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_0_ && edge->IsPassive()) { + assert(goal_node == NULL || goal_node == head_node); + goal_node = head_node; + } + Hypergraph::TailNodeVector tail; + SparseVector extra; + if (edge->IsCreatedByPredict()) { + // extra.set_value(FD::Convert("predict"), 1); + } else if (edge->IsCreatedByScan()) { + tail.push_back(edge2node[edge->active_parent]->id_); + if (tps) { + tail.push_back(tps->id_); + } + //extra.set_value(FD::Convert("scan"), 1); + } else if (edge->IsCreatedByComplete()) { + tail.push_back(edge2node[edge->active_parent]->id_); + tail.push_back(edge2node[edge->passive_parent]->id_); + //extra.set_value(FD::Convert("complete"), 1); + } else { + assert(!"unexpected edge type!"); + } + //cerr << head_node->id_ << "<--" << *edge << endl; + +#ifdef DEBUG_CHART_PARSER + for (int i = 0; i < tail.size(); ++i) + if (tail[i] == head_node->id_) { + cerr << "ERROR: " << *edge << "\n i=" << i << endl; + if (i == 1) { cerr << "\tP: " << *edge->passive_parent << endl; } + if (i == 0) { cerr << "\tA: " << *edge->active_parent << endl; } + assert(!"self-loop found!"); + } +#endif + Hypergraph::Edge* hg_edge = NULL; + if (tail.size() == 0) { + hg_edge = hg->AddEdge(kEPSRule, tail); + } else if (tail.size() == 1) { + hg_edge = hg->AddEdge(kX1, tail); + } else if (tail.size() == 2) { + hg_edge = hg->AddEdge(kX1X2, tail); + } + if (edge->features) + hg_edge->feature_values_ += *edge->features; + hg_edge->feature_values_ += extra; + hg->ConnectEdgeToHeadNode(hg_edge, head_node); + } + + Hypergraph::Node* goal_node; + EdgeQueue exp_agenda; + EdgeQueue agenda; + unordered_map tps2node; + unordered_map edge2node; + unordered_set all_traversals; + unordered_set all_edges; + unordered_multiset passive_edges; + unordered_multiset active_edges; + vector free_list_; + vector traversal_free_list_; + const WordID start_cat_; + const FSTNode* const q_0_; +}; + +#ifdef DEBUG_CHART_PARSER +static string TrimRule(const string& r) { + size_t start = r.find(" |||") + 5; + size_t end = r.rfind(" |||"); + return r.substr(start, end - start); +} +#endif + +void AddGrammarRule(const string& r, EGrammar* g) { + const size_t pos = r.find(" ||| "); + if (pos == string::npos || r[0] != '[') { + cerr << "Bad rule: " << r << endl; + return; + } + const size_t rpos = r.rfind(" ||| "); + string feats; + string rs = r; + if (rpos != pos) { + feats = r.substr(rpos + 5); + rs = r.substr(0, rpos); + } + string rhs = rs.substr(pos + 5); + string trule = rs + " ||| " + rhs + " ||| " + feats; + TRule tr(trule); +#ifdef DEBUG_CHART_PARSER + string hint_last_rule; +#endif + EGrammarNode* cur = &(*g)[tr.GetLHS()]; + cur->is_root = true; + for (int i = 0; i < tr.FLength(); ++i) { + WordID sym = tr.f()[i]; +#ifdef DEBUG_CHART_PARSER + hint_last_rule = TD::Convert(sym < 0 ? -sym : sym); + cur->hint += " <@@> (*" + hint_last_rule + ") " + TrimRule(tr.AsString()); +#endif + if (sym < 0) + cur = &cur->ntptr[sym]; + else + cur = &cur->tptr[sym]; + } +#ifdef DEBUG_CHART_PARSER + cur->hint += " <@@> (" + hint_last_rule + "*) " + TrimRule(tr.AsString()); +#endif + cur->is_some_rule_complete = true; + cur->input_features = tr.GetFeatureValues(); +} + +EarleyComposer::~EarleyComposer() { + delete pimpl_; +} + +EarleyComposer::EarleyComposer(const FSTNode* fst) { + InitializeConstants(); + pimpl_ = new EarleyComposerImpl(kUNIQUE_START, *fst); +} + +bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest) { + // first, convert the src forest into an EGrammar + EGrammar g; + const int nedges = src_forest.edges_.size(); + const int nnodes = src_forest.nodes_.size(); + vector cats(nnodes); + bool assign_cats = false; + for (int i = 0; i < nnodes; ++i) + if (assign_cats) { + cats[i] = TD::Convert("CAT_" + boost::lexical_cast(i)) * -1; + } else { + cats[i] = src_forest.nodes_[i].cat_; + } + // construct the grammar + for (int i = 0; i < nedges; ++i) { + const Hypergraph::Edge& edge = src_forest.edges_[i]; + const vector& src = edge.rule_->f(); + EGrammarNode* cur = &g[cats[edge.head_node_]]; + cur->is_root = true; + int ntc = 0; + for (int j = 0; j < src.size(); ++j) { + WordID sym = src[j]; + if (sym <= 0) { + sym = cats[edge.tail_nodes_[ntc]]; + ++ntc; + cur = &cur->ntptr[sym]; + } else { + cur = &cur->tptr[sym]; + } + } + cur->is_some_rule_complete = true; + cur->input_features = edge.feature_values_; + } + EGrammarNode& goal_rule = g[kUNIQUE_START]; + assert((goal_rule.ntptr.size() == 1 && goal_rule.tptr.size() == 0) || + (goal_rule.ntptr.size() == 0 && goal_rule.tptr.size() == 1)); + + return pimpl_->Compose(g, trg_forest); +} + +bool EarleyComposer::Compose(istream* in, Hypergraph* trg_forest) { + EGrammar g; + while(*in) { + string line; + getline(*in, line); + if (line.empty()) continue; + AddGrammarRule(line, &g); + } + + return pimpl_->Compose(g, trg_forest); +} diff --git a/src/earley_composer.h b/src/earley_composer.h new file mode 100644 index 00000000..9f786bf6 --- /dev/null +++ b/src/earley_composer.h @@ -0,0 +1,29 @@ +#ifndef _EARLEY_COMPOSER_H_ +#define _EARLEY_COMPOSER_H_ + +#include + +class EarleyComposerImpl; +class FSTNode; +class Hypergraph; + +class EarleyComposer { + public: + ~EarleyComposer(); + EarleyComposer(const FSTNode* phrasetable_root); + bool Compose(const Hypergraph& src_forest, Hypergraph* trg_forest); + + // reads the grammar from a file. There must be a single top-level + // S -> X rule. Anything else is possible. Format is: + // [S] ||| [SS,1] + // [SS] ||| [NP,1] [VP,2] ||| Feature1=0.2 Feature2=-2.3 + // [SS] ||| [VP,1] [NP,2] ||| Feature1=0.8 + // [NP] ||| [DET,1] [N,2] ||| Feature3=2 + // ... + bool Compose(std::istream* grammar_file, Hypergraph* trg_forest); + + private: + EarleyComposerImpl* pimpl_; +}; + +#endif diff --git a/src/exp_semiring.h b/src/exp_semiring.h new file mode 100644 index 00000000..f91beee4 --- /dev/null +++ b/src/exp_semiring.h @@ -0,0 +1,71 @@ +#ifndef _EXP_SEMIRING_H_ +#define _EXP_SEMIRING_H_ + +#include + +// this file implements the first-order expectation semiring described +// in Li & Eisner (EMNLP 2009) + +// requirements: +// RType * RType ==> RType +// PType * PType ==> PType +// RType * PType ==> RType +// good examples: +// PType scalar, RType vector +// BAD examples: +// PType vector, RType scalar +template +struct PRPair { + PRPair() : p(), r() {} + // Inside algorithm requires that T(0) and T(1) + // return the 0 and 1 values of the semiring + explicit PRPair(double x) : p(x), r() {} + PRPair(const PType& p, const RType& r) : p(p), r(r) {} + PRPair& operator+=(const PRPair& o) { + p += o.p; + r += o.r; + return *this; + } + PRPair& operator*=(const PRPair& o) { + r = (o.r * p) + (o.p * r); + p *= o.p; + return *this; + } + PType p; + RType r; +}; + +template +std::ostream& operator<<(std::ostream& o, const PRPair& x) { + return o << '<' << x.p << ", " << x.r << '>'; +} + +template +const PRPair operator+(const PRPair& a, const PRPair& b) { + PRPair result = a; + result += b; + return result; +} + +template +const PRPair operator*(const PRPair& a, const PRPair& b) { + PRPair result = a; + result *= b; + return result; +} + +template +struct PRWeightFunction { + explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(), + const RWeightFunction& rwf = RWeightFunction()) : + pweight(pwf), rweight(rwf) {} + PRPair operator()(const Hypergraph::Edge& e) const { + const P p = pweight(e); + const R r = rweight(e); + return PRPair(p, r * p); + } + const PWeightFunction pweight; + const RWeightFunction rweight; +}; + +#endif diff --git a/src/fdict.cc b/src/fdict.cc new file mode 100644 index 00000000..83aa7cea --- /dev/null +++ b/src/fdict.cc @@ -0,0 +1,4 @@ +#include "fdict.h" + +Dict FD::dict_; + diff --git a/src/fdict.h b/src/fdict.h new file mode 100644 index 00000000..ff491cfb --- /dev/null +++ b/src/fdict.h @@ -0,0 +1,21 @@ +#ifndef _FDICT_H_ +#define _FDICT_H_ + +#include +#include +#include "dict.h" + +struct FD { + static Dict dict_; + static inline int NumFeats() { + return dict_.max() + 1; + } + static inline WordID Convert(const std::string& s) { + return dict_.Convert(s); + } + static inline const std::string& Convert(const WordID& w) { + return dict_.Convert(w); + } +}; + +#endif diff --git a/src/ff.cc b/src/ff.cc new file mode 100644 index 00000000..488e6468 --- /dev/null +++ b/src/ff.cc @@ -0,0 +1,93 @@ +#include "ff.h" + +#include "tdict.h" +#include "hg.h" + +using namespace std; + +FeatureFunction::~FeatureFunction() {} + + +void FeatureFunction::FinalTraversalFeatures(const void* ant_state, + SparseVector* features) const { + (void) ant_state; + (void) features; +} + +// Hiero and Joshua use log_10(e) as the value, so I do to +WordPenalty::WordPenalty(const string& param) : + fid_(FD::Convert("WordPenalty")), + value_(-1.0 / log(10)) { + if (!param.empty()) { + cerr << "Warning WordPenalty ignoring parameter: " << param << endl; + } +} + +void WordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + (void) smeta; + (void) ant_states; + (void) state; + (void) estimated_features; + features->set_value(fid_, edge.rule_->EWords() * value_); +} + +ModelSet::ModelSet(const vector& w, const vector& models) : + models_(models), + weights_(w), + state_size_(0), + model_state_pos_(models.size()) { + for (int i = 0; i < models_.size(); ++i) { + model_state_pos_[i] = state_size_; + state_size_ += models_[i]->NumBytesContext(); + } +} + +void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, + const Hypergraph& hg, + Hypergraph::Edge* edge, + string* context, + prob_t* combination_cost_estimate) const { + context->resize(state_size_); + memset(&(*context)[0], 0, state_size_); + SparseVector est_vals; // only computed if combination_cost_estimate is non-NULL + if (combination_cost_estimate) *combination_cost_estimate = prob_t::One(); + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + void* cur_ff_context = NULL; + vector ants(edge->tail_nodes_.size()); + bool has_context = ff.NumBytesContext() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + cur_ff_context = &(*context)[spos]; + for (int i = 0; i < ants.size(); ++i) { + ants[i] = &hg.nodes_[edge->tail_nodes_[i]].state_[spos]; + } + } + ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context); + } + if (combination_cost_estimate) + combination_cost_estimate->logeq(est_vals.dot(weights_)); + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + +void ModelSet::AddFinalFeatures(const std::string& state, Hypergraph::Edge* edge) const { + assert(1 == edge->rule_->Arity()); + + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + const void* ant_state = NULL; + bool has_context = ff.NumBytesContext() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + ant_state = &state[spos]; + } + ff.FinalTraversalFeatures(ant_state, &edge->feature_values_); + } + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + diff --git a/src/ff.h b/src/ff.h new file mode 100644 index 00000000..c97e2fe2 --- /dev/null +++ b/src/ff.h @@ -0,0 +1,121 @@ +#ifndef _FF_H_ +#define _FF_H_ + +#include + +#include "fdict.h" +#include "hg.h" + +class SentenceMetadata; +class FeatureFunction; // see definition below + +// if you want to develop a new feature, inherit from this class and +// override TraversalFeaturesImpl(...). If it's a feature that returns / +// depends on context, you may also need to implement +// FinalTraversalFeatures(...) +class FeatureFunction { + public: + FeatureFunction() : state_size_() {} + explicit FeatureFunction(int state_size) : state_size_(state_size) {} + virtual ~FeatureFunction(); + + // returns the number of bytes of context that this feature function will + // (maximally) use. By default, 0 ("stateless" models in Hiero/Joshua). + // NOTE: this value is fixed for the instance of your class, you cannot + // use different amounts of memory for different nodes in the forest. + inline int NumBytesContext() const { return state_size_; } + + // Compute the feature values and (if this applies) the estimates of the + // feature values when this edge is used incorporated into a larger context + inline void TraversalFeatures(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_state) const { + TraversalFeaturesImpl(smeta, edge, ant_contexts, + features, estimated_features, out_state); + // TODO it's easy for careless feature function developers to overwrite + // the end of their state and clobber someone else's memory. These bugs + // will be horrendously painful to track down. There should be some + // optional strict mode that's enforced here that adds some kind of + // barrier between the blocks reserved for the residual contexts + } + + // if there's some state left when you transition to the goal state, score + // it here. For example, the language model computes the cost of adding + // and . + virtual void FinalTraversalFeatures(const void* residual_state, + SparseVector* final_features) const; + + protected: + // context is a pointer to a buffer of size NumBytesContext() that the + // feature function can write its state to. It's up to the feature function + // to determine how much space it needs and to determine how to encode its + // residual contextual information since it is OPAQUE to all clients outside + // of the particular FeatureFunction class. There is one exception: + // equality of the contents (i.e., memcmp) is required to determine whether + // two states can be combined. + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const = 0; + + // !!! ONLY call this from subclass *CONSTRUCTORS* !!! + void SetStateSize(size_t state_size) { + state_size_ = state_size; + } + + private: + int state_size_; +}; + +// word penalty feature, for each word on the E side of a rule, +// add value_ +class WordPenalty : public FeatureFunction { + public: + WordPenalty(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + const int fid_; + const double value_; +}; + +// this class is a set of FeatureFunctions that can be used to score, rescore, +// etc. a (translation?) forest +class ModelSet { + public: + ModelSet() : state_size_(0) {} + + ModelSet(const std::vector& weights, + const std::vector& models); + + // sets edge->feature_values_ and edge->edge_prob_ + // NOTE: edge must not necessarily be in hg.edges_ but its TAIL nodes + // must be. + void AddFeaturesToEdge(const SentenceMetadata& smeta, + const Hypergraph& hg, + Hypergraph::Edge* edge, + std::string* residual_context, + prob_t* combination_cost_estimate = NULL) const; + + void AddFinalFeatures(const std::string& residual_context, + Hypergraph::Edge* edge) const; + + bool empty() const { return models_.empty(); } + private: + std::vector models_; + std::vector weights_; + int state_size_; + std::vector model_state_pos_; +}; + +#endif diff --git a/src/ff_factory.cc b/src/ff_factory.cc new file mode 100644 index 00000000..1854e0bb --- /dev/null +++ b/src/ff_factory.cc @@ -0,0 +1,35 @@ +#include "ff_factory.h" + +#include "ff.h" + +using boost::shared_ptr; +using namespace std; + +FFFactoryBase::~FFFactoryBase() {} + +void FFRegistry::DisplayList() const { + for (map >::const_iterator it = reg_.begin(); + it != reg_.end(); ++it) { + cerr << " " << it->first << endl; + } +} + +shared_ptr FFRegistry::Create(const string& ffname, const string& param) const { + map >::const_iterator it = reg_.find(ffname); + shared_ptr res; + if (it == reg_.end()) { + cerr << "I don't know how to create feature " << ffname << endl; + } else { + res = it->second->Create(param); + } + return res; +} + +void FFRegistry::Register(const string& ffname, FFFactoryBase* factory) { + if (reg_.find(ffname) != reg_.end()) { + cerr << "Duplicate registration of FeatureFunction with name " << ffname << "!\n"; + abort(); + } + reg_[ffname].reset(factory); +} + diff --git a/src/ff_factory.h b/src/ff_factory.h new file mode 100644 index 00000000..bc586567 --- /dev/null +++ b/src/ff_factory.h @@ -0,0 +1,39 @@ +#ifndef _FF_FACTORY_H_ +#define _FF_FACTORY_H_ + +#include +#include +#include + +#include + +class FeatureFunction; +class FFRegistry; +class FFFactoryBase; +extern boost::shared_ptr global_ff_registry; + +class FFRegistry { + friend int main(int argc, char** argv); + friend class FFFactoryBase; + public: + boost::shared_ptr Create(const std::string& ffname, const std::string& param) const; + void DisplayList() const; + void Register(const std::string& ffname, FFFactoryBase* factory); + private: + FFRegistry() {} + std::map > reg_; +}; + +struct FFFactoryBase { + virtual ~FFFactoryBase(); + virtual boost::shared_ptr Create(const std::string& param) const = 0; +}; + +template +class FFFactory : public FFFactoryBase { + boost::shared_ptr Create(const std::string& param) const { + return boost::shared_ptr(new FF(param)); + } +}; + +#endif diff --git a/src/ff_itg_span.h b/src/ff_itg_span.h new file mode 100644 index 00000000..b990f86a --- /dev/null +++ b/src/ff_itg_span.h @@ -0,0 +1,7 @@ +#ifndef _FF_ITG_SPAN_H_ +#define _FF_ITG_SPAN_H_ + +class ITGSpanFeatures : public FeatureFunction { +}; + +#endif diff --git a/src/ff_test.cc b/src/ff_test.cc new file mode 100644 index 00000000..1c20f9ac --- /dev/null +++ b/src/ff_test.cc @@ -0,0 +1,134 @@ +#include +#include +#include +#include +#include +#include "hg.h" +#include "lm_ff.h" +#include "ff.h" +#include "trule.h" +#include "sentence_metadata.h" + +using namespace std; + +LanguageModel* lm_ = NULL; +LanguageModel* lm3_ = NULL; + +class FFTest : public testing::Test { + public: + FFTest() : smeta(0,Lattice()) { + if (!lm_) { + static LanguageModel slm("-o 2 ./test_data/brown.lm.gz"); + lm_ = &slm; + static LanguageModel slm3("./test_data/dummy.3gram.lm -o 3"); + lm3_ = &slm3; + } + } + protected: + virtual void SetUp() { } + virtual void TearDown() { } + SentenceMetadata smeta; +}; + +TEST_F(FFTest,LanguageModel) { + vector ms(1, lm_); + TRulePtr tr1(new TRule("[X] ||| [X,1] said")); + TRulePtr tr2(new TRule("[X] ||| the man said")); + TRulePtr tr3(new TRule("[X] ||| the fat man")); + Hypergraph hg; + const int lm_fid = FD::Convert("LanguageModel"); + vector w(lm_fid + 1,1); + ModelSet models(w, ms); + string state; + Hypergraph::Edge edge; + edge.rule_ = tr2; + models.AddFeaturesToEdge(smeta, hg, &edge, &state); + double lm1 = edge.feature_values_.dot(w); + cerr << "lm=" << edge.feature_values_[lm_fid] << endl; + + hg.nodes_.resize(1); + hg.edges_.resize(2); + hg.edges_[0].rule_ = tr3; + models.AddFeaturesToEdge(smeta, hg, &hg.edges_[0], &hg.nodes_[0].state_); + hg.edges_[1].tail_nodes_.push_back(0); + hg.edges_[1].rule_ = tr1; + string state2; + models.AddFeaturesToEdge(smeta, hg, &hg.edges_[1], &state2); + double tot = hg.edges_[1].feature_values_[lm_fid] + hg.edges_[0].feature_values_[lm_fid]; + cerr << "lm=" << tot << endl; + EXPECT_TRUE(state2 == state); + EXPECT_FALSE(state == hg.nodes_[0].state_); +} + +TEST_F(FFTest, Small) { + WordPenalty wp(""); + vector ms(2, lm_); + ms[1] = ℘ + TRulePtr tr1(new TRule("[X] ||| [X,1] said")); + TRulePtr tr2(new TRule("[X] ||| john said")); + TRulePtr tr3(new TRule("[X] ||| john")); + cerr << "RULE: " << tr1->AsString() << endl; + Hypergraph hg; + vector w(2); w[0]=1.0; w[1]=-2.0; + ModelSet models(w, ms); + string state; + Hypergraph::Edge edge; + edge.rule_ = tr2; + cerr << tr2->AsString() << endl; + models.AddFeaturesToEdge(smeta, hg, &edge, &state); + double s1 = edge.feature_values_.dot(w); + cerr << "lm=" << edge.feature_values_[0] << endl; + cerr << "wp=" << edge.feature_values_[1] << endl; + + hg.nodes_.resize(1); + hg.edges_.resize(2); + hg.edges_[0].rule_ = tr3; + models.AddFeaturesToEdge(smeta, hg, &hg.edges_[0], &hg.nodes_[0].state_); + double acc = hg.edges_[0].feature_values_.dot(w); + cerr << hg.edges_[0].feature_values_[0] << endl; + hg.edges_[1].tail_nodes_.push_back(0); + hg.edges_[1].rule_ = tr1; + string state2; + models.AddFeaturesToEdge(smeta, hg, &hg.edges_[1], &state2); + acc += hg.edges_[1].feature_values_.dot(w); + double tot = hg.edges_[1].feature_values_[0] + hg.edges_[0].feature_values_[0]; + cerr << "lm=" << tot << endl; + cerr << "acc=" << acc << endl; + cerr << " s1=" << s1 << endl; + EXPECT_TRUE(state2 == state); + EXPECT_FALSE(state == hg.nodes_[0].state_); + EXPECT_FLOAT_EQ(acc, s1); +} + +TEST_F(FFTest, LM3) { + int x = lm3_->NumBytesContext(); + Hypergraph::Edge edge1; + edge1.rule_.reset(new TRule("[X] ||| x y ||| one ||| 1.0 -2.4 3.0")); + Hypergraph::Edge edge2; + edge2.rule_.reset(new TRule("[X] ||| [X,1] a ||| [X,1] two ||| 1.0 -2.4 3.0")); + Hypergraph::Edge edge3; + edge3.rule_.reset(new TRule("[X] ||| [X,1] a ||| zero [X,1] two ||| 1.0 -2.4 3.0")); + vector ants1; + string state(x, '\0'); + SparseVector feats; + SparseVector est; + lm3_->TraversalFeatures(smeta, edge1, ants1, &feats, &est, (void *)&state[0]); + cerr << "returned " << feats << endl; + cerr << edge1.feature_values_ << endl; + cerr << lm3_->DebugStateToString((const void*)&state[0]) << endl; + EXPECT_EQ("[ one ]", lm3_->DebugStateToString((const void*)&state[0])); + ants1.push_back((const void*)&state[0]); + string state2(x, '\0'); + lm3_->TraversalFeatures(smeta, edge2, ants1, &feats, &est, (void *)&state2[0]); + cerr << lm3_->DebugStateToString((const void*)&state2[0]) << endl; + EXPECT_EQ("[ one two ]", lm3_->DebugStateToString((const void*)&state2[0])); + string state3(x, '\0'); + lm3_->TraversalFeatures(smeta, edge3, ants1, &feats, &est, (void *)&state3[0]); + cerr << lm3_->DebugStateToString((const void*)&state3[0]) << endl; + EXPECT_EQ("[ zero one <{STAR}> one two ]", lm3_->DebugStateToString((const void*)&state3[0])); +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ff_wordalign.cc b/src/ff_wordalign.cc new file mode 100644 index 00000000..e605ac8d --- /dev/null +++ b/src/ff_wordalign.cc @@ -0,0 +1,221 @@ +#include "ff_wordalign.h" + +#include +#include + +#include "stringlib.h" +#include "sentence_metadata.h" +#include "hg.h" +#include "fdict.h" +#include "aligner.h" +#include "tdict.h" // Blunsom hack +#include "filelib.h" // Blunsom hack + +using namespace std; + +RelativeSentencePosition::RelativeSentencePosition(const string& param) : + fid_(FD::Convert("RelativeSentencePosition")) {} + +void RelativeSentencePosition::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + // if the source word is either null or the generated word + // has no position in the reference + if (edge.i_ == -1 || edge.prev_i_ == -1) + return; + + assert(smeta.GetTargetLength() > 0); + const double val = fabs(static_cast(edge.i_) / smeta.GetSourceLength() - + static_cast(edge.prev_i_) / smeta.GetTargetLength()); + features->set_value(fid_, val); +// cerr << f_len_ << " " << e_len_ << " [" << edge.i_ << "," << edge.j_ << "|" << edge.prev_i_ << "," << edge.prev_j_ << "]\t" << edge.rule_->AsString() << "\tVAL=" << val << endl; +} + +MarkovJump::MarkovJump(const string& param) : + FeatureFunction(1), + fid_(FD::Convert("MarkovJump")), + individual_params_per_jumpsize_(false), + condition_on_flen_(false) { + cerr << " MarkovJump: Blunsom&Cohn feature"; + vector argv; + int argc = SplitOnWhitespace(param, &argv); + if (argc > 0) { + if (argc != 1 || !(argv[0] == "-f" || argv[0] == "-i" || argv[0] == "-if")) { + cerr << "MarkovJump: expected parameters to be -f, -i, or -if\n"; + exit(1); + } + individual_params_per_jumpsize_ = (argv[0][1] == 'i'); + condition_on_flen_ = (argv[0][argv[0].size() - 1] == 'f'); + if (individual_params_per_jumpsize_) { + template_ = "Jump:000"; + cerr << ", individual jump parameters"; + if (condition_on_flen_) { + template_ += ":F00"; + cerr << " (split by f-length)"; + } + } + } + cerr << endl; +} + +void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + unsigned char& dpstate = *((unsigned char*)state); + if (edge.Arity() == 0) { + dpstate = static_cast(edge.i_); + } else if (edge.Arity() == 1) { + dpstate = *((unsigned char*)ant_states[0]); + } else if (edge.Arity() == 2) { + int left_index = *((unsigned char*)ant_states[0]); + int right_index = *((unsigned char*)ant_states[1]); + if (right_index == -1) + dpstate = static_cast(left_index); + else + dpstate = static_cast(right_index); + const int jumpsize = right_index - left_index; + features->set_value(fid_, fabs(jumpsize - 1)); // Blunsom and Cohn def + + if (individual_params_per_jumpsize_) { + string fname = template_; + int param = jumpsize; + if (jumpsize < 0) { + param *= -1; + fname[5]='L'; + } else if (jumpsize > 0) { + fname[5]='R'; + } + if (param) { + fname[6] = '0' + (param / 10); + fname[7] = '0' + (param % 10); + } + if (condition_on_flen_) { + const int flen = smeta.GetSourceLength(); + fname[10] = '0' + (flen / 10); + fname[11] = '0' + (flen % 10); + } + features->set_value(FD::Convert(fname), 1.0); + } + } else { + assert(!"something really unexpected is happening"); + } +} + +AlignerResults::AlignerResults(const std::string& param) : + cur_sent_(-1), + cur_grid_(NULL) { + vector argv; + int argc = SplitOnWhitespace(param, &argv); + if (argc != 2) { + cerr << "Required format: AlignerResults [FeatureName] [file.pharaoh]\n"; + exit(1); + } + cerr << " feature: " << argv[0] << "\talignments: " << argv[1] << endl; + fid_ = FD::Convert(argv[0]); + ReadFile rf(argv[1]); + istream& in = *rf.stream(); int lc = 0; + while(in) { + string line; + getline(in, line); + if (!in) break; + ++lc; + is_aligned_.push_back(AlignerTools::ReadPharaohAlignmentGrid(line)); + } + cerr << " Loaded " << lc << " refs\n"; +} + +void AlignerResults::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + if (edge.i_ == -1 || edge.prev_i_ == -1) + return; + + if (cur_sent_ != smeta.GetSentenceID()) { + assert(smeta.HasReference()); + cur_sent_ = smeta.GetSentenceID(); + assert(cur_sent_ < is_aligned_.size()); + cur_grid_ = is_aligned_[cur_sent_].get(); + } + + //cerr << edge.rule_->AsString() << endl; + + int j = edge.i_; // source side (f) + int i = edge.prev_i_; // target side (e) + if (j < cur_grid_->height() && i < cur_grid_->width() && (*cur_grid_)(i, j)) { +// if (edge.rule_->e_[0] == smeta.GetReference()[i][0].label) { + features->set_value(fid_, 1.0); +// cerr << edge.rule_->AsString() << " (" << i << "," << j << ")\n"; +// } + } +} + +BlunsomSynchronousParseHack::BlunsomSynchronousParseHack(const string& param) : + FeatureFunction((100 / 8) + 1), fid_(FD::Convert("NotRef")), cur_sent_(-1) { + ReadFile rf(param); + istream& in = *rf.stream(); int lc = 0; + while(in) { + string line; + getline(in, line); + if (!in) break; + ++lc; + refs_.push_back(vector()); + TD::ConvertSentence(line, &refs_.back()); + } + cerr << " Loaded " << lc << " refs\n"; +} + +void BlunsomSynchronousParseHack::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + if (cur_sent_ != smeta.GetSentenceID()) { + // assert(smeta.HasReference()); + cur_sent_ = smeta.GetSentenceID(); + assert(cur_sent_ < refs_.size()); + cur_ref_ = &refs_[cur_sent_]; + cur_map_.clear(); + for (int i = 0; i < cur_ref_->size(); ++i) { + vector phrase; + for (int j = i; j < cur_ref_->size(); ++j) { + phrase.push_back((*cur_ref_)[j]); + cur_map_[phrase] = i; + } + } + } + //cerr << edge.rule_->AsString() << endl; + for (int i = 0; i < ant_states.size(); ++i) { + if (DoesNotBelong(ant_states[i])) { + //cerr << " ant " << i << " does not belong\n"; + return; + } + } + vector > ants(ant_states.size()); + vector* > pants(ant_states.size()); + for (int i = 0; i < ant_states.size(); ++i) { + AppendAntecedentString(ant_states[i], &ants[i]); + //cerr << " ant[" << i << "]: " << ((int)*(static_cast(ant_states[i]))) << " " << TD::GetString(ants[i]) << endl; + pants[i] = &ants[i]; + } + vector yield; + edge.rule_->ESubstitute(pants, &yield); + //cerr << "YIELD: " << TD::GetString(yield) << endl; + Vec2Int::iterator it = cur_map_.find(yield); + if (it == cur_map_.end()) { + features->set_value(fid_, 1); + //cerr << " BAD!\n"; + return; + } + SetStateMask(it->second, it->second + yield.size(), state); +} + diff --git a/src/ff_wordalign.h b/src/ff_wordalign.h new file mode 100644 index 00000000..1581641c --- /dev/null +++ b/src/ff_wordalign.h @@ -0,0 +1,133 @@ +#ifndef _FF_WORD_ALIGN_H_ +#define _FF_WORD_ALIGN_H_ + +#include "ff.h" +#include "array2d.h" + +class RelativeSentencePosition : public FeatureFunction { + public: + RelativeSentencePosition(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + const int fid_; +}; + +class MarkovJump : public FeatureFunction { + public: + MarkovJump(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + const int fid_; + bool individual_params_per_jumpsize_; + bool condition_on_flen_; + std::string template_; +}; + +class AlignerResults : public FeatureFunction { + public: + AlignerResults(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + int fid_; + std::vector > > is_aligned_; + mutable int cur_sent_; + const Array2D mutable* cur_grid_; +}; + +#include +#include +#include +class BlunsomSynchronousParseHack : public FeatureFunction { + public: + BlunsomSynchronousParseHack(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + inline bool DoesNotBelong(const void* state) const { + for (int i = 0; i < NumBytesContext(); ++i) { + if (*(static_cast(state) + i)) return false; + } + return true; + } + + inline void AppendAntecedentString(const void* state, std::vector* yield) const { + int i = 0; + int ind = 0; + while (i < NumBytesContext() && !(*(static_cast(state) + i))) { ++i; ind += 8; } + // std::cerr << i << " " << NumBytesContext() << std::endl; + assert(i != NumBytesContext()); + assert(ind < cur_ref_->size()); + int cur = *(static_cast(state) + i); + int comp = 1; + while (comp < 256 && (comp & cur) == 0) { comp <<= 1; ++ind; } + assert(ind < cur_ref_->size()); + assert(comp < 256); + do { + assert(ind < cur_ref_->size()); + yield->push_back((*cur_ref_)[ind]); + ++ind; + comp <<= 1; + if (comp == 256) { + comp = 1; + ++i; + cur = *(static_cast(state) + i); + } + } while (comp & cur); + } + + inline void SetStateMask(int start, int end, void* state) const { + assert((end / 8) < NumBytesContext()); + int i = 0; + int comp = 1; + for (int j = 0; j < start; ++j) { + comp <<= 1; + if (comp == 256) { + ++i; + comp = 1; + } + } + //std::cerr << "SM: " << i << "\n"; + for (int j = start; j < end; ++j) { + *(static_cast(state) + i) |= comp; + //std::cerr << " " << comp << "\n"; + comp <<= 1; + if (comp == 256) { + ++i; + comp = 1; + } + } + //std::cerr << " MASK: " << ((int)*(static_cast(state))) << "\n"; + } + + const int fid_; + mutable int cur_sent_; + typedef std::tr1::unordered_map, int, boost::hash > > Vec2Int; + mutable Vec2Int cur_map_; + const std::vector mutable * cur_ref_; + mutable std::vector > refs_; +}; + +#endif diff --git a/src/filelib.cc b/src/filelib.cc new file mode 100644 index 00000000..79ad2847 --- /dev/null +++ b/src/filelib.cc @@ -0,0 +1,22 @@ +#include "filelib.h" + +#include +#include + +using namespace std; + +bool FileExists(const std::string& fn) { + struct stat info; + int s = stat(fn.c_str(), &info); + return (s==0); +} + +bool DirectoryExists(const string& dir) { + if (access(dir.c_str(),0) == 0) { + struct stat status; + stat(dir.c_str(), &status); + if (status.st_mode & S_IFDIR) return true; + } + return false; +} + diff --git a/src/filelib.h b/src/filelib.h new file mode 100644 index 00000000..62cb9427 --- /dev/null +++ b/src/filelib.h @@ -0,0 +1,66 @@ +#ifndef _FILELIB_H_ +#define _FILELIB_H_ + +#include +#include +#include +#include +#include "gzstream.h" + +// reads from standard in if filename is - +// uncompresses if file ends with .gz +// otherwise, reads from a normal file +class ReadFile { + public: + ReadFile(const std::string& filename) : + no_delete_on_exit_(filename == "-"), + in_(no_delete_on_exit_ ? static_cast(&std::cin) : + (EndsWith(filename, ".gz") ? + static_cast(new igzstream(filename.c_str())) : + static_cast(new std::ifstream(filename.c_str())))) { + if (!*in_) { + std::cerr << "Failed to open " << filename << std::endl; + abort(); + } + } + ~ReadFile() { + if (!no_delete_on_exit_) delete in_; + } + + inline std::istream* stream() { return in_; } + + private: + static bool EndsWith(const std::string& f, const std::string& suf) { + return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size()); + } + const bool no_delete_on_exit_; + std::istream* const in_; +}; + +class WriteFile { + public: + WriteFile(const std::string& filename) : + no_delete_on_exit_(filename == "-"), + out_(no_delete_on_exit_ ? static_cast(&std::cout) : + (EndsWith(filename, ".gz") ? + static_cast(new ogzstream(filename.c_str())) : + static_cast(new std::ofstream(filename.c_str())))) {} + ~WriteFile() { + (*out_) << std::flush; + if (!no_delete_on_exit_) delete out_; + } + + inline std::ostream* stream() { return out_; } + + private: + static bool EndsWith(const std::string& f, const std::string& suf) { + return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size()); + } + const bool no_delete_on_exit_; + std::ostream* const out_; +}; + +bool FileExists(const std::string& file_name); +bool DirectoryExists(const std::string& dir_name); + +#endif diff --git a/src/forest_writer.cc b/src/forest_writer.cc new file mode 100644 index 00000000..a9117d18 --- /dev/null +++ b/src/forest_writer.cc @@ -0,0 +1,23 @@ +#include "forest_writer.h" + +#include + +#include + +#include "filelib.h" +#include "hg_io.h" +#include "hg.h" + +using namespace std; + +ForestWriter::ForestWriter(const std::string& path, int num) : + fname_(path + '/' + boost::lexical_cast(num) + ".json.gz"), used_(false) {} + +bool ForestWriter::Write(const Hypergraph& forest, bool minimal_rules) { + assert(!used_); + used_ = true; + cerr << " Writing forest to " << fname_ << endl; + WriteFile wf(fname_); + return HypergraphIO::WriteToJSON(forest, minimal_rules, wf.stream()); +} + diff --git a/src/forest_writer.h b/src/forest_writer.h new file mode 100644 index 00000000..819a8940 --- /dev/null +++ b/src/forest_writer.h @@ -0,0 +1,16 @@ +#ifndef _FOREST_WRITER_H_ +#define _FOREST_WRITER_H_ + +#include + +class Hypergraph; + +struct ForestWriter { + ForestWriter(const std::string& path, int num); + bool Write(const Hypergraph& forest, bool minimal_rules); + + const std::string fname_; + bool used_; +}; + +#endif diff --git a/src/freqdict.cc b/src/freqdict.cc new file mode 100644 index 00000000..4cfffe58 --- /dev/null +++ b/src/freqdict.cc @@ -0,0 +1,23 @@ +#include +#include +#include +#include "freqdict.h" + +void FreqDict::load(const std::string& fname) { + std::ifstream ifs(fname.c_str()); + int cc=0; + while (!ifs.eof()) { + std::string word; + ifs >> word; + if (word.size() == 0) continue; + if (word[0] == '#') continue; + double count = 0; + ifs >> count; + assert(count > 0.0); // use -log(f) + counts_[word]=count; + ++cc; + if (cc % 10000 == 0) { std::cerr << "."; } + } + std::cerr << "\n"; + std::cerr << "Loaded " << cc << " words\n"; +} diff --git a/src/freqdict.h b/src/freqdict.h new file mode 100644 index 00000000..c9bb4c42 --- /dev/null +++ b/src/freqdict.h @@ -0,0 +1,19 @@ +#ifndef _FREQDICT_H_ +#define _FREQDICT_H_ + +#include +#include + +class FreqDict { + public: + void load(const std::string& fname); + float frequency(const std::string& word) const { + std::map::const_iterator i = counts_.find(word); + if (i == counts_.end()) return 0; + return i->second; + } + private: + std::map counts_; +}; + +#endif diff --git a/src/fst_translator.cc b/src/fst_translator.cc new file mode 100644 index 00000000..57feb227 --- /dev/null +++ b/src/fst_translator.cc @@ -0,0 +1,91 @@ +#include "translator.h" + +#include +#include + +#include "sentence_metadata.h" +#include "filelib.h" +#include "hg.h" +#include "hg_io.h" +#include "earley_composer.h" +#include "phrasetable_fst.h" +#include "tdict.h" + +using namespace std; + +struct FSTTranslatorImpl { + FSTTranslatorImpl(const boost::program_options::variables_map& conf) : + goal_sym(conf["goal"].as()), + kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")), + kGOAL(TD::Convert("Goal") * -1), + add_pass_through_rules(conf.count("add_pass_through_rules")) { + fst.reset(LoadTextPhrasetable(conf["grammar"].as >())); + ec.reset(new EarleyComposer(fst.get())); + } + + bool Translate(const string& input, + const vector& weights, + Hypergraph* forest) { + bool composed = false; + if (input.find("{\"rules\"") == 0) { + istringstream is(input); + Hypergraph src_cfg_hg; + assert(HypergraphIO::ReadFromJSON(&is, &src_cfg_hg)); + if (add_pass_through_rules) { + SparseVector feats; + feats.set_value(FD::Convert("PassThrough"), 1); + for (int i = 0; i < src_cfg_hg.edges_.size(); ++i) { + const vector& f = src_cfg_hg.edges_[i].rule_->f_; + for (int j = 0; j < f.size(); ++j) { + if (f[j] > 0) { + fst->AddPassThroughTranslation(f[j], feats); + } + } + } + } + composed = ec->Compose(src_cfg_hg, forest); + } else { + const string dummy_grammar("[" + goal_sym + "] ||| " + input + " ||| TOP=1"); + cerr << " Dummy grammar: " << dummy_grammar << endl; + istringstream is(dummy_grammar); + if (add_pass_through_rules) { + vector words; + TD::ConvertSentence(input, &words); + SparseVector feats; + feats.set_value(FD::Convert("PassThrough"), 1); + for (int i = 0; i < words.size(); ++i) + fst->AddPassThroughTranslation(words[i], feats); + } + composed = ec->Compose(&is, forest); + } + if (composed) { + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1, ""); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + forest->Reweight(weights); + } + if (add_pass_through_rules) + fst->ClearPassThroughTranslations(); + return composed; + } + + const string goal_sym; + const TRulePtr kGOAL_RULE; + const WordID kGOAL; + const bool add_pass_through_rules; + boost::shared_ptr ec; + boost::shared_ptr fst; +}; + +FSTTranslator::FSTTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new FSTTranslatorImpl(conf)) {} + +bool FSTTranslator::Translate(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* minus_lm_forest) { + smeta->SetSourceLength(0); // don't know how to compute this + return pimpl_->Translate(input, weights, minus_lm_forest); +} + diff --git a/src/grammar.cc b/src/grammar.cc new file mode 100644 index 00000000..69e38320 --- /dev/null +++ b/src/grammar.cc @@ -0,0 +1,163 @@ +#include "grammar.h" + +#include +#include +#include + +#include "filelib.h" +#include "tdict.h" + +using namespace std; + +const vector Grammar::NO_RULES; + +RuleBin::~RuleBin() {} +GrammarIter::~GrammarIter() {} +Grammar::~Grammar() {} + +bool Grammar::HasRuleForSpan(int i, int j) const { + (void) i; + (void) j; + return true; // always true by default +} + +struct TextRuleBin : public RuleBin { + int GetNumRules() const { + return rules_.size(); + } + TRulePtr GetIthRule(int i) const { + return rules_[i]; + } + void AddRule(TRulePtr t) { + rules_.push_back(t); + } + int Arity() const { + return rules_.front()->Arity(); + } + void Dump() const { + for (int i = 0; i < rules_.size(); ++i) + VLOG(1) << rules_[i]->AsString() << endl; + } + private: + vector rules_; +}; + +struct TextGrammarNode : public GrammarIter { + TextGrammarNode() : rb_(NULL) {} + ~TextGrammarNode() { + delete rb_; + } + const GrammarIter* Extend(int symbol) const { + map::const_iterator i = tree_.find(symbol); + if (i == tree_.end()) return NULL; + return &i->second; + } + + const RuleBin* GetRules() const { + if (rb_) { + //rb_->Dump(); + } + return rb_; + } + + map tree_; + TextRuleBin* rb_; +}; + +struct TGImpl { + TextGrammarNode root_; +}; + +TextGrammar::TextGrammar() : max_span_(10), pimpl_(new TGImpl) {} +TextGrammar::TextGrammar(const string& file) : + max_span_(10), + pimpl_(new TGImpl) { + ReadFromFile(file); +} + +const GrammarIter* TextGrammar::GetRoot() const { + return &pimpl_->root_; +} + +void TextGrammar::AddRule(const TRulePtr& rule) { + if (rule->IsUnary()) { + rhs2unaries_[rule->f().front()].push_back(rule); + unaries_.push_back(rule); + } else { + TextGrammarNode* cur = &pimpl_->root_; + for (int i = 0; i < rule->f_.size(); ++i) + cur = &cur->tree_[rule->f_[i]]; + if (cur->rb_ == NULL) + cur->rb_ = new TextRuleBin; + cur->rb_->AddRule(rule); + } +} + +void TextGrammar::ReadFromFile(const string& filename) { + ReadFile in(filename); + istream& in_file = *in.stream(); + assert(in_file); + long long int rule_count = 0; + bool fl = false; + while(in_file) { + string line; + getline(in_file, line); + if (line.empty()) continue; + ++rule_count; + if (rule_count % 50000 == 0) { cerr << '.' << flush; fl = true; } + if (rule_count % 2000000 == 0) { cerr << " [" << rule_count << "]\n"; fl = false; } + TRulePtr rule(TRule::CreateRuleSynchronous(line)); + if (rule) { + AddRule(rule); + } else { + if (fl) { cerr << endl; } + cerr << "Skipping badly formatted rule in line " << rule_count << " of " << filename << endl; + fl = false; + } + } + if (fl) cerr << endl; + cerr << " " << rule_count << " rules read.\n"; +} + +bool TextGrammar::HasRuleForSpan(int i, int j) const { + return (max_span_ >= (j - i)); +} + +GlueGrammar::GlueGrammar(const string& file) : TextGrammar(file) {} + +GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt) { + TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]")); + TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] [" + + default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1")); + + AddRule(stop_glue); + AddRule(glue); + //cerr << "GLUE: " << stop_glue->AsString() << endl; + //cerr << "GLUE: " << glue->AsString() << endl; +} + +bool GlueGrammar::HasRuleForSpan(int i, int j) const { + (void) j; + return (i == 0); +} + +PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat) : + has_rule_(input.size() + 1) { + for (int i = 0; i < input.size(); ++i) { + const vector& alts = input[i]; + for (int k = 0; k < alts.size(); ++k) { + const int j = alts[k].dist2next + i; + has_rule_[i].insert(j); + const string& src = TD::Convert(alts[k].label); + TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1")); + AddRule(pt); +// cerr << "PT: " << pt->AsString() << endl; + } + } +} + +bool PassThroughGrammar::HasRuleForSpan(int i, int j) const { + const set& hr = has_rule_[i]; + if (i == j) { return !hr.empty(); } + return (hr.find(j) != hr.end()); +} diff --git a/src/grammar.h b/src/grammar.h new file mode 100644 index 00000000..4a03c505 --- /dev/null +++ b/src/grammar.h @@ -0,0 +1,83 @@ +#ifndef GRAMMAR_H_ +#define GRAMMAR_H_ + +#include +#include +#include +#include + +#include "lattice.h" +#include "trule.h" + +struct RuleBin { + virtual ~RuleBin(); + virtual int GetNumRules() const = 0; + virtual TRulePtr GetIthRule(int i) const = 0; + virtual int Arity() const = 0; +}; + +struct GrammarIter { + virtual ~GrammarIter(); + virtual const RuleBin* GetRules() const = 0; + virtual const GrammarIter* Extend(int symbol) const = 0; +}; + +struct Grammar { + typedef std::map > Cat2Rules; + static const std::vector NO_RULES; + + virtual ~Grammar(); + virtual const GrammarIter* GetRoot() const = 0; + virtual bool HasRuleForSpan(int i, int j) const; + + // cat is the category to be rewritten + inline const std::vector& GetAllUnaryRules() const { + return unaries_; + } + + // get all the unary rules that rewrite category cat + inline const std::vector& GetUnaryRulesForRHS(const WordID& cat) const { + Cat2Rules::const_iterator found = rhs2unaries_.find(cat); + if (found == rhs2unaries_.end()) + return NO_RULES; + else + return found->second; + } + + protected: + Cat2Rules rhs2unaries_; // these must be filled in by subclasses! + std::vector unaries_; +}; + +typedef boost::shared_ptr GrammarPtr; + +class TGImpl; +struct TextGrammar : public Grammar { + TextGrammar(); + TextGrammar(const std::string& file); + void SetMaxSpan(int m) { max_span_ = m; } + virtual const GrammarIter* GetRoot() const; + void AddRule(const TRulePtr& rule); + void ReadFromFile(const std::string& filename); + virtual bool HasRuleForSpan(int i, int j) const; + const std::vector& GetUnaryRules(const WordID& cat) const; + private: + int max_span_; + boost::shared_ptr pimpl_; +}; + +struct GlueGrammar : public TextGrammar { + // read glue grammar from file + explicit GlueGrammar(const std::string& file); + GlueGrammar(const std::string& goal_nt, const std::string& default_nt); // "S", "X" + virtual bool HasRuleForSpan(int i, int j) const; +}; + +struct PassThroughGrammar : public TextGrammar { + PassThroughGrammar(const Lattice& input, const std::string& cat); + virtual bool HasRuleForSpan(int i, int j) const; + private: + std::vector > has_rule_; // index by [i][j] +}; + +#endif diff --git a/src/grammar_test.cc b/src/grammar_test.cc new file mode 100644 index 00000000..62b8f958 --- /dev/null +++ b/src/grammar_test.cc @@ -0,0 +1,59 @@ +#include +#include +#include +#include +#include +#include "trule.h" +#include "tdict.h" +#include "grammar.h" +#include "bottom_up_parser.h" +#include "ff.h" +#include "weights.h" + +using namespace std; + +class GrammarTest : public testing::Test { + public: + GrammarTest() { + wts.InitFromFile("test_data/weights.gt"); + } + protected: + virtual void SetUp() { } + virtual void TearDown() { } + Weights wts; +}; + +TEST_F(GrammarTest,TestTextGrammar) { + vector w; + vector ms; + ModelSet models(w, ms); + + TextGrammar g; + TRulePtr r1(new TRule("[X] ||| a b c ||| A B C ||| 0.1 0.2 0.3", true)); + TRulePtr r2(new TRule("[X] ||| a b c ||| 1 2 3 ||| 0.2 0.3 0.4", true)); + TRulePtr r3(new TRule("[X] ||| a b c d ||| A B C D ||| 0.1 0.2 0.3", true)); + cerr << r1->AsString() << endl; + g.AddRule(r1); + g.AddRule(r2); + g.AddRule(r3); +} + +TEST_F(GrammarTest,TestTextGrammarFile) { + GrammarPtr g(new TextGrammar("./test_data/grammar.prune")); + vector grammars(1, g); + + LatticeArc a(TD::Convert("ein"), 0.0, 1); + LatticeArc b(TD::Convert("haus"), 0.0, 1); + Lattice lattice(2); + lattice[0].push_back(a); + lattice[1].push_back(b); + Hypergraph forest; + ExhaustiveBottomUpParser parser("PHRASE", grammars); + parser.Parse(lattice, &forest); + forest.PrintGraphviz(); +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/gzstream.cc b/src/gzstream.cc new file mode 100644 index 00000000..9703e6ad --- /dev/null +++ b/src/gzstream.cc @@ -0,0 +1,165 @@ +// ============================================================================ +// gzstream, C++ iostream classes wrapping the zlib compression library. +// Copyright (C) 2001 Deepak Bandyopadhyay, Lutz Kettner +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 2.1 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// ============================================================================ +// +// File : gzstream.C +// Revision : $Revision: 1.1 $ +// Revision_date : $Date: 2006/03/30 04:05:52 $ +// Author(s) : Deepak Bandyopadhyay, Lutz Kettner +// +// Standard streambuf implementation following Nicolai Josuttis, "The +// Standard C++ Library". +// ============================================================================ + +#include "gzstream.h" +#include +#include + +#ifdef GZSTREAM_NAMESPACE +namespace GZSTREAM_NAMESPACE { +#endif + +// ---------------------------------------------------------------------------- +// Internal classes to implement gzstream. See header file for user classes. +// ---------------------------------------------------------------------------- + +// -------------------------------------- +// class gzstreambuf: +// -------------------------------------- + +gzstreambuf* gzstreambuf::open( const char* name, int open_mode) { + if ( is_open()) + return (gzstreambuf*)0; + mode = open_mode; + // no append nor read/write mode + if ((mode & std::ios::ate) || (mode & std::ios::app) + || ((mode & std::ios::in) && (mode & std::ios::out))) + return (gzstreambuf*)0; + char fmode[10]; + char* fmodeptr = fmode; + if ( mode & std::ios::in) + *fmodeptr++ = 'r'; + else if ( mode & std::ios::out) + *fmodeptr++ = 'w'; + *fmodeptr++ = 'b'; + *fmodeptr = '\0'; + file = gzopen( name, fmode); + if (file == 0) + return (gzstreambuf*)0; + opened = 1; + return this; +} + +gzstreambuf * gzstreambuf::close() { + if ( is_open()) { + sync(); + opened = 0; + if ( gzclose( file) == Z_OK) + return this; + } + return (gzstreambuf*)0; +} + +int gzstreambuf::underflow() { // used for input buffer only + if ( gptr() && ( gptr() < egptr())) + return * reinterpret_cast( gptr()); + + if ( ! (mode & std::ios::in) || ! opened) + return EOF; + // Josuttis' implementation of inbuf + int n_putback = gptr() - eback(); + if ( n_putback > 4) + n_putback = 4; + memcpy( buffer + (4 - n_putback), gptr() - n_putback, n_putback); + + int num = gzread( file, buffer+4, bufferSize-4); + if (num <= 0) // ERROR or EOF + return EOF; + + // reset buffer pointers + setg( buffer + (4 - n_putback), // beginning of putback area + buffer + 4, // read position + buffer + 4 + num); // end of buffer + + // return next character + return * reinterpret_cast( gptr()); +} + +int gzstreambuf::flush_buffer() { + // Separate the writing of the buffer from overflow() and + // sync() operation. + int w = pptr() - pbase(); + if ( gzwrite( file, pbase(), w) != w) + return EOF; + pbump( -w); + return w; +} + +int gzstreambuf::overflow( int c) { // used for output buffer only + if ( ! ( mode & std::ios::out) || ! opened) + return EOF; + if (c != EOF) { + *pptr() = c; + pbump(1); + } + if ( flush_buffer() == EOF) + return EOF; + return c; +} + +int gzstreambuf::sync() { + // Changed to use flush_buffer() instead of overflow( EOF) + // which caused improper behavior with std::endl and flush(), + // bug reported by Vincent Ricard. + if ( pptr() && pptr() > pbase()) { + if ( flush_buffer() == EOF) + return -1; + } + return 0; +} + +// -------------------------------------- +// class gzstreambase: +// -------------------------------------- + +gzstreambase::gzstreambase( const char* name, int mode) { + init( &buf); + open( name, mode); +} + +gzstreambase::~gzstreambase() { + buf.close(); +} + +void gzstreambase::open( const char* name, int open_mode) { + if ( ! buf.open( name, open_mode)) + clear( rdstate() | std::ios::badbit); +} + +void gzstreambase::close() { + if ( buf.is_open()) + if ( ! buf.close()) + clear( rdstate() | std::ios::badbit); +} + +#ifdef GZSTREAM_NAMESPACE +} // namespace GZSTREAM_NAMESPACE +#endif + +// ============================================================================ +// EOF // diff --git a/src/gzstream.h b/src/gzstream.h new file mode 100644 index 00000000..ad9785fd --- /dev/null +++ b/src/gzstream.h @@ -0,0 +1,121 @@ +// ============================================================================ +// gzstream, C++ iostream classes wrapping the zlib compression library. +// Copyright (C) 2001 Deepak Bandyopadhyay, Lutz Kettner +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 2.1 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// ============================================================================ +// +// File : gzstream.h +// Revision : $Revision: 1.1 $ +// Revision_date : $Date: 2006/03/30 04:05:52 $ +// Author(s) : Deepak Bandyopadhyay, Lutz Kettner +// +// Standard streambuf implementation following Nicolai Josuttis, "The +// Standard C++ Library". +// ============================================================================ + +#ifndef GZSTREAM_H +#define GZSTREAM_H 1 + +// standard C++ with new header file names and std:: namespace +#include +#include +#include + +#ifdef GZSTREAM_NAMESPACE +namespace GZSTREAM_NAMESPACE { +#endif + +// ---------------------------------------------------------------------------- +// Internal classes to implement gzstream. See below for user classes. +// ---------------------------------------------------------------------------- + +class gzstreambuf : public std::streambuf { +private: + static const int bufferSize = 47+256; // size of data buff + // totals 512 bytes under g++ for igzstream at the end. + + gzFile file; // file handle for compressed file + char buffer[bufferSize]; // data buffer + char opened; // open/close state of stream + int mode; // I/O mode + + int flush_buffer(); +public: + gzstreambuf() : opened(0) { + setp( buffer, buffer + (bufferSize-1)); + setg( buffer + 4, // beginning of putback area + buffer + 4, // read position + buffer + 4); // end position + // ASSERT: both input & output capabilities will not be used together + } + int is_open() { return opened; } + gzstreambuf* open( const char* name, int open_mode); + gzstreambuf* close(); + ~gzstreambuf() { close(); } + + virtual int overflow( int c = EOF); + virtual int underflow(); + virtual int sync(); +}; + +class gzstreambase : virtual public std::ios { +protected: + gzstreambuf buf; +public: + gzstreambase() { init(&buf); } + gzstreambase( const char* name, int open_mode); + ~gzstreambase(); + void open( const char* name, int open_mode); + void close(); + gzstreambuf* rdbuf() { return &buf; } +}; + +// ---------------------------------------------------------------------------- +// User classes. Use igzstream and ogzstream analogously to ifstream and +// ofstream respectively. They read and write files based on the gz* +// function interface of the zlib. Files are compatible with gzip compression. +// ---------------------------------------------------------------------------- + +class igzstream : public gzstreambase, public std::istream { +public: + igzstream() : std::istream( &buf) {} + igzstream( const char* name, int open_mode = std::ios::in) + : gzstreambase( name, open_mode), std::istream( &buf) {} + gzstreambuf* rdbuf() { return gzstreambase::rdbuf(); } + void open( const char* name, int open_mode = std::ios::in) { + gzstreambase::open( name, open_mode); + } +}; + +class ogzstream : public gzstreambase, public std::ostream { +public: + ogzstream() : std::ostream( &buf) {} + ogzstream( const char* name, int mode = std::ios::out) + : gzstreambase( name, mode), std::ostream( &buf) {} + gzstreambuf* rdbuf() { return gzstreambase::rdbuf(); } + void open( const char* name, int open_mode = std::ios::out) { + gzstreambase::open( name, open_mode); + } +}; + +#ifdef GZSTREAM_NAMESPACE +} // namespace GZSTREAM_NAMESPACE +#endif + +#endif // GZSTREAM_H +// ============================================================================ +// EOF // + diff --git a/src/hg.cc b/src/hg.cc new file mode 100644 index 00000000..dd8f8eba --- /dev/null +++ b/src/hg.cc @@ -0,0 +1,483 @@ +#include "hg.h" + +#include +#include +#include +#include +#include + +#include "viterbi.h" +#include "inside_outside.h" +#include "tdict.h" + +using namespace std; + +double Hypergraph::NumberOfPaths() const { + return Inside(*this); +} + +prob_t Hypergraph::ComputeEdgePosteriors(double scale, vector* posts) const { + const ScaledEdgeProb weight(scale); + SparseVector pv; + const double inside = InsideOutside, + EdgeFeaturesWeightFunction>(*this, &pv, weight); + posts->resize(edges_.size()); + for (int i = 0; i < edges_.size(); ++i) + (*posts)[i] = prob_t(pv.value(i)); + return prob_t(inside); +} + +prob_t Hypergraph::ComputeBestPathThroughEdges(vector* post) const { + vector in(edges_.size()); + vector out(edges_.size()); + post->resize(edges_.size()); + + vector ins_node_best(nodes_.size()); + for (int i = 0; i < nodes_.size(); ++i) { + const Node& node = nodes_[i]; + prob_t& node_ins_best = ins_node_best[i]; + if (node.in_edges_.empty()) node_ins_best = prob_t::One(); + for (int j = 0; j < node.in_edges_.size(); ++j) { + const Edge& edge = edges_[node.in_edges_[j]]; + prob_t& in_edge_sco = in[node.in_edges_[j]]; + in_edge_sco = edge.edge_prob_; + for (int k = 0; k < edge.tail_nodes_.size(); ++k) + in_edge_sco *= ins_node_best[edge.tail_nodes_[k]]; + if (in_edge_sco > node_ins_best) node_ins_best = in_edge_sco; + } + } + const prob_t ins_sco = ins_node_best[nodes_.size() - 1]; + + // sanity check + int tots = 0; + for (int i = 0; i < nodes_.size(); ++i) { if (nodes_[i].out_edges_.empty()) tots++; } + assert(tots == 1); + + // compute outside scores, potentially using inside scores + vector out_node_best(nodes_.size()); + for (int i = nodes_.size() - 1; i >= 0; --i) { + const Node& node = nodes_[i]; + prob_t& node_out_best = out_node_best[node.id_]; + if (node.out_edges_.empty()) node_out_best = prob_t::One(); + for (int j = 0; j < node.out_edges_.size(); ++j) { + const Edge& edge = edges_[node.out_edges_[j]]; + prob_t sco = edge.edge_prob_ * out_node_best[edge.head_node_]; + for (int k = 0; k < edge.tail_nodes_.size(); ++k) { + if (edge.tail_nodes_[k] != i) + sco *= ins_node_best[edge.tail_nodes_[k]]; + } + if (sco > node_out_best) node_out_best = sco; + } + for (int j = 0; j < node.in_edges_.size(); ++j) { + out[node.in_edges_[j]] = node_out_best; + } + } + + for (int i = 0; i < in.size(); ++i) + (*post)[i] = in[i] * out[i]; + + return ins_sco; +} + +void Hypergraph::PushWeightsToSource(double scale) { + vector posts; + ComputeEdgePosteriors(scale, &posts); + for (int i = 0; i < nodes_.size(); ++i) { + const Hypergraph::Node& node = nodes_[i]; + prob_t z = prob_t::Zero(); + for (int j = 0; j < node.out_edges_.size(); ++j) + z += posts[node.out_edges_[j]]; + for (int j = 0; j < node.out_edges_.size(); ++j) { + edges_[node.out_edges_[j]].edge_prob_ = posts[node.out_edges_[j]] / z; + } + } +} + +void Hypergraph::PushWeightsToGoal(double scale) { + vector posts; + ComputeEdgePosteriors(scale, &posts); + for (int i = 0; i < nodes_.size(); ++i) { + const Hypergraph::Node& node = nodes_[i]; + prob_t z = prob_t::Zero(); + for (int j = 0; j < node.in_edges_.size(); ++j) + z += posts[node.in_edges_[j]]; + for (int j = 0; j < node.in_edges_.size(); ++j) { + edges_[node.in_edges_[j]].edge_prob_ = posts[node.in_edges_[j]] / z; + } + } +} + +void Hypergraph::PruneEdges(const std::vector& prune_edge) { + assert(prune_edge.size() == edges_.size()); + TopologicallySortNodesAndEdges(nodes_.size() - 1, &prune_edge); +} + +void Hypergraph::DensityPruneInsideOutside(const double scale, + const bool use_sum_prod_semiring, + const double density, + const vector* preserve_mask) { + assert(density >= 1.0); + const int plen = ViterbiPathLength(*this); + vector bp; + int rnum = min(static_cast(edges_.size()), static_cast(density * static_cast(plen))); + if (rnum == edges_.size()) { + cerr << "No pruning required: denisty already sufficient"; + return; + } + vector io(edges_.size()); + if (use_sum_prod_semiring) + ComputeEdgePosteriors(scale, &io); + else + ComputeBestPathThroughEdges(&io); + assert(edges_.size() == io.size()); + vector sorted = io; + nth_element(sorted.begin(), sorted.begin() + rnum, sorted.end(), greater()); + const double cutoff = sorted[rnum]; + vector prune(edges_.size()); + for (int i = 0; i < edges_.size(); ++i) { + prune[i] = (io[i] < cutoff); + if (preserve_mask && (*preserve_mask)[i]) prune[i] = false; + } + PruneEdges(prune); +} + +void Hypergraph::BeamPruneInsideOutside( + const double scale, + const bool use_sum_prod_semiring, + const double alpha, + const vector* preserve_mask) { + assert(alpha > 0.0); + assert(scale > 0.0); + vector io(edges_.size()); + if (use_sum_prod_semiring) + ComputeEdgePosteriors(scale, &io); + else + ComputeBestPathThroughEdges(&io); + assert(edges_.size() == io.size()); + prob_t best; // initializes to zero + for (int i = 0; i < io.size(); ++i) + if (io[i] > best) best = io[i]; + const prob_t aprob(exp(-alpha)); + const prob_t cutoff = best * aprob; + vector prune(edges_.size()); + //cerr << preserve_mask.size() << " " << edges_.size() << endl; + int pc = 0; + for (int i = 0; i < io.size(); ++i) { + const bool prune_edge = (io[i] < cutoff); + if (prune_edge) ++pc; + prune[i] = (io[i] < cutoff); + if (preserve_mask && (*preserve_mask)[i]) prune[i] = false; + } + cerr << "Beam pruning " << pc << "/" << io.size() << " edges\n"; + PruneEdges(prune); +} + +void Hypergraph::PrintGraphviz() const { + int ei = 0; + cerr << "digraph G {\n rankdir=LR;\n nodesep=.05;\n"; + for (vector::const_iterator i = edges_.begin(); + i != edges_.end(); ++i) { + const Edge& edge=*i; + ++ei; + static const string none = ""; + string rule = (edge.rule_ ? edge.rule_->AsString(false) : none); + + cerr << " A_" << ei << " [label=\"" << rule << " p=" << edge.edge_prob_ + << " F:" << edge.feature_values_ + << "\" shape=\"rect\"];\n"; + for (int i = 0; i < edge.tail_nodes_.size(); ++i) { + cerr << " " << edge.tail_nodes_[i] << " -> A_" << ei << ";\n"; + } + cerr << " A_" << ei << " -> " << edge.head_node_ << ";\n"; + } + for (vector::const_iterator ni = nodes_.begin(); + ni != nodes_.end(); ++ni) { + cerr << " " << ni->id_ << "[label=\"" << (ni->cat_ < 0 ? TD::Convert(ni->cat_ * -1) : "") + //cerr << " " << ni->id_ << "[label=\"" << ni->cat_ + << " n=" << ni->id_ +// << ",x=" << &*ni +// << ",in=" << ni->in_edges_.size() +// << ",out=" << ni->out_edges_.size() + << "\"];\n"; + } + cerr << "}\n"; +} + +void Hypergraph::Union(const Hypergraph& other) { + if (&other == this) return; + if (nodes_.empty()) { nodes_ = other.nodes_; edges_ = other.edges_; return; } + int noff = nodes_.size(); + int eoff = edges_.size(); + int ogoal = other.nodes_.size() - 1; + int cgoal = noff - 1; + // keep a single goal node, so add nodes.size - 1 + nodes_.resize(nodes_.size() + ogoal); + // add all edges + edges_.resize(edges_.size() + other.edges_.size()); + + for (int i = 0; i < ogoal; ++i) { + const Node& on = other.nodes_[i]; + Node& cn = nodes_[i + noff]; + cn.id_ = i + noff; + cn.in_edges_.resize(on.in_edges_.size()); + for (int j = 0; j < on.in_edges_.size(); ++j) + cn.in_edges_[j] = on.in_edges_[j] + eoff; + + cn.out_edges_.resize(on.out_edges_.size()); + for (int j = 0; j < on.out_edges_.size(); ++j) + cn.out_edges_[j] = on.out_edges_[j] + eoff; + } + + for (int i = 0; i < other.edges_.size(); ++i) { + const Edge& oe = other.edges_[i]; + Edge& ce = edges_[i + eoff]; + ce.id_ = i + eoff; + ce.rule_ = oe.rule_; + ce.feature_values_ = oe.feature_values_; + if (oe.head_node_ == ogoal) { + ce.head_node_ = cgoal; + nodes_[cgoal].in_edges_.push_back(ce.id_); + } else { + ce.head_node_ = oe.head_node_ + noff; + } + ce.tail_nodes_.resize(oe.tail_nodes_.size()); + for (int j = 0; j < oe.tail_nodes_.size(); ++j) + ce.tail_nodes_[j] = oe.tail_nodes_[j] + noff; + } + + TopologicallySortNodesAndEdges(cgoal); +} + +int Hypergraph::MarkReachable(const Node& node, + vector* rmap, + const vector* prune_edges) const { + int total = 0; + if (!(*rmap)[node.id_]) { + total = 1; + (*rmap)[node.id_] = true; + for (int i = 0; i < node.in_edges_.size(); ++i) { + if (!(prune_edges && (*prune_edges)[node.in_edges_[i]])) { + for (int j = 0; j < edges_[node.in_edges_[i]].tail_nodes_.size(); ++j) + total += MarkReachable(nodes_[edges_[node.in_edges_[i]].tail_nodes_[j]], rmap, prune_edges); + } + } + } + return total; +} + +void Hypergraph::PruneUnreachable(int goal_node_id) { + TopologicallySortNodesAndEdges(goal_node_id, NULL); +} + +void Hypergraph::RemoveNoncoaccessibleStates(int goal_node_id) { + if (goal_node_id < 0) goal_node_id += nodes_.size(); + assert(goal_node_id >= 0); + assert(goal_node_id < nodes_.size()); + + // TODO finish implementation + abort(); +} + +void Hypergraph::TopologicallySortNodesAndEdges(int goal_index, + const vector* prune_edges) { + vector sedges(edges_.size()); + // figure out which nodes are reachable from the goal + vector reachable(nodes_.size(), false); + int num_reachable = MarkReachable(nodes_[goal_index], &reachable, prune_edges); + vector snodes(num_reachable); snodes.clear(); + + // enumerate all reachable nodes in topologically sorted order + vector old_node_to_new_id(nodes_.size(), -1); + vector node_to_incount(nodes_.size(), -1); + vector node_processed(nodes_.size(), false); + typedef map > PQueue; + PQueue pri_q; + for (int i = 0; i < nodes_.size(); ++i) { + if (!reachable[i]) + continue; + const int inedges = nodes_[i].in_edges_.size(); + int incount = inedges; + for (int j = 0; j < inedges; ++j) + if (edges_[nodes_[i].in_edges_[j]].tail_nodes_.size() == 0 || + (prune_edges && (*prune_edges)[nodes_[i].in_edges_[j]])) + --incount; + // cerr << &nodes_[i] <<" : incount=" << incount << "\tout=" << nodes_[i].out_edges_.size() << "\t(in-edges=" << inedges << ")\n"; + assert(node_to_incount[i] == -1); + node_to_incount[i] = incount; + pri_q[incount].insert(i); + } + + int edge_count = 0; + while (!pri_q.empty()) { + PQueue::iterator iter = pri_q.find(0); + assert(iter != pri_q.end()); + assert(!iter->second.empty()); + + // get first node with incount = 0 + const int cur_index = *iter->second.begin(); + const Node& node = nodes_[cur_index]; + assert(reachable[cur_index]); + //cerr << "node: " << node << endl; + const int new_node_index = snodes.size(); + old_node_to_new_id[cur_index] = new_node_index; + snodes.push_back(node); + Node& new_node = snodes.back(); + new_node.id_ = new_node_index; + new_node.out_edges_.clear(); + + // fix up edges - we can now process the in edges and + // the out edges of their tails + int oi = 0; + for (int i = 0; i < node.in_edges_.size(); ++i, ++oi) { + if (prune_edges && (*prune_edges)[node.in_edges_[i]]) { + --oi; + continue; + } + new_node.in_edges_[oi] = edge_count; + Edge& edge = sedges[edge_count]; + edge.id_ = edge_count; + ++edge_count; + const Edge& old_edge = edges_[node.in_edges_[i]]; + edge.rule_ = old_edge.rule_; + edge.feature_values_ = old_edge.feature_values_; + edge.head_node_ = new_node_index; + edge.tail_nodes_.resize(old_edge.tail_nodes_.size()); + edge.edge_prob_ = old_edge.edge_prob_; + edge.i_ = old_edge.i_; + edge.j_ = old_edge.j_; + edge.prev_i_ = old_edge.prev_i_; + edge.prev_j_ = old_edge.prev_j_; + for (int j = 0; j < old_edge.tail_nodes_.size(); ++j) { + const Node& old_tail_node = nodes_[old_edge.tail_nodes_[j]]; + edge.tail_nodes_[j] = old_node_to_new_id[old_tail_node.id_]; + snodes[edge.tail_nodes_[j]].out_edges_.push_back(edge_count-1); + assert(edge.tail_nodes_[j] != new_node_index); + } + } + assert(oi <= new_node.in_edges_.size()); + new_node.in_edges_.resize(oi); + + for (int i = 0; i < node.out_edges_.size(); ++i) { + const Edge& edge = edges_[node.out_edges_[i]]; + const int next_index = edge.head_node_; + assert(cur_index != next_index); + if (!reachable[next_index]) continue; + if (prune_edges && (*prune_edges)[edge.id_]) continue; + + bool dontReduce = false; + for (int j = 0; j < edge.tail_nodes_.size() && !dontReduce; ++j) { + int tail_index = edge.tail_nodes_[j]; + dontReduce = (tail_index != cur_index) && !node_processed[tail_index]; + } + if (dontReduce) + continue; + + const int incount = node_to_incount[next_index]; + if (incount <= 0) { + cerr << "incount = " << incount << ", should be > 0!\n"; + cerr << "do you have a cycle in your hypergraph?\n"; + abort(); + } + PQueue::iterator it = pri_q.find(incount); + assert(it != pri_q.end()); + it->second.erase(next_index); + if (it->second.empty()) pri_q.erase(it); + + // reinsert node with reduced incount + pri_q[incount-1].insert(next_index); + --node_to_incount[next_index]; + } + + // remove node from set + iter->second.erase(cur_index); + if (iter->second.empty()) + pri_q.erase(iter); + node_processed[cur_index] = true; + } + + sedges.resize(edge_count); + nodes_.swap(snodes); + edges_.swap(sedges); + assert(nodes_.back().out_edges_.size() == 0); +} + +TRulePtr Hypergraph::kEPSRule; +TRulePtr Hypergraph::kUnaryRule; + +void Hypergraph::EpsilonRemove(WordID eps) { + if (!kEPSRule) { + kEPSRule.reset(new TRule("[X] ||| ||| ")); + kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); + } + vector kill(edges_.size(), false); + for (int i = 0; i < edges_.size(); ++i) { + const Edge& edge = edges_[i]; + if (edge.tail_nodes_.empty() && + edge.rule_->f_.size() == 1 && + edge.rule_->f_[0] == eps) { + kill[i] = true; + if (!edge.feature_values_.empty()) { + Node& node = nodes_[edge.head_node_]; + if (node.in_edges_.size() != 1) { + cerr << "[WARNING] edge with features going into non-empty node - can't promote\n"; + // this *probably* means that there are multiple derivations of the + // same sequence via different paths through the input forest + // this needs to be investigated and fixed + } else { + for (int j = 0; j < node.out_edges_.size(); ++j) + edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_; + // cerr << "PROMOTED " << edge.feature_values_ << endl; + } + } + } + } + bool created_eps = false; + PruneEdges(kill); + for (int i = 0; i < nodes_.size(); ++i) { + const Node& node = nodes_[i]; + if (node.in_edges_.empty()) { + for (int j = 0; j < node.out_edges_.size(); ++j) { + Edge& edge = edges_[node.out_edges_[j]]; + if (edge.rule_->Arity() == 2) { + assert(edge.rule_->f_.size() == 2); + assert(edge.rule_->e_.size() == 2); + edge.rule_ = kUnaryRule; + int cur = node.id_; + int t = -1; + assert(edge.tail_nodes_.size() == 2); + for (int i = 0; i < 2; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; } + assert(t != -1); + edge.tail_nodes_.resize(1); + edge.tail_nodes_[0] = t; + } else { + edge.rule_ = kEPSRule; + edge.rule_->f_[0] = eps; + edge.rule_->e_[0] = eps; + edge.tail_nodes_.clear(); + created_eps = true; + } + } + } + } + vector k2(edges_.size(), false); + PruneEdges(k2); + if (created_eps) EpsilonRemove(eps); +} + +struct EdgeWeightSorter { + const Hypergraph& hg; + EdgeWeightSorter(const Hypergraph& h) : hg(h) {} + bool operator()(int a, int b) const { + return hg.edges_[a].edge_prob_ > hg.edges_[b].edge_prob_; + } +}; + +void Hypergraph::SortInEdgesByEdgeWeights() { + for (int i = 0; i < nodes_.size(); ++i) { + Node& node = nodes_[i]; + sort(node.in_edges_.begin(), node.in_edges_.end(), EdgeWeightSorter(*this)); + } +} + diff --git a/src/hg.h b/src/hg.h new file mode 100644 index 00000000..7a2658b8 --- /dev/null +++ b/src/hg.h @@ -0,0 +1,225 @@ +#ifndef _HG_H_ +#define _HG_H_ + +#include +#include + +#include "small_vector.h" +#include "sparse_vector.h" +#include "wordid.h" +#include "trule.h" +#include "prob.h" + +// class representing an acyclic hypergraph +// - edges have 1 head, 0..n tails +class Hypergraph { + public: + Hypergraph() {} + + // SmallVector is a fast, small vector implementation for sizes <= 2 + typedef SmallVector TailNodeVector; + + // TODO get rid of state_ and cat_? + struct Node { + Node() : id_(), cat_() {} + int id_; // equal to this object's position in the nodes_ vector + WordID cat_; // non-terminal category if <0, 0 if not set + std::vector in_edges_; // contents refer to positions in edges_ + std::vector out_edges_; // contents refer to positions in edges_ + std::string state_; // opaque state + }; + + // TODO get rid of edge_prob_? (can be computed on the fly as the dot + // product of the weight vector and the feature values) + struct Edge { + Edge() : i_(-1), j_(-1), prev_i_(-1), prev_j_(-1) {} + inline int Arity() const { return tail_nodes_.size(); } + int head_node_; // refers to a position in nodes_ + TailNodeVector tail_nodes_; // contents refer to positions in nodes_ + TRulePtr rule_; + SparseVector feature_values_; + prob_t edge_prob_; // dot product of weights and feat_values + int id_; // equal to this object's position in the edges_ vector + + // span info. typically, i_ and j_ refer to indices in the source sentence + // if a synchronous parse has been executed i_ and j_ will refer to indices + // in the target sentence / lattice and prev_i_ prev_j_ will refer to + // positions in the source. Note: it is up to the translator implementation + // to properly set these values. For some models (like the Forest-input + // phrase based model) it may not be straightforward to do. if these values + // are not properly set, most things will work but alignment and any features + // that depend on them will be broken. + short int i_; + short int j_; + short int prev_i_; + short int prev_j_; + }; + + void swap(Hypergraph& other) { + other.nodes_.swap(nodes_); + other.edges_.swap(edges_); + } + + void ResizeNodes(int size) { + nodes_.resize(size); + for (int i = 0; i < size; ++i) nodes_[i].id_ = i; + } + + // reserves space in the nodes vector to prevent memory locations + // from changing + void ReserveNodes(size_t n, size_t e = 0) { + nodes_.reserve(n); + if (e) edges_.reserve(e); + } + + Edge* AddEdge(const TRulePtr& rule, const TailNodeVector& tail) { + edges_.push_back(Edge()); + Edge* edge = &edges_.back(); + edge->rule_ = rule; + edge->tail_nodes_ = tail; + edge->id_ = edges_.size() - 1; + for (int i = 0; i < edge->tail_nodes_.size(); ++i) + nodes_[edge->tail_nodes_[i]].out_edges_.push_back(edge->id_); + return edge; + } + + Node* AddNode(const WordID& cat, const std::string& state = "") { + nodes_.push_back(Node()); + nodes_.back().cat_ = cat; + nodes_.back().state_ = state; + nodes_.back().id_ = nodes_.size() - 1; + return &nodes_.back(); + } + + void ConnectEdgeToHeadNode(const int edge_id, const int head_id) { + edges_[edge_id].head_node_ = head_id; + nodes_[head_id].in_edges_.push_back(edge_id); + } + + // TODO remove this - use the version that takes indices + void ConnectEdgeToHeadNode(Edge* edge, Node* head) { + edge->head_node_ = head->id_; + head->in_edges_.push_back(edge->id_); + } + + // merge the goal node from other with this goal node + void Union(const Hypergraph& other); + + void PrintGraphviz() const; + + // compute the total number of paths in the forest + double NumberOfPaths() const; + + // BEWARE. this assumes that the source and target language + // strings are identical and that there are no loops. + // It assumes a bunch of other things about where the + // epsilons will be. It tries to assert failure if you + // break these assumptions, but it may not. + // TODO - make this work + void EpsilonRemove(WordID eps); + + // multiple the weights vector by the edge feature vector + // (inner product) to set the edge probabilities + template + void Reweight(const V& weights) { + for (int i = 0; i < edges_.size(); ++i) { + Edge& e = edges_[i]; + e.edge_prob_.logeq(e.feature_values_.dot(weights)); + } + } + + // computes inside and outside scores for each + // edge in the hypergraph + // alpha->size = edges_.size = beta->size + // returns inside prob of goal node + prob_t ComputeEdgePosteriors(double scale, + std::vector* posts) const; + + // find the score of the very best path passing through each edge + prob_t ComputeBestPathThroughEdges(std::vector* posts) const; + + // move weights as near to the source as possible, resulting in a + // stochastic automaton. ONLY FUNCTIONAL FOR *LATTICES*. + // See M. Mohri and M. Riley. A Weight Pushing Algorithm for Large + // Vocabulary Speech Recognition. 2001. + // the log semiring (NOT tropical) is used + void PushWeightsToSource(double scale = 1.0); + // same, except weights are pushed to the goal, works for HGs, + // not just lattices + void PushWeightsToGoal(double scale = 1.0); + + void SortInEdgesByEdgeWeights(); + + void PruneUnreachable(int goal_node_id); // DEPRECATED + + void RemoveNoncoaccessibleStates(int goal_node_id = -1); + + // remove edges from the hypergraph if prune_edge[edge_id] is true + void PruneEdges(const std::vector& prune_edge); + + // if you don't know, use_sum_prod_semiring should be false + void DensityPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double density, + const std::vector* preserve_mask = NULL); + + // prunes any edge whose score on the best path taking that edge is more than alpha away + // from the score of the global best past (or the highest edge posterior) + void BeamPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double alpha, + const std::vector* preserve_mask = NULL); + + void clear() { + nodes_.clear(); + edges_.clear(); + } + + inline size_t NumberOfEdges() const { return edges_.size(); } + inline size_t NumberOfNodes() const { return nodes_.size(); } + inline bool empty() const { return nodes_.empty(); } + + // nodes_ is sorted in topological order + std::vector nodes_; + // edges_ is not guaranteed to be in any particular order + std::vector edges_; + + // reorder nodes_ so they are in topological order + // source nodes at 0 sink nodes at size-1 + void TopologicallySortNodesAndEdges(int goal_idx, + const std::vector* prune_edges = NULL); + private: + // returns total nodes reachable + int MarkReachable(const Node& node, + std::vector* rmap, + const std::vector* prune_edges) const; + + static TRulePtr kEPSRule; + static TRulePtr kUnaryRule; +}; + +// common WeightFunctions, map an edge -> WeightType +// for generic Viterbi/Inside algorithms +struct EdgeProb { + inline const prob_t& operator()(const Hypergraph::Edge& e) const { return e.edge_prob_; } +}; + +struct ScaledEdgeProb { + ScaledEdgeProb(const double& alpha) : alpha_(alpha) {} + inline prob_t operator()(const Hypergraph::Edge& e) const { return e.edge_prob_.pow(alpha_); } + const double alpha_; +}; + +struct EdgeFeaturesWeightFunction { + inline const SparseVector& operator()(const Hypergraph::Edge& e) const { return e.feature_values_; } +}; + +struct TransitionEventWeightFunction { + inline SparseVector operator()(const Hypergraph::Edge& e) const { + SparseVector result; + result.set_value(e.id_, prob_t::One()); + return result; + } +}; + +struct TransitionCountWeightFunction { + inline double operator()(const Hypergraph::Edge& e) const { (void)e; return 1.0; } +}; + +#endif diff --git a/src/hg_intersect.cc b/src/hg_intersect.cc new file mode 100644 index 00000000..a5e8913a --- /dev/null +++ b/src/hg_intersect.cc @@ -0,0 +1,121 @@ +#include "hg_intersect.h" + +#include +#include +#include +#include + +#include "tdict.h" +#include "hg.h" +#include "trule.h" +#include "wordid.h" +#include "bottom_up_parser.h" + +using boost::lexical_cast; +using namespace std::tr1; +using namespace std; + +struct RuleFilter { + unordered_map, bool, boost::hash > > exists_; + bool true_lattice; + RuleFilter(const Lattice& target, int max_phrase_size) { + true_lattice = false; + for (int i = 0; i < target.size(); ++i) { + vector phrase; + int lim = min(static_cast(target.size()), i + max_phrase_size); + for (int j = i; j < lim; ++j) { + if (target[j].size() > 1) { true_lattice = true; break; } + phrase.push_back(target[j][0].label); + exists_[phrase] = true; + } + } + vector sos(1, TD::Convert("")); + exists_[sos] = true; + } + bool operator()(const TRule& r) const { + // TODO do some smarter filtering for lattices + if (true_lattice) return false; // don't filter "true lattice" input + const vector& e = r.e(); + for (int i = 0; i < e.size(); ++i) { + if (e[i] <= 0) continue; + vector phrase; + for (int j = i; j < e.size(); ++j) { + if (e[j] <= 0) break; + phrase.push_back(e[j]); + if (exists_.count(phrase) == 0) return true; + } + } + return false; + } +}; + +bool HG::Intersect(const Lattice& target, Hypergraph* hg) { + vector rem(hg->edges_.size(), false); + const RuleFilter filter(target, 15); // TODO make configurable + for (int i = 0; i < rem.size(); ++i) + rem[i] = filter(*hg->edges_[i].rule_); + hg->PruneEdges(rem); + + const int nedges = hg->edges_.size(); + const int nnodes = hg->nodes_.size(); + + TextGrammar* g = new TextGrammar; + GrammarPtr gp(g); + vector cats(nnodes); + // each node in the translation forest becomes a "non-terminal" in the new + // grammar, create the labels here + for (int i = 0; i < nnodes; ++i) + cats[i] = TD::Convert("CAT_" + lexical_cast(i)) * -1; + + // construct the grammar + for (int i = 0; i < nedges; ++i) { + const Hypergraph::Edge& edge = hg->edges_[i]; + const vector& tgt = edge.rule_->e(); + const vector& src = edge.rule_->f(); + TRulePtr rule(new TRule); + rule->prev_i = edge.i_; + rule->prev_j = edge.j_; + rule->lhs_ = cats[edge.head_node_]; + vector& f = rule->f_; + vector& e = rule->e_; + f.resize(tgt.size()); // swap source and target, since the parser + e.resize(src.size()); // parses using the source side! + Hypergraph::TailNodeVector tn(edge.tail_nodes_.size()); + int ntc = 0; + for (int j = 0; j < tgt.size(); ++j) { + const WordID& cur = tgt[j]; + if (cur > 0) { + f[j] = cur; + } else { + tn[ntc++] = cur; + f[j] = cats[edge.tail_nodes_[-cur]]; + } + } + ntc = 0; + for (int j = 0; j < src.size(); ++j) { + const WordID& cur = src[j]; + if (cur > 0) { + e[j] = cur; + } else { + e[j] = tn[ntc++]; + } + } + rule->scores_ = edge.feature_values_; + rule->parent_rule_ = edge.rule_; + rule->ComputeArity(); + //cerr << "ADD: " << rule->AsString() << endl; + + g->AddRule(rule); + } + g->SetMaxSpan(target.size() + 1); + const string& new_goal = TD::Convert(cats.back() * -1); + vector grammars(1, gp); + Hypergraph tforest; + ExhaustiveBottomUpParser parser(new_goal, grammars); + if (!parser.Parse(target, &tforest)) + return false; + else + hg->swap(tforest); + return true; +} + diff --git a/src/hg_intersect.h b/src/hg_intersect.h new file mode 100644 index 00000000..826bdaae --- /dev/null +++ b/src/hg_intersect.h @@ -0,0 +1,13 @@ +#ifndef _HG_INTERSECT_H_ +#define _HG_INTERSECT_H_ + +#include + +#include "lattice.h" + +class Hypergraph; +struct HG { + static bool Intersect(const Lattice& target, Hypergraph* hg); +}; + +#endif diff --git a/src/hg_io.cc b/src/hg_io.cc new file mode 100644 index 00000000..629e65f1 --- /dev/null +++ b/src/hg_io.cc @@ -0,0 +1,585 @@ +#include "hg_io.h" + +#include +#include + +#include + +#include "tdict.h" +#include "json_parse.h" +#include "hg.h" + +using namespace std; + +struct HGReader : public JSONParser { + HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; } + + void CreateNode(const string& cat, const vector& in_edges) { + WordID c = TD::Convert("X") * -1; + if (!cat.empty()) c = TD::Convert(cat) * -1; + Hypergraph::Node* node = hg.AddNode(c, ""); + for (int i = 0; i < in_edges.size(); ++i) { + if (in_edges[i] >= hg.edges_.size()) { + cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i] + << ", but hg only has " << hg.edges_.size() << " edges!\n"; + abort(); + } + hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node); + } + } + void CreateEdge(const TRulePtr& rule, SparseVector* feats, const SmallVector& tail) { + Hypergraph::Edge* edge = hg.AddEdge(rule, tail); + feats->swap(edge->feature_values_); + } + + bool HandleJSONEvent(int type, const JSON_value* value) { + switch(state) { + case -1: + assert(type == JSON_T_OBJECT_BEGIN); + state = 0; + break; + case 0: + if (type == JSON_T_OBJECT_END) { + //cerr << "HG created\n"; // TODO, signal some kind of callback + } else if (type == JSON_T_KEY) { + string val = value->vu.str.value; + if (val == "features") { assert(fdict.empty()); state = 1; } + else if (val == "is_sorted") { state = 3; } + else if (val == "rules") { assert(rules.empty()); state = 4; } + else if (val == "node") { state = 8; } + else if (val == "edges") { state = 13; } + else { cerr << "Unexpected key: " << val << endl; return false; } + } + break; + + // features + case 1: + if(type == JSON_T_NULL) { state = 0; break; } + assert(type == JSON_T_ARRAY_BEGIN); + state = 2; + break; + case 2: + if(type == JSON_T_ARRAY_END) { state = 0; break; } + assert(type == JSON_T_STRING); + fdict.push_back(FD::Convert(value->vu.str.value)); + break; + + // is_sorted + case 3: + assert(type == JSON_T_TRUE || type == JSON_T_FALSE); + is_sorted = (type == JSON_T_TRUE); + if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; } + state = 0; + break; + + // rules + case 4: + if(type == JSON_T_NULL) { state = 0; break; } + assert(type == JSON_T_ARRAY_BEGIN); + state = 5; + break; + case 5: + if(type == JSON_T_ARRAY_END) { state = 0; break; } + assert(type == JSON_T_INTEGER); + state = 6; + rule_id = value->vu.integer_value; + break; + case 6: + assert(type == JSON_T_STRING); + rules[rule_id] = TRulePtr(new TRule(value->vu.str.value)); + state = 5; + break; + + // Nodes + case 8: + assert(type == JSON_T_OBJECT_BEGIN); + ++nodes; + in_edges.clear(); + cat.clear(); + state = 9; break; + case 9: + if (type == JSON_T_OBJECT_END) { + //cerr << "Creating NODE\n"; + CreateNode(cat, in_edges); + state = 0; break; + } + assert(type == JSON_T_KEY); + cur_key = value->vu.str.value; + if (cur_key == "cat") { assert(cat.empty()); state = 10; break; } + if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; } + cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n"; + return false; + case 10: + assert(type == JSON_T_STRING || type == JSON_T_NULL); + cat = value->vu.str.value; + state = 9; break; + case 11: + if (type == JSON_T_NULL) { state = 9; break; } + assert(type == JSON_T_ARRAY_BEGIN); + state = 12; break; + case 12: + if (type == JSON_T_ARRAY_END) { state = 9; break; } + assert(type == JSON_T_INTEGER); + //cerr << "in_edges: " << value->vu.integer_value << endl; + in_edges.push_back(value->vu.integer_value); + break; + + // "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12}, + // { "tail": null, "feats" : [0,0.87,1,0.02], "rule": 17}, + // { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}] + case 13: + assert(type == JSON_T_ARRAY_BEGIN); + state = 14; + break; + case 14: + if (type == JSON_T_ARRAY_END) { state = 0; break; } + assert(type == JSON_T_OBJECT_BEGIN); + //cerr << "New edge\n"; + ++edges; + cur_rule.reset(); feats.clear(); tail.clear(); + state = 15; break; + case 15: + if (type == JSON_T_OBJECT_END) { + CreateEdge(cur_rule, &feats, tail); + state = 14; break; + } + assert(type == JSON_T_KEY); + cur_key = value->vu.str.value; + //cerr << "edge key " << cur_key << endl; + if (cur_key == "rule") { assert(!cur_rule); state = 16; break; } + if (cur_key == "feats") { assert(feats.empty()); state = 17; break; } + if (cur_key == "tail") { assert(tail.empty()); state = 20; break; } + cerr << "Unexpected key " << cur_key << " in edge specification\n"; + return false; + case 16: // edge.rule + if (type == JSON_T_INTEGER) { + int rule_id = value->vu.integer_value; + if (rules.find(rule_id) == rules.end()) { + // rules list must come before the edge definitions! + cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n"; + return false; + } + cur_rule = rules[rule_id]; + } else if (type == JSON_T_STRING) { + cur_rule.reset(new TRule(value->vu.str.value)); + } else { + cerr << "Rule must be either a rule id or a rule string" << endl; + return false; + } + // cerr << "Edge: rule=" << cur_rule->AsString() << endl; + state = 15; + break; + case 17: // edge.feats + if (type == JSON_T_NULL) { state = 15; break; } + assert(type == JSON_T_ARRAY_BEGIN); + state = 18; break; + case 18: + if (type == JSON_T_ARRAY_END) { state = 15; break; } + if (type != JSON_T_INTEGER && type != JSON_T_STRING) { + cerr << "Unexpected feature id type\n"; return false; + } + if (type == JSON_T_INTEGER) { + fid = value->vu.integer_value; + assert(fid < fdict.size()); + fid = fdict[fid]; + } else if (JSON_T_STRING) { + fid = FD::Convert(value->vu.str.value); + } else { abort(); } + state = 19; + break; + case 19: + { + assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT); + double val = (type == JSON_T_INTEGER ? static_cast(value->vu.integer_value) : + strtod(value->vu.str.value, NULL)); + feats.set_value(fid, val); + state = 18; + break; + } + case 20: // edge.tail + if (type == JSON_T_NULL) { state = 15; break; } + assert(type == JSON_T_ARRAY_BEGIN); + state = 21; break; + case 21: + if (type == JSON_T_ARRAY_END) { state = 15; break; } + assert(type == JSON_T_INTEGER); + tail.push_back(value->vu.integer_value); + break; + } + return true; + } + string rp; + string cat; + SmallVector tail; + vector in_edges; + TRulePtr cur_rule; + map rules; + vector fdict; + SparseVector feats; + int state; + int fid; + int nodes; + int edges; + string cur_key; + Hypergraph& hg; + int rule_id; + bool nodes_needed; + bool edges_needed; + bool is_sorted; +}; + +bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) { + hg->clear(); + HGReader reader(hg); + return reader.Parse(in); +} + +static void WriteRule(const TRule& r, ostream* out) { + if (!r.lhs_) { (*out) << "[X] ||| "; } + JSONParser::WriteEscapedString(r.AsString(), out); +} + +bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) { + map rid; + ostream& o = *out; + rid[NULL] = 0; + o << '{'; + if (!remove_rules) { + o << "\"rules\":["; + for (int i = 0; i < hg.edges_.size(); ++i) { + const TRule* r = hg.edges_[i].rule_.get(); + int &id = rid[r]; + if (!id) { + id=rid.size() - 1; + if (id > 1) o << ','; + o << id << ','; + WriteRule(*r, &o); + }; + } + o << "],"; + } + const bool use_fdict = FD::NumFeats() < 1000; + if (use_fdict) { + o << "\"features\":["; + for (int i = 1; i < FD::NumFeats(); ++i) { + o << (i==1 ? "":",") << '"' << FD::Convert(i) << '"'; + } + o << "],"; + } + vector edgemap(hg.edges_.size(), -1); // edges may be in non-topo order + int edge_count = 0; + for (int i = 0; i < hg.nodes_.size(); ++i) { + const Hypergraph::Node& node = hg.nodes_[i]; + if (i > 0) { o << ","; } + o << "\"edges\":["; + for (int j = 0; j < node.in_edges_.size(); ++j) { + const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; + edgemap[edge.id_] = edge_count; + ++edge_count; + o << (j == 0 ? "" : ",") << "{"; + + o << "\"tail\":["; + for (int k = 0; k < edge.tail_nodes_.size(); ++k) { + o << (k > 0 ? "," : "") << edge.tail_nodes_[k]; + } + o << "],"; + + o << "\"feats\":["; + bool first = true; + for (SparseVector::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) { + if (!it->second) continue; + if (!first) o << ','; + if (use_fdict) + o << (it->first - 1); + else + o << '"' << FD::Convert(it->first) << '"'; + o << ',' << it->second; + first = false; + } + o << "]"; + if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; } + o << "}"; + } + o << "],"; + + o << "\"node\":{\"in_edges\":["; + for (int j = 0; j < node.in_edges_.size(); ++j) { + int mapped_edge = edgemap[node.in_edges_[j]]; + assert(mapped_edge >= 0); + o << (j == 0 ? "" : ",") << mapped_edge; + } + o << "]"; + if (node.cat_ < 0) { o << ",\"cat\":\"" << TD::Convert(node.cat_ * -1) << '"'; } + o << "}"; + } + o << "}\n"; + return true; +} + +bool needs_escape[128]; +void InitEscapes() { + memset(needs_escape, false, 128); + needs_escape[static_cast('\'')] = true; + needs_escape[static_cast('\\')] = true; +} + +string HypergraphIO::Escape(const string& s) { + size_t len = s.size(); + for (int i = 0; i < s.size(); ++i) { + unsigned char c = s[i]; + if (c < 128 && needs_escape[c]) ++len; + } + if (len == s.size()) return s; + string res(len, ' '); + size_t o = 0; + for (int i = 0; i < s.size(); ++i) { + unsigned char c = s[i]; + if (c < 128 && needs_escape[c]) + res[o++] = '\\'; + res[o++] = c; + } + assert(o == len); + return res; +} + +string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses) { + static bool first = true; + if (first) { InitEscapes(); first = false; } + if (hg.nodes_.empty()) return "()"; + ostringstream os; + if (include_global_parentheses) os << '('; + static const string EPS="*EPS*"; + for (int i = 0; i < hg.nodes_.size()-1; ++i) { + os << '('; + if (hg.nodes_[i].out_edges_.empty()) abort(); + for (int j = 0; j < hg.nodes_[i].out_edges_.size(); ++j) { + const Hypergraph::Edge& e = hg.edges_[hg.nodes_[i].out_edges_[j]]; + const string output = e.rule_->e_.size() ==2 ? Escape(TD::Convert(e.rule_->e_[1])) : EPS; + double prob = log(e.edge_prob_); + if (isinf(prob)) { prob = -9e20; } + if (isnan(prob)) { prob = 0; } + os << "('" << output << "'," << prob << "," << e.head_node_ - i << "),"; + } + os << "),"; + } + if (include_global_parentheses) os << ')'; + return os.str(); +} + +namespace PLF { + +const string chars = "'\\"; +const char& quote = chars[0]; +const char& slash = chars[1]; + +// safe get +inline char get(const std::string& in, int c) { + if (c < 0 || c >= (int)in.size()) return 0; + else return in[(size_t)c]; +} + +// consume whitespace +inline void eatws(const std::string& in, int& c) { + while (get(in,c) == ' ') { c++; } +} + +// from 'foo' return foo +std::string getEscapedString(const std::string& in, int &c) +{ + eatws(in,c); + if (get(in,c++) != quote) return "ERROR"; + std::string res; + char cur = 0; + do { + cur = get(in,c++); + if (cur == slash) { res += get(in,c++); } + else if (cur != quote) { res += cur; } + } while (get(in,c) != quote && (c < (int)in.size())); + c++; + eatws(in,c); + return res; +} + +// basically atof +float getFloat(const std::string& in, int &c) +{ + std::string tmp; + eatws(in,c); + while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') { + tmp += get(in,c++); + } + eatws(in,c); + return atof(tmp.c_str()); +} + +// basically atoi +int getInt(const std::string& in, int &c) +{ + std::string tmp; + eatws(in,c); + while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') { + tmp += get(in,c++); + } + eatws(in,c); + return atoi(tmp.c_str()); +} + +// maximum number of nodes permitted +#define MAX_NODES 100000000 +// parse ('foo', 0.23) +void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { + if (get(in,c++) != '(') { assert(!"PCN/PLF parse error: expected ( at start of cn alt block\n"); } + vector ewords(2, 0); + ewords[1] = TD::Convert(getEscapedString(in,c)); + TRulePtr r(new TRule(ewords)); + //cerr << "RULE: " << r->AsString() << endl; + if (get(in,c++) != ',') { assert(!"PCN/PLF parse error: expected , after string\n"); } + size_t cnNext = 1; + std::vector probs; + probs.push_back(getFloat(in,c)); + while (get(in,c) == ',') { + c++; + float val = getFloat(in,c); + probs.push_back(val); + } + //if we read more than one prob, this was a lattice, last item was column increment + if (probs.size()>1) { + cnNext = static_cast(probs.back()); + probs.pop_back(); + if (cnNext < 1) { assert(!"PCN/PLF parse error: bad link length at last element of cn alt block\n"); } + } + if (get(in,c++) != ')') { assert(!"PCN/PLF parse error: expected ) at end of cn alt block\n"); } + eatws(in,c); + Hypergraph::TailNodeVector tail(1, cur_node); + Hypergraph::Edge* edge = hg->AddEdge(r, tail); + //cerr << " <--" << cur_node << endl; + int head_node = cur_node + cnNext; + assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory + if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); } + hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]); + for (int i = 0; i < probs.size(); ++i) + edge->feature_values_.set_value(FD::Convert("Feature_" + boost::lexical_cast(i)), probs[i]); +} + +// parse (('foo', 0.23), ('bar', 0.77)) +void ReadPLFNode(const std::string& in, int &c, int cur_node, int line, Hypergraph* hg) { + //cerr << "PLF READING NODE " << cur_node << endl; + if (hg->nodes_.size() < (cur_node + 1)) { hg->ResizeNodes(cur_node + 1); } + if (get(in,c++) != '(') { cerr << line << ": Syntax error 1\n"; abort(); } + eatws(in,c); + while (1) { + if (c > (int)in.size()) { break; } + if (get(in,c) == ')') { + c++; + eatws(in,c); + break; + } + if (get(in,c) == ',' && get(in,c+1) == ')') { + c+=2; + eatws(in,c); + break; + } + if (get(in,c) == ',') { c++; eatws(in,c); } + ReadPLFEdge(in, c, cur_node, hg); + } +} + +} // namespace PLF + +void HypergraphIO::ReadFromPLF(const std::string& in, Hypergraph* hg, int line) { + hg->clear(); + int c = 0; + int cur_node = 0; + if (in[c++] != '(') { cerr << line << ": Syntax error!\n"; abort(); } + while (1) { + if (c > (int)in.size()) { break; } + if (PLF::get(in,c) == ')') { + c++; + PLF::eatws(in,c); + break; + } + if (PLF::get(in,c) == ',' && PLF::get(in,c+1) == ')') { + c+=2; + PLF::eatws(in,c); + break; + } + if (PLF::get(in,c) == ',') { c++; PLF::eatws(in,c); } + PLF::ReadPLFNode(in, c, cur_node, line, hg); + ++cur_node; + } + assert(cur_node == hg->nodes_.size() - 1); +} + +void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) { + Lattice& l = *pl; + Hypergraph g; + ReadFromPLF(plf, &g, 0); + const int num_nodes = g.nodes_.size() - 1; + l.resize(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + vector& alts = l[i]; + const Hypergraph::Node& node = g.nodes_[i]; + const int num_alts = node.out_edges_.size(); + alts.resize(num_alts); + for (int j = 0; j < num_alts; ++j) { + const Hypergraph::Edge& edge = g.edges_[node.out_edges_[j]]; + alts[j].label = edge.rule_->e_[1]; + alts[j].cost = edge.feature_values_.value(FD::Convert("Feature_0")); + alts[j].dist2next = edge.head_node_ - node.id_; + } + } +} + +namespace B64 { + +static const char cb64[]="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +static const char cd64[]="|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`abcdefghijklmnopq"; + +static void encodeblock(const unsigned char* in, ostream* os, int len) { + char out[4]; + out[0] = cb64[ in[0] >> 2 ]; + out[1] = cb64[ ((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4) ]; + out[2] = (len > 1 ? cb64[ ((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6) ] : '='); + out[3] = (len > 2 ? cb64[ in[2] & 0x3f ] : '='); + os->write(out, 4); +} + +void b64encode(const char* data, const size_t size, ostream* out) { + size_t cur = 0; + while(cur < size) { + int len = min(static_cast(3), size - cur); + encodeblock(reinterpret_cast(&data[cur]), out, len); + cur += len; + } +} + +static void decodeblock(const unsigned char* in, unsigned char* out) { + out[0] = (unsigned char ) (in[0] << 2 | in[1] >> 4); + out[1] = (unsigned char ) (in[1] << 4 | in[2] >> 2); + out[2] = (unsigned char ) (((in[2] << 6) & 0xc0) | in[3]); +} + +bool b64decode(const unsigned char* data, const size_t insize, char* out, const size_t outsize) { + size_t cur = 0; + size_t ocur = 0; + unsigned char in[4]; + while(cur < insize) { + assert(ocur < outsize); + for (int i = 0; i < 4; ++i) { + unsigned char v = data[cur]; + v = (unsigned char) ((v < 43 || v > 122) ? '\0' : cd64[ v - 43 ]); + if (!v) { + cerr << "B64 decode error at offset " << cur << " offending character: " << (int)data[cur] << endl; + return false; + } + v = (unsigned char) ((v == '$') ? '\0' : v - 61); + if (v) in[i] = v - 1; else in[i] = 0; + ++cur; + } + decodeblock(in, reinterpret_cast(&out[ocur])); + ocur += 3; + } + return true; +} +} + diff --git a/src/hg_io.h b/src/hg_io.h new file mode 100644 index 00000000..69a516c1 --- /dev/null +++ b/src/hg_io.h @@ -0,0 +1,37 @@ +#ifndef _HG_IO_H_ +#define _HG_IO_H_ + +#include + +#include "lattice.h" +class Hypergraph; + +struct HypergraphIO { + + // the format is basically a list of nodes and edges in topological order + // any edge you read, you must have already read its tail nodes + // any node you read, you must have already read its incoming edges + // this may make writing a bit more challenging if your forest is not + // topologically sorted (but that probably doesn't happen very often), + // but it makes reading much more memory efficient. + // see test_data/small.json.gz for an email encoding + static bool ReadFromJSON(std::istream* in, Hypergraph* out); + + // if remove_rules is used, the hypergraph is serialized without rule information + // (so it only contains structure and feature information) + static bool WriteToJSON(const Hypergraph& hg, bool remove_rules, std::ostream* out); + + // serialization utils + static void ReadFromPLF(const std::string& in, Hypergraph* out, int line = 0); + // return PLF string representation (undefined behavior on non-lattices) + static std::string AsPLF(const Hypergraph& hg, bool include_global_parentheses = true); + static void PLFtoLattice(const std::string& plf, Lattice* pl); + static std::string Escape(const std::string& s); // PLF helper +}; + +namespace B64 { + bool b64decode(const unsigned char* data, const size_t insize, char* out, const size_t outsize); + void b64encode(const char* data, const size_t size, std::ostream* out); +} + +#endif diff --git a/src/hg_test.cc b/src/hg_test.cc new file mode 100644 index 00000000..ecd97508 --- /dev/null +++ b/src/hg_test.cc @@ -0,0 +1,441 @@ +#include +#include +#include +#include +#include +#include "tdict.h" + +#include "json_parse.h" +#include "filelib.h" +#include "hg.h" +#include "hg_io.h" +#include "hg_intersect.h" +#include "viterbi.h" +#include "kbest.h" +#include "inside_outside.h" + +using namespace std; + +class HGTest : public testing::Test { + protected: + virtual void SetUp() { } + virtual void TearDown() { } + void CreateHG(Hypergraph* hg) const; + void CreateHG_int(Hypergraph* hg) const; + void CreateHG_tiny(Hypergraph* hg) const; + void CreateHGBalanced(Hypergraph* hg) const; + void CreateLatticeHG(Hypergraph* hg) const; + void CreateTinyLatticeHG(Hypergraph* hg) const; +}; + +void HGTest::CreateTinyLatticeHG(Hypergraph* hg) const { + const string json = "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] b\",4,\"[X] ||| [1] B'\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.2],\"rule\":1},{\"tail\":[0],\"feats\":[0,-0.6],\"rule\":2}],\"node\":{\"in_edges\":[0,1]},\"edges\":[{\"tail\":[1],\"feats\":[0,-0.1],\"rule\":3},{\"tail\":[1],\"feats\":[0,-0.9],\"rule\":4}],\"node\":{\"in_edges\":[2,3]}}"; + istringstream instr(json); + EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg)); +} + +void HGTest::CreateLatticeHG(Hypergraph* hg) const { + const string json = "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] A A\",4,\"[X] ||| [1] b\",5,\"[X] ||| [1] c\",6,\"[X] ||| [1] B C\",7,\"[X] ||| [1] A B C\",8,\"[X] ||| [1] CC\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[2,-0.3],\"rule\":1},{\"tail\":[0],\"feats\":[2,-0.6],\"rule\":2},{\"tail\":[0],\"feats\":[2,-1.7],\"rule\":3}],\"node\":{\"in_edges\":[0,1,2]},\"edges\":[{\"tail\":[1],\"feats\":[2,-0.5],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[2],\"feats\":[2,-0.6],\"rule\":5},{\"tail\":[1],\"feats\":[2,-0.8],\"rule\":6},{\"tail\":[0],\"feats\":[2,-0.01],\"rule\":7},{\"tail\":[2],\"feats\":[2,-0.8],\"rule\":8}],\"node\":{\"in_edges\":[4,5,6,7]}}"; + istringstream instr(json); + EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg)); +} + +void HGTest::CreateHG_tiny(Hypergraph* hg) const { + const string json = "{\"rules\":[1,\"[X] ||| \",2,\"[X] ||| X [1]\",3,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,-2,1,-99],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.5,1,-0.8],\"rule\":2},{\"tail\":[0],\"feats\":[0,-0.7,1,-0.9],\"rule\":3}],\"node\":{\"in_edges\":[1,2]}}"; + istringstream instr(json); + EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg)); +} + +void HGTest::CreateHG_int(Hypergraph* hg) const { + const string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| b\",3,\"[X] ||| a [1]\",4,\"[X] ||| [1] b\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,0.1],\"rule\":1},{\"tail\":[],\"feats\":[0,0.1],\"rule\":2}],\"node\":{\"in_edges\":[0,1],\"cat\":\"X\"},\"edges\":[{\"tail\":[0],\"feats\":[0,0.3],\"rule\":3},{\"tail\":[0],\"feats\":[0,0.2],\"rule\":4}],\"node\":{\"in_edges\":[2,3],\"cat\":\"Goal\"}}"; + istringstream instr(json); + EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg)); +} + +void HGTest::CreateHG(Hypergraph* hg) const { + string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}"; + istringstream instr(json); + EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg)); +} + +void HGTest::CreateHGBalanced(Hypergraph* hg) const { + const string json = "{\"rules\":[1,\"[X] ||| i\",2,\"[X] ||| a\",3,\"[X] ||| b\",4,\"[X] ||| [1] [2]\",5,\"[X] ||| [1] [2]\",6,\"[X] ||| c\",7,\"[X] ||| d\",8,\"[X] ||| [1] [2]\",9,\"[X] ||| [1] [2]\",10,\"[X] ||| [1] [2]\",11,\"[X] ||| [1] [2]\",12,\"[X] ||| [1] [2]\",13,\"[X] ||| [1] [2]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[1,2],\"feats\":[],\"rule\":4},{\"tail\":[2,1],\"feats\":[],\"rule\":5}],\"node\":{\"in_edges\":[3,4]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":6}],\"node\":{\"in_edges\":[5]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":7}],\"node\":{\"in_edges\":[6]},\"edges\":[{\"tail\":[4,5],\"feats\":[],\"rule\":8},{\"tail\":[5,4],\"feats\":[],\"rule\":9}],\"node\":{\"in_edges\":[7,8]},\"edges\":[{\"tail\":[3,6],\"feats\":[],\"rule\":10},{\"tail\":[6,3],\"feats\":[],\"rule\":11}],\"node\":{\"in_edges\":[9,10]},\"edges\":[{\"tail\":[7,0],\"feats\":[],\"rule\":12},{\"tail\":[0,7],\"feats\":[],\"rule\":13}],\"node\":{\"in_edges\":[11,12]}}"; + istringstream instr(json); + EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg)); +} + +TEST_F(HGTest,Controlled) { + Hypergraph hg; + CreateHG_tiny(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 0.8); + hg.Reweight(wts); + vector trans; + prob_t prob = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "prob: " << prob << "\n"; + EXPECT_FLOAT_EQ(-80.839996, log(prob)); + EXPECT_EQ("X ", TD::GetString(trans)); + vector post; + hg.PrintGraphviz(); + prob_t c2 = Inside(hg, NULL, ScaledEdgeProb(0.6)); + EXPECT_FLOAT_EQ(-47.8577, log(c2)); +} + +TEST_F(HGTest,Union) { + Hypergraph hg1; + Hypergraph hg2; + CreateHG_tiny(&hg1); + CreateHG(&hg2); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 1.0); + hg1.Reweight(wts); + hg2.Reweight(wts); + prob_t c1,c2,c3,c4; + vector t1,t2,t3,t4; + c1 = ViterbiESentence(hg1, &t1); + c2 = ViterbiESentence(hg2, &t2); + int l2 = ViterbiPathLength(hg2); + cerr << c1 << "\t" << TD::GetString(t1) << endl; + cerr << c2 << "\t" << TD::GetString(t2) << endl; + hg1.Union(hg2); + hg1.Reweight(wts); + c3 = ViterbiESentence(hg1, &t3); + int l3 = ViterbiPathLength(hg1); + cerr << c3 << "\t" << TD::GetString(t3) << endl; + EXPECT_FLOAT_EQ(c2, c3); + EXPECT_EQ(TD::GetString(t2), TD::GetString(t3)); + EXPECT_EQ(l2, l3); + + wts.set_value(FD::Convert("f2"), -1); + hg1.Reweight(wts); + c4 = ViterbiESentence(hg1, &t4); + cerr << c4 << "\t" << TD::GetString(t4) << endl; + EXPECT_EQ("Z ", TD::GetString(t4)); + EXPECT_FLOAT_EQ(98.82, log(c4)); + + vector, prob_t> > list; + KBest::KBestDerivations, ESentenceTraversal> kbest(hg1, 10); + for (int i = 0; i < 10; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg1.nodes_.size() - 1, i); + if (!d) break; + list.push_back(make_pair(d->yield, d->score)); + } + EXPECT_TRUE(list[0].first == t4); + EXPECT_FLOAT_EQ(log(list[0].second), log(c4)); + EXPECT_EQ(list.size(), 6); + EXPECT_FLOAT_EQ(log(list.back().second / list.front().second), -97.7); +} + +TEST_F(HGTest,ControlledKBest) { + Hypergraph hg; + CreateHG(&hg); + vector w(2); w[0]=0.4; w[1]=0.8; + hg.Reweight(w); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + + int best = 0; + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 10); + for (int i = 0; i < 10; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << TD::GetString(d->yield) << endl; + ++best; + } + EXPECT_EQ(4, best); +} + + +TEST_F(HGTest,InsideScore) { + SparseVector wts; + wts.set_value(FD::Convert("f1"), 1.0); + Hypergraph hg; + CreateTinyLatticeHG(&hg); + hg.Reweight(wts); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + hg.PrintGraphviz(); + prob_t inside = Inside(hg); + EXPECT_FLOAT_EQ(1.7934048, inside); // computed by hand + vector post; + inside = hg.ComputeBestPathThroughEdges(&post); + EXPECT_FLOAT_EQ(-0.3, log(inside)); // computed by hand + EXPECT_EQ(post.size(), 4); + for (int i = 0; i < 4; ++i) { + cerr << "edge post: " << log(post[i]) << '\t' << hg.edges_[i].rule_->AsString() << endl; + } +} + + +TEST_F(HGTest,PruneInsideOutside) { + SparseVector wts; + wts.set_value(FD::Convert("Feature_1"), 1.0); + Hypergraph hg; + CreateLatticeHG(&hg); + hg.Reweight(wts); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + hg.PrintGraphviz(); + //hg.DensityPruneInsideOutside(0.5, false, 2.0); + hg.BeamPruneInsideOutside(0.5, false, 0.5); + cost = ViterbiESentence(hg, &trans); + cerr << "Ncst: " << cost << endl; + cerr << TD::GetString(trans) << "\n"; + hg.PrintGraphviz(); +} + +TEST_F(HGTest,TestPruneEdges) { + Hypergraph hg; + CreateLatticeHG(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 1.0); + hg.Reweight(wts); + hg.PrintGraphviz(); + vector prune(hg.edges_.size(), true); + prune[6] = false; + hg.PruneEdges(prune); + cerr << "Pruned:\n"; + hg.PrintGraphviz(); +} + +TEST_F(HGTest,TestIntersect) { + Hypergraph hg; + CreateHG_int(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 1.0); + hg.Reweight(wts); + hg.PrintGraphviz(); + + int best = 0; + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 10); + for (int i = 0; i < 10; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << TD::GetString(d->yield) << endl; + ++best; + } + EXPECT_EQ(4, best); + + Lattice target(2); + target[0].push_back(LatticeArc(TD::Convert("a"), 0.0, 1)); + target[1].push_back(LatticeArc(TD::Convert("b"), 0.0, 1)); + HG::Intersect(target, &hg); + hg.PrintGraphviz(); +} + +TEST_F(HGTest,TestPrune2) { + Hypergraph hg; + CreateHG_int(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 1.0); + hg.Reweight(wts); + hg.PrintGraphviz(); + vector rem(hg.edges_.size(), false); + rem[0] = true; + rem[1] = true; + hg.PruneEdges(rem); + hg.PrintGraphviz(); + cerr << "TODO: fix this pruning behavior-- the resulting HG should be empty!\n"; +} + +TEST_F(HGTest,Sample) { + Hypergraph hg; + CreateLatticeHG(&hg); + SparseVector wts; + wts.set_value(FD::Convert("Feature_1"), 0.0); + hg.Reweight(wts); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + hg.PrintGraphviz(); +} + +TEST_F(HGTest,PLF) { + Hypergraph hg; + string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)"; + HypergraphIO::ReadFromPLF(inplf, &hg); + SparseVector wts; + wts.set_value(FD::Convert("Feature_0"), 1.0); + hg.Reweight(wts); + hg.PrintGraphviz(); + string outplf = HypergraphIO::AsPLF(hg); + cerr << " IN: " << inplf << endl; + cerr << "OUT: " << outplf << endl; + assert(inplf == outplf); +} + +TEST_F(HGTest,PushWeightsToGoal) { + Hypergraph hg; + CreateHG(&hg); + vector w(2); w[0]=0.4; w[1]=0.8; + hg.Reweight(w); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + hg.PrintGraphviz(); + hg.PushWeightsToGoal(); + hg.PrintGraphviz(); +} + +TEST_F(HGTest,TestSpecialKBest) { + Hypergraph hg; + CreateHGBalanced(&hg); + vector w(1); w[0]=0; + hg.Reweight(w); + vector, prob_t> > list; + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 100000); + for (int i = 0; i < 100000; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << TD::GetString(d->yield) << endl; + } + hg.PrintGraphviz(); +} + +TEST_F(HGTest, TestGenericViterbi) { + Hypergraph hg; + CreateHG_tiny(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 0.8); + hg.Reweight(wts); + vector trans; + const prob_t prob = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "prob: " << prob << "\n"; + EXPECT_FLOAT_EQ(-80.839996, log(prob)); + EXPECT_EQ("X ", TD::GetString(trans)); +} + +TEST_F(HGTest, TestGenericInside) { + Hypergraph hg; + CreateTinyLatticeHG(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 1.0); + hg.Reweight(wts); + vector inside; + prob_t ins = Inside(hg, &inside); + EXPECT_FLOAT_EQ(1.7934048, ins); // computed by hand + vector outside; + Outside(hg, inside, &outside); + EXPECT_EQ(3, outside.size()); + EXPECT_FLOAT_EQ(1.7934048, outside[0]); + EXPECT_FLOAT_EQ(1.3114071, outside[1]); + EXPECT_FLOAT_EQ(1.0, outside[2]); +} + +TEST_F(HGTest,TestGenericInside2) { + Hypergraph hg; + CreateHG(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 0.8); + hg.Reweight(wts); + vector inside, outside; + prob_t ins = Inside(hg, &inside); + Outside(hg, inside, &outside); + for (int i = 0; i < hg.nodes_.size(); ++i) + cerr << i << "\t" << log(inside[i]) << "\t" << log(outside[i]) << endl; + EXPECT_FLOAT_EQ(0, log(inside[0])); + EXPECT_FLOAT_EQ(-1.7861683, log(outside[0])); + EXPECT_FLOAT_EQ(-0.4, log(inside[1])); + EXPECT_FLOAT_EQ(-1.3861683, log(outside[1])); + EXPECT_FLOAT_EQ(-0.8, log(inside[2])); + EXPECT_FLOAT_EQ(-0.986168, log(outside[2])); + EXPECT_FLOAT_EQ(-0.96, log(inside[3])); + EXPECT_FLOAT_EQ(-0.8261683, log(outside[3])); + EXPECT_FLOAT_EQ(-1.562512, log(inside[4])); + EXPECT_FLOAT_EQ(-0.22365622, log(outside[4])); + EXPECT_FLOAT_EQ(-1.7861683, log(inside[5])); + EXPECT_FLOAT_EQ(0, log(outside[5])); +} + +TEST_F(HGTest,TestAddExpectations) { + Hypergraph hg; + CreateHG(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 0.8); + hg.Reweight(wts); + SparseVector feat_exps; + InsideOutside, EdgeFeaturesWeightFunction>(hg, &feat_exps); + EXPECT_FLOAT_EQ(-2.5439765, feat_exps[FD::Convert("f1")]); + EXPECT_FLOAT_EQ(-2.6357865, feat_exps[FD::Convert("f2")]); + cerr << feat_exps << endl; + SparseVector posts; + InsideOutside, TransitionEventWeightFunction>(hg, &posts); +} + +TEST_F(HGTest, Small) { + ReadFile rf("test_data/small.json.gz"); + Hypergraph hg; + assert(HypergraphIO::ReadFromJSON(rf.stream(), &hg)); + SparseVector wts; + wts.set_value(FD::Convert("Model_0"), -2.0); + wts.set_value(FD::Convert("Model_1"), -0.5); + wts.set_value(FD::Convert("Model_2"), -1.1); + wts.set_value(FD::Convert("Model_3"), -1.0); + wts.set_value(FD::Convert("Model_4"), -1.0); + wts.set_value(FD::Convert("Model_5"), 0.5); + wts.set_value(FD::Convert("Model_6"), 0.2); + wts.set_value(FD::Convert("Model_7"), -3.0); + hg.Reweight(wts); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + vector post; + prob_t c2 = Inside(hg, NULL, ScaledEdgeProb(0.6)); + EXPECT_FLOAT_EQ(2.1431036, log(c2)); +} + +TEST_F(HGTest, JSONTest) { + ostringstream os; + JSONParser::WriteEscapedString("\"I don't know\", she said.", &os); + EXPECT_EQ("\"\\\"I don't know\\\", she said.\"", os.str()); + ostringstream os2; + JSONParser::WriteEscapedString("yes", &os2); + EXPECT_EQ("\"yes\"", os2.str()); +} + +TEST_F(HGTest, TestGenericKBest) { + Hypergraph hg; + CreateHG(&hg); + //CreateHGBalanced(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 1.0); + hg.Reweight(wts); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 1000); + for (int i = 0; i < 1000; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << TD::GetString(d->yield) << " F:" << d->feature_values << endl; + } +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ibm_model1.cc b/src/ibm_model1.cc new file mode 100644 index 00000000..eb9c617c --- /dev/null +++ b/src/ibm_model1.cc @@ -0,0 +1,4 @@ +#include + + + diff --git a/src/inside_outside.h b/src/inside_outside.h new file mode 100644 index 00000000..9114c9d7 --- /dev/null +++ b/src/inside_outside.h @@ -0,0 +1,111 @@ +#ifndef _INSIDE_H_ +#define _INSIDE_H_ + +#include +#include +#include "hg.h" + +// run the inside algorithm and return the inside score +// if result is non-NULL, result will contain the inside +// score for each node +// NOTE: WeightType(0) must construct the semiring's additive identity +// WeightType(1) must construct the semiring's multiplicative identity +template +WeightType Inside(const Hypergraph& hg, + std::vector* result = NULL, + const WeightFunction& weight = WeightFunction()) { + const int num_nodes = hg.nodes_.size(); + std::vector dummy; + std::vector& inside_score = result ? *result : dummy; + inside_score.resize(num_nodes); + std::fill(inside_score.begin(), inside_score.end(), WeightType()); + for (int i = 0; i < num_nodes; ++i) { + const Hypergraph::Node& cur_node = hg.nodes_[i]; + WeightType* const cur_node_inside_score = &inside_score[i]; + const int num_in_edges = cur_node.in_edges_.size(); + if (num_in_edges == 0) { + *cur_node_inside_score = WeightType(1); + continue; + } + for (int j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]]; + WeightType score = weight(edge); + for (int k = 0; k < edge.tail_nodes_.size(); ++k) { + const int tail_node_index = edge.tail_nodes_[k]; + score *= inside_score[tail_node_index]; + } + *cur_node_inside_score += score; + } + } + return inside_score.back(); +} + +template +void Outside(const Hypergraph& hg, + std::vector& inside_score, + std::vector* result, + const WeightFunction& weight = WeightFunction()) { + assert(result); + const int num_nodes = hg.nodes_.size(); + assert(inside_score.size() == num_nodes); + std::vector& outside_score = *result; + outside_score.resize(num_nodes); + std::fill(outside_score.begin(), outside_score.end(), WeightType(0)); + outside_score.back() = WeightType(1); + for (int i = num_nodes - 1; i >= 0; --i) { + const Hypergraph::Node& cur_node = hg.nodes_[i]; + const WeightType& head_node_outside_score = outside_score[i]; + const int num_in_edges = cur_node.in_edges_.size(); + for (int j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]]; + const WeightType head_and_edge_weight = weight(edge) * head_node_outside_score; + const int num_tail_nodes = edge.tail_nodes_.size(); + for (int k = 0; k < num_tail_nodes; ++k) { + const int update_tail_node_index = edge.tail_nodes_[k]; + WeightType* const tail_outside_score = &outside_score[update_tail_node_index]; + WeightType inside_contribution = WeightType(1); + for (int l = 0; l < num_tail_nodes; ++l) { + const int other_tail_node_index = edge.tail_nodes_[l]; + if (update_tail_node_index != other_tail_node_index) + inside_contribution *= inside_score[other_tail_node_index]; + } + *tail_outside_score += head_and_edge_weight * inside_contribution; + } + } + } +} + +// this is the Inside-Outside optimization described in Li et al. (EMNLP 2009) +// for computing the inside algorithm over expensive semirings +// (such as expectations over features). See Figure 4. It is slightly different +// in that x/k is returned not (k,x) +// NOTE: RType * PType must be valid (and yield RType) +template +PType InsideOutside(const Hypergraph& hg, + RType* result_x, + const WeightFunction& weight1 = WeightFunction(), + const WeightFunction2& weight2 = WeightFunction2()) { + const int num_nodes = hg.nodes_.size(); + std::vector inside, outside; + const PType z = Inside(hg, &inside, weight1); + Outside(hg, inside, &outside, weight1); + RType& x = *result_x; + x = RType(); + for (int i = 0; i < num_nodes; ++i) { + const Hypergraph::Node& cur_node = hg.nodes_[i]; + const int num_in_edges = cur_node.in_edges_.size(); + for (int j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]]; + PType prob = outside[i]; + prob *= weight1(edge); + const int num_tail_nodes = edge.tail_nodes_.size(); + for (int k = 0; k < num_tail_nodes; ++k) + prob *= inside[edge.tail_nodes_[k]]; + prob /= z; + x += weight2(edge) * prob; + } + } + return z; +} + +#endif diff --git a/src/json_parse.cc b/src/json_parse.cc new file mode 100644 index 00000000..f6fdfea8 --- /dev/null +++ b/src/json_parse.cc @@ -0,0 +1,50 @@ +#include "json_parse.h" + +#include +#include + +using namespace std; + +static const char *json_hex_chars = "0123456789abcdef"; + +void JSONParser::WriteEscapedString(const string& in, ostream* out) { + int pos = 0; + int start_offset = 0; + unsigned char c = 0; + (*out) << '"'; + while(pos < in.size()) { + c = in[pos]; + switch(c) { + case '\b': + case '\n': + case '\r': + case '\t': + case '"': + case '\\': + case '/': + if(pos - start_offset > 0) + (*out) << in.substr(start_offset, pos - start_offset); + if(c == '\b') (*out) << "\\b"; + else if(c == '\n') (*out) << "\\n"; + else if(c == '\r') (*out) << "\\r"; + else if(c == '\t') (*out) << "\\t"; + else if(c == '"') (*out) << "\\\""; + else if(c == '\\') (*out) << "\\\\"; + else if(c == '/') (*out) << "\\/"; + start_offset = ++pos; + break; + default: + if(c < ' ') { + cerr << "Warning, bad character (" << static_cast(c) << ") in string\n"; + if(pos - start_offset > 0) + (*out) << in.substr(start_offset, pos - start_offset); + (*out) << "\\u00" << json_hex_chars[c >> 4] << json_hex_chars[c & 0xf]; + start_offset = ++pos; + } else pos++; + } + } + if(pos - start_offset > 0) + (*out) << in.substr(start_offset, pos - start_offset); + (*out) << '"'; +} + diff --git a/src/json_parse.h b/src/json_parse.h new file mode 100644 index 00000000..c3cba954 --- /dev/null +++ b/src/json_parse.h @@ -0,0 +1,58 @@ +#ifndef _JSON_WRAPPER_H_ +#define _JSON_WRAPPER_H_ + +#include +#include +#include "JSON_parser.h" + +class JSONParser { + public: + JSONParser() { + init_JSON_config(&config); + hack.mf = &JSONParser::Callback; + config.depth = 10; + config.callback_ctx = reinterpret_cast(this); + config.callback = hack.cb; + config.allow_comments = 1; + config.handle_floats_manually = 1; + jc = new_JSON_parser(&config); + } + virtual ~JSONParser() { + delete_JSON_parser(jc); + } + bool Parse(std::istream* in) { + int count = 0; + int lc = 1; + for (; in ; ++count) { + int next_char = in->get(); + if (!in->good()) break; + if (lc == '\n') { ++lc; } + if (!JSON_parser_char(jc, next_char)) { + std::cerr << "JSON_parser_char: syntax error, line " << lc << " (byte " << count << ")" << std::endl; + return false; + } + } + if (!JSON_parser_done(jc)) { + std::cerr << "JSON_parser_done: syntax error\n"; + return false; + } + return true; + } + static void WriteEscapedString(const std::string& in, std::ostream* out); + protected: + virtual bool HandleJSONEvent(int type, const JSON_value* value) = 0; + private: + int Callback(int type, const JSON_value* value) { + if (HandleJSONEvent(type, value)) return 1; + return 0; + } + JSON_parser_struct* jc; + JSON_config config; + typedef int (JSONParser::* MF)(int type, const struct JSON_value_struct* value); + union CBHack { + JSON_parser_callback cb; + MF mf; + } hack; +}; + +#endif diff --git a/src/kbest.h b/src/kbest.h new file mode 100644 index 00000000..cd9b6c2b --- /dev/null +++ b/src/kbest.h @@ -0,0 +1,207 @@ +#ifndef _HG_KBEST_H_ +#define _HG_KBEST_H_ + +#include +#include +#include + +#include + +#include "wordid.h" +#include "hg.h" + +namespace KBest { + // default, don't filter any derivations from the k-best list + struct NoFilter { + bool operator()(const std::vector& yield) { + (void) yield; + return false; + } + }; + + // optional, filter unique yield strings + struct FilterUnique { + std::tr1::unordered_set, boost::hash > > unique; + + bool operator()(const std::vector& yield) { + return !unique.insert(yield).second; + } + }; + + // utility class to lazily create the k-best derivations from a forest, uses + // the lazy k-best algorithm (Algorithm 3) from Huang and Chiang (IWPT 2005) + template + struct KBestDerivations { + KBestDerivations(const Hypergraph& hg, + const size_t k, + const Traversal& tf = Traversal(), + const WeightFunction& wf = WeightFunction()) : + traverse(tf), w(wf), g(hg), nds(g.nodes_.size()), k_prime(k) {} + + ~KBestDerivations() { + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; + } + + struct Derivation { + Derivation(const Hypergraph::Edge& e, + const SmallVector& jv, + const WeightType& w, + const SparseVector& f) : + edge(&e), + j(jv), + score(w), + feature_values(f) {} + + // dummy constructor, just for query + Derivation(const Hypergraph::Edge& e, + const SmallVector& jv) : edge(&e), j(jv) {} + + T yield; + const Hypergraph::Edge* const edge; + const SmallVector j; + const WeightType score; + const SparseVector feature_values; + }; + struct HeapCompare { + bool operator()(const Derivation* a, const Derivation* b) const { + return a->score < b->score; + } + }; + struct DerivationCompare { + bool operator()(const Derivation* a, const Derivation* b) const { + return a->score > b->score; + } + }; + struct DerivationUniquenessHash { + size_t operator()(const Derivation* d) const { + size_t x = 5381; + x = ((x << 5) + x) ^ d->edge->id_; + for (int i = 0; i < d->j.size(); ++i) + x = ((x << 5) + x) ^ d->j[i]; + return x; + } + }; + struct DerivationUniquenessEquals { + bool operator()(const Derivation* a, const Derivation* b) const { + return (a->edge == b->edge) && (a->j == b->j); + } + }; + typedef std::vector CandidateHeap; + typedef std::vector DerivationList; + typedef std::tr1::unordered_set< + const Derivation*, DerivationUniquenessHash, DerivationUniquenessEquals> UniqueDerivationSet; + + struct NodeDerivationState { + CandidateHeap cand; + DerivationList D; + DerivationFilter filter; + UniqueDerivationSet ds; + explicit NodeDerivationState(const DerivationFilter& f = DerivationFilter()) : filter(f) {} + }; + + Derivation* LazyKthBest(int v, int k) { + NodeDerivationState& s = GetCandidates(v); + CandidateHeap& cand = s.cand; + DerivationList& D = s.D; + DerivationFilter& filter = s.filter; + bool add_next = true; + while (D.size() <= k) { + if (add_next && D.size() > 0) { + const Derivation* d = D.back(); + LazyNext(d, &cand, &s.ds); + } + add_next = false; + + if (cand.size() > 0) { + std::pop_heap(cand.begin(), cand.end(), HeapCompare()); + Derivation* d = cand.back(); + cand.pop_back(); + std::vector ants(d->edge->Arity()); + for (int j = 0; j < ants.size(); ++j) + ants[j] = &LazyKthBest(d->edge->tail_nodes_[j], d->j[j])->yield; + traverse(*d->edge, ants, &d->yield); + if (!filter(d->yield)) { + D.push_back(d); + add_next = true; + } + } else { + break; + } + } + if (k < D.size()) return D[k]; else return NULL; + } + + private: + // creates a derivation object with all fields set but the yield + // the yield is computed in LazyKthBest before the derivation is added to D + // returns NULL if j refers to derivation numbers larger than the + // antecedent structure define + Derivation* CreateDerivation(const Hypergraph::Edge& e, const SmallVector& j) { + WeightType score = w(e); + SparseVector feats = e.feature_values_; + for (int i = 0; i < e.Arity(); ++i) { + const Derivation* ant = LazyKthBest(e.tail_nodes_[i], j[i]); + if (!ant) { return NULL; } + score *= ant->score; + feats += ant->feature_values; + } + freelist.push_back(new Derivation(e, j, score, feats)); + return freelist.back(); + } + + NodeDerivationState& GetCandidates(int v) { + NodeDerivationState& s = nds[v]; + if (!s.D.empty() || !s.cand.empty()) return s; + + const Hypergraph::Node& node = g.nodes_[v]; + for (int i = 0; i < node.in_edges_.size(); ++i) { + const Hypergraph::Edge& edge = g.edges_[node.in_edges_[i]]; + SmallVector jv(edge.Arity(), 0); + Derivation* d = CreateDerivation(edge, jv); + assert(d); + s.cand.push_back(d); + } + + const int effective_k = std::min(k_prime, s.cand.size()); + const typename CandidateHeap::iterator kth = s.cand.begin() + effective_k; + std::nth_element(s.cand.begin(), kth, s.cand.end(), DerivationCompare()); + s.cand.resize(effective_k); + std::make_heap(s.cand.begin(), s.cand.end(), HeapCompare()); + + return s; + } + + void LazyNext(const Derivation* d, CandidateHeap* cand, UniqueDerivationSet* ds) { + for (int i = 0; i < d->j.size(); ++i) { + SmallVector j = d->j; + ++j[i]; + const Derivation* ant = LazyKthBest(d->edge->tail_nodes_[i], j[i]); + if (ant) { + Derivation query_unique(*d->edge, j); + if (ds->count(&query_unique) == 0) { + Derivation* new_d = CreateDerivation(*d->edge, j); + if (new_d) { + cand->push_back(new_d); + std::push_heap(cand->begin(), cand->end(), HeapCompare()); + assert(ds->insert(new_d).second); // insert into uniqueness set, sanity check + } + } + } + } + } + + const Traversal traverse; + const WeightFunction w; + const Hypergraph& g; + std::vector nds; + std::vector freelist; + const size_t k_prime; + }; +} + +#endif diff --git a/src/lattice.cc b/src/lattice.cc new file mode 100644 index 00000000..aa1df3db --- /dev/null +++ b/src/lattice.cc @@ -0,0 +1,27 @@ +#include "lattice.h" + +#include "tdict.h" +#include "hg_io.h" + +using namespace std; + +bool LatticeTools::LooksLikePLF(const string &line) { + return (line.size() > 5) && (line.substr(0,4) == "((('"); +} + +void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) { + Lattice& l = *pl; + vector ids; + TD::ConvertSentence(text, &ids); + l.resize(ids.size()); + for (int i = 0; i < l.size(); ++i) + l[i].push_back(LatticeArc(ids[i], 0.0, 1)); +} + +void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) { + if (LooksLikePLF(text_or_plf)) + HypergraphIO::PLFtoLattice(text_or_plf, pl); + else + ConvertTextToLattice(text_or_plf, pl); +} + diff --git a/src/lattice.h b/src/lattice.h new file mode 100644 index 00000000..1177e768 --- /dev/null +++ b/src/lattice.h @@ -0,0 +1,31 @@ +#ifndef __LATTICE_H_ +#define __LATTICE_H_ + +#include +#include +#include "wordid.h" + +struct LatticeArc { + WordID label; + double cost; + int dist2next; + LatticeArc() : label(), cost(), dist2next() {} + LatticeArc(WordID w, double c, int i) : label(w), cost(c), dist2next(i) {} +}; + +class Lattice : public std::vector > { + public: + Lattice() {} + explicit Lattice(size_t t, const std::vector& v = std::vector()) : + std::vector >(t, v) {} + + // TODO add distance functions +}; + +struct LatticeTools { + static bool LooksLikePLF(const std::string &line); + static void ConvertTextToLattice(const std::string& text, Lattice* pl); + static void ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl); +}; + +#endif diff --git a/src/lexcrf.cc b/src/lexcrf.cc new file mode 100644 index 00000000..33455a3d --- /dev/null +++ b/src/lexcrf.cc @@ -0,0 +1,112 @@ +#include "lexcrf.h" + +#include + +#include "filelib.h" +#include "hg.h" +#include "tdict.h" +#include "grammar.h" +#include "sentence_metadata.h" + +using namespace std; + +struct LexicalCRFImpl { + LexicalCRFImpl(const boost::program_options::variables_map& conf) : + use_null(false), + kXCAT(TD::Convert("X")*-1), + kNULL(TD::Convert("")), + kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")), + kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) { + vector gfiles = conf["grammar"].as >(); + assert(gfiles.size() == 1); + ReadFile rf(gfiles.front()); + TextGrammar *tg = new TextGrammar; + grammar.reset(tg); + istream* in = rf.stream(); + int lc = 0; + bool flag = false; + while(*in) { + string line; + getline(*in, line); + if (line.empty()) continue; + ++lc; + TRulePtr r(TRule::CreateRulePhrasetable(line)); + tg->AddRule(r); + if (lc % 50000 == 0) { cerr << '.'; flag = true; } + if (lc % 2000000 == 0) { cerr << " [" << lc << "]\n"; flag = false; } + } + if (flag) cerr << endl; + cerr << "Loaded " << lc << " rules\n"; + } + + void BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) { + const int e_len = smeta.GetTargetLength(); + assert(e_len > 0); + const int f_len = lattice.size(); + // hack to tell the feature function system how big the sentence pair is + const int f_start = (use_null ? -1 : 0); + int prev_node_id = -1; + for (int i = 0; i < e_len; ++i) { // for each word in the *ref* + Hypergraph::Node* node = forest->AddNode(kXCAT); + const int new_node_id = node->id_; + for (int j = f_start; j < f_len; ++j) { // for each word in the source + const WordID src_sym = (j < 0 ? kNULL : lattice[j][0].label); + const GrammarIter* gi = grammar->GetRoot()->Extend(src_sym); + if (!gi) { + cerr << "No translations found for: " << TD::Convert(src_sym) << "\n"; + abort(); + } + const RuleBin* rb = gi->GetRules(); + assert(rb); + for (int k = 0; k < rb->GetNumRules(); ++k) { + TRulePtr rule = rb->GetIthRule(k); + Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); + edge->i_ = j; + edge->j_ = j+1; + edge->prev_i_ = i; + edge->prev_j_ = i+1; + edge->feature_values_ += edge->rule_->GetFeatureValues(); + forest->ConnectEdgeToHeadNode(edge->id_, new_node_id); + } + } + if (prev_node_id >= 0) { + const int comb_node_id = forest->AddNode(kXCAT)->id_; + Hypergraph::TailNodeVector tail(2, prev_node_id); + tail[1] = new_node_id; + const int edge_id = forest->AddEdge(kBINARY, tail)->id_; + forest->ConnectEdgeToHeadNode(edge_id, comb_node_id); + prev_node_id = comb_node_id; + } else { + prev_node_id = new_node_id; + } + } + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("[Goal]")*-1); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + } + + private: + const bool use_null; + const WordID kXCAT; + const WordID kNULL; + const TRulePtr kBINARY; + const TRulePtr kGOAL_RULE; + GrammarPtr grammar; +}; + +LexicalCRF::LexicalCRF(const boost::program_options::variables_map& conf) : + pimpl_(new LexicalCRFImpl(conf)) {} + +bool LexicalCRF::Translate(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* forest) { + Lattice lattice; + LatticeTools::ConvertTextToLattice(input, &lattice); + smeta->SetSourceLength(lattice.size()); + pimpl_->BuildTrellis(lattice, *smeta, forest); + forest->Reweight(weights); + return true; +} + diff --git a/src/lexcrf.h b/src/lexcrf.h new file mode 100644 index 00000000..99362c81 --- /dev/null +++ b/src/lexcrf.h @@ -0,0 +1,18 @@ +#ifndef _LEXCRF_H_ +#define _LEXCRF_H_ + +#include "translator.h" +#include "lattice.h" + +struct LexicalCRFImpl; +struct LexicalCRF : public Translator { + LexicalCRF(const boost::program_options::variables_map& conf); + bool Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* forest); + private: + boost::shared_ptr pimpl_; +}; + +#endif diff --git a/src/lm_ff.cc b/src/lm_ff.cc new file mode 100644 index 00000000..f95140de --- /dev/null +++ b/src/lm_ff.cc @@ -0,0 +1,328 @@ +#include "lm_ff.h" + +#include +#include +#include +#include +#include +#include + +#include "tdict.h" +#include "Vocab.h" +#include "Ngram.h" +#include "hg.h" +#include "stringlib.h" + +using namespace std; + +struct LMClient { + struct Cache { + map tree; + float prob; + Cache() : prob() {} + }; + + LMClient(const char* host) : port(6666) { + s = strchr(host, ':'); + if (s != NULL) { + *s = '\0'; + ++s; + port = atoi(s); + } + sock = socket(AF_INET, SOCK_STREAM, 0); + hp = gethostbyname(host); + if (hp == NULL) { + cerr << "unknown host " << host << endl; + abort(); + } + bzero((char *)&server, sizeof(server)); + bcopy(hp->h_addr, (char *)&server.sin_addr, hp->h_length); + server.sin_family = hp->h_addrtype; + server.sin_port = htons(port); + + int errors = 0; + while (connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) { + cerr << "Error: connect()\n"; + sleep(1); + errors++; + if (errors > 3) exit(1); + } + cerr << "Connected to LM on " << host << " on port " << port << endl; + } + + float wordProb(int word, int* context) { + Cache* cur = &cache; + int i = 0; + while (context[i] > 0) { + cur = &cur->tree[context[i++]]; + } + cur = &cur->tree[word]; + if (cur->prob) { return cur->prob; } + + i = 0; + ostringstream os; + os << "prob " << TD::Convert(word); + while (context[i] > 0) { + os << ' ' << TD::Convert(context[i++]); + } + os << endl; + string out = os.str(); + write(sock, out.c_str(), out.size()); + int r = read(sock, res, 6); + int errors = 0; + int cnt = 0; + while (1) { + if (r < 0) { + errors++; sleep(1); + cerr << "Error: read()\n"; + if (errors > 5) exit(1); + } else if (r==0 || res[cnt] == '\n') { break; } + else { + cnt += r; + if (cnt==6) break; + read(sock, &res[cnt], 6-cnt); + } + } + cur->prob = *reinterpret_cast(res); + return cur->prob; + } + + void clear() { + cache.tree.clear(); + } + + private: + Cache cache; + int sock, port; + char *s; + struct hostent *hp; + struct sockaddr_in server; + char res[8]; +}; + +class LanguageModelImpl { + public: + LanguageModelImpl(int order, const string& f) : + ngram_(*TD::dict_), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1), + floor_(-100.0), + client_(NULL), + kSTART(TD::Convert("")), + kSTOP(TD::Convert("")), + kUNKNOWN(TD::Convert("")), + kNONE(-1), + kSTAR(TD::Convert("<{STAR}>")) { + if (f.find("lm://") == 0) { + client_ = new LMClient(f.substr(5).c_str()); + } else { + File file(f.c_str(), "r", 0); + assert(file); + cerr << "Reading " << order_ << "-gram LM from " << f << endl; + ngram_.read(file, false); + } + } + + ~LanguageModelImpl() { + delete client_; + } + + inline int StateSize(const void* state) const { + return *(static_cast(state) + state_size_); + } + + inline void SetStateSize(int size, void* state) const { + *(static_cast(state) + state_size_) = size; + } + + inline double LookupProbForBufferContents(int i) { + double p = client_ ? + client_->wordProb(buffer_[i], &buffer_[i+1]) + : ngram_.wordProb(buffer_[i], (VocabIndex*)&buffer_[i+1]); + if (p < floor_) p = floor_; + return p; + } + + string DebugStateToString(const void* state) const { + int len = StateSize(state); + const int* astate = reinterpret_cast(state); + string res = "["; + for (int i = 0; i < len; ++i) { + res += " "; + res += TD::Convert(astate[i]); + } + res += " ]"; + return res; + } + + inline double ProbNoRemnant(int i, int len) { + int edge = len; + bool flag = true; + double sum = 0.0; + while (i >= 0) { + if (buffer_[i] == kSTAR) { + edge = i; + flag = false; + } else if (buffer_[i] <= 0) { + edge = i; + flag = true; + } else { + if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART))) + sum += LookupProbForBufferContents(i); + } + --i; + } + return sum; + } + + double EstimateProb(const vector& phrase) { + int len = phrase.size(); + buffer_.resize(len + 1); + buffer_[len] = kNONE; + int i = len - 1; + for (int j = 0; j < len; ++j,--i) + buffer_[i] = phrase[j]; + return ProbNoRemnant(len - 1, len); + } + + double EstimateProb(const void* state) { + int len = StateSize(state); + // cerr << "residual len: " << len << endl; + buffer_.resize(len + 1); + buffer_[len] = kNONE; + const int* astate = reinterpret_cast(state); + int i = len - 1; + for (int j = 0; j < len; ++j,--i) + buffer_[i] = astate[j]; + return ProbNoRemnant(len - 1, len); + } + + double FinalTraversalCost(const void* state) { + int slen = StateSize(state); + int len = slen + 2; + // cerr << "residual len: " << len << endl; + buffer_.resize(len + 1); + buffer_[len] = kNONE; + buffer_[len-1] = kSTART; + const int* astate = reinterpret_cast(state); + int i = len - 2; + for (int j = 0; j < slen; ++j,--i) + buffer_[i] = astate[j]; + buffer_[i] = kSTOP; + assert(i == 0); + return ProbNoRemnant(len - 1, len); + } + + double LookupWords(const TRule& rule, const vector& ant_states, void* vstate) { + int len = rule.ELength() - rule.Arity(); + for (int i = 0; i < ant_states.size(); ++i) + len += StateSize(ant_states[i]); + buffer_.resize(len + 1); + buffer_[len] = kNONE; + int i = len - 1; + const vector& e = rule.e(); + for (int j = 0; j < e.size(); ++j) { + if (e[j] < 1) { + const int* astate = reinterpret_cast(ant_states[-e[j]]); + int slen = StateSize(astate); + for (int k = 0; k < slen; ++k) + buffer_[i--] = astate[k]; + } else { + buffer_[i--] = e[j]; + } + } + + double sum = 0.0; + int* remnant = reinterpret_cast(vstate); + int j = 0; + i = len - 1; + int edge = len; + + while (i >= 0) { + if (buffer_[i] == kSTAR) { + edge = i; + } else if (edge-i >= order_) { + sum += LookupProbForBufferContents(i); + } else if (edge == len && remnant) { + remnant[j++] = buffer_[i]; + } + --i; + } + if (!remnant) return sum; + + if (edge != len || len >= order_) { + remnant[j++] = kSTAR; + if (order_-1 < edge) edge = order_-1; + for (int i = edge-1; i >= 0; --i) + remnant[j++] = buffer_[i]; + } + + SetStateSize(j, vstate); + return sum; + } + + static int OrderToStateSize(int order) { + return ((order-1) * 2 + 1) * sizeof(WordID) + 1; + } + + private: + Ngram ngram_; + vector buffer_; + const int order_; + const int state_size_; + const double floor_; + LMClient* client_; + + public: + const WordID kSTART; + const WordID kSTOP; + const WordID kUNKNOWN; + const WordID kNONE; + const WordID kSTAR; +}; + +LanguageModel::LanguageModel(const string& param) : + fid_(FD::Convert("LanguageModel")) { + vector argv; + int argc = SplitOnWhitespace(param, &argv); + int order = 3; + // TODO add support for -n FeatureName + string filename; + if (argc < 1) { cerr << "LanguageModel requires a filename, minimally!\n"; abort(); } + else if (argc == 1) { filename = argv[0]; } + else if (argc == 2 || argc > 3) { cerr << "Don't understand 'LanguageModel " << param << "'\n"; } + else if (argc == 3) { + if (argv[0] == "-o") { + order = atoi(argv[1].c_str()); + filename = argv[2]; + } else if (argv[1] == "-o") { + order = atoi(argv[2].c_str()); + filename = argv[0]; + } + } + SetStateSize(LanguageModelImpl::OrderToStateSize(order)); + pimpl_ = new LanguageModelImpl(order, filename); +} + +LanguageModel::~LanguageModel() { + delete pimpl_; +} + +string LanguageModel::DebugStateToString(const void* state) const{ + return pimpl_->DebugStateToString(state); +} + +void LanguageModel::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + (void) smeta; + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state)); + estimated_features->set_value(fid_, pimpl_->EstimateProb(state)); +} + +void LanguageModel::FinalTraversalFeatures(const void* ant_state, + SparseVector* features) const { + features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); +} + diff --git a/src/lm_ff.h b/src/lm_ff.h new file mode 100644 index 00000000..cd717360 --- /dev/null +++ b/src/lm_ff.h @@ -0,0 +1,32 @@ +#ifndef _LM_FF_H_ +#define _LM_FF_H_ + +#include +#include + +#include "hg.h" +#include "ff.h" + +class LanguageModelImpl; + +class LanguageModel : public FeatureFunction { + public: + // param = "filename.lm [-o n]" + LanguageModel(const std::string& param); + ~LanguageModel(); + virtual void FinalTraversalFeatures(const void* context, + SparseVector* features) const; + std::string DebugStateToString(const void* state) const; + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + const int fid_; + mutable LanguageModelImpl* pimpl_; +}; + +#endif diff --git a/src/logval.h b/src/logval.h new file mode 100644 index 00000000..a8ca620c --- /dev/null +++ b/src/logval.h @@ -0,0 +1,136 @@ +#ifndef LOGVAL_H_ +#define LOGVAL_H_ + +#include +#include + +template +class LogVal { + public: + LogVal() : v_(-std::numeric_limits::infinity()) {} + explicit LogVal(double x) : v_(std::log(x)) {} + LogVal(const LogVal& o) : v_(o.v_) {} + static LogVal One() { return LogVal(1); } + static LogVal Zero() { return LogVal(); } + + void logeq(const T& v) { v_ = v; } + + LogVal& operator+=(const LogVal& a) { + if (a.v_ == -std::numeric_limits::infinity()) return *this; + if (a.v_ < v_) { + v_ = v_ + log1p(std::exp(a.v_ - v_)); + } else { + v_ = a.v_ + log1p(std::exp(v_ - a.v_)); + } + return *this; + } + + LogVal& operator*=(const LogVal& a) { + v_ += a.v_; + return *this; + } + + LogVal& operator*=(const T& a) { + v_ += log(a); + return *this; + } + + LogVal& operator/=(const LogVal& a) { + v_ -= a.v_; + return *this; + } + + LogVal& poweq(const T& power) { + if (power == 0) v_ = 0; else v_ *= power; + return *this; + } + + LogVal pow(const T& power) const { + LogVal res = *this; + res.poweq(power); + return res; + } + + operator T() const { + return std::exp(v_); + } + + T v_; +}; + +template +LogVal operator+(const LogVal& o1, const LogVal& o2) { + LogVal res(o1); + res += o2; + return res; +} + +template +LogVal operator*(const LogVal& o1, const LogVal& o2) { + LogVal res(o1); + res *= o2; + return res; +} + +template +LogVal operator*(const LogVal& o1, const T& o2) { + LogVal res(o1); + res *= o2; + return res; +} + +template +LogVal operator*(const T& o1, const LogVal& o2) { + LogVal res(o2); + res *= o1; + return res; +} + +template +LogVal operator/(const LogVal& o1, const LogVal& o2) { + LogVal res(o1); + res /= o2; + return res; +} + +template +T log(const LogVal& o) { + return o.v_; +} + +template +LogVal pow(const LogVal& b, const T& e) { + return b.pow(e); +} + +template +bool operator<(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ < rhs.v_); +} + +template +bool operator<=(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ <= rhs.v_); +} + +template +bool operator>(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ > rhs.v_); +} + +template +bool operator>=(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ >= rhs.v_); +} + +template +bool operator==(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ == rhs.v_); +} + +template +bool operator!=(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ != rhs.v_); +} + +#endif diff --git a/src/maxtrans_blunsom.cc b/src/maxtrans_blunsom.cc new file mode 100644 index 00000000..4a6680e0 --- /dev/null +++ b/src/maxtrans_blunsom.cc @@ -0,0 +1,287 @@ +#include "apply_models.h" + +#include +#include +#include +#include + +#include +#include + +#include "tdict.h" +#include "hg.h" +#include "ff.h" + +using boost::tuple; +using namespace std; +using namespace std::tr1; + +namespace Hack { + +struct Candidate; +typedef SmallVector JVector; +typedef vector CandidateHeap; +typedef vector CandidateList; + +// life cycle: candidates are created, placed on the heap +// and retrieved by their estimated cost, when they're +// retrieved, they're incorporated into the +LM hypergraph +// where they also know the head node index they are +// attached to. After they are added to the +LM hypergraph +// inside_prob_ and est_prob_ fields may be updated as better +// derivations are found (this happens since the successor's +// of derivation d may have a better score- they are +// explored lazily). However, the updates don't happen +// when a candidate is in the heap so maintaining the heap +// property is not an issue. +struct Candidate { + int node_index_; // -1 until incorporated + // into the +LM forest + const Hypergraph::Edge* in_edge_; // in -LM forest + Hypergraph::Edge out_edge_; + vector state_; + const JVector j_; + prob_t inside_prob_; // these are fixed until the cand + // is popped, then they may be updated + prob_t est_prob_; + + Candidate(const Hypergraph::Edge& e, + const JVector& j, + const vector& D, + bool is_goal) : + node_index_(-1), + in_edge_(&e), + j_(j) { + InitializeCandidate(D, is_goal); + } + + // used to query uniqueness + Candidate(const Hypergraph::Edge& e, + const JVector& j) : in_edge_(&e), j_(j) {} + + bool IsIncorporatedIntoHypergraph() const { + return node_index_ >= 0; + } + + void InitializeCandidate(const vector >& D, + const bool is_goal) { + const Hypergraph::Edge& in_edge = *in_edge_; + out_edge_.rule_ = in_edge.rule_; + out_edge_.feature_values_ = in_edge.feature_values_; + Hypergraph::TailNodeVector& tail = out_edge_.tail_nodes_; + tail.resize(j_.size()); + prob_t p = prob_t::One(); + // cerr << "\nEstimating application of " << in_edge.rule_->AsString() << endl; + vector* > ants(tail.size()); + for (int i = 0; i < tail.size(); ++i) { + const Candidate& ant = *D[in_edge.tail_nodes_[i]][j_[i]]; + ants[i] = &ant.state_; + assert(ant.IsIncorporatedIntoHypergraph()); + tail[i] = ant.node_index_; + p *= ant.inside_prob_; + } + prob_t edge_estimate = prob_t::One(); + if (is_goal) { + assert(tail.size() == 1); + out_edge_.edge_prob_ = in_edge.edge_prob_; + } else { + in_edge.rule_->ESubstitute(ants, &state_); + out_edge_.edge_prob_ = in_edge.edge_prob_; + } + inside_prob_ = out_edge_.edge_prob_ * p; + est_prob_ = inside_prob_ * edge_estimate; + } +}; + +ostream& operator<<(ostream& os, const Candidate& cand) { + os << "CAND["; + if (!cand.IsIncorporatedIntoHypergraph()) { os << "PENDING "; } + else { os << "+LM_node=" << cand.node_index_; } + os << " edge=" << cand.in_edge_->id_; + os << " j=<"; + for (int i = 0; i < cand.j_.size(); ++i) + os << (i==0 ? "" : " ") << cand.j_[i]; + os << "> vit=" << log(cand.inside_prob_); + os << " est=" << log(cand.est_prob_); + return os << ']'; +} + +struct HeapCandCompare { + bool operator()(const Candidate* l, const Candidate* r) const { + return l->est_prob_ < r->est_prob_; + } +}; + +struct EstProbSorter { + bool operator()(const Candidate* l, const Candidate* r) const { + return l->est_prob_ > r->est_prob_; + } +}; + +// the same candidate can be added multiple times if +// j is multidimensional (if you're going NW in Manhattan, you +// can first go north, then west, or you can go west then north) +// this is a hash function on the relevant variables from +// Candidate to enforce this. +struct CandidateUniquenessHash { + size_t operator()(const Candidate* c) const { + size_t x = 5381; + x = ((x << 5) + x) ^ c->in_edge_->id_; + for (int i = 0; i < c->j_.size(); ++i) + x = ((x << 5) + x) ^ c->j_[i]; + return x; + } +}; + +struct CandidateUniquenessEquals { + bool operator()(const Candidate* a, const Candidate* b) const { + return (a->in_edge_ == b->in_edge_) && (a->j_ == b->j_); + } +}; + +typedef unordered_set UniqueCandidateSet; +typedef unordered_map, Candidate*, boost::hash > > State2Node; + +class MaxTransBeamSearch { + +public: + MaxTransBeamSearch(const Hypergraph& i, int pop_limit, Hypergraph* o) : + in(i), + out(*o), + D(in.nodes_.size()), + pop_limit_(pop_limit) { + cerr << " Finding max translation (cube pruning, pop_limit = " << pop_limit_ << ')' << endl; + } + + void Apply() { + int num_nodes = in.nodes_.size(); + int goal_id = num_nodes - 1; + int pregoal = goal_id - 1; + assert(in.nodes_[pregoal].out_edges_.size() == 1); + cerr << " "; + for (int i = 0; i < in.nodes_.size(); ++i) { + cerr << '.'; + KBest(i, i == goal_id); + } + cerr << endl; + int best_node = D[goal_id].front()->in_edge_->tail_nodes_.front(); + Candidate& best = *D[best_node].front(); + cerr << " Best path: " << log(best.inside_prob_) + << "\t" << log(best.est_prob_) << endl; + cout << TD::GetString(D[best_node].front()->state_) << endl; + FreeAll(); + } + + private: + void FreeAll() { + for (int i = 0; i < D.size(); ++i) { + CandidateList& D_i = D[i]; + for (int j = 0; j < D_i.size(); ++j) + delete D_i[j]; + } + D.clear(); + } + + void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) { + Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_.rule_, item->out_edge_.tail_nodes_); + new_edge->feature_values_ = item->out_edge_.feature_values_; + new_edge->edge_prob_ = item->out_edge_.edge_prob_; + Candidate*& o_item = (*s2n)[item->state_]; + if (!o_item) o_item = item; + + int& node_id = o_item->node_index_; + if (node_id < 0) { + Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_, ""); + node_id = new_node->id_; + } + Hypergraph::Node* node = &out.nodes_[node_id]; + out.ConnectEdgeToHeadNode(new_edge, node); + + if (item != o_item) { + assert(o_item->state_ == item->state_); // sanity check! + o_item->est_prob_ += item->est_prob_; + o_item->inside_prob_ += item->inside_prob_; + freelist->push_back(item); + } + } + + void KBest(const int vert_index, const bool is_goal) { + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_cands; + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, D, is_goal)); + assert(unique_cands.insert(cand.back()).second); // these should all be unique! + } +// cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + PushSucc(*item, is_goal, &cand, &unique_cands); + IncorporateIntoPlusLMForest(item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i) + D_v[c++] = i->second; + sort(D_v.begin(), D_v.end(), EstProbSorter()); + // cerr << " expanded to " << D_v.size() << " nodes\n"; + + for (int i = 0; i < cand.size(); ++i) + delete cand[i]; + // freelist is necessary since even after an item merged, it still stays in + // the unique set so it can't be deleted til now + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; + } + + void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) { + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (cs->count(&query_unique) == 0) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, D, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + assert(cs->insert(new_cand).second); // insert into uniqueness set, sanity check + } + } + } + } + + const Hypergraph& in; + Hypergraph& out; + + vector D; // maps nodes in in-HG to the + // equivalent nodes (many due to state + // splits) in the out-HG. + const int pop_limit_; +}; + +// each node in the graph has one of these, it keeps track of +void MaxTrans(const Hypergraph& in, + int beam_size) { + Hypergraph out; + MaxTransBeamSearch ma(in, beam_size, &out); + ma.Apply(); +} + +} diff --git a/src/parser_test.cc b/src/parser_test.cc new file mode 100644 index 00000000..da1fbd89 --- /dev/null +++ b/src/parser_test.cc @@ -0,0 +1,35 @@ +#include +#include +#include +#include +#include +#include "hg.h" +#include "trule.h" +#include "bottom_up_parser.h" +#include "tdict.h" + +using namespace std; + +class ChartTest : public testing::Test { + protected: + virtual void SetUp() { } + virtual void TearDown() { } +}; + +TEST_F(ChartTest,LanguageModel) { + LatticeArc a(TD::Convert("ein"), 0.0, 1); + LatticeArc b(TD::Convert("haus"), 0.0, 1); + Lattice lattice(2); + lattice[0].push_back(a); + lattice[1].push_back(b); + Hypergraph forest; + GrammarPtr g(new TextGrammar); + vector grammars(1, g); + ExhaustiveBottomUpParser parser("PHRASE", grammars); + parser.Parse(lattice, &forest); +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/phrasebased_translator.cc b/src/phrasebased_translator.cc new file mode 100644 index 00000000..5eb70876 --- /dev/null +++ b/src/phrasebased_translator.cc @@ -0,0 +1,206 @@ +#include "phrasebased_translator.h" + +#include +#include +#include +#include + +#include +#include + +#include "sentence_metadata.h" +#include "tdict.h" +#include "hg.h" +#include "filelib.h" +#include "lattice.h" +#include "phrasetable_fst.h" +#include "array2d.h" + +using namespace std; +using namespace std::tr1; +using namespace boost::tuples; + +struct Coverage : public vector { + explicit Coverage(int n, bool v = false) : vector(n, v), first_gap() {} + void Cover(int i, int j) { + vector::iterator it = this->begin() + i; + vector::iterator end = this->begin() + j; + while (it != end) + *it++ = true; + if (first_gap == i) { + first_gap = j; + it = end; + while (*it && it != this->end()) { + ++it; + ++first_gap; + } + } + } + bool Collides(int i, int j) const { + vector::const_iterator it = this->begin() + i; + vector::const_iterator end = this->begin() + j; + while (it != end) + if (*it++) return true; + return false; + } + int GetFirstGap() const { return first_gap; } + private: + int first_gap; +}; +struct CoverageHash { + size_t operator()(const Coverage& cov) const { + return hasher_(static_cast&>(cov)); + } + private: + boost::hash > hasher_; +}; +ostream& operator<<(ostream& os, const Coverage& cov) { + os << '['; + for (int i = 0; i < cov.size(); ++i) + os << (cov[i] ? '*' : '.'); + return os << " gap=" << cov.GetFirstGap() << ']'; +} + +typedef unordered_map CoverageNodeMap; +typedef unordered_set UniqueCoverageSet; + +struct PhraseBasedTranslatorImpl { + PhraseBasedTranslatorImpl(const boost::program_options::variables_map& conf) : + add_pass_through_rules(conf.count("add_pass_through_rules")), + max_distortion(conf["pb_max_distortion"].as()), + kSOURCE_RULE(new TRule("[X] ||| [X,1] ||| [X,1]", true)), + kCONCAT_RULE(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]", true)), + kNT_TYPE(TD::Convert("X") * -1) { + assert(max_distortion >= 0); + vector gfiles = conf["grammar"].as >(); + assert(gfiles.size() == 1); + cerr << "Reading phrasetable from " << gfiles.front() << endl; + ReadFile in(gfiles.front()); + fst.reset(LoadTextPhrasetable(in.stream())); + } + + struct State { + State(const Coverage& c, int _i, int _j, const FSTNode* q) : + coverage(c), i(_i), j(_j), fst(q) {} + Coverage coverage; + int i; + int j; + const FSTNode* fst; + }; + + // we keep track of unique coverages that have been extended since it's + // possible to "extend" the same coverage twice, e.g. translate "a b c" + // with phrases "a" "b" "a b" and "c". There are two ways to cover "a b" + void EnqueuePossibleContinuations(const Coverage& coverage, queue* q, UniqueCoverageSet* ucs) { + if (ucs->insert(coverage).second) { + const int gap = coverage.GetFirstGap(); + const int end = min(static_cast(coverage.size()), gap + max_distortion + 1); + for (int i = gap; i < end; ++i) + if (!coverage[i]) q->push(State(coverage, i, i, fst.get())); + } + } + + bool Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest) { + Lattice lattice; + LatticeTools::ConvertTextOrPLF(input, &lattice); + smeta->SetSourceLength(lattice.size()); + size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion); + minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100); + if (add_pass_through_rules) { + SparseVector feats; + feats.set_value(FD::Convert("PassThrough"), 1); + for (int i = 0; i < lattice.size(); ++i) { + const vector& arcs = lattice[i]; + for (int j = 0; j < arcs.size(); ++j) { + fst->AddPassThroughTranslation(arcs[j].label, feats); + // TODO handle lattice edge features + } + } + } + CoverageNodeMap c; + queue q; + UniqueCoverageSet ucs; + const Coverage empty_cov(lattice.size(), false); + const Coverage goal_cov(lattice.size(), true); + EnqueuePossibleContinuations(empty_cov, &q, &ucs); + c[empty_cov] = 0; // have to handle the left edge specially + while(!q.empty()) { + const State s = q.front(); + q.pop(); + // cerr << "(" << s.i << "," << s.j << " ptr=" << s.fst << ") cov=" << s.coverage << endl; + const vector& arcs = lattice[s.j]; + if (s.fst->HasData()) { + Coverage new_cov = s.coverage; + new_cov.Cover(s.i, s.j); + EnqueuePossibleContinuations(new_cov, &q, &ucs); + const vector& phrases = s.fst->GetTranslations()->GetRules(); + const int phrase_head_index = minus_lm_forest->AddNode(kNT_TYPE)->id_; + for (int i = 0; i < phrases.size(); ++i) { + Hypergraph::Edge* edge = minus_lm_forest->AddEdge(phrases[i], Hypergraph::TailNodeVector()); + edge->feature_values_ = edge->rule_->scores_; + minus_lm_forest->ConnectEdgeToHeadNode(edge->id_, phrase_head_index); + } + CoverageNodeMap::iterator cit = c.find(s.coverage); + assert(cit != c.end()); + const int tail_node_plus1 = cit->second; + if (tail_node_plus1 == 0) { // left edge + c[new_cov] = phrase_head_index + 1; + } else { // not left edge + int& head_node_plus1 = c[new_cov]; + if (!head_node_plus1) + head_node_plus1 = minus_lm_forest->AddNode(kNT_TYPE)->id_ + 1; + Hypergraph::TailNodeVector tail(2, tail_node_plus1 - 1); + tail[1] = phrase_head_index; + const int concat_edge = minus_lm_forest->AddEdge(kCONCAT_RULE, tail)->id_; + minus_lm_forest->ConnectEdgeToHeadNode(concat_edge, head_node_plus1 - 1); + } + } + if (s.j == lattice.size()) continue; + for (int l = 0; l < arcs.size(); ++l) { + const LatticeArc& arc = arcs[l]; + + const FSTNode* next_fst_state = s.fst->Extend(arc.label); + const int next_j = s.j + arc.dist2next; + if (next_fst_state && + !s.coverage.Collides(s.i, next_j)) { + q.push(State(s.coverage, s.i, next_j, next_fst_state)); + } + } + } + if (add_pass_through_rules) + fst->ClearPassThroughTranslations(); + int pregoal_plus1 = c[goal_cov]; + if (pregoal_plus1 > 0) { + TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [X,1]")); + int goal = minus_lm_forest->AddNode(TD::Convert("Goal") * -1)->id_; + int gedge = minus_lm_forest->AddEdge(kGOAL_RULE, Hypergraph::TailNodeVector(1, pregoal_plus1 - 1))->id_; + minus_lm_forest->ConnectEdgeToHeadNode(gedge, goal); + // they are almost topo, but not quite always + minus_lm_forest->TopologicallySortNodesAndEdges(goal); + minus_lm_forest->Reweight(weights); + return true; + } else { + return false; // composition failed + } + } + + const bool add_pass_through_rules; + const int max_distortion; + TRulePtr kSOURCE_RULE; + const TRulePtr kCONCAT_RULE; + const WordID kNT_TYPE; + boost::shared_ptr fst; +}; + +PhraseBasedTranslator::PhraseBasedTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new PhraseBasedTranslatorImpl(conf)) {} + +bool PhraseBasedTranslator::Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest) { + return pimpl_->Translate(input, smeta, weights, minus_lm_forest); +} diff --git a/src/phrasebased_translator.h b/src/phrasebased_translator.h new file mode 100644 index 00000000..d42ce79c --- /dev/null +++ b/src/phrasebased_translator.h @@ -0,0 +1,18 @@ +#ifndef _PHRASEBASED_TRANSLATOR_H_ +#define _PHRASEBASED_TRANSLATOR_H_ + +#include "translator.h" + +class PhraseBasedTranslatorImpl; +class PhraseBasedTranslator : public Translator { + public: + PhraseBasedTranslator(const boost::program_options::variables_map& conf); + bool Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest); + private: + boost::shared_ptr pimpl_; +}; + +#endif diff --git a/src/phrasetable_fst.cc b/src/phrasetable_fst.cc new file mode 100644 index 00000000..f421e941 --- /dev/null +++ b/src/phrasetable_fst.cc @@ -0,0 +1,141 @@ +#include "phrasetable_fst.h" + +#include +#include +#include + +#include + +#include "filelib.h" +#include "tdict.h" + +using boost::shared_ptr; +using namespace std; + +TargetPhraseSet::~TargetPhraseSet() {} +FSTNode::~FSTNode() {} + +class TextTargetPhraseSet : public TargetPhraseSet { + public: + void AddRule(TRulePtr rule) { + rules_.push_back(rule); + } + const vector& GetRules() const { + return rules_; + } + + private: + // all rules must have arity 0 + vector rules_; +}; + +class TextFSTNode : public FSTNode { + public: + const TargetPhraseSet* GetTranslations() const { return data.get(); } + bool HasData() const { return (bool)data; } + bool HasOutgoingNonEpsilonEdges() const { return !ptr.empty(); } + const FSTNode* Extend(const WordID& t) const { + map::const_iterator it = ptr.find(t); + if (it == ptr.end()) return NULL; + return &it->second; + } + + void AddPhrase(const string& phrase); + + void AddPassThroughTranslation(const WordID& w, const SparseVector& feats); + void ClearPassThroughTranslations(); + private: + vector passthroughs; + shared_ptr data; + map ptr; +}; + +#ifdef DEBUG_CHART_PARSER +static string TrimRule(const string& r) { + size_t start = r.find(" |||") + 5; + size_t end = r.rfind(" |||"); + return r.substr(start, end - start); +} +#endif + +void TextFSTNode::AddPhrase(const string& phrase) { + vector words; + TRulePtr rule(TRule::CreateRulePhrasetable(phrase)); + if (!rule) { + static int err = 0; + ++err; + if (err > 2) { cerr << "TOO MANY PHRASETABLE ERRORS\n"; exit(1); } + return; + } + + TextFSTNode* fsa = this; + for (int i = 0; i < rule->FLength(); ++i) + fsa = &fsa->ptr[rule->f_[i]]; + + if (!fsa->data) + fsa->data.reset(new TextTargetPhraseSet); + static_cast(fsa->data.get())->AddRule(rule); +} + +void TextFSTNode::AddPassThroughTranslation(const WordID& w, const SparseVector& feats) { + TextFSTNode* next = &ptr[w]; + // current, rules are only added if the symbol is completely missing as a + // word starting the phrase. As a result, it is possible that some sentences + // won't parse. If this becomes a problem, fix it here. + if (!next->data) { + TextTargetPhraseSet* tps = new TextTargetPhraseSet; + next->data.reset(tps); + TRule* rule = new TRule; + rule->e_.resize(1, w); + rule->f_.resize(1, w); + rule->lhs_ = TD::Convert("___PHRASE") * -1; + rule->scores_ = feats; + rule->arity_ = 0; + tps->AddRule(TRulePtr(rule)); + passthroughs.push_back(w); + } +} + +void TextFSTNode::ClearPassThroughTranslations() { + for (int i = 0; i < passthroughs.size(); ++i) + ptr.erase(passthroughs[i]); + passthroughs.clear(); +} + +static void AddPhrasetableToFST(istream* in, TextFSTNode* fst) { + int lc = 0; + bool flag = false; + while(*in) { + string line; + getline(*in, line); + if (line.empty()) continue; + ++lc; + fst->AddPhrase(line); + if (lc % 10000 == 0) { flag = true; cerr << '.' << flush; } + if (lc % 500000 == 0) { flag = false; cerr << " [" << lc << ']' << endl << flush; } + } + if (flag) cerr << endl; + cerr << "Loaded " << lc << " source phrases\n"; +} + +FSTNode* LoadTextPhrasetable(istream* in) { + TextFSTNode *fst = new TextFSTNode; + AddPhrasetableToFST(in, fst); + return fst; +} + +FSTNode* LoadTextPhrasetable(const vector& filenames) { + TextFSTNode* fst = new TextFSTNode; + for (int i = 0; i < filenames.size(); ++i) { + ReadFile rf(filenames[i]); + cerr << "Reading phrase from " << filenames[i] << endl; + AddPhrasetableToFST(rf.stream(), fst); + } + return fst; +} + +FSTNode* LoadBinaryPhrasetable(const string& fname_prefix) { + (void) fname_prefix; + assert(!"not implemented yet"); +} + diff --git a/src/phrasetable_fst.h b/src/phrasetable_fst.h new file mode 100644 index 00000000..477de1f7 --- /dev/null +++ b/src/phrasetable_fst.h @@ -0,0 +1,34 @@ +#ifndef _PHRASETABLE_FST_H_ +#define _PHRASETABLE_FST_H_ + +#include +#include + +#include "sparse_vector.h" +#include "trule.h" + +class TargetPhraseSet { + public: + virtual ~TargetPhraseSet(); + virtual const std::vector& GetRules() const = 0; +}; + +class FSTNode { + public: + virtual ~FSTNode(); + virtual const TargetPhraseSet* GetTranslations() const = 0; + virtual bool HasData() const = 0; + virtual bool HasOutgoingNonEpsilonEdges() const = 0; + virtual const FSTNode* Extend(const WordID& t) const = 0; + + // these should only be called on q_0: + virtual void AddPassThroughTranslation(const WordID& w, const SparseVector& feats) = 0; + virtual void ClearPassThroughTranslations() = 0; +}; + +// attn caller: you own the memory +FSTNode* LoadTextPhrasetable(const std::vector& filenames); +FSTNode* LoadTextPhrasetable(std::istream* in); +FSTNode* LoadBinaryPhrasetable(const std::string& fname_prefix); + +#endif diff --git a/src/prob.h b/src/prob.h new file mode 100644 index 00000000..bc297870 --- /dev/null +++ b/src/prob.h @@ -0,0 +1,8 @@ +#ifndef _PROB_H_ +#define _PROB_H_ + +#include "logval.h" + +typedef LogVal prob_t; + +#endif diff --git a/src/sampler.h b/src/sampler.h new file mode 100644 index 00000000..e5840f41 --- /dev/null +++ b/src/sampler.h @@ -0,0 +1,136 @@ +#ifndef SAMPLER_H_ +#define SAMPLER_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "prob.h" + +struct SampleSet; + +template +struct RandomNumberGenerator { + static uint32_t GetTrulyRandomSeed() { + uint32_t seed; + std::ifstream r("/dev/urandom"); + if (r) { + r.read((char*)&seed,sizeof(uint32_t)); + } + if (r.fail() || !r) { + std::cerr << "Warning: could not read from /dev/urandom. Seeding from clock" << std::endl; + seed = time(NULL); + } + std::cerr << "Seeding random number sequence to " << seed << std::endl; + return seed; + } + + RandomNumberGenerator() : m_dist(0,1), m_generator(), m_random(m_generator,m_dist) { + uint32_t seed = GetTrulyRandomSeed(); + m_generator.seed(seed); + } + explicit RandomNumberGenerator(uint32_t seed) : m_dist(0,1), m_generator(), m_random(m_generator,m_dist) { + if (!seed) seed = GetTrulyRandomSeed(); + m_generator.seed(seed); + } + + size_t SelectSample(const prob_t& a, const prob_t& b, double T = 1.0) { + if (T == 1.0) { + if (this->next() > (a / (a + b))) return 1; else return 0; + } else { + assert(!"not implemented"); + } + } + + // T is the annealing temperature, if desired + size_t SelectSample(const SampleSet& ss, double T = 1.0); + + // draw a value from U(0,1) + double next() {return m_random();} + + // draw a value from N(mean,var) + double NextNormal(double mean, double var) { + return boost::normal_distribution(mean, var)(m_random); + } + + // draw a value from a Poisson distribution + // lambda must be greater than 0 + int NextPoisson(int lambda) { + return boost::poisson_distribution(lambda)(m_random); + } + + bool AcceptMetropolisHastings(const prob_t& p_cur, + const prob_t& p_prev, + const prob_t& q_cur, + const prob_t& q_prev) { + const prob_t a = (p_cur / p_prev) * (q_prev / q_cur); + if (log(a) >= 0.0) return true; + return (prob_t(this->next()) < a); + } + + private: + boost::uniform_real<> m_dist; + RNG m_generator; + boost::variate_generator > m_random; +}; + +typedef RandomNumberGenerator MT19937; + +class SampleSet { + public: + const prob_t& operator[](int i) const { return m_scores[i]; } + bool empty() const { return m_scores.empty(); } + void add(const prob_t& s) { m_scores.push_back(s); } + void clear() { m_scores.clear(); } + size_t size() const { return m_scores.size(); } + std::vector m_scores; +}; + +template +size_t RandomNumberGenerator::SelectSample(const SampleSet& ss, double T) { + assert(T > 0.0); + assert(ss.m_scores.size() > 0); + if (ss.m_scores.size() == 1) return 0; + const prob_t annealing_factor(1.0 / T); + const bool anneal = (annealing_factor != prob_t::One()); + prob_t sum = prob_t::Zero(); + if (anneal) { + for (int i = 0; i < ss.m_scores.size(); ++i) + sum += ss.m_scores[i].pow(annealing_factor); // p^(1/T) + } else { + sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), prob_t::Zero()); + } + //for (size_t i = 0; i < ss.m_scores.size(); ++i) std::cerr << ss.m_scores[i] << ","; + //std::cerr << std::endl; + + prob_t random(this->next()); // random number between 0 and 1 + random *= sum; // scale with normalization factor + //std::cerr << "Random number " << random << std::endl; + + //now figure out which sample + size_t position = 1; + sum = ss.m_scores[0]; + if (anneal) { + sum.poweq(annealing_factor); + for (; position < ss.m_scores.size() && sum < random; ++position) + sum += ss.m_scores[position].pow(annealing_factor); + } else { + for (; position < ss.m_scores.size() && sum < random; ++position) + sum += ss.m_scores[position]; + } + //std::cout << "random: " << random << " sample: " << position << std::endl; + //std::cerr << "Sample: " << position-1 << std::endl; + //exit(1); + return position-1; +} + +#endif diff --git a/src/scfg_translator.cc b/src/scfg_translator.cc new file mode 100644 index 00000000..03602c6b --- /dev/null +++ b/src/scfg_translator.cc @@ -0,0 +1,66 @@ +#include "translator.h" + +#include + +#include "hg.h" +#include "grammar.h" +#include "bottom_up_parser.h" +#include "sentence_metadata.h" + +using namespace std; + +Translator::~Translator() {} + +struct SCFGTranslatorImpl { + SCFGTranslatorImpl(const boost::program_options::variables_map& conf) : + max_span_limit(conf["scfg_max_span_limit"].as()), + add_pass_through_rules(conf.count("add_pass_through_rules")), + goal(conf["goal"].as()), + default_nt(conf["scfg_default_nt"].as()) { + vector gfiles = conf["grammar"].as >(); + for (int i = 0; i < gfiles.size(); ++i) { + cerr << "Reading SCFG grammar from " << gfiles[i] << endl; + TextGrammar* g = new TextGrammar(gfiles[i]); + g->SetMaxSpan(max_span_limit); + grammars.push_back(GrammarPtr(g)); + } + if (!conf.count("scfg_no_hiero_glue_grammar")) + grammars.push_back(GrammarPtr(new GlueGrammar(goal, default_nt))); + if (conf.count("scfg_extra_glue_grammar")) + grammars.push_back(GrammarPtr(new GlueGrammar(conf["scfg_extra_glue_grammar"].as()))); + } + + const int max_span_limit; + const bool add_pass_through_rules; + const string goal; + const string default_nt; + vector grammars; + + bool Translate(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* forest) { + vector glist = grammars; + Lattice lattice; + LatticeTools::ConvertTextOrPLF(input, &lattice); + smeta->SetSourceLength(lattice.size()); + if (add_pass_through_rules) + glist.push_back(GrammarPtr(new PassThroughGrammar(lattice, default_nt))); + ExhaustiveBottomUpParser parser(goal, glist); + if (!parser.Parse(lattice, forest)) + return false; + forest->Reweight(weights); + return true; + } +}; + +SCFGTranslator::SCFGTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new SCFGTranslatorImpl(conf)) {} + +bool SCFGTranslator::Translate(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* minus_lm_forest) { + return pimpl_->Translate(input, smeta, weights, minus_lm_forest); +} + diff --git a/src/sentence_metadata.h b/src/sentence_metadata.h new file mode 100644 index 00000000..0178f1f5 --- /dev/null +++ b/src/sentence_metadata.h @@ -0,0 +1,42 @@ +#ifndef _SENTENCE_METADATA_H_ +#define _SENTENCE_METADATA_H_ + +#include +#include "lattice.h" + +struct SentenceMetadata { + SentenceMetadata(int id, const Lattice& ref) : + sent_id_(id), + src_len_(-1), + has_reference_(ref.size() > 0), + trg_len_(ref.size()), + ref_(has_reference_ ? &ref : NULL) {} + + // this should be called by the Translator object after + // it has parsed the source + void SetSourceLength(int sl) { src_len_ = sl; } + + // this should be called if a separate model needs to + // specify how long the target sentence should be + void SetTargetLength(int tl) { + assert(!has_reference_); + trg_len_ = tl; + } + bool HasReference() const { return has_reference_; } + const Lattice& GetReference() const { return *ref_; } + int GetSourceLength() const { return src_len_; } + int GetTargetLength() const { return trg_len_; } + int GetSentenceID() const { return sent_id_; } + + private: + const int sent_id_; + int src_len_; + + // you need to be very careful when depending on these values + // they will only be set during training / alignment contexts + const bool has_reference_; + int trg_len_; + const Lattice* const ref_; +}; + +#endif diff --git a/src/small_vector.h b/src/small_vector.h new file mode 100644 index 00000000..800c1df1 --- /dev/null +++ b/src/small_vector.h @@ -0,0 +1,187 @@ +#ifndef _SMALL_VECTOR_H_ + +#include // std::max - where to get this? +#include +#include + +#define __SV_MAX_STATIC 2 + +class SmallVector { + + public: + SmallVector() : size_(0) {} + + explicit SmallVector(size_t s, int v = 0) : size_(s) { + assert(s < 0x80); + if (s <= __SV_MAX_STATIC) { + for (int i = 0; i < s; ++i) data_.vals[i] = v; + } else { + capacity_ = s; + size_ = s; + data_.ptr = new int[s]; + for (int i = 0; i < size_; ++i) data_.ptr[i] = v; + } + } + + SmallVector(const SmallVector& o) : size_(o.size_) { + if (size_ <= __SV_MAX_STATIC) { + for (int i = 0; i < __SV_MAX_STATIC; ++i) data_.vals[i] = o.data_.vals[i]; + } else { + capacity_ = size_ = o.size_; + data_.ptr = new int[capacity_]; + std::memcpy(data_.ptr, o.data_.ptr, size_ * sizeof(int)); + } + } + + const SmallVector& operator=(const SmallVector& o) { + if (size_ <= __SV_MAX_STATIC) { + if (o.size_ <= __SV_MAX_STATIC) { + size_ = o.size_; + for (int i = 0; i < __SV_MAX_STATIC; ++i) data_.vals[i] = o.data_.vals[i]; + } else { + capacity_ = size_ = o.size_; + data_.ptr = new int[capacity_]; + std::memcpy(data_.ptr, o.data_.ptr, size_ * sizeof(int)); + } + } else { + if (o.size_ <= __SV_MAX_STATIC) { + delete[] data_.ptr; + size_ = o.size_; + for (int i = 0; i < size_; ++i) data_.vals[i] = o.data_.vals[i]; + } else { + if (capacity_ < o.size_) { + delete[] data_.ptr; + capacity_ = o.size_; + data_.ptr = new int[capacity_]; + } + size_ = o.size_; + for (int i = 0; i < size_; ++i) + data_.ptr[i] = o.data_.ptr[i]; + } + } + return *this; + } + + ~SmallVector() { + if (size_ <= __SV_MAX_STATIC) return; + delete[] data_.ptr; + } + + void clear() { + if (size_ > __SV_MAX_STATIC) { + delete[] data_.ptr; + } + size_ = 0; + } + + bool empty() const { return size_ == 0; } + size_t size() const { return size_; } + + inline void ensure_capacity(unsigned char min_size) { + assert(min_size > __SV_MAX_STATIC); + if (min_size < capacity_) return; + unsigned char new_cap = std::max(static_cast(capacity_ << 1), min_size); + int* tmp = new int[new_cap]; + std::memcpy(tmp, data_.ptr, capacity_ * sizeof(int)); + delete[] data_.ptr; + data_.ptr = tmp; + capacity_ = new_cap; + } + + inline void copy_vals_to_ptr() { + capacity_ = __SV_MAX_STATIC * 2; + int* tmp = new int[capacity_]; + for (int i = 0; i < __SV_MAX_STATIC; ++i) tmp[i] = data_.vals[i]; + data_.ptr = tmp; + } + + inline void push_back(int v) { + if (size_ < __SV_MAX_STATIC) { + data_.vals[size_] = v; + ++size_; + return; + } else if (size_ == __SV_MAX_STATIC) { + copy_vals_to_ptr(); + } else if (size_ == capacity_) { + ensure_capacity(size_ + 1); + } + data_.ptr[size_] = v; + ++size_; + } + + int& back() { return this->operator[](size_ - 1); } + const int& back() const { return this->operator[](size_ - 1); } + int& front() { return this->operator[](0); } + const int& front() const { return this->operator[](0); } + + void resize(size_t s, int v = 0) { + if (s <= __SV_MAX_STATIC) { + if (size_ > __SV_MAX_STATIC) { + int tmp[__SV_MAX_STATIC]; + for (int i = 0; i < s; ++i) tmp[i] = data_.ptr[i]; + delete[] data_.ptr; + for (int i = 0; i < s; ++i) data_.vals[i] = tmp[i]; + size_ = s; + return; + } + if (s <= size_) { + size_ = s; + return; + } else { + for (int i = size_; i < s; ++i) + data_.vals[i] = v; + size_ = s; + return; + } + } else { + if (size_ <= __SV_MAX_STATIC) + copy_vals_to_ptr(); + if (s > capacity_) + ensure_capacity(s); + if (s > size_) { + for (int i = size_; i < s; ++i) + data_.ptr[i] = v; + } + size_ = s; + } + } + + int& operator[](size_t i) { + if (size_ <= __SV_MAX_STATIC) return data_.vals[i]; + return data_.ptr[i]; + } + + const int& operator[](size_t i) const { + if (size_ <= __SV_MAX_STATIC) return data_.vals[i]; + return data_.ptr[i]; + } + + bool operator==(const SmallVector& o) const { + if (size_ != o.size_) return false; + if (size_ <= __SV_MAX_STATIC) { + for (size_t i = 0; i < size_; ++i) + if (data_.vals[i] != o.data_.vals[i]) return false; + return true; + } else { + for (size_t i = 0; i < size_; ++i) + if (data_.ptr[i] != o.data_.ptr[i]) return false; + return true; + } + } + + private: + unsigned char capacity_; // only defined when size_ >= __SV_MAX_STATIC + unsigned char size_; + union StorageType { + int vals[__SV_MAX_STATIC]; + int* ptr; + }; + StorageType data_; + +}; + +inline bool operator!=(const SmallVector& a, const SmallVector& b) { + return !(a==b); +} + +#endif diff --git a/src/small_vector_test.cc b/src/small_vector_test.cc new file mode 100644 index 00000000..84237791 --- /dev/null +++ b/src/small_vector_test.cc @@ -0,0 +1,129 @@ +#include "small_vector.h" + +#include +#include +#include +#include + +using namespace std; + +class SVTest : public testing::Test { + protected: + virtual void SetUp() { } + virtual void TearDown() { } +}; + +TEST_F(SVTest, LargerThan2) { + SmallVector v; + SmallVector v2; + v.push_back(0); + v.push_back(1); + v.push_back(2); + assert(v.size() == 3); + assert(v[2] == 2); + assert(v[1] == 1); + assert(v[0] == 0); + v2 = v; + SmallVector copy(v); + assert(copy.size() == 3); + assert(copy[0] == 0); + assert(copy[1] == 1); + assert(copy[2] == 2); + assert(copy == v2); + copy[1] = 99; + assert(copy != v2); + assert(v2.size() == 3); + assert(v2[2] == 2); + assert(v2[1] == 1); + assert(v2[0] == 0); + v2[0] = -2; + v2[1] = -1; + v2[2] = 0; + assert(v2[2] == 0); + assert(v2[1] == -1); + assert(v2[0] == -2); + SmallVector v3(1,1); + assert(v3[0] == 1); + v2 = v3; + assert(v2.size() == 1); + assert(v2[0] == 1); + SmallVector v4(10, 1); + assert(v4.size() == 10); + assert(v4[5] == 1); + assert(v4[9] == 1); + v4 = v; + assert(v4.size() == 3); + assert(v4[2] == 2); + assert(v4[1] == 1); + assert(v4[0] == 0); + SmallVector v5(10, 2); + assert(v5.size() == 10); + assert(v5[7] == 2); + assert(v5[0] == 2); + assert(v.size() == 3); + v = v5; + assert(v.size() == 10); + assert(v[2] == 2); + assert(v[9] == 2); + SmallVector cc; + for (int i = 0; i < 33; ++i) + cc.push_back(i); + for (int i = 0; i < 33; ++i) + assert(cc[i] == i); + cc.resize(20); + assert(cc.size() == 20); + for (int i = 0; i < 20; ++i) + assert(cc[i] == i); + cc[0]=-1; + cc.resize(1, 999); + assert(cc.size() == 1); + assert(cc[0] == -1); + cc.resize(99, 99); + for (int i = 1; i < 99; ++i) { + cerr << i << " " << cc[i] << endl; + assert(cc[i] == 99); + } + cc.clear(); + assert(cc.size() == 0); +} + +TEST_F(SVTest, Small) { + SmallVector v; + SmallVector v1(1,0); + SmallVector v2(2,10); + SmallVector v1a(2,0); + EXPECT_TRUE(v1 != v1a); + EXPECT_TRUE(v1 == v1); + EXPECT_EQ(v1[0], 0); + EXPECT_EQ(v2[1], 10); + EXPECT_EQ(v2[0], 10); + ++v2[1]; + --v2[0]; + EXPECT_EQ(v2[0], 9); + EXPECT_EQ(v2[1], 11); + SmallVector v3(v2); + assert(v3[0] == 9); + assert(v3[1] == 11); + assert(!v3.empty()); + assert(v3.size() == 2); + v3.clear(); + assert(v3.empty()); + assert(v3.size() == 0); + assert(v3 != v2); + assert(v2 != v3); + v3 = v2; + assert(v3 == v2); + assert(v2 == v3); + assert(v3[0] == 9); + assert(v3[1] == 11); + assert(!v3.empty()); + assert(v3.size() == 2); + cerr << sizeof(SmallVector) << endl; + cerr << sizeof(vector) << endl; +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + diff --git a/src/sparse_vector.cc b/src/sparse_vector.cc new file mode 100644 index 00000000..4035b9ef --- /dev/null +++ b/src/sparse_vector.cc @@ -0,0 +1,98 @@ +#include "sparse_vector.h" + +#include +#include + +#include "hg_io.h" + +using namespace std; + +namespace B64 { + +void Encode(double objective, const SparseVector& v, ostream* out) { + const int num_feats = v.num_active(); + size_t tot_size = 0; + const size_t off_objective = tot_size; + tot_size += sizeof(double); // objective + const size_t off_num_feats = tot_size; + tot_size += sizeof(int); // num_feats + const size_t off_data = tot_size; + tot_size += sizeof(unsigned char) * num_feats; // lengths of feature names; + typedef SparseVector::const_iterator const_iterator; + for (const_iterator it = v.begin(); it != v.end(); ++it) + tot_size += FD::Convert(it->first).size(); // feature names; + tot_size += sizeof(double) * num_feats; // gradient + const size_t off_magic = tot_size; + tot_size += 4; // magic + + // size_t b64_size = tot_size * 4 / 3; + // cerr << "Sparse vector binary size: " << tot_size << " (b64 size=" << b64_size << ")\n"; + char* data = new char[tot_size]; + *reinterpret_cast(&data[off_objective]) = objective; + *reinterpret_cast(&data[off_num_feats]) = num_feats; + char* cur = &data[off_data]; + assert(cur - data == off_data); + for (const_iterator it = v.begin(); it != v.end(); ++it) { + const string& fname = FD::Convert(it->first); + *cur++ = static_cast(fname.size()); // name len + memcpy(cur, &fname[0], fname.size()); + cur += fname.size(); + *reinterpret_cast(cur) = it->second; + cur += sizeof(double); + } + assert(cur - data == off_magic); + *reinterpret_cast(cur) = 0xBAABABBAu; + cur += sizeof(unsigned int); + assert(cur - data == tot_size); + b64encode(data, tot_size, out); + delete[] data; +} + +bool Decode(double* objective, SparseVector* v, const char* in, size_t size) { + v->clear(); + if (size % 4 != 0) { + cerr << "B64 error - line % 4 != 0\n"; + return false; + } + const size_t decoded_size = size * 3 / 4 - sizeof(unsigned int); + const size_t buf_size = decoded_size + sizeof(unsigned int); + if (decoded_size < 6) { cerr << "SparseVector decoding error: too short!\n"; return false; } + char* data = new char[buf_size]; + if (!b64decode(reinterpret_cast(in), size, data, buf_size)) { + delete[] data; + return false; + } + size_t cur = 0; + *objective = *reinterpret_cast(data); + cur += sizeof(double); + const int num_feats = *reinterpret_cast(&data[cur]); + cur += sizeof(int); + int fc = 0; + while(fc < num_feats && cur < decoded_size) { + ++fc; + const int fname_len = data[cur++]; + assert(fname_len > 0); + assert(fname_len < 256); + string fname(fname_len, '\0'); + memcpy(&fname[0], &data[cur], fname_len); + cur += fname_len; + const double val = *reinterpret_cast(&data[cur]); + cur += sizeof(double); + int fid = FD::Convert(fname); + v->set_value(fid, val); + } + if(num_feats != fc) { + cerr << "Expected " << num_feats << " but only decoded " << fc << "!\n"; + delete[] data; + return false; + } + if (*reinterpret_cast(&data[cur]) != 0xBAABABBAu) { + cerr << "SparseVector decodeding error : magic does not match!\n"; + delete[] data; + return false; + } + delete[] data; + return true; +} + +} diff --git a/src/sparse_vector.h b/src/sparse_vector.h new file mode 100644 index 00000000..6a8c9bf4 --- /dev/null +++ b/src/sparse_vector.h @@ -0,0 +1,264 @@ +#ifndef _SPARSE_VECTOR_H_ +#define _SPARSE_VECTOR_H_ + +// this is a modified version of code originally written +// by Phil Blunsom + +#include +#include +#include +#include + +#include "fdict.h" + +template +class SparseVector { +public: + SparseVector() {} + + const T operator[](int index) const { + typename std::map::const_iterator found = _values.find(index); + if (found == _values.end()) + return T(0); + else + return found->second; + } + + void set_value(int index, const T &value) { + _values[index] = value; + } + + void add_value(int index, const T &value) { + _values[index] += value; + } + + T value(int index) const { + typename std::map::const_iterator found = _values.find(index); + if (found != _values.end()) + return found->second; + else + return T(0); + } + + void store(std::valarray* target) const { + (*target) *= 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) { + if (it->first >= target->size()) break; + (*target)[it->first] = it->second; + } + } + + int max_index() const { + if (_values.empty()) return 0; + typename std::map::const_iterator found =_values.end(); + --found; + return found->first; + } + + // dot product with a unit vector of the same length + // as the sparse vector + T dot() const { + T sum = 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + sum += it->second; + return sum; + } + + template + S dot(const SparseVector &vec) const { + S sum = 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + { + typename std::map::const_iterator + found = vec._values.find(it->first); + if (found != vec._values.end()) + sum += it->second * found->second; + } + return sum; + } + + template + S dot(const std::vector &vec) const { + S sum = 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + { + if (it->first < static_cast(vec.size())) + sum += it->second * vec[it->first]; + } + return sum; + } + + template + S dot(const S *vec) const { + // this is not range checked! + S sum = 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + sum += it->second * vec[it->first]; + std::cout << "dot(*vec) " << sum << std::endl; + return sum; + } + + T l1norm() const { + T sum = 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + sum += fabs(it->second); + return sum; + } + + T l2norm() const { + T sum = 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + sum += it->second * it->second; + return sqrt(sum); + } + + SparseVector &operator+=(const SparseVector &other) { + for (typename std::map::const_iterator + it = other._values.begin(); it != other._values.end(); ++it) + { + T v = (_values[it->first] += it->second); + if (v == 0) + _values.erase(it->first); + } + return *this; + } + + SparseVector &operator-=(const SparseVector &other) { + for (typename std::map::const_iterator + it = other._values.begin(); it != other._values.end(); ++it) + { + T v = (_values[it->first] -= it->second); + if (v == 0) + _values.erase(it->first); + } + return *this; + } + + SparseVector &operator-=(const double &x) { + for (typename std::map::iterator + it = _values.begin(); it != _values.end(); ++it) + it->second -= x; + return *this; + } + + SparseVector &operator+=(const double &x) { + for (typename std::map::iterator + it = _values.begin(); it != _values.end(); ++it) + it->second += x; + return *this; + } + + SparseVector &operator/=(const double &x) { + for (typename std::map::iterator + it = _values.begin(); it != _values.end(); ++it) + it->second /= x; + return *this; + } + + SparseVector &operator*=(const T& x) { + for (typename std::map::iterator + it = _values.begin(); it != _values.end(); ++it) + it->second *= x; + return *this; + } + + SparseVector operator+(const double &x) const { + SparseVector result = *this; + return result += x; + } + + SparseVector operator-(const double &x) const { + SparseVector result = *this; + return result -= x; + } + + SparseVector operator/(const double &x) const { + SparseVector result = *this; + return result /= x; + } + + std::ostream &operator<<(std::ostream &out) const { + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + out << (it == _values.begin() ? "" : ";") + << FD::Convert(it->first) << '=' << it->second; + return out; + } + + bool operator<(const SparseVector &other) const { + typename std::map::const_iterator it = _values.begin(); + typename std::map::const_iterator other_it = other._values.begin(); + + for (; it != _values.end() && other_it != other._values.end(); ++it, ++other_it) + { + if (it->first < other_it->first) return true; + if (it->first > other_it->first) return false; + if (it->second < other_it->second) return true; + if (it->second > other_it->second) return false; + } + return _values.size() < other._values.size(); + } + + int num_active() const { return _values.size(); } + bool empty() const { return _values.empty(); } + + typedef typename std::map::const_iterator const_iterator; + const_iterator begin() const { return _values.begin(); } + const_iterator end() const { return _values.end(); } + + void clear() { + _values.clear(); + } + + void swap(SparseVector& other) { + _values.swap(other._values); + } + +private: + std::map _values; +}; + +template +SparseVector operator+(const SparseVector& a, const SparseVector& b) { + SparseVector result = a; + return result += b; +} + +template +SparseVector operator*(const SparseVector& a, const double& b) { + SparseVector result = a; + return result *= b; +} + +template +SparseVector operator*(const SparseVector& a, const T& b) { + SparseVector result = a; + return result *= b; +} + +template +SparseVector operator*(const double& a, const SparseVector& b) { + SparseVector result = b; + return result *= a; +} + +template +std::ostream &operator<<(std::ostream &out, const SparseVector &vec) +{ + return vec.operator<<(out); +} + +namespace B64 { + void Encode(double objective, const SparseVector& v, std::ostream* out); + // returns false if failed to decode + bool Decode(double* objective, SparseVector* v, const char* data, size_t size); +} + +#endif diff --git a/src/stringlib.cc b/src/stringlib.cc new file mode 100644 index 00000000..3ed74bef --- /dev/null +++ b/src/stringlib.cc @@ -0,0 +1,97 @@ +#include "stringlib.h" + +#include +#include +#include +#include + +#include "lattice.h" + +using namespace std; + +void ParseTranslatorInput(const string& line, string* input, string* ref) { + size_t hint = 0; + if (line.find("{\"rules\":") == 0) { + hint = line.find("}}"); + if (hint == string::npos) { + cerr << "Syntax error: " << line << endl; + abort(); + } + hint += 2; + } + size_t pos = line.find("|||", hint); + if (pos == string::npos) { *input = line; return; } + ref->clear(); + *input = line.substr(0, pos - 1); + string rline = line.substr(pos + 4); + if (rline.size() > 0) { + assert(ref); + *ref = rline; + } +} + +void ParseTranslatorInputLattice(const string& line, string* input, Lattice* ref) { + string sref; + ParseTranslatorInput(line, input, &sref); + if (sref.size() > 0) { + assert(ref); + LatticeTools::ConvertTextOrPLF(sref, ref); + } +} + +void ProcessAndStripSGML(string* pline, map* out) { + map& meta = *out; + string& line = *pline; + string lline = LowercaseString(line); + if (lline.find(""); + if (close == string::npos) return; // error + size_t end = lline.find(""); + string seg = Trim(lline.substr(4, close-4)); + string text = line.substr(close+1, end - close - 1); + for (size_t i = 1; i < seg.size(); i++) { + if (seg[i] == '=' && seg[i-1] == ' ') { + string less = seg.substr(0, i-1) + seg.substr(i); + seg = less; i = 0; continue; + } + if (seg[i] == '=' && seg[i+1] == ' ') { + string less = seg.substr(0, i+1); + if (i+2 < seg.size()) less += seg.substr(i+2); + seg = less; i = 0; continue; + } + } + line = Trim(text); + if (seg == "") return; + for (size_t i = 1; i < seg.size(); i++) { + if (seg[i] == '=') { + string label = seg.substr(0, i); + string val = seg.substr(i+1); + if (val[0] == '"') { + val = val.substr(1); + size_t close = val.find('"'); + if (close == string::npos) { + cerr << "SGML parse error: missing \"\n"; + seg = ""; + i = 0; + } else { + seg = val.substr(close+1); + val = val.substr(0, close); + i = 0; + } + } else { + size_t close = val.find(' '); + if (close == string::npos) { + seg = ""; + i = 0; + } else { + seg = val.substr(close+1); + val = val.substr(0, close); + } + } + label = Trim(label); + seg = Trim(seg); + meta[label] = val; + } + } +} + diff --git a/src/stringlib.h b/src/stringlib.h new file mode 100644 index 00000000..d26952c7 --- /dev/null +++ b/src/stringlib.h @@ -0,0 +1,91 @@ +#ifndef _STRINGLIB_H_ + +#include +#include +#include +#include + +// read line in the form of either: +// source +// source ||| target +// source will be returned as a string, target must be a sentence or +// a lattice (in PLF format) and will be returned as a Lattice object +void ParseTranslatorInput(const std::string& line, std::string* input, std::string* ref); +struct Lattice; +void ParseTranslatorInputLattice(const std::string& line, std::string* input, Lattice* ref); + +inline const std::string Trim(const std::string& str, const std::string& dropChars = " \t") { + std::string res = str; + res.erase(str.find_last_not_of(dropChars)+1); + return res.erase(0, res.find_first_not_of(dropChars)); +} + +inline void Tokenize(const std::string& str, char delimiter, std::vector* res) { + std::string s = str; + int last = 0; + res->clear(); + for (int i=0; i < s.size(); ++i) + if (s[i] == delimiter) { + s[i]=0; + if (last != i) { + res->push_back(&s[last]); + } + last = i + 1; + } + if (last != s.size()) + res->push_back(&s[last]); +} + +inline std::string LowercaseString(const std::string& in) { + std::string res(in.size(),' '); + for (int i = 0; i < in.size(); ++i) + res[i] = tolower(in[i]); + return res; +} + +inline int CountSubstrings(const std::string& str, const std::string& sub) { + size_t p = 0; + int res = 0; + while (p < str.size()) { + p = str.find(sub, p); + if (p == std::string::npos) break; + ++res; + p += sub.size(); + } + return res; +} + +inline int SplitOnWhitespace(const std::string& in, std::vector* out) { + out->clear(); + int i = 0; + int start = 0; + std::string cur; + while(i < in.size()) { + if (in[i] == ' ' || in[i] == '\t') { + if (i - start > 0) + out->push_back(in.substr(start, i - start)); + start = i + 1; + } + ++i; + } + if (i > start) + out->push_back(in.substr(start, i - start)); + return out->size(); +} + +inline void SplitCommandAndParam(const std::string& in, std::string* cmd, std::string* param) { + cmd->clear(); + param->clear(); + std::vector x; + SplitOnWhitespace(in, &x); + if (x.size() == 0) return; + *cmd = x[0]; + for (int i = 1; i < x.size(); ++i) { + if (i > 1) { *param += " "; } + *param += x[i]; + } +} + +void ProcessAndStripSGML(std::string* line, std::map* out); + +#endif diff --git a/src/synparse.cc b/src/synparse.cc new file mode 100644 index 00000000..96588f1e --- /dev/null +++ b/src/synparse.cc @@ -0,0 +1,212 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "prob.h" +#include "tdict.h" +#include "filelib.h" + +using namespace std; +using namespace __gnu_cxx; +namespace po = boost::program_options; + +const prob_t kMONO(1.0); // 0.6 +const prob_t kINV(1.0); // 0.1 +const prob_t kLEX(1.0); // 0.3 + +typedef hash_map, hash_map, prob_t, boost::hash > >, boost::hash > > PTable; +typedef boost::multi_array CChart; +typedef pair SpanType; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("phrasetable,p",po::value(), "[REQD] Phrase pairs for ITG alignment") + ("input,i",po::value()->default_value("-"), "Input file") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (!conf->count("phrasetable")) { + cerr << "Please specify a grammar file with -p \n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void LoadITGPhrasetable(const string& fname, PTable* ptable) { + const WordID sep = TD::Convert("|||"); + ReadFile rf(fname); + istream& in = *rf.stream(); + assert(in); + int lc = 0; + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + ++lc; + vector full, f, e; + TD::ConvertSentence(line, &full); + int i = 0; + for (; i < full.size(); ++i) { + if (full[i] == sep) break; + f.push_back(full[i]); + } + ++i; + for (; i < full.size(); ++i) { + if (full[i] == sep) break; + e.push_back(full[i]); + } + ++i; + prob_t prob(0.000001); + if (i < full.size()) { prob = prob_t(atof(TD::Convert(full[i]))); ++i; } + + if (i < full.size()) { cerr << "Warning line " << lc << " has extra stuff.\n"; } + assert(f.size() > 0); + assert(e.size() > 0); + (*ptable)[f][e] = prob; + } + cerr << "Read " << lc << " phrase pairs\n"; +} + +void FindPhrases(const vector& e, const vector& f, const PTable& pt, CChart* pcc) { + CChart& cc = *pcc; + const size_t n = f.size(); + const size_t m = e.size(); + typedef hash_map, vector, boost::hash > > PhraseToSpan; + PhraseToSpan e_locations; + for (int i = 0; i < m; ++i) { + const int mel = m - i; + vector e_phrase; + for (int el = 0; el < mel; ++el) { + e_phrase.push_back(e[i + el]); + e_locations[e_phrase].push_back(make_pair(i, i + el + 1)); + } + } + //cerr << "Cached the locations of " << e_locations.size() << " e-phrases\n"; + + for (int s = 0; s < n; ++s) { + const int mfl = n - s; + vector f_phrase; + for (int fl = 0; fl < mfl; ++fl) { + f_phrase.push_back(f[s + fl]); + PTable::const_iterator it = pt.find(f_phrase); + if (it == pt.end()) continue; + const hash_map, prob_t, boost::hash > >& es = it->second; + for (hash_map, prob_t, boost::hash > >::const_iterator eit = es.begin(); eit != es.end(); ++eit) { + PhraseToSpan::iterator loc = e_locations.find(eit->first); + if (loc == e_locations.end()) continue; + const vector& espans = loc->second; + for (int j = 0; j < espans.size(); ++j) { + cc[s][s + fl + 1][espans[j].first][espans[j].second] = eit->second; + //cerr << '[' << s << ',' << (s + fl + 1) << ',' << espans[j].first << ',' << espans[j].second << "] is C\n"; + } + } + } + } +} + +long long int evals = 0; + +void ProcessSynchronousCell(const int s, + const int t, + const int u, + const int v, + const prob_t& lex, + const prob_t& mono, + const prob_t& inv, + const CChart& tc, CChart* ntc) { + prob_t& inside = (*ntc)[s][t][u][v]; + // cerr << log(tc[s][t][u][v]) << " + " << log(lex) << endl; + inside = tc[s][t][u][v] * lex; + // cerr << " terminal span: " << log(inside) << endl; + if (t - s == 1) return; + if (v - u == 1) return; + for (int x = s+1; x < t; ++x) { + for (int y = u+1; y < v; ++y) { + const prob_t m = (*ntc)[s][x][u][y] * (*ntc)[x][t][y][v] * mono; + const prob_t i = (*ntc)[s][x][y][v] * (*ntc)[x][t][u][y] * inv; + // cerr << log(i) << "\t" << log(m) << endl; + inside += m; + inside += i; + evals++; + } + } + // cerr << " span: " << log(inside) << endl; +} + +prob_t SynchronousParse(const int n, const int m, const prob_t& lex, const prob_t& mono, const prob_t& inv, const CChart& tc, CChart* ntc) { + for (int fl = 0; fl < n; ++fl) { + for (int el = 0; el < m; ++el) { + const int ms = n - fl; + for (int s = 0; s < ms; ++s) { + const int t = s + fl + 1; + const int mu = m - el; + for (int u = 0; u < mu; ++u) { + const int v = u + el + 1; + //cerr << "Processing cell [" << s << ',' << t << ',' << u << ',' << v << "]\n"; + ProcessSynchronousCell(s, t, u, v, lex, mono, inv, tc, ntc); + } + } + } + } + return (*ntc)[0][n][0][m]; +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + PTable ptable; + LoadITGPhrasetable(conf["phrasetable"].as(), &ptable); + ReadFile rf(conf["input"].as()); + istream& in = *rf.stream(); + int lc = 0; + const WordID sep = TD::Convert("|||"); + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + ++lc; + vector full, f, e; + TD::ConvertSentence(line, &full); + int i = 0; + for (; i < full.size(); ++i) { + if (full[i] == sep) break; + f.push_back(full[i]); + } + ++i; + for (; i < full.size(); ++i) { + if (full[i] == sep) break; + e.push_back(full[i]); + } + if (e.empty()) cerr << "E is empty!\n"; + if (f.empty()) cerr << "F is empty!\n"; + if (e.empty() || f.empty()) continue; + int n = f.size(); + int m = e.size(); + cerr << "Synchronous chart has " << (n * n * m * m) << " cells\n"; + clock_t start = clock(); + CChart cc(boost::extents[n+1][n+1][m+1][m+1]); + FindPhrases(e, f, ptable, &cc); + CChart ntc(boost::extents[n+1][n+1][m+1][m+1]); + prob_t likelihood = SynchronousParse(n, m, kLEX, kMONO, kINV, cc, &ntc); + clock_t end = clock(); + cerr << "log Z: " << log(likelihood) << endl; + cerr << " Z: " << likelihood << endl; + double etime = (end - start) / 1000000.0; + cout << " time: " << etime << endl; + cout << "evals: " << evals << endl; + } + return 0; +} + diff --git a/src/tdict.cc b/src/tdict.cc new file mode 100644 index 00000000..c00d20b8 --- /dev/null +++ b/src/tdict.cc @@ -0,0 +1,49 @@ +#include "Ngram.h" +#include "dict.h" +#include "tdict.h" +#include "Vocab.h" + +using namespace std; + +Vocab* TD::dict_ = new Vocab; + +static const string empty; +static const string space = " "; + +WordID TD::Convert(const std::string& s) { + return dict_->addWord((VocabString)s.c_str()); +} + +const char* TD::Convert(const WordID& w) { + return dict_->getWord((VocabIndex)w); +} + +void TD::GetWordIDs(const std::vector& strings, std::vector* ids) { + ids->clear(); + for (vector::const_iterator i = strings.begin(); i != strings.end(); ++i) + ids->push_back(TD::Convert(*i)); +} + +std::string TD::GetString(const std::vector& str) { + string res; + for (vector::const_iterator i = str.begin(); i != str.end(); ++i) + res += (i == str.begin() ? empty : space) + TD::Convert(*i); + return res; +} + +void TD::ConvertSentence(const std::string& sent, std::vector* ids) { + string s = sent; + int last = 0; + ids->clear(); + for (int i=0; i < s.size(); ++i) + if (s[i] == 32 || s[i] == '\t') { + s[i]=0; + if (last != i) { + ids->push_back(Convert(&s[last])); + } + last = i + 1; + } + if (last != s.size()) + ids->push_back(Convert(&s[last])); +} + diff --git a/src/tdict.h b/src/tdict.h new file mode 100644 index 00000000..9d4318fe --- /dev/null +++ b/src/tdict.h @@ -0,0 +1,19 @@ +#ifndef _TDICT_H_ +#define _TDICT_H_ + +#include +#include +#include "wordid.h" + +class Vocab; + +struct TD { + static Vocab* dict_; + static void ConvertSentence(const std::string& sent, std::vector* ids); + static void GetWordIDs(const std::vector& strings, std::vector* ids); + static std::string GetString(const std::vector& str); + static WordID Convert(const std::string& s); + static const char* Convert(const WordID& w); +}; + +#endif diff --git a/src/timing_stats.cc b/src/timing_stats.cc new file mode 100644 index 00000000..85b95de5 --- /dev/null +++ b/src/timing_stats.cc @@ -0,0 +1,24 @@ +#include "timing_stats.h" + +#include + +using namespace std; + +map Timer::stats; + +Timer::Timer(const string& timername) : start_t(clock()), cur(stats[timername]) {} + +Timer::~Timer() { + ++cur.calls; + const clock_t end_t = clock(); + const double elapsed = (end_t - start_t) / 1000000.0; + cur.total_time += elapsed; +} + +void Timer::Summarize() { + for (map::iterator it = stats.begin(); it != stats.end(); ++it) { + cerr << it->first << ": " << it->second.total_time << " secs (" << it->second.calls << " calls)\n"; + } + stats.clear(); +} + diff --git a/src/timing_stats.h b/src/timing_stats.h new file mode 100644 index 00000000..0a9f7656 --- /dev/null +++ b/src/timing_stats.h @@ -0,0 +1,25 @@ +#ifndef _TIMING_STATS_H_ +#define _TIMING_STATS_H_ + +#include +#include + +struct TimerInfo { + int calls; + double total_time; + TimerInfo() : calls(), total_time() {} +}; + +struct Timer { + Timer(const std::string& info); + ~Timer(); + static void Summarize(); + private: + static std::map stats; + clock_t start_t; + TimerInfo& cur; + Timer(const Timer& other); + const Timer& operator=(const Timer& other); +}; + +#endif diff --git a/src/translator.h b/src/translator.h new file mode 100644 index 00000000..194efbaa --- /dev/null +++ b/src/translator.h @@ -0,0 +1,54 @@ +#ifndef _TRANSLATOR_H_ +#define _TRANSLATOR_H_ + +#include +#include +#include +#include + +class Hypergraph; +class SentenceMetadata; + +class Translator { + public: + virtual ~Translator(); + // returns true if goal reached, false otherwise + // minus_lm_forest will contain the unpruned forest. the + // feature values from the phrase table / grammar / etc + // should be in the forest already - the "late" features + // should not just copy values that are available without + // any context or computation. + // SentenceMetadata contains information about the sentence, + // but it is an input/output parameter since the Translator + // is also responsible for setting the value of src_len. + virtual bool Translate(const std::string& src, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest) = 0; +}; + +class SCFGTranslatorImpl; +class SCFGTranslator : public Translator { + public: + SCFGTranslator(const boost::program_options::variables_map& conf); + bool Translate(const std::string& src, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest); + private: + boost::shared_ptr pimpl_; +}; + +class FSTTranslatorImpl; +class FSTTranslator : public Translator { + public: + FSTTranslator(const boost::program_options::variables_map& conf); + bool Translate(const std::string& src, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest); + private: + boost::shared_ptr pimpl_; +}; + +#endif diff --git a/src/trule.cc b/src/trule.cc new file mode 100644 index 00000000..b8f6995e --- /dev/null +++ b/src/trule.cc @@ -0,0 +1,237 @@ +#include "trule.h" + +#include + +#include "stringlib.h" +#include "tdict.h" + +using namespace std; + +static WordID ConvertTrgString(const string& w) { + int len = w.size(); + WordID id = 0; + // [X,0] or [0] + // for target rules, we ignore the category, just keep the index + if (len > 2 && w[0]=='[' && w[len-1]==']' && w[len-2] > '0' && w[len-2] <= '9' && + (len == 3 || (len > 4 && w[len-3] == ','))) { + id = w[len-2] - '0'; + id = 1 - id; + } else { + id = TD::Convert(w); + } + return id; +} + +static WordID ConvertSrcString(const string& w, bool mono = false) { + int len = w.size(); + // [X,0] + // for source rules, we keep the category and ignore the index (source rules are + // always numbered 1, 2, 3... + if (mono) { + if (len > 2 && w[0]=='[' && w[len-1]==']') { + if (len > 4 && w[len-3] == ',') { + cerr << "[ERROR] Monolingual rules mut not have non-terminal indices:\n " + << w << endl; + exit(1); + } + // TODO check that source indices go 1,2,3,etc. + return TD::Convert(w.substr(1, len-2)) * -1; + } else { + return TD::Convert(w); + } + } else { + if (len > 4 && w[0]=='[' && w[len-1]==']' && w[len-3] == ',' && w[len-2] > '0' && w[len-2] <= '9') { + return TD::Convert(w.substr(1, len-4)) * -1; + } else { + return TD::Convert(w); + } + } +} + +static WordID ConvertLHS(const string& w) { + if (w[0] == '[') { + int len = w.size(); + if (len < 3) { cerr << "Format error: " << w << endl; exit(1); } + return TD::Convert(w.substr(1, len-2)) * -1; + } else { + return TD::Convert(w) * -1; + } +} + +TRule* TRule::CreateRuleSynchronous(const std::string& rule) { + TRule* res = new TRule; + if (res->ReadFromString(rule, true, false)) return res; + cerr << "[ERROR] Failed to creating rule from: " << rule << endl; + delete res; + return NULL; +} + +TRule* TRule::CreateRulePhrasetable(const string& rule) { + // TODO make this faster + // TODO add configuration for default NT type + if (rule[0] == '[') { + cerr << "Phrasetable rules shouldn't have a LHS / non-terminals:\n " << rule << endl; + return NULL; + } + TRule* res = new TRule("[X] ||| " + rule, true, false); + if (res->Arity() != 0) { + cerr << "Phrasetable rules should have arity 0:\n " << rule << endl; + delete res; + return NULL; + } + return res; +} + +TRule* TRule::CreateRuleMonolingual(const string& rule) { + return new TRule(rule, false, true); +} + +bool TRule::ReadFromString(const string& line, bool strict, bool mono) { + e_.clear(); + f_.clear(); + scores_.clear(); + + string w; + istringstream is(line); + int format = CountSubstrings(line, "|||"); + if (strict && format < 2) { + cerr << "Bad rule format in strict mode:\n" << line << endl; + return false; + } + if (format >= 2 || (mono && format == 1)) { + while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); } + while(is>>w && w!="|||") { f_.push_back(ConvertSrcString(w, mono)); } + if (!mono) { + while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); } + } + int fv = 0; + if (is) { + string ss; + getline(is, ss); + //cerr << "L: " << ss << endl; + int start = 0; + const int len = ss.size(); + while (start < len) { + while(start < len && (ss[start] == ' ' || ss[start] == ';')) + ++start; + if (start == len) break; + int end = start + 1; + while(end < len && (ss[end] != '=' && ss[end] != ' ' && ss[end] != ';')) + ++end; + if (end == len || ss[end] == ' ' || ss[end] == ';') { + //cerr << "PROC: '" << ss.substr(start, end - start) << "'\n"; + // non-named features + if (end != len) { ss[end] = 0; } + string fname = "PhraseModel_X"; + if (fv > 9) { cerr << "Too many phrasetable scores - used named format\n"; abort(); } + fname[12]='0' + fv; + ++fv; + scores_.set_value(FD::Convert(fname), atof(&ss[start])); + //cerr << "F: " << fname << " VAL=" << scores_.value(FD::Convert(fname)) << endl; + } else { + const int fid = FD::Convert(ss.substr(start, end - start)); + start = end + 1; + end = start + 1; + while(end < len && (ss[end] != ' ' && ss[end] != ';')) + ++end; + if (end < len) { ss[end] = 0; } + assert(start < len); + scores_.set_value(fid, atof(&ss[start])); + //cerr << "F: " << FD::Convert(fid) << " VAL=" << scores_.value(fid) << endl; + } + start = end + 1; + } + } + } else if (format == 1) { + while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); } + while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); } + f_ = e_; + int x = ConvertLHS("[X]"); + for (int i = 0; i < f_.size(); ++i) + if (f_[i] <= 0) { f_[i] = x; } + } else { + cerr << "F: " << format << endl; + cerr << "[ERROR] Don't know how to read:\n" << line << endl; + } + if (mono) { + e_ = f_; + int ci = 0; + for (int i = 0; i < e_.size(); ++i) + if (e_[i] < 0) + e_[i] = ci--; + } + ComputeArity(); + return SanityCheck(); +} + +bool TRule::SanityCheck() const { + vector used(f_.size(), 0); + int ac = 0; + for (int i = 0; i < e_.size(); ++i) { + int ind = e_[i]; + if (ind > 0) continue; + ind = -ind; + if ((++used[ind]) != 1) { + cerr << "[ERROR] e-side variable index " << (ind+1) << " used more than once!\n"; + return false; + } + ac++; + } + if (ac != Arity()) { + cerr << "[ERROR] e-side arity mismatches f-side\n"; + return false; + } + return true; +} + +void TRule::ComputeArity() { + int min = 1; + for (vector::const_iterator i = e_.begin(); i != e_.end(); ++i) + if (*i < min) min = *i; + arity_ = 1 - min; +} + +static string AnonymousStrVar(int i) { + string res("[v]"); + if(!(i <= 0 && i >= -8)) { + cerr << "Can't handle more than 9 non-terminals: index=" << (-i) << endl; + abort(); + } + res[1] = '1' - i; + return res; +} + +string TRule::AsString(bool verbose) const { + ostringstream os; + int idx = 0; + if (lhs_ && verbose) { + os << '[' << TD::Convert(lhs_ * -1) << "] |||"; + for (int i = 0; i < f_.size(); ++i) { + const WordID& w = f_[i]; + if (w < 0) { + int wi = w * -1; + ++idx; + os << " [" << TD::Convert(wi) << ',' << idx << ']'; + } else { + os << ' ' << TD::Convert(w); + } + } + os << " ||| "; + } + if (idx > 9) { + cerr << "Too many non-terminals!\n partial: " << os.str() << endl; + exit(1); + } + for (int i =0; i +#include +#include +#include + +#include "sparse_vector.h" +#include "wordid.h" + +class TRule; +typedef boost::shared_ptr TRulePtr; +struct SpanInfo; + +// Translation rule +class TRule { + public: + TRule() : lhs_(0), prev_i(-1), prev_j(-1) { } + explicit TRule(const std::vector& e) : e_(e), lhs_(0), prev_i(-1), prev_j(-1) {} + TRule(const std::vector& e, const std::vector& f, const WordID& lhs) : + e_(e), f_(f), lhs_(lhs), prev_i(-1), prev_j(-1) {} + + // deprecated - this will be private soon + explicit TRule(const std::string& text, bool strict = false, bool mono = false) { + ReadFromString(text, strict, mono); + } + + // make a rule from a hiero-like rule table, e.g. + // [X] ||| [X,1] DE [X,2] ||| [X,2] of the [X,1] + // if misformatted, returns NULL + static TRule* CreateRuleSynchronous(const std::string& rule); + + // make a rule from a phrasetable entry (i.e., one that has no LHS type), e.g: + // el gato ||| the cat ||| Feature_2=0.34 + static TRule* CreateRulePhrasetable(const std::string& rule); + + // make a rule from a non-synchrnous CFG representation, e.g.: + // [LHS] ||| term1 [NT] term2 [OTHER_NT] [YET_ANOTHER_NT] + static TRule* CreateRuleMonolingual(const std::string& rule); + + void ESubstitute(const std::vector* >& var_values, + std::vector* result) const { + int vc = 0; + result->clear(); + for (std::vector::const_iterator i = e_.begin(); i != e_.end(); ++i) { + const WordID& c = *i; + if (c < 1) { + ++vc; + const std::vector& var_value = *var_values[-c]; + std::copy(var_value.begin(), + var_value.end(), + std::back_inserter(*result)); + } else { + result->push_back(c); + } + } + assert(vc == var_values.size()); + } + + void FSubstitute(const std::vector* >& var_values, + std::vector* result) const { + int vc = 0; + result->clear(); + for (std::vector::const_iterator i = f_.begin(); i != f_.end(); ++i) { + const WordID& c = *i; + if (c < 1) { + const std::vector& var_value = *var_values[vc++]; + std::copy(var_value.begin(), + var_value.end(), + std::back_inserter(*result)); + } else { + result->push_back(c); + } + } + assert(vc == var_values.size()); + } + + bool ReadFromString(const std::string& line, bool strict = false, bool monolingual = false); + + bool Initialized() const { return e_.size(); } + + std::string AsString(bool verbose = true) const; + + static TRule DummyRule() { + TRule res; + res.e_.resize(1, 0); + return res; + } + + const std::vector& f() const { return f_; } + const std::vector& e() const { return e_; } + + int EWords() const { return ELength() - Arity(); } + int FWords() const { return FLength() - Arity(); } + int FLength() const { return f_.size(); } + int ELength() const { return e_.size(); } + int Arity() const { return arity_; } + bool IsUnary() const { return (Arity() == 1) && (f_.size() == 1); } + const SparseVector& GetFeatureValues() const { return scores_; } + double Score(int i) const { return scores_[i]; } + WordID GetLHS() const { return lhs_; } + void ComputeArity(); + + // 0 = first variable, -1 = second variable, -2 = third ... + std::vector e_; + // < 0: *-1 = encoding of category of variable + std::vector f_; + WordID lhs_; + SparseVector scores_; + char arity_; + TRulePtr parent_rule_; // usually NULL, except when doing constrained decoding + + // this is only used when doing synchronous parsing + short int prev_i; + short int prev_j; + + private: + bool SanityCheck() const; +}; + +#endif diff --git a/src/trule_test.cc b/src/trule_test.cc new file mode 100644 index 00000000..02a70764 --- /dev/null +++ b/src/trule_test.cc @@ -0,0 +1,65 @@ +#include "trule.h" + +#include +#include +#include +#include "tdict.h" + +using namespace std; + +class TRuleTest : public testing::Test { + protected: + virtual void SetUp() { } + virtual void TearDown() { } +}; + +TEST_F(TRuleTest,TestFSubstitute) { + TRule r1("[X] ||| ob [X,1] [X,2] sah . ||| whether [X,1] saw [X,2] . ||| 0.99"); + TRule r2("[X] ||| ich ||| i ||| 1.0"); + TRule r3("[X] ||| ihn ||| him ||| 1.0"); + vector*> ants; + vector res2; + r2.FSubstitute(ants, &res2); + assert(TD::GetString(res2) == "ich"); + vector res3; + r3.FSubstitute(ants, &res3); + assert(TD::GetString(res3) == "ihn"); + ants.push_back(&res2); + ants.push_back(&res3); + vector res; + r1.FSubstitute(ants, &res); + cerr << TD::GetString(res) << endl; + assert(TD::GetString(res) == "ob ich ihn sah ."); +} + +TEST_F(TRuleTest,TestPhrasetableRule) { + TRulePtr t(TRule::CreateRulePhrasetable("gato ||| cat ||| PhraseModel_0=-23.2;Foo=1;Bar=12")); + cerr << t->AsString() << endl; + assert(t->scores_.num_active() == 3); +}; + + +TEST_F(TRuleTest,TestMonoRule) { + TRulePtr m(TRule::CreateRuleMonolingual("[LHS] ||| term1 [NT] term2 [NT2] [NT3]")); + assert(m->Arity() == 3); + cerr << m->AsString() << endl; + TRulePtr m2(TRule::CreateRuleMonolingual("[LHS] ||| term1 [NT] term2 [NT2] [NT3] ||| Feature1=0.23")); + assert(m2->Arity() == 3); + cerr << m2->AsString() << endl; + EXPECT_FLOAT_EQ(m2->scores_.value(FD::Convert("Feature1")), 0.23); +} + +TEST_F(TRuleTest,TestRuleR) { + TRule t6; + t6.ReadFromString("[X] ||| den [X,1] sah [X,2] . ||| [X,2] saw the [X,1] . ||| 0.12321 0.23232 0.121"); + cerr << "TEXT: " << t6.AsString() << endl; + EXPECT_EQ(t6.Arity(), 2); + EXPECT_EQ(t6.e_[0], -1); + EXPECT_EQ(t6.e_[3], 0); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + diff --git a/src/ttables.cc b/src/ttables.cc new file mode 100644 index 00000000..2ea960f0 --- /dev/null +++ b/src/ttables.cc @@ -0,0 +1,31 @@ +#include "ttables.h" + +#include + +#include "dict.h" + +using namespace std; +using namespace std::tr1; + +void TTable::DeserializeProbsFromText(std::istream* in) { + int c = 0; + while(*in) { + string e; + string f; + double p; + (*in) >> e >> f >> p; + if (e.empty()) break; + ++c; + ttable[TD::Convert(e)][TD::Convert(f)] = prob_t(p); + } + cerr << "Loaded " << c << " translation parameters.\n"; +} + +void TTable::SerializeHelper(string* out, const Word2Word2Double& o) { + assert(!"not implemented"); +} + +void TTable::DeserializeHelper(const string& in, Word2Word2Double* o) { + assert(!"not implemented"); +} + diff --git a/src/ttables.h b/src/ttables.h new file mode 100644 index 00000000..3ffc238a --- /dev/null +++ b/src/ttables.h @@ -0,0 +1,87 @@ +#ifndef _TTABLES_H_ +#define _TTABLES_H_ + +#include +#include + +#include "wordid.h" +#include "prob.h" +#include "tdict.h" + +class TTable { + public: + TTable() {} + typedef std::map Word2Double; + typedef std::map Word2Word2Double; + inline const prob_t prob(const int& e, const int& f) const { + const Word2Word2Double::const_iterator cit = ttable.find(e); + if (cit != ttable.end()) { + const Word2Double& cpd = cit->second; + const Word2Double::const_iterator it = cpd.find(f); + if (it == cpd.end()) return prob_t(0.00001); + return prob_t(it->second); + } else { + return prob_t(0.00001); + } + } + inline void Increment(const int& e, const int& f) { + counts[e][f] += 1.0; + } + inline void Increment(const int& e, const int& f, double x) { + counts[e][f] += x; + } + void Normalize() { + ttable.swap(counts); + for (Word2Word2Double::iterator cit = ttable.begin(); + cit != ttable.end(); ++cit) { + double tot = 0; + Word2Double& cpd = cit->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + tot += it->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + it->second /= tot; + } + counts.clear(); + } + // adds counts from another TTable - probabilities remain unchanged + TTable& operator+=(const TTable& rhs) { + for (Word2Word2Double::const_iterator it = rhs.counts.begin(); + it != rhs.counts.end(); ++it) { + const Word2Double& cpd = it->second; + Word2Double& tgt = counts[it->first]; + for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { + tgt[j->first] += j->second; + } + } + return *this; + } + void ShowTTable() { + for (Word2Word2Double::iterator it = ttable.begin(); it != ttable.end(); ++it) { + Word2Double& cpd = it->second; + for (Word2Double::iterator j = cpd.begin(); j != cpd.end(); ++j) { + std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; + } + } + } + void ShowCounts() { + for (Word2Word2Double::iterator it = counts.begin(); it != counts.end(); ++it) { + Word2Double& cpd = it->second; + for (Word2Double::iterator j = cpd.begin(); j != cpd.end(); ++j) { + std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; + } + } + } + void DeserializeProbsFromText(std::istream* in); + void SerializeCounts(std::string* out) const { SerializeHelper(out, counts); } + void DeserializeCounts(const std::string& in) { DeserializeHelper(in, &counts); } + void SerializeProbs(std::string* out) const { SerializeHelper(out, ttable); } + void DeserializeProbs(const std::string& in) { DeserializeHelper(in, &ttable); } + private: + static void SerializeHelper(std::string*, const Word2Word2Double& o); + static void DeserializeHelper(const std::string&, Word2Word2Double* o); + public: + Word2Word2Double ttable; + Word2Word2Double counts; +}; + +#endif diff --git a/src/viterbi.cc b/src/viterbi.cc new file mode 100644 index 00000000..82b2ce6d --- /dev/null +++ b/src/viterbi.cc @@ -0,0 +1,39 @@ +#include "viterbi.h" + +#include +#include "hg.h" + +using namespace std; + +string ViterbiETree(const Hypergraph& hg) { + vector tmp; + const prob_t p = Viterbi, ETreeTraversal, prob_t, EdgeProb>(hg, &tmp); + return TD::GetString(tmp); +} + +string ViterbiFTree(const Hypergraph& hg) { + vector tmp; + const prob_t p = Viterbi, FTreeTraversal, prob_t, EdgeProb>(hg, &tmp); + return TD::GetString(tmp); +} + +prob_t ViterbiESentence(const Hypergraph& hg, vector* result) { + return Viterbi, ESentenceTraversal, prob_t, EdgeProb>(hg, result); +} + +prob_t ViterbiFSentence(const Hypergraph& hg, vector* result) { + return Viterbi, FSentenceTraversal, prob_t, EdgeProb>(hg, result); +} + +int ViterbiELength(const Hypergraph& hg) { + int len = -1; + Viterbi(hg, &len); + return len; +} + +int ViterbiPathLength(const Hypergraph& hg) { + int len = -1; + Viterbi(hg, &len); + return len; +} + diff --git a/src/viterbi.h b/src/viterbi.h new file mode 100644 index 00000000..46a4f528 --- /dev/null +++ b/src/viterbi.h @@ -0,0 +1,130 @@ +#ifndef _VITERBI_H_ +#define _VITERBI_H_ + +#include +#include "prob.h" +#include "hg.h" +#include "tdict.h" + +// V must implement: +// void operator()(const vector& ants, T* result); +template +WeightType Viterbi(const Hypergraph& hg, + T* result, + const Traversal& traverse = Traversal(), + const WeightFunction& weight = WeightFunction()) { + const int num_nodes = hg.nodes_.size(); + std::vector vit_result(num_nodes); + std::vector vit_weight(num_nodes, WeightType::Zero()); + + for (int i = 0; i < num_nodes; ++i) { + const Hypergraph::Node& cur_node = hg.nodes_[i]; + WeightType* const cur_node_best_weight = &vit_weight[i]; + T* const cur_node_best_result = &vit_result[i]; + + const int num_in_edges = cur_node.in_edges_.size(); + if (num_in_edges == 0) { + *cur_node_best_weight = WeightType(1); + continue; + } + for (int j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]]; + WeightType score = weight(edge); + std::vector ants(edge.tail_nodes_.size()); + for (int k = 0; k < edge.tail_nodes_.size(); ++k) { + const int tail_node_index = edge.tail_nodes_[k]; + score *= vit_weight[tail_node_index]; + ants[k] = &vit_result[tail_node_index]; + } + if (*cur_node_best_weight < score) { + *cur_node_best_weight = score; + traverse(edge, ants, cur_node_best_result); + } + } + } + std::swap(*result, vit_result.back()); + return vit_weight.back(); +} + +struct PathLengthTraversal { + void operator()(const Hypergraph::Edge& edge, + const std::vector& ants, + int* result) const { + (void) edge; + *result = 1; + for (int i = 0; i < ants.size(); ++i) *result += *ants[i]; + } +}; + +struct ESentenceTraversal { + void operator()(const Hypergraph::Edge& edge, + const std::vector*>& ants, + std::vector* result) const { + edge.rule_->ESubstitute(ants, result); + } +}; + +struct ELengthTraversal { + void operator()(const Hypergraph::Edge& edge, + const std::vector& ants, + int* result) const { + *result = edge.rule_->ELength() - edge.rule_->Arity(); + for (int i = 0; i < ants.size(); ++i) *result += *ants[i]; + } +}; + +struct FSentenceTraversal { + void operator()(const Hypergraph::Edge& edge, + const std::vector*>& ants, + std::vector* result) const { + edge.rule_->FSubstitute(ants, result); + } +}; + +// create a strings of the form (S (X the man) (X said (X he (X would (X go))))) +struct ETreeTraversal { + ETreeTraversal() : left("("), space(" "), right(")") {} + const std::string left; + const std::string space; + const std::string right; + void operator()(const Hypergraph::Edge& edge, + const std::vector*>& ants, + std::vector* result) const { + std::vector tmp; + edge.rule_->ESubstitute(ants, &tmp); + const std::string cat = TD::Convert(edge.rule_->GetLHS() * -1); + if (cat == "Goal") + result->swap(tmp); + else + TD::ConvertSentence(left + cat + space + TD::GetString(tmp) + right, + result); + } +}; + +struct FTreeTraversal { + FTreeTraversal() : left("("), space(" "), right(")") {} + const std::string left; + const std::string space; + const std::string right; + void operator()(const Hypergraph::Edge& edge, + const std::vector*>& ants, + std::vector* result) const { + std::vector tmp; + edge.rule_->FSubstitute(ants, &tmp); + const std::string cat = TD::Convert(edge.rule_->GetLHS() * -1); + if (cat == "Goal") + result->swap(tmp); + else + TD::ConvertSentence(left + cat + space + TD::GetString(tmp) + right, + result); + } +}; + +prob_t ViterbiESentence(const Hypergraph& hg, std::vector* result); +std::string ViterbiETree(const Hypergraph& hg); +prob_t ViterbiFSentence(const Hypergraph& hg, std::vector* result); +std::string ViterbiFTree(const Hypergraph& hg); +int ViterbiELength(const Hypergraph& hg); +int ViterbiPathLength(const Hypergraph& hg); + +#endif diff --git a/src/weights.cc b/src/weights.cc new file mode 100644 index 00000000..bb0a878f --- /dev/null +++ b/src/weights.cc @@ -0,0 +1,73 @@ +#include "weights.h" + +#include + +#include "fdict.h" +#include "filelib.h" + +using namespace std; + +void Weights::InitFromFile(const std::string& filename, vector* feature_list) { + cerr << "Reading weights from " << filename << endl; + ReadFile in_file(filename); + istream& in = *in_file.stream(); + assert(in); + int weight_count = 0; + bool fl = false; + while (in) { + double val = 0; + string buf; + getline(in, buf); + if (buf.size() == 0) continue; + if (buf[0] == '#') continue; + for (int i = 0; i < buf.size(); ++i) + if (buf[i] == '=') buf[i] = ' '; + int start = 0; + while(start < buf.size() && buf[start] == ' ') ++start; + int end = 0; + while(end < buf.size() && buf[end] != ' ') ++end; + int fid = FD::Convert(buf.substr(start, end - start)); + while(end < buf.size() && buf[end] == ' ') ++end; + val = strtod(&buf.c_str()[end], NULL); + if (wv_.size() <= fid) + wv_.resize(fid + 1); + wv_[fid] = val; + if (feature_list) { feature_list->push_back(FD::Convert(fid)); } + ++weight_count; + if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; } + if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; } + } + if (fl) { cerr << endl; } + cerr << "Loaded " << weight_count << " feature weights\n"; +} + +void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_features) const { + WriteFile out(fname); + ostream& o = *out.stream(); + assert(o); + o.precision(17); + const int num_feats = FD::NumFeats(); + for (int i = 1; i < num_feats; ++i) { + const double val = (i < wv_.size() ? wv_[i] : 0.0); + if (hide_zero_value_features && val == 0.0) continue; + o << FD::Convert(i) << ' ' << val << endl; + } +} + +void Weights::InitVector(std::vector* w) const { + *w = wv_; +} + +void Weights::InitSparseVector(SparseVector* w) const { + for (int i = 1; i < wv_.size(); ++i) { + const double& weight = wv_[i]; + if (weight) w->set_value(i, weight); + } +} + +void Weights::InitFromVector(const std::vector& w) { + wv_ = w; + if (wv_.size() > FD::NumFeats()) + cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n"; + wv_.resize(FD::NumFeats(), 0); +} diff --git a/src/weights.h b/src/weights.h new file mode 100644 index 00000000..f19aa3ce --- /dev/null +++ b/src/weights.h @@ -0,0 +1,21 @@ +#ifndef _WEIGHTS_H_ +#define _WEIGHTS_H_ + +#include +#include +#include +#include "sparse_vector.h" + +class Weights { + public: + Weights() {} + void InitFromFile(const std::string& fname, std::vector* feature_list = NULL); + void WriteToFile(const std::string& fname, bool hide_zero_value_features = true) const; + void InitVector(std::vector* w) const; + void InitSparseVector(SparseVector* w) const; + void InitFromVector(const std::vector& w); + private: + std::vector wv_; +}; + +#endif diff --git a/src/weights_test.cc b/src/weights_test.cc new file mode 100644 index 00000000..aa6b3db2 --- /dev/null +++ b/src/weights_test.cc @@ -0,0 +1,28 @@ +#include +#include +#include +#include +#include +#include "weights.h" +#include "tdict.h" +#include "hg.h" + +using namespace std; + +class WeightsTest : public testing::Test { + protected: + virtual void SetUp() { } + virtual void TearDown() { } +}; + + +TEST_F(WeightsTest,Load) { + Weights w; + w.InitFromFile("test_data/weights"); + w.WriteToFile("-"); +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/wordid.h b/src/wordid.h new file mode 100644 index 00000000..fb50bcc1 --- /dev/null +++ b/src/wordid.h @@ -0,0 +1,6 @@ +#ifndef _WORD_ID_H_ +#define _WORD_ID_H_ + +typedef int WordID; + +#endif diff --git a/training/Makefile.am b/training/Makefile.am new file mode 100644 index 00000000..a2888f7a --- /dev/null +++ b/training/Makefile.am @@ -0,0 +1,39 @@ +bin_PROGRAMS = \ + model1 \ + mr_optimize_reduce \ + grammar_convert \ + atools \ + plftools \ + lbfgs_test \ + mr_em_train \ + collapse_weights \ + optimize_test + +atools_SOURCES = atools.cc + +model1_SOURCES = model1.cc +model1_LDADD = libhg.a + +grammar_convert_SOURCES = grammar_convert.cc + +optimize_test_SOURCES = optimize_test.cc + +collapse_weights_SOURCES = collapse_weights.cc + +lbfgs_test_SOURCES = lbfgs_test.cc + +mr_optimize_reduce_SOURCES = mr_optimize_reduce.cc +mr_optimize_reduce_LDADD = libhg.a + +mr_em_train_SOURCES = mr_em_train.cc +mr_em_train_LDADD = libhg.a + +plftools_SOURCES = plftools.cc +plftools_LDADD = libhg.a + +LDADD = libhg.a + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) +AM_LDFLAGS = $(BOOST_LDFLAGS) $(BOOST_PROGRAM_OPTIONS_LIB) -lz + +noinst_LIBRARIES = libhg.a diff --git a/training/atools.cc b/training/atools.cc new file mode 100644 index 00000000..bac73859 --- /dev/null +++ b/training/atools.cc @@ -0,0 +1,207 @@ +#include +#include +#include + +#include +#include +#include + +#include "filelib.h" +#include "aligner.h" + +namespace po = boost::program_options; +using namespace std; +using boost::shared_ptr; + +struct Command { + virtual ~Command() {} + virtual string Name() const = 0; + + // returns 1 for alignment grid output [default] + // returns 2 if Summary() should be called [for AER, etc] + virtual int Result() const { return 1; } + + virtual bool RequiresTwoOperands() const { return true; } + virtual void Apply(const Array2D& a, const Array2D& b, Array2D* x) = 0; + void EnsureSize(const Array2D& a, const Array2D& b, Array2D* x) { + x->resize(max(a.width(), b.width()), max(a.height(), b.width())); + } + bool Safe(const Array2D& a, int i, int j) const { + if (i < a.width() && j < a.height()) + return a(i,j); + else + return false; + } + virtual void Summary() { assert(!"Summary should have been overridden"); } +}; + +// compute fmeasure, second alignment is reference, first is hyp +struct FMeasureCommand : public Command { + FMeasureCommand() : matches(), num_predicted(), num_in_ref() {} + int Result() const { return 2; } + string Name() const { return "f"; } + bool RequiresTwoOperands() const { return true; } + void Apply(const Array2D& hyp, const Array2D& ref, Array2D* x) { + int i_len = ref.width(); + int j_len = ref.height(); + for (int i = 0; i < i_len; ++i) { + for (int j = 0; j < j_len; ++j) { + if (ref(i,j)) { + ++num_in_ref; + if (Safe(hyp, i, j)) ++matches; + } + } + } + for (int i = 0; i < hyp.width(); ++i) + for (int j = 0; j < hyp.height(); ++j) + if (hyp(i,j)) ++num_predicted; + } + void Summary() { + if (num_predicted == 0 || num_in_ref == 0) { + cerr << "Insufficient statistics to compute f-measure!\n"; + abort(); + } + const double prec = static_cast(matches) / num_predicted; + const double rec = static_cast(matches) / num_in_ref; + cout << "P: " << prec << endl; + cout << "R: " << rec << endl; + const double f = (2.0 * prec * rec) / (rec + prec); + cout << "F: " << f << endl; + } + int matches; + int num_predicted; + int num_in_ref; +}; + +struct ConvertCommand : public Command { + string Name() const { return "convert"; } + bool RequiresTwoOperands() const { return false; } + void Apply(const Array2D& in, const Array2D¬_used, Array2D* x) { + *x = in; + } +}; + +struct InvertCommand : public Command { + string Name() const { return "invert"; } + bool RequiresTwoOperands() const { return false; } + void Apply(const Array2D& in, const Array2D¬_used, Array2D* x) { + Array2D& res = *x; + res.resize(in.height(), in.width()); + for (int i = 0; i < in.height(); ++i) + for (int j = 0; j < in.width(); ++j) + res(i, j) = in(j, i); + } +}; + +struct IntersectCommand : public Command { + string Name() const { return "intersect"; } + bool RequiresTwoOperands() const { return true; } + void Apply(const Array2D& a, const Array2D& b, Array2D* x) { + EnsureSize(a, b, x); + Array2D& res = *x; + for (int i = 0; i < a.width(); ++i) + for (int j = 0; j < a.height(); ++j) + res(i, j) = Safe(a, i, j) && Safe(b, i, j); + } +}; + +map > commands; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + ostringstream os; + os << "[REQ] Operation to perform:"; + for (map >::iterator it = commands.begin(); + it != commands.end(); ++it) { + os << ' ' << it->first; + } + string cstr = os.str(); + opts.add_options() + ("input_1,i", po::value(), "[REQ] Alignment 1 file, - for STDIN") + ("input_2,j", po::value(), "[OPT] Alignment 2 file, - for STDIN") + ("command,c", po::value()->default_value("convert"), cstr.c_str()) + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("input_1") == 0 || conf->count("command") == 0) { + cerr << dcmdline_options << endl; + exit(1); + } + const string cmd = (*conf)["command"].as(); + if (commands.count(cmd) == 0) { + cerr << "Don't understand command: " << cmd << endl; + exit(1); + } + if (commands[cmd]->RequiresTwoOperands()) { + if (conf->count("input_2") == 0) { + cerr << "Command '" << cmd << "' requires two alignment files\n"; + exit(1); + } + if ((*conf)["input_1"].as() == "-" && (*conf)["input_2"].as() == "-") { + cerr << "Both inputs cannot be STDIN\n"; + exit(1); + } + } else { + if (conf->count("input_2") != 0) { + cerr << "Command '" << cmd << "' requires only one alignment file\n"; + exit(1); + } + } +} + +template static void AddCommand() { + C* c = new C; + commands[c->Name()].reset(c); +} + +int main(int argc, char **argv) { + AddCommand(); + AddCommand(); + AddCommand(); + AddCommand(); + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + Command& cmd = *commands[conf["command"].as()]; + boost::shared_ptr rf1(new ReadFile(conf["input_1"].as())); + boost::shared_ptr rf2; + if (cmd.RequiresTwoOperands()) + rf2.reset(new ReadFile(conf["input_2"].as())); + istream* in1 = rf1->stream(); + istream* in2 = NULL; + if (rf2) in2 = rf2->stream(); + while(*in1) { + string line1; + string line2; + getline(*in1, line1); + if (in2) { + getline(*in2, line2); + if ((*in1 && !*in2) || (*in2 && !*in1)) { + cerr << "Mismatched number of lines!\n"; + exit(1); + } + } + if (line1.empty() && !*in1) break; + shared_ptr > out(new Array2D); + shared_ptr > a1 = AlignerTools::ReadPharaohAlignmentGrid(line1); + if (in2) { + shared_ptr > a2 = AlignerTools::ReadPharaohAlignmentGrid(line2); + cmd.Apply(*a1, *a2, out.get()); + } else { + Array2D dummy; + cmd.Apply(*a1, dummy, out.get()); + } + + if (cmd.Result() == 1) { + AlignerTools::SerializePharaohFormat(*out, &cout); + } + } + if (cmd.Result() == 2) + cmd.Summary(); + return 0; +} + diff --git a/training/cluster-em.pl b/training/cluster-em.pl new file mode 100755 index 00000000..175870da --- /dev/null +++ b/training/cluster-em.pl @@ -0,0 +1,110 @@ +#!/usr/bin/perl -w + +use strict; +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR; } +use Getopt::Long; +my $parallel = 1; + +my $CWD=`pwd`; chomp $CWD; +my $BIN_DIR = "/chomes/redpony/cdyer-svn-repo/cdec/src"; +my $OPTIMIZER = "$BIN_DIR/mr_em_train"; +my $DECODER = "$BIN_DIR/cdec"; +my $COMBINER_CACHE_SIZE = 150; +my $PARALLEL = "/chomes/redpony/svn-trunk/sa-utils/parallelize.pl"; +die "Can't find $OPTIMIZER" unless -f $OPTIMIZER; +die "Can't execute $OPTIMIZER" unless -x $OPTIMIZER; +die "Can't find $DECODER" unless -f $DECODER; +die "Can't execute $DECODER" unless -x $DECODER; +die "Can't find $PARALLEL" unless -f $PARALLEL; +die "Can't execute $PARALLEL" unless -x $PARALLEL; +my $restart = ''; +if ($ARGV[0] && $ARGV[0] eq '--restart') { shift @ARGV; $restart = 1; } + +die "Usage: $0 [--restart] training.corpus weights.init grammar.file [grammar2.file] ...\n" unless (scalar @ARGV >= 3); + +my $training_corpus = shift @ARGV; +my $initial_weights = shift @ARGV; +my @in_grammar_files = @ARGV; +my $pmem="2500mb"; +my $nodes = 40; +my $max_iteration = 1000; +my $CFLAG = "-C 1"; +unless ($parallel) { $CFLAG = "-C 500"; } +my @grammar_files; +for my $g (@in_grammar_files) { + unless ($g =~ /^\//) { $g = $CWD . '/' . $g; } + die "Can't find $g" unless -f $g; + push @grammar_files, $g; +} + +print STDERR < \$DECODER, + "run_locally" => \$LOCAL, + "gaussian_prior" => \$PRIOR, + "sigma_squared=f" => \$sigsq, + "means=s" => \$means_file, + "optimizer=s" => \$OALG, + "pmem=s" => \$pmem + ) or usage(); +usage() unless scalar @ARGV==3; +my $config_file = shift @ARGV; +my $training_corpus = shift @ARGV; +my $initial_weights = shift @ARGV; +die "Can't find $config_file" unless -f $config_file; +die "Can't find $DECODER" unless -f $DECODER; +die "Can't execute $DECODER" unless -x $DECODER; +if ($LOCAL) { print STDERR "Will running LOCALLY.\n"; $parallel = 0; } +if ($PRIOR) { + $PRIOR_FLAG="-p --sigma_squared $sigsq"; + if ($means_file) { $PRIOR_FLAG .= " -u $means_file"; } +} + +if ($parallel) { + die "Can't find $PARALLEL" unless -f $PARALLEL; + die "Can't execute $PARALLEL" unless -x $PARALLEL; +} +unless ($parallel) { $CFLAG = "-C 500"; } +unless ($config_file =~ /^\//) { $config_file = $CWD . '/' . $config_file; } + +print STDERR <> 8; + if ($exit_code == 99) { + $iter_attempts++; + if ($iter_attempts > $MAX_ITER_ATTEMPTS) { + die "Received restart request $iter_attempts times from optimizer, giving up\n"; + } + print STDERR "Function evaluation failed, retrying (attempt $iter_attempts)\n"; + next; + } + if ($? != 0) { + die "Error running iteration $iter: $!"; + } + chomp $result; + my $end = time; + my $diff = ($end - $start); + print STDERR " ITERATION $iter TOOK $diff SECONDS\n"; + $iter = $next_iter; + if ($result =~ /1$/) { + print STDERR "Training converged.\n"; + last; + } + $iter_attempts = 1; +} + +print "FINAL WEIGHTS: $dir/weights.$iter\n"; + +sub usage { + die "Usage: $0 [OPTIONS] cdec.ini training.corpus weights.init\n"; +} diff --git a/training/grammar_convert.cc b/training/grammar_convert.cc new file mode 100644 index 00000000..22ba0f46 --- /dev/null +++ b/training/grammar_convert.cc @@ -0,0 +1,316 @@ +#include +#include +#include + +#include +#include + +#include "tdict.h" +#include "filelib.h" +#include "hg.h" +#include "hg_io.h" +#include "kbest.h" +#include "viterbi.h" +#include "weights.h" + +namespace po = boost::program_options; +using namespace std; + +WordID kSTART; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input,i", po::value()->default_value("-"), "Input file") + ("format,f", po::value()->default_value("cfg"), "Input format. Values: cfg, json, split") + ("output,o", po::value()->default_value("json"), "Output command. Values: json, 1best") + ("reorder,r", "Add Yamada & Knight (2002) reorderings") + ("weights,w", po::value(), "Feature weights for k-best derivations [optional]") + ("collapse_weights,C", "Collapse order features into a single feature whose value is all of the locally applying feature weights") + ("k_derivations,k", po::value(), "Show k derivations and their features") + ("max_reorder,m", po::value()->default_value(999), "Move a constituent at most this far") + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("input") == 0) { + cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into JSON hypergraph.\n"; + cerr << dcmdline_options << endl; + exit(1); + } +} + +int GetOrCreateNode(const WordID& lhs, map* lhs2node, Hypergraph* hg) { + int& node_id = (*lhs2node)[lhs]; + if (!node_id) + node_id = hg->AddNode(lhs)->id_ + 1; + return node_id - 1; +} + +void FilterAndCheckCorrectness(int goal, Hypergraph* hg) { + if (goal < 0) { + cerr << "Error! [S] not found in grammar!\n"; + exit(1); + } + if (hg->nodes_[goal].in_edges_.size() != 1) { + cerr << "Error! [S] has more than one rewrite!\n"; + exit(1); + } + int old_size = hg->nodes_.size(); + hg->TopologicallySortNodesAndEdges(goal); + if (hg->nodes_.size() != old_size) { + cerr << "Warning! During sorting " << (old_size - hg->nodes_.size()) << " disappeared!\n"; + } +} + +void CreateEdge(const TRulePtr& r, const Hypergraph::TailNodeVector& tail, Hypergraph::Node* head_node, Hypergraph* hg) { + Hypergraph::Edge* new_edge = hg->AddEdge(r, tail); + hg->ConnectEdgeToHeadNode(new_edge, head_node); + new_edge->feature_values_ = r->scores_; +} + +// from a category label like "NP_2", return "NP" +string PureCategory(WordID cat) { + assert(cat < 0); + string c = TD::Convert(cat*-1); + size_t p = c.find("_"); + if (p == string::npos) return c; + return c.substr(0, p); +}; + +string ConstituentOrderFeature(const TRule& rule, const vector& pi) { + const static string kTERM_VAR = "x"; + const vector& f = rule.f(); + map used; + vector terms(f.size()); + for (int i = 0; i < f.size(); ++i) { + const string term = (f[i] < 0 ? PureCategory(f[i]) : kTERM_VAR); + int& count = used[term]; + if (!count) { + terms[i] = term; + } else { + ostringstream os; + os << term << count; + terms[i] = os.str(); + } + ++count; + } + ostringstream os; + os << PureCategory(rule.GetLHS()) << ':'; + for (int i = 0; i < f.size(); ++i) { + if (i > 0) os << '_'; + os << terms[pi[i]]; + } + return os.str(); +} + +bool CheckPermutationMask(const vector& mask, const vector& pi) { + assert(mask.size() == pi.size()); + + int req_min = -1; + int cur_max = 0; + int cur_mask = -1; + for (int i = 0; i < mask.size(); ++i) { + if (mask[i] != cur_mask) { + cur_mask = mask[i]; + req_min = cur_max - 1; + } + if (pi[i] > req_min) { + if (pi[i] > cur_max) cur_max = pi[i]; + } else { + return false; + } + } + + return true; +} + +void PermuteYKRecursive(int nodeid, const WordID& parent, const int max_reorder, Hypergraph* hg) { + Hypergraph::Node* node = &hg->nodes_[nodeid]; + if (node->in_edges_.size() != 1) { + cerr << "Multiple rewrites of [" << TD::Convert(node->cat_ * -1) << "] (parent is [" << TD::Convert(parent*-1) << "])\n"; + cerr << " not recursing!\n"; + return; + } + const int oe_index = node->in_edges_.front(); + const TRule& rule = *hg->edges_[oe_index].rule_; + const Hypergraph::TailNodeVector orig_tail = hg->edges_[oe_index].tail_nodes_; + const int tail_size = orig_tail.size(); + for (int i = 0; i < tail_size; ++i) { + PermuteYKRecursive(hg->edges_[oe_index].tail_nodes_[i], node->cat_, max_reorder, hg); + } + const vector& of = rule.f_; + if (of.size() == 1) return; +// cerr << "Permuting [" << TD::Convert(node->cat_ * -1) << "]\n"; +// cerr << "ORIG: " << rule.AsString() << endl; + vector pi(of.size(), 0); + for (int i = 0; i < pi.size(); ++i) pi[i] = i; + + vector permutation_mask(of.size(), 0); + const bool dont_reorder_across_PU = true; // TODO add configuration + if (dont_reorder_across_PU) { + int cur = 0; + for (int i = 0; i < pi.size(); ++i) { + if (of[i] >= 0) continue; + const string cat = PureCategory(of[i]); + if (cat == "PU" || cat == "PU!H" || cat == "PUNC" || cat == "PUNC!H" || cat == "CC") { + ++cur; + permutation_mask[i] = cur; + ++cur; + } else { + permutation_mask[i] = cur; + } + } + } + int fid = FD::Convert(ConstituentOrderFeature(rule, pi)); + hg->edges_[oe_index].feature_values_.set_value(fid, 1.0); + while (next_permutation(pi.begin(), pi.end())) { + if (!CheckPermutationMask(permutation_mask, pi)) + continue; + vector nf(pi.size(), 0); + Hypergraph::TailNodeVector tail(pi.size(), 0); + bool skip = false; + for (int i = 0; i < pi.size(); ++i) { + int dist = pi[i] - i; if (dist < 0) dist *= -1; + if (dist > max_reorder) { skip = true; break; } + nf[i] = of[pi[i]]; + tail[i] = orig_tail[pi[i]]; + } + if (skip) continue; + TRulePtr nr(new TRule(rule)); + nr->f_ = nf; + int fid = FD::Convert(ConstituentOrderFeature(rule, pi)); + nr->scores_.set_value(fid, 1.0); +// cerr << "PERM: " << nr->AsString() << endl; + CreateEdge(nr, tail, node, hg); + } +} + +void PermuteYamadaAndKnight(Hypergraph* hg, int max_reorder) { + assert(hg->nodes_.back().cat_ == kSTART); + assert(hg->nodes_.back().in_edges_.size() == 1); + PermuteYKRecursive(hg->nodes_.size() - 1, kSTART, max_reorder, hg); +} + +void CollapseWeights(Hypergraph* hg) { + int fid = FD::Convert("Reordering"); + for (int i = 0; i < hg->edges_.size(); ++i) { + Hypergraph::Edge& edge = hg->edges_[i]; + edge.feature_values_.clear(); + if (edge.edge_prob_ != prob_t::Zero()) { + edge.feature_values_.set_value(fid, log(edge.edge_prob_)); + } + } +} + +void ProcessHypergraph(const vector& w, const po::variables_map& conf, const string& ref, Hypergraph* hg) { + if (conf.count("reorder")) + PermuteYamadaAndKnight(hg, conf["max_reorder"].as()); + if (w.size() > 0) { hg->Reweight(w); } + if (conf.count("collapse_weights")) CollapseWeights(hg); + if (conf["output"].as() == "json") { + HypergraphIO::WriteToJSON(*hg, false, &cout); + if (!ref.empty()) { cerr << "REF: " << ref << endl; } + } else { + vector onebest; + ViterbiESentence(*hg, &onebest); + if (ref.empty()) { + cout << TD::GetString(onebest) << endl; + } else { + cout << TD::GetString(onebest) << " ||| " << ref << endl; + } + } + if (conf.count("k_derivations")) { + const int k = conf["k_derivations"].as(); + KBest::KBestDerivations, ESentenceTraversal> kbest(*hg, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg->nodes_.size() - 1, i); + if (!d) break; + cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + } +} + +int main(int argc, char **argv) { + kSTART = TD::Convert("S") * -1; + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + string infile = conf["input"].as(); + const bool is_split_input = (conf["format"].as() == "split"); + const bool is_json_input = is_split_input || (conf["format"].as() == "json"); + const bool collapse_weights = conf.count("collapse_weights"); + Weights wts; + vector w; + if (conf.count("weights")) { + wts.InitFromFile(conf["weights"].as()); + wts.InitVector(&w); + } + if (collapse_weights && !w.size()) { + cerr << "--collapse_weights requires a weights file to be specified!\n"; + exit(1); + } + ReadFile rf(infile); + istream* in = rf.stream(); + assert(*in); + int lc = 0; + Hypergraph hg; + map lhs2node; + while(*in) { + string line; + ++lc; + getline(*in, line); + if (is_json_input) { + if (line.empty() || line[0] == '#') continue; + string ref; + if (is_split_input) { + size_t pos = line.rfind("}}"); + assert(pos != string::npos); + size_t rstart = line.find("||| ", pos); + assert(rstart != string::npos); + ref = line.substr(rstart + 4); + line = line.substr(0, pos + 2); + } + istringstream is(line); + if (HypergraphIO::ReadFromJSON(&is, &hg)) { + ProcessHypergraph(w, conf, ref, &hg); + hg.clear(); + } else { + cerr << "Error reading grammar from JSON: line " << lc << endl; + exit(1); + } + } else { + if (line.empty()) { + int goal = lhs2node[kSTART] - 1; + FilterAndCheckCorrectness(goal, &hg); + ProcessHypergraph(w, conf, "", &hg); + hg.clear(); + lhs2node.clear(); + continue; + } + if (line[0] == '#') continue; + if (line[0] != '[') { + cerr << "Line " << lc << ": bad format\n"; + exit(1); + } + TRulePtr tr(TRule::CreateRuleMonolingual(line)); + Hypergraph::TailNodeVector tail; + for (int i = 0; i < tr->f_.size(); ++i) { + WordID var_cat = tr->f_[i]; + if (var_cat < 0) + tail.push_back(GetOrCreateNode(var_cat, &lhs2node, &hg)); + } + const WordID lhs = tr->GetLHS(); + int head = GetOrCreateNode(lhs, &lhs2node, &hg); + Hypergraph::Edge* edge = hg.AddEdge(tr, tail); + edge->feature_values_ = tr->scores_; + Hypergraph::Node* node = &hg.nodes_[head]; + hg.ConnectEdgeToHeadNode(edge, node); + } + } +} + diff --git a/training/lbfgs.h b/training/lbfgs.h new file mode 100644 index 00000000..e8baecab --- /dev/null +++ b/training/lbfgs.h @@ -0,0 +1,1459 @@ +#ifndef SCITBX_LBFGS_H +#define SCITBX_LBFGS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace scitbx { + +//! Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) %minimizer. +/*! Implementation of the + Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) + algorithm for large-scale multidimensional minimization + problems. + + This code was manually derived from Java code which was + in turn derived from the Fortran program + lbfgs.f. The Java translation was + effected mostly mechanically, with some manual + clean-up; in particular, array indices start at 0 + instead of 1. Most of the comments from the Fortran + code have been pasted in. + + Information on the original LBFGS Fortran source code is + available at + http://www.netlib.org/opt/lbfgs_um.shar . The following + information is taken verbatim from the Netlib documentation + for the Fortran source. + +
+    file    opt/lbfgs_um.shar
+    for     unconstrained optimization problems
+    alg     limited memory BFGS method
+    by      J. Nocedal
+    contact nocedal@eecs.nwu.edu
+    ref     D. C. Liu and J. Nocedal, ``On the limited memory BFGS method for
+    ,       large scale optimization methods'' Mathematical Programming 45
+    ,       (1989), pp. 503-528.
+    ,       (Postscript file of this paper is available via anonymous ftp
+    ,       to eecs.nwu.edu in the directory pub/%lbfgs/lbfgs_um.)
+    
+ + @author Jorge Nocedal: original Fortran version, including comments + (July 1990).
+ Robert Dodier: Java translation, August 1997.
+ Ralf W. Grosse-Kunstleve: C++ port, March 2002.
+ Chris Dyer: serialize/deserialize functionality + */ +namespace lbfgs { + + //! Generic exception class for %lbfgs %error messages. + /*! All exceptions thrown by the minimizer are derived from this class. + */ + class error : public std::exception { + public: + //! Constructor. + error(std::string const& msg) throw() + : msg_("lbfgs error: " + msg) + {} + //! Access to error message. + virtual const char* what() const throw() { return msg_.c_str(); } + protected: + virtual ~error() throw() {} + std::string msg_; + public: + static std::string itoa(unsigned long i) { + std::ostringstream os; + os << i; + return os.str(); + } + }; + + //! Specific exception class. + class error_internal_error : public error { + public: + //! Constructor. + error_internal_error(const char* file, unsigned long line) throw() + : error( + "Internal Error: " + std::string(file) + "(" + itoa(line) + ")") + {} + }; + + //! Specific exception class. + class error_improper_input_parameter : public error { + public: + //! Constructor. + error_improper_input_parameter(std::string const& msg) throw() + : error("Improper input parameter: " + msg) + {} + }; + + //! Specific exception class. + class error_improper_input_data : public error { + public: + //! Constructor. + error_improper_input_data(std::string const& msg) throw() + : error("Improper input data: " + msg) + {} + }; + + //! Specific exception class. + class error_search_direction_not_descent : public error { + public: + //! Constructor. + error_search_direction_not_descent() throw() + : error("The search direction is not a descent direction.") + {} + }; + + //! Specific exception class. + class error_line_search_failed : public error { + public: + //! Constructor. + error_line_search_failed(std::string const& msg) throw() + : error("Line search failed: " + msg) + {} + }; + + //! Specific exception class. + class error_line_search_failed_rounding_errors + : public error_line_search_failed { + public: + //! Constructor. + error_line_search_failed_rounding_errors(std::string const& msg) throw() + : error_line_search_failed(msg) + {} + }; + + namespace detail { + + template + inline + NumType + pow2(NumType const& x) { return x * x; } + + template + inline + NumType + abs(NumType const& x) { + if (x < NumType(0)) return -x; + return x; + } + + // This class implements an algorithm for multi-dimensional line search. + template + class mcsrch + { + protected: + int infoc; + FloatType dginit; + bool brackt; + bool stage1; + FloatType finit; + FloatType dgtest; + FloatType width; + FloatType width1; + FloatType stx; + FloatType fx; + FloatType dgx; + FloatType sty; + FloatType fy; + FloatType dgy; + FloatType stmin; + FloatType stmax; + + static FloatType const& max3( + FloatType const& x, + FloatType const& y, + FloatType const& z) + { + return x < y ? (y < z ? z : y ) : (x < z ? z : x ); + } + + public: + /* Minimize a function along a search direction. This code is + a Java translation of the function MCSRCH from + lbfgs.f, which in turn is a slight modification + of the subroutine CSRCH of More' and Thuente. + The changes are to allow reverse communication, and do not + affect the performance of the routine. This function, in turn, + calls mcstep.

+ + The Java translation was effected mostly mechanically, with + some manual clean-up; in particular, array indices start at 0 + instead of 1. Most of the comments from the Fortran code have + been pasted in here as well.

+ + The purpose of mcsrch is to find a step which + satisfies a sufficient decrease condition and a curvature + condition.

+ + At each stage this function updates an interval of uncertainty + with endpoints stx and sty. The + interval of uncertainty is initially chosen so that it + contains a minimizer of the modified function +

+                f(x+stp*s) - f(x) - ftol*stp*(gradf(x)'s).
+           
+ If a step is obtained for which the modified function has a + nonpositive function value and nonnegative derivative, then + the interval of uncertainty is chosen so that it contains a + minimizer of f(x+stp*s).

+ + The algorithm is designed to find a step which satisfies + the sufficient decrease condition +

+                 f(x+stp*s) <= f(X) + ftol*stp*(gradf(x)'s),
+           
+ and the curvature condition +
+                 abs(gradf(x+stp*s)'s)) <= gtol*abs(gradf(x)'s).
+           
+ If ftol is less than gtol and if, + for example, the function is bounded below, then there is + always a step which satisfies both conditions. If no step can + be found which satisfies both conditions, then the algorithm + usually stops when rounding errors prevent further progress. + In this case stp only satisfies the sufficient + decrease condition.

+ + @author Original Fortran version by Jorge J. More' and + David J. Thuente as part of the Minpack project, June 1983, + Argonne National Laboratory. Java translation by Robert + Dodier, August 1997. + + @param n The number of variables. + + @param x On entry this contains the base point for the line + search. On exit it contains x + stp*s. + + @param f On entry this contains the value of the objective + function at x. On exit it contains the value + of the objective function at x + stp*s. + + @param g On entry this contains the gradient of the objective + function at x. On exit it contains the gradient + at x + stp*s. + + @param s The search direction. + + @param stp On entry this contains an initial estimate of a + satifactory step length. On exit stp contains + the final estimate. + + @param ftol Tolerance for the sufficient decrease condition. + + @param xtol Termination occurs when the relative width of the + interval of uncertainty is at most xtol. + + @param maxfev Termination occurs when the number of evaluations + of the objective function is at least maxfev by + the end of an iteration. + + @param info This is an output variable, which can have these + values: +

    +
  • info = -1 A return is made to compute + the function and gradient. +
  • info = 1 The sufficient decrease condition + and the directional derivative condition hold. +
+ + @param nfev On exit, this is set to the number of function + evaluations. + + @param wa Temporary storage array, of length n. + */ + void run( + FloatType const& gtol, + FloatType const& stpmin, + FloatType const& stpmax, + SizeType n, + FloatType* x, + FloatType f, + const FloatType* g, + FloatType* s, + SizeType is0, + FloatType& stp, + FloatType ftol, + FloatType xtol, + SizeType maxfev, + int& info, + SizeType& nfev, + FloatType* wa); + + /* The purpose of this function is to compute a safeguarded step + for a linesearch and to update an interval of uncertainty for + a minimizer of the function.

+ + The parameter stx contains the step with the + least function value. The parameter stp contains + the current step. It is assumed that the derivative at + stx is negative in the direction of the step. If + brackt is true when + mcstep returns then a minimizer has been + bracketed in an interval of uncertainty with endpoints + stx and sty.

+ + Variables that must be modified by mcstep are + implemented as 1-element arrays. + + @param stx Step at the best step obtained so far. + This variable is modified by mcstep. + @param fx Function value at the best step obtained so far. + This variable is modified by mcstep. + @param dx Derivative at the best step obtained so far. + The derivative must be negative in the direction of the + step, that is, dx and stp-stx must + have opposite signs. This variable is modified by + mcstep. + + @param sty Step at the other endpoint of the interval of + uncertainty. This variable is modified by mcstep. + @param fy Function value at the other endpoint of the interval + of uncertainty. This variable is modified by + mcstep. + + @param dy Derivative at the other endpoint of the interval of + uncertainty. This variable is modified by mcstep. + + @param stp Step at the current step. If brackt is set + then on input stp must be between stx + and sty. On output stp is set to the + new step. + @param fp Function value at the current step. + @param dp Derivative at the current step. + + @param brackt Tells whether a minimizer has been bracketed. + If the minimizer has not been bracketed, then on input this + variable must be set false. If the minimizer has + been bracketed, then on output this variable is + true. + + @param stpmin Lower bound for the step. + @param stpmax Upper bound for the step. + + If the return value is 1, 2, 3, or 4, then the step has + been computed successfully. A return value of 0 indicates + improper input parameters. + + @author Jorge J. More, David J. Thuente: original Fortran version, + as part of Minpack project. Argonne Nat'l Laboratory, June 1983. + Robert Dodier: Java translation, August 1997. + */ + static int mcstep( + FloatType& stx, + FloatType& fx, + FloatType& dx, + FloatType& sty, + FloatType& fy, + FloatType& dy, + FloatType& stp, + FloatType fp, + FloatType dp, + bool& brackt, + FloatType stpmin, + FloatType stpmax); + + void serialize(std::ostream* out) const { + out->write((const char*)&infoc,sizeof(infoc)); + out->write((const char*)&dginit,sizeof(dginit)); + out->write((const char*)&brackt,sizeof(brackt)); + out->write((const char*)&stage1,sizeof(stage1)); + out->write((const char*)&finit,sizeof(finit)); + out->write((const char*)&dgtest,sizeof(dgtest)); + out->write((const char*)&width,sizeof(width)); + out->write((const char*)&width1,sizeof(width1)); + out->write((const char*)&stx,sizeof(stx)); + out->write((const char*)&fx,sizeof(fx)); + out->write((const char*)&dgx,sizeof(dgx)); + out->write((const char*)&sty,sizeof(sty)); + out->write((const char*)&fy,sizeof(fy)); + out->write((const char*)&dgy,sizeof(dgy)); + out->write((const char*)&stmin,sizeof(stmin)); + out->write((const char*)&stmax,sizeof(stmax)); + } + + void deserialize(std::istream* in) const { + in->read((char*)&infoc, sizeof(infoc)); + in->read((char*)&dginit, sizeof(dginit)); + in->read((char*)&brackt, sizeof(brackt)); + in->read((char*)&stage1, sizeof(stage1)); + in->read((char*)&finit, sizeof(finit)); + in->read((char*)&dgtest, sizeof(dgtest)); + in->read((char*)&width, sizeof(width)); + in->read((char*)&width1, sizeof(width1)); + in->read((char*)&stx, sizeof(stx)); + in->read((char*)&fx, sizeof(fx)); + in->read((char*)&dgx, sizeof(dgx)); + in->read((char*)&sty, sizeof(sty)); + in->read((char*)&fy, sizeof(fy)); + in->read((char*)&dgy, sizeof(dgy)); + in->read((char*)&stmin, sizeof(stmin)); + in->read((char*)&stmax, sizeof(stmax)); + } + }; + + template + void mcsrch::run( + FloatType const& gtol, + FloatType const& stpmin, + FloatType const& stpmax, + SizeType n, + FloatType* x, + FloatType f, + const FloatType* g, + FloatType* s, + SizeType is0, + FloatType& stp, + FloatType ftol, + FloatType xtol, + SizeType maxfev, + int& info, + SizeType& nfev, + FloatType* wa) + { + if (info != -1) { + infoc = 1; + if ( n == 0 + || maxfev == 0 + || gtol < FloatType(0) + || xtol < FloatType(0) + || stpmin < FloatType(0) + || stpmax < stpmin) { + throw error_internal_error(__FILE__, __LINE__); + } + if (stp <= FloatType(0) || ftol < FloatType(0)) { + throw error_internal_error(__FILE__, __LINE__); + } + // Compute the initial gradient in the search direction + // and check that s is a descent direction. + dginit = FloatType(0); + for (SizeType j = 0; j < n; j++) { + dginit += g[j] * s[is0+j]; + } + if (dginit >= FloatType(0)) { + throw error_search_direction_not_descent(); + } + brackt = false; + stage1 = true; + nfev = 0; + finit = f; + dgtest = ftol*dginit; + width = stpmax - stpmin; + width1 = FloatType(2) * width; + std::copy(x, x+n, wa); + // The variables stx, fx, dgx contain the values of the step, + // function, and directional derivative at the best step. + // The variables sty, fy, dgy contain the value of the step, + // function, and derivative at the other endpoint of + // the interval of uncertainty. + // The variables stp, f, dg contain the values of the step, + // function, and derivative at the current step. + stx = FloatType(0); + fx = finit; + dgx = dginit; + sty = FloatType(0); + fy = finit; + dgy = dginit; + } + for (;;) { + if (info != -1) { + // Set the minimum and maximum steps to correspond + // to the present interval of uncertainty. + if (brackt) { + stmin = std::min(stx, sty); + stmax = std::max(stx, sty); + } + else { + stmin = stx; + stmax = stp + FloatType(4) * (stp - stx); + } + // Force the step to be within the bounds stpmax and stpmin. + stp = std::max(stp, stpmin); + stp = std::min(stp, stpmax); + // If an unusual termination is to occur then let + // stp be the lowest point obtained so far. + if ( (brackt && (stp <= stmin || stp >= stmax)) + || nfev >= maxfev - 1 || infoc == 0 + || (brackt && stmax - stmin <= xtol * stmax)) { + stp = stx; + } + // Evaluate the function and gradient at stp + // and compute the directional derivative. + // We return to main program to obtain F and G. + for (SizeType j = 0; j < n; j++) { + x[j] = wa[j] + stp * s[is0+j]; + } + info=-1; + break; + } + info = 0; + nfev++; + FloatType dg(0); + for (SizeType j = 0; j < n; j++) { + dg += g[j] * s[is0+j]; + } + FloatType ftest1 = finit + stp*dgtest; + // Test for convergence. + if ((brackt && (stp <= stmin || stp >= stmax)) || infoc == 0) { + throw error_line_search_failed_rounding_errors( + "Rounding errors prevent further progress." + " There may not be a step which satisfies the" + " sufficient decrease and curvature conditions." + " Tolerances may be too small."); + } + if (stp == stpmax && f <= ftest1 && dg <= dgtest) { + throw error_line_search_failed( + "The step is at the upper bound stpmax()."); + } + if (stp == stpmin && (f > ftest1 || dg >= dgtest)) { + throw error_line_search_failed( + "The step is at the lower bound stpmin()."); + } + if (nfev >= maxfev) { + throw error_line_search_failed( + "Number of function evaluations has reached maxfev()."); + } + if (brackt && stmax - stmin <= xtol * stmax) { + throw error_line_search_failed( + "Relative width of the interval of uncertainty" + " is at most xtol()."); + } + // Check for termination. + if (f <= ftest1 && abs(dg) <= gtol * (-dginit)) { + info = 1; + break; + } + // In the first stage we seek a step for which the modified + // function has a nonpositive value and nonnegative derivative. + if ( stage1 && f <= ftest1 + && dg >= std::min(ftol, gtol) * dginit) { + stage1 = false; + } + // A modified function is used to predict the step only if + // we have not obtained a step for which the modified + // function has a nonpositive function value and nonnegative + // derivative, and if a lower function value has been + // obtained but the decrease is not sufficient. + if (stage1 && f <= fx && f > ftest1) { + // Define the modified function and derivative values. + FloatType fm = f - stp*dgtest; + FloatType fxm = fx - stx*dgtest; + FloatType fym = fy - sty*dgtest; + FloatType dgm = dg - dgtest; + FloatType dgxm = dgx - dgtest; + FloatType dgym = dgy - dgtest; + // Call cstep to update the interval of uncertainty + // and to compute the new step. + infoc = mcstep(stx, fxm, dgxm, sty, fym, dgym, stp, fm, dgm, + brackt, stmin, stmax); + // Reset the function and gradient values for f. + fx = fxm + stx*dgtest; + fy = fym + sty*dgtest; + dgx = dgxm + dgtest; + dgy = dgym + dgtest; + } + else { + // Call mcstep to update the interval of uncertainty + // and to compute the new step. + infoc = mcstep(stx, fx, dgx, sty, fy, dgy, stp, f, dg, + brackt, stmin, stmax); + } + // Force a sufficient decrease in the size of the + // interval of uncertainty. + if (brackt) { + if (abs(sty - stx) >= FloatType(0.66) * width1) { + stp = stx + FloatType(0.5) * (sty - stx); + } + width1 = width; + width = abs(sty - stx); + } + } + } + + template + int mcsrch::mcstep( + FloatType& stx, + FloatType& fx, + FloatType& dx, + FloatType& sty, + FloatType& fy, + FloatType& dy, + FloatType& stp, + FloatType fp, + FloatType dp, + bool& brackt, + FloatType stpmin, + FloatType stpmax) + { + bool bound; + FloatType gamma, p, q, r, s, sgnd, stpc, stpf, stpq, theta; + int info = 0; + if ( ( brackt && (stp <= std::min(stx, sty) + || stp >= std::max(stx, sty))) + || dx * (stp - stx) >= FloatType(0) || stpmax < stpmin) { + return 0; + } + // Determine if the derivatives have opposite sign. + sgnd = dp * (dx / abs(dx)); + if (fp > fx) { + // First case. A higher function value. + // The minimum is bracketed. If the cubic step is closer + // to stx than the quadratic step, the cubic step is taken, + // else the average of the cubic and quadratic steps is taken. + info = 1; + bound = true; + theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; + s = max3(abs(theta), abs(dx), abs(dp)); + gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s)); + if (stp < stx) gamma = - gamma; + p = (gamma - dx) + theta; + q = ((gamma - dx) + gamma) + dp; + r = p/q; + stpc = stx + r * (stp - stx); + stpq = stx + + ((dx / ((fx - fp) / (stp - stx) + dx)) / FloatType(2)) + * (stp - stx); + if (abs(stpc - stx) < abs(stpq - stx)) { + stpf = stpc; + } + else { + stpf = stpc + (stpq - stpc) / FloatType(2); + } + brackt = true; + } + else if (sgnd < FloatType(0)) { + // Second case. A lower function value and derivatives of + // opposite sign. The minimum is bracketed. If the cubic + // step is closer to stx than the quadratic (secant) step, + // the cubic step is taken, else the quadratic step is taken. + info = 2; + bound = false; + theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; + s = max3(abs(theta), abs(dx), abs(dp)); + gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s)); + if (stp > stx) gamma = - gamma; + p = (gamma - dp) + theta; + q = ((gamma - dp) + gamma) + dx; + r = p/q; + stpc = stp + r * (stx - stp); + stpq = stp + (dp / (dp - dx)) * (stx - stp); + if (abs(stpc - stp) > abs(stpq - stp)) { + stpf = stpc; + } + else { + stpf = stpq; + } + brackt = true; + } + else if (abs(dp) < abs(dx)) { + // Third case. A lower function value, derivatives of the + // same sign, and the magnitude of the derivative decreases. + // The cubic step is only used if the cubic tends to infinity + // in the direction of the step or if the minimum of the cubic + // is beyond stp. Otherwise the cubic step is defined to be + // either stpmin or stpmax. The quadratic (secant) step is also + // computed and if the minimum is bracketed then the the step + // closest to stx is taken, else the step farthest away is taken. + info = 3; + bound = true; + theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; + s = max3(abs(theta), abs(dx), abs(dp)); + gamma = s * std::sqrt( + std::max(FloatType(0), pow2(theta / s) - (dx / s) * (dp / s))); + if (stp > stx) gamma = -gamma; + p = (gamma - dp) + theta; + q = (gamma + (dx - dp)) + gamma; + r = p/q; + if (r < FloatType(0) && gamma != FloatType(0)) { + stpc = stp + r * (stx - stp); + } + else if (stp > stx) { + stpc = stpmax; + } + else { + stpc = stpmin; + } + stpq = stp + (dp / (dp - dx)) * (stx - stp); + if (brackt) { + if (abs(stp - stpc) < abs(stp - stpq)) { + stpf = stpc; + } + else { + stpf = stpq; + } + } + else { + if (abs(stp - stpc) > abs(stp - stpq)) { + stpf = stpc; + } + else { + stpf = stpq; + } + } + } + else { + // Fourth case. A lower function value, derivatives of the + // same sign, and the magnitude of the derivative does + // not decrease. If the minimum is not bracketed, the step + // is either stpmin or stpmax, else the cubic step is taken. + info = 4; + bound = false; + if (brackt) { + theta = FloatType(3) * (fp - fy) / (sty - stp) + dy + dp; + s = max3(abs(theta), abs(dy), abs(dp)); + gamma = s * std::sqrt(pow2(theta / s) - (dy / s) * (dp / s)); + if (stp > sty) gamma = -gamma; + p = (gamma - dp) + theta; + q = ((gamma - dp) + gamma) + dy; + r = p/q; + stpc = stp + r * (sty - stp); + stpf = stpc; + } + else if (stp > stx) { + stpf = stpmax; + } + else { + stpf = stpmin; + } + } + // Update the interval of uncertainty. This update does not + // depend on the new step or the case analysis above. + if (fp > fx) { + sty = stp; + fy = fp; + dy = dp; + } + else { + if (sgnd < FloatType(0)) { + sty = stx; + fy = fx; + dy = dx; + } + stx = stp; + fx = fp; + dx = dp; + } + // Compute the new step and safeguard it. + stpf = std::min(stpmax, stpf); + stpf = std::max(stpmin, stpf); + stp = stpf; + if (brackt && bound) { + if (sty > stx) { + stp = std::min(stx + FloatType(0.66) * (sty - stx), stp); + } + else { + stp = std::max(stx + FloatType(0.66) * (sty - stx), stp); + } + } + return info; + } + + /* Compute the sum of a vector times a scalar plus another vector. + Adapted from the subroutine daxpy in + lbfgs.f. + */ + template + void daxpy( + SizeType n, + FloatType da, + const FloatType* dx, + SizeType ix0, + SizeType incx, + FloatType* dy, + SizeType iy0, + SizeType incy) + { + SizeType i, ix, iy, m; + if (n == 0) return; + if (da == FloatType(0)) return; + if (!(incx == 1 && incy == 1)) { + ix = 0; + iy = 0; + for (i = 0; i < n; i++) { + dy[iy0+iy] += da * dx[ix0+ix]; + ix += incx; + iy += incy; + } + return; + } + m = n % 4; + for (i = 0; i < m; i++) { + dy[iy0+i] += da * dx[ix0+i]; + } + for (; i < n;) { + dy[iy0+i] += da * dx[ix0+i]; i++; + dy[iy0+i] += da * dx[ix0+i]; i++; + dy[iy0+i] += da * dx[ix0+i]; i++; + dy[iy0+i] += da * dx[ix0+i]; i++; + } + } + + template + inline + void daxpy( + SizeType n, + FloatType da, + const FloatType* dx, + SizeType ix0, + FloatType* dy) + { + daxpy(n, da, dx, ix0, SizeType(1), dy, SizeType(0), SizeType(1)); + } + + /* Compute the dot product of two vectors. + Adapted from the subroutine ddot + in lbfgs.f. + */ + template + FloatType ddot( + SizeType n, + const FloatType* dx, + SizeType ix0, + SizeType incx, + const FloatType* dy, + SizeType iy0, + SizeType incy) + { + SizeType i, ix, iy, m; + FloatType dtemp(0); + if (n == 0) return FloatType(0); + if (!(incx == 1 && incy == 1)) { + ix = 0; + iy = 0; + for (i = 0; i < n; i++) { + dtemp += dx[ix0+ix] * dy[iy0+iy]; + ix += incx; + iy += incy; + } + return dtemp; + } + m = n % 5; + for (i = 0; i < m; i++) { + dtemp += dx[ix0+i] * dy[iy0+i]; + } + for (; i < n;) { + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + } + return dtemp; + } + + template + inline + FloatType ddot( + SizeType n, + const FloatType* dx, + const FloatType* dy) + { + return ddot( + n, dx, SizeType(0), SizeType(1), dy, SizeType(0), SizeType(1)); + } + + } // namespace detail + + //! Interface to the LBFGS %minimizer. + /*! This class solves the unconstrained minimization problem +

+          min f(x),  x = (x1,x2,...,x_n),
+      
+ using the limited-memory BFGS method. The routine is + especially effective on problems involving a large number of + variables. In a typical iteration of this method an + approximation Hk to the inverse of the Hessian + is obtained by applying m BFGS updates to a + diagonal matrix Hk0, using information from the + previous m steps. The user specifies the number + m, which determines the amount of storage + required by the routine. The user may also provide the + diagonal matrices Hk0 (parameter diag in the run() + function) if not satisfied with the default choice. The + algorithm is described in "On the limited memory BFGS method for + large scale optimization", by D. Liu and J. Nocedal, Mathematical + Programming B 45 (1989) 503-528. + + The user is required to calculate the function value + f and its gradient g. In order to + allow the user complete control over these computations, + reverse communication is used. The routine must be called + repeatedly under the control of the member functions + requests_f_and_g(), + requests_diag(). + If neither requests_f_and_g() nor requests_diag() is + true the user should check for convergence + (using class traditional_convergence_test or any + other custom test). If the convergence test is negative, + the minimizer may be called again for the next iteration. + + The steplength (stp()) is determined at each iteration + by means of the line search routine mcsrch, which is + a slight modification of the routine CSRCH written + by More' and Thuente. + + The only variables that are machine-dependent are + xtol, + stpmin and + stpmax. + + Fatal errors cause error exceptions to be thrown. + The generic class error is sub-classed (e.g. + class error_line_search_failed) to facilitate + granular %error handling. + + A note on performance: Using Compaq Fortran V5.4 and + Compaq C++ V6.5, the C++ implementation is about 15% slower + than the Fortran implementation. + */ + template + class minimizer + { + public: + //! Default constructor. Some members are not initialized! + minimizer() + : n_(0), m_(0), maxfev_(0), + gtol_(0), xtol_(0), + stpmin_(0), stpmax_(0), + ispt(0), iypt(0) + {} + + //! Constructor. + /*! @param n The number of variables in the minimization problem. + Restriction: n > 0. + + @param m The number of corrections used in the BFGS update. + Values of m less than 3 are not recommended; + large values of m will result in excessive + computing time. 3 <= m <= 7 is + recommended. + Restriction: m > 0. + + @param maxfev Maximum number of function evaluations + per line search. + Termination occurs when the number of evaluations + of the objective function is at least maxfev by + the end of an iteration. + + @param gtol Controls the accuracy of the line search. + If the function and gradient evaluations are inexpensive with + respect to the cost of the iteration (which is sometimes the + case when solving very large problems) it may be advantageous + to set gtol to a small value. A typical small + value is 0.1. + Restriction: gtol should be greater than 1e-4. + + @param xtol An estimate of the machine precision (e.g. 10e-16 + on a SUN station 3/60). The line search routine will + terminate if the relative width of the interval of + uncertainty is less than xtol. + + @param stpmin Specifies the lower bound for the step + in the line search. + The default value is 1e-20. This value need not be modified + unless the exponent is too large for the machine being used, + or unless the problem is extremely badly scaled (in which + case the exponent should be increased). + + @param stpmax specifies the upper bound for the step + in the line search. + The default value is 1e20. This value need not be modified + unless the exponent is too large for the machine being used, + or unless the problem is extremely badly scaled (in which + case the exponent should be increased). + */ + explicit + minimizer( + SizeType n, + SizeType m = 5, + SizeType maxfev = 20, + FloatType gtol = FloatType(0.9), + FloatType xtol = FloatType(1.e-16), + FloatType stpmin = FloatType(1.e-20), + FloatType stpmax = FloatType(1.e20)) + : n_(n), m_(m), maxfev_(maxfev), + gtol_(gtol), xtol_(xtol), + stpmin_(stpmin), stpmax_(stpmax), + iflag_(0), requests_f_and_g_(false), requests_diag_(false), + iter_(0), nfun_(0), stp_(0), + stp1(0), ftol(0.0001), ys(0), point(0), npt(0), + ispt(n+2*m), iypt((n+2*m)+n*m), + info(0), bound(0), nfev(0) + { + if (n_ == 0) { + throw error_improper_input_parameter("n = 0."); + } + if (m_ == 0) { + throw error_improper_input_parameter("m = 0."); + } + if (maxfev_ == 0) { + throw error_improper_input_parameter("maxfev = 0."); + } + if (gtol_ <= FloatType(1.e-4)) { + throw error_improper_input_parameter("gtol <= 1.e-4."); + } + if (xtol_ < FloatType(0)) { + throw error_improper_input_parameter("xtol < 0."); + } + if (stpmin_ < FloatType(0)) { + throw error_improper_input_parameter("stpmin < 0."); + } + if (stpmax_ < stpmin) { + throw error_improper_input_parameter("stpmax < stpmin"); + } + w_.resize(n_*(2*m_+1)+2*m_); + scratch_array_.resize(n_); + } + + //! Number of free parameters (as passed to the constructor). + SizeType n() const { return n_; } + + //! Number of corrections kept (as passed to the constructor). + SizeType m() const { return m_; } + + /*! \brief Maximum number of evaluations of the objective function + per line search (as passed to the constructor). + */ + SizeType maxfev() const { return maxfev_; } + + /*! \brief Control of the accuracy of the line search. + (as passed to the constructor). + */ + FloatType gtol() const { return gtol_; } + + //! Estimate of the machine precision (as passed to the constructor). + FloatType xtol() const { return xtol_; } + + /*! \brief Lower bound for the step in the line search. + (as passed to the constructor). + */ + FloatType stpmin() const { return stpmin_; } + + /*! \brief Upper bound for the step in the line search. + (as passed to the constructor). + */ + FloatType stpmax() const { return stpmax_; } + + //! Status indicator for reverse communication. + /*! true if the run() function returns to request + evaluation of the objective function (f) and + gradients (g) for the current point + (x). To continue the minimization the + run() function is called again with the updated values for + f and g. +

+ See also: requests_diag() + */ + bool requests_f_and_g() const { return requests_f_and_g_; } + + //! Status indicator for reverse communication. + /*! true if the run() function returns to request + evaluation of the diagonal matrix (diag) + for the current point (x). + To continue the minimization the run() function is called + again with the updated values for diag. +

+ See also: requests_f_and_g() + */ + bool requests_diag() const { return requests_diag_; } + + //! Number of iterations so far. + /*! Note that one iteration may involve multiple evaluations + of the objective function. +

+ See also: nfun() + */ + SizeType iter() const { return iter_; } + + //! Total number of evaluations of the objective function so far. + /*! The total number of function evaluations increases by the + number of evaluations required for the line search. The total + is only increased after a successful line search. +

+ See also: iter() + */ + SizeType nfun() const { return nfun_; } + + //! Norm of gradient given gradient array of length n(). + FloatType euclidean_norm(const FloatType* a) const { + return std::sqrt(detail::ddot(n_, a, a)); + } + + //! Current stepsize. + FloatType stp() const { return stp_; } + + //! Execution of one step of the minimization. + /*! @param x On initial entry this must be set by the user to + the values of the initial estimate of the solution vector. + + @param f Before initial entry or on re-entry under the + control of requests_f_and_g(), f must be set + by the user to contain the value of the objective function + at the current point x. + + @param g Before initial entry or on re-entry under the + control of requests_f_and_g(), g must be set + by the user to contain the components of the gradient at + the current point x. + + The return value is true if either + requests_f_and_g() or requests_diag() is true. + Otherwise the user should check for convergence + (e.g. using class traditional_convergence_test) and + call the run() function again to continue the minimization. + If the return value is false the user + should not update f, g or + diag (other overload) before calling + the run() function again. + + Note that x is always modified by the run() + function. Depending on the situation it can therefore be + necessary to evaluate the objective function one more time + after the minimization is terminated. + */ + bool run( + FloatType* x, + FloatType f, + const FloatType* g) + { + return generic_run(x, f, g, false, 0); + } + + //! Execution of one step of the minimization. + /*! @param x See other overload. + + @param f See other overload. + + @param g See other overload. + + @param diag On initial entry or on re-entry under the + control of requests_diag(), diag must be set by + the user to contain the values of the diagonal matrix Hk0. + The routine will return at each iteration of the algorithm + with requests_diag() set to true. +

+ Restriction: all elements of diag must be + positive. + */ + bool run( + FloatType* x, + FloatType f, + const FloatType* g, + const FloatType* diag) + { + return generic_run(x, f, g, true, diag); + } + + void serialize(std::ostream* out) const { + out->write((const char*)&n_, sizeof(n_)); // sanity check + out->write((const char*)&m_, sizeof(m_)); // sanity check + SizeType fs = sizeof(FloatType); + out->write((const char*)&fs, sizeof(fs)); // sanity check + + mcsrch_instance.serialize(out); + out->write((const char*)&iflag_, sizeof(iflag_)); + out->write((const char*)&requests_f_and_g_, sizeof(requests_f_and_g_)); + out->write((const char*)&requests_diag_, sizeof(requests_diag_)); + out->write((const char*)&iter_, sizeof(iter_)); + out->write((const char*)&nfun_, sizeof(nfun_)); + out->write((const char*)&stp_, sizeof(stp_)); + out->write((const char*)&stp1, sizeof(stp1)); + out->write((const char*)&ftol, sizeof(ftol)); + out->write((const char*)&ys, sizeof(ys)); + out->write((const char*)&point, sizeof(point)); + out->write((const char*)&npt, sizeof(npt)); + out->write((const char*)&info, sizeof(info)); + out->write((const char*)&bound, sizeof(bound)); + out->write((const char*)&nfev, sizeof(nfev)); + out->write((const char*)&w_[0], sizeof(FloatType) * w_.size()); + out->write((const char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size()); + } + + void deserialize(std::istream* in) { + SizeType n, m, fs; + in->read((char*)&n, sizeof(n)); + in->read((char*)&m, sizeof(m)); + in->read((char*)&fs, sizeof(fs)); + assert(n == n_); + assert(m == m_); + assert(fs == sizeof(FloatType)); + + mcsrch_instance.deserialize(in); + in->read((char*)&iflag_, sizeof(iflag_)); + in->read((char*)&requests_f_and_g_, sizeof(requests_f_and_g_)); + in->read((char*)&requests_diag_, sizeof(requests_diag_)); + in->read((char*)&iter_, sizeof(iter_)); + in->read((char*)&nfun_, sizeof(nfun_)); + in->read((char*)&stp_, sizeof(stp_)); + in->read((char*)&stp1, sizeof(stp1)); + in->read((char*)&ftol, sizeof(ftol)); + in->read((char*)&ys, sizeof(ys)); + in->read((char*)&point, sizeof(point)); + in->read((char*)&npt, sizeof(npt)); + in->read((char*)&info, sizeof(info)); + in->read((char*)&bound, sizeof(bound)); + in->read((char*)&nfev, sizeof(nfev)); + in->read((char*)&w_[0], sizeof(FloatType) * w_.size()); + in->read((char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size()); + } + + protected: + static void throw_diagonal_element_not_positive(SizeType i) { + throw error_improper_input_data( + "The " + error::itoa(i) + ". diagonal element of the" + " inverse Hessian approximation is not positive."); + } + + bool generic_run( + FloatType* x, + FloatType f, + const FloatType* g, + bool diagco, + const FloatType* diag); + + detail::mcsrch mcsrch_instance; + const SizeType n_; + const SizeType m_; + const SizeType maxfev_; + const FloatType gtol_; + const FloatType xtol_; + const FloatType stpmin_; + const FloatType stpmax_; + int iflag_; + bool requests_f_and_g_; + bool requests_diag_; + SizeType iter_; + SizeType nfun_; + FloatType stp_; + FloatType stp1; + FloatType ftol; + FloatType ys; + SizeType point; + SizeType npt; + const SizeType ispt; + const SizeType iypt; + int info; + SizeType bound; + SizeType nfev; + std::vector w_; + std::vector scratch_array_; + }; + + template + bool minimizer::generic_run( + FloatType* x, + FloatType f, + const FloatType* g, + bool diagco, + const FloatType* diag) + { + bool execute_entire_while_loop = false; + if (!(requests_f_and_g_ || requests_diag_)) { + execute_entire_while_loop = true; + } + requests_f_and_g_ = false; + requests_diag_ = false; + FloatType* w = &(*(w_.begin())); + if (iflag_ == 0) { // Initialize. + nfun_ = 1; + if (diagco) { + for (SizeType i = 0; i < n_; i++) { + if (diag[i] <= FloatType(0)) { + throw_diagonal_element_not_positive(i); + } + } + } + else { + std::fill_n(scratch_array_.begin(), n_, FloatType(1)); + diag = &(*(scratch_array_.begin())); + } + for (SizeType i = 0; i < n_; i++) { + w[ispt + i] = -g[i] * diag[i]; + } + FloatType gnorm = std::sqrt(detail::ddot(n_, g, g)); + if (gnorm == FloatType(0)) return false; + stp1 = FloatType(1) / gnorm; + execute_entire_while_loop = true; + } + if (execute_entire_while_loop) { + bound = iter_; + iter_++; + info = 0; + if (iter_ != 1) { + if (iter_ > m_) bound = m_; + ys = detail::ddot( + n_, w, iypt + npt, SizeType(1), w, ispt + npt, SizeType(1)); + if (!diagco) { + FloatType yy = detail::ddot( + n_, w, iypt + npt, SizeType(1), w, iypt + npt, SizeType(1)); + std::fill_n(scratch_array_.begin(), n_, ys / yy); + diag = &(*(scratch_array_.begin())); + } + else { + iflag_ = 2; + requests_diag_ = true; + return true; + } + } + } + if (execute_entire_while_loop || iflag_ == 2) { + if (iter_ != 1) { + if (diag == 0) { + throw error_internal_error(__FILE__, __LINE__); + } + if (diagco) { + for (SizeType i = 0; i < n_; i++) { + if (diag[i] <= FloatType(0)) { + throw_diagonal_element_not_positive(i); + } + } + } + SizeType cp = point; + if (point == 0) cp = m_; + w[n_ + cp -1] = 1 / ys; + SizeType i; + for (i = 0; i < n_; i++) { + w[i] = -g[i]; + } + cp = point; + for (i = 0; i < bound; i++) { + if (cp == 0) cp = m_; + cp--; + FloatType sq = detail::ddot( + n_, w, ispt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1)); + SizeType inmc=n_+m_+cp; + SizeType iycn=iypt+cp*n_; + w[inmc] = w[n_ + cp] * sq; + detail::daxpy(n_, -w[inmc], w, iycn, w); + } + for (i = 0; i < n_; i++) { + w[i] *= diag[i]; + } + for (i = 0; i < bound; i++) { + FloatType yr = detail::ddot( + n_, w, iypt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1)); + FloatType beta = w[n_ + cp] * yr; + SizeType inmc=n_+m_+cp; + beta = w[inmc] - beta; + SizeType iscn=ispt+cp*n_; + detail::daxpy(n_, beta, w, iscn, w); + cp++; + if (cp == m_) cp = 0; + } + std::copy(w, w+n_, w+(ispt + point * n_)); + } + stp_ = FloatType(1); + if (iter_ == 1) stp_ = stp1; + std::copy(g, g+n_, w); + } + mcsrch_instance.run( + gtol_, stpmin_, stpmax_, n_, x, f, g, w, ispt + point * n_, + stp_, ftol, xtol_, maxfev_, info, nfev, &(*(scratch_array_.begin()))); + if (info == -1) { + iflag_ = 1; + requests_f_and_g_ = true; + return true; + } + if (info != 1) { + throw error_internal_error(__FILE__, __LINE__); + } + nfun_ += nfev; + npt = point*n_; + for (SizeType i = 0; i < n_; i++) { + w[ispt + npt + i] = stp_ * w[ispt + npt + i]; + w[iypt + npt + i] = g[i] - w[i]; + } + point++; + if (point == m_) point = 0; + return false; + } + + //! Traditional LBFGS convergence test. + /*! This convergence test is equivalent to the test embedded + in the lbfgs.f Fortran code. The test assumes that + there is a meaningful relation between the Euclidean norm of the + parameter vector x and the norm of the gradient + vector g. Therefore this test should not be used if + this assumption is not correct for a given problem. + */ + template + class traditional_convergence_test + { + public: + //! Default constructor. + traditional_convergence_test() + : n_(0), eps_(0) + {} + + //! Constructor. + /*! @param n The number of variables in the minimization problem. + Restriction: n > 0. + + @param eps Determines the accuracy with which the solution + is to be found. + */ + explicit + traditional_convergence_test( + SizeType n, + FloatType eps = FloatType(1.e-5)) + : n_(n), eps_(eps) + { + if (n_ == 0) { + throw error_improper_input_parameter("n = 0."); + } + if (eps_ < FloatType(0)) { + throw error_improper_input_parameter("eps < 0."); + } + } + + //! Number of free parameters (as passed to the constructor). + SizeType n() const { return n_; } + + /*! \brief Accuracy with which the solution is to be found + (as passed to the constructor). + */ + FloatType eps() const { return eps_; } + + //! Execution of the convergence test for the given parameters. + /*! Returns true if +

+            ||g|| < eps * max(1,||x||),
+          
+ where ||.|| denotes the Euclidean norm. + + @param x Current solution vector. + + @param g Components of the gradient at the current + point x. + */ + bool + operator()(const FloatType* x, const FloatType* g) const + { + FloatType xnorm = std::sqrt(detail::ddot(n_, x, x)); + FloatType gnorm = std::sqrt(detail::ddot(n_, g, g)); + if (gnorm <= eps_ * std::max(FloatType(1), xnorm)) return true; + return false; + } + protected: + const SizeType n_; + const FloatType eps_; + }; + +}} // namespace scitbx::lbfgs + +template +std::ostream& operator<<(std::ostream& os, const scitbx::lbfgs::minimizer& min) { + return os << "ITER=" << min.iter() << "\tNFUN=" << min.nfun() << "\tSTP=" << min.stp() << "\tDIAG=" << min.requests_diag() << "\tF&G=" << min.requests_f_and_g(); +} + + +#endif // SCITBX_LBFGS_H diff --git a/training/lbfgs_test.cc b/training/lbfgs_test.cc new file mode 100644 index 00000000..4171c118 --- /dev/null +++ b/training/lbfgs_test.cc @@ -0,0 +1,112 @@ +#include +#include +#include +#include "lbfgs.h" +#include "sparse_vector.h" +#include "fdict.h" + +using namespace std; + +double TestOptimizer() { + cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; + + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + double x[3]; + double g[3]; + scitbx::lbfgs::minimizer opt(3); + scitbx::lbfgs::traditional_convergence_test converged(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + do { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + opt.run(x, obj, g); + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + cerr << opt << endl; + } while (!converged(x, g)); + return obj; +} + +double TestPersistentOptimizer() { + cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + double x[3]; + double g[3]; + scitbx::lbfgs::traditional_convergence_test converged(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + string state; + do { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + + { + scitbx::lbfgs::minimizer opt(3); + if (state.size() > 0) { + istringstream is(state, ios::binary); + opt.deserialize(&is); + } + opt.run(x, obj, g); + ostringstream os(ios::binary); opt.serialize(&os); state = os.str(); + } + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + } while (!converged(x, g)); + return obj; +} + +void TestSparseVector() { + cerr << "Testing SparseVector serialization.\n"; + int f1 = FD::Convert("Feature_1"); + int f2 = FD::Convert("Feature_2"); + FD::Convert("LanguageModel"); + int f4 = FD::Convert("SomeFeature"); + int f5 = FD::Convert("SomeOtherFeature"); + SparseVector g; + g.set_value(f2, log(0.5)); + g.set_value(f4, log(0.125)); + g.set_value(f1, 0); + g.set_value(f5, 23.777); + ostringstream os; + double iobj = 1.5; + B64::Encode(iobj, g, &os); + cerr << iobj << "\t" << g << endl; + string data = os.str(); + cout << data << endl; + SparseVector v; + double obj; + assert(B64::Decode(&obj, &v, &data[0], data.size())); + cerr << obj << "\t" << v << endl; + assert(obj == iobj); + assert(g.num_active() == v.num_active()); +} + +int main() { + double o1 = TestOptimizer(); + double o2 = TestPersistentOptimizer(); + if (o1 != o2) { + cerr << "OPTIMIZERS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl; + return 1; + } + TestSparseVector(); + cerr << "SUCCESS\n"; + return 0; +} + diff --git a/training/make-lexcrf-grammar.pl b/training/make-lexcrf-grammar.pl new file mode 100755 index 00000000..0e290492 --- /dev/null +++ b/training/make-lexcrf-grammar.pl @@ -0,0 +1,236 @@ +#!/usr/bin/perl -w +use utf8; +use strict; +my ($effile, $model1) = @ARGV; +die "Usage: $0 corpus.fr-en corpus.model1\n" unless $effile && -f $effile && $model1 && -f $model1; + +open EF, "<$effile" or die; +open M1, "<$model1" or die; +binmode(EF,":utf8"); +binmode(M1,":utf8"); +binmode(STDOUT,":utf8"); +my %model1; +while() { + chomp; + my ($f, $e, $lp) = split /\s+/; + $model1{$f}->{$e} = $lp; +} + +my $ADD_MODEL1 = 0; # found that model1 hurts performance +my $IS_FRENCH_F = 0; # indicates that the f language is french +my $IS_ARABIC_F = 1; # indicates that the f language is arabic +my $ADD_PREFIX_ID = 0; +my $ADD_LEN = 1; +my $ADD_LD = 0; +my $ADD_DICE = 1; +my $ADD_111 = 1; +my $ADD_ID = 1; +my $ADD_PUNC = 1; +my $ADD_NUM_MM = 1; +my $ADD_NULL = 1; +my $BEAM_RATIO = 50; + +my %fdict; +my %fcounts; +my %ecounts; + +while() { + chomp; + my ($f, $e) = split /\s*\|\|\|\s*/; + my @es = split /\s+/, $e; + my @fs = split /\s+/, $f; + for my $ew (@es){ $ecounts{$ew}++; } + push @fs, '' if $ADD_NULL; + for my $fw (@fs){ $fcounts{$fw}++; } + for my $fw (@fs){ + for my $ew (@es){ + $fdict{$fw}->{$ew}++; + } + } +} + +print STDERR "Dice 0\n" if $ADD_DICE; +print STDERR "OneOneOne 0\nId_OneOneOne 0\n" if $ADD_111; +print STDERR "Identical 0\n" if $ADD_ID; +print STDERR "PuncMiss 0\n" if $ADD_PUNC; +print STDERR "IsNull 0\n" if $ADD_NULL; +print STDERR "Model1 0\n" if $ADD_MODEL1; +print STDERR "DLen 0\n" if $ADD_LEN; +print STDERR "NumMM 0\n" if $ADD_NUM_MM; +print STDERR "Level 0\n" if $ADD_LD; +print STDERR "PfxIdentical 0\n" if ($ADD_PREFIX_ID); +my $fc = 1000000; +for my $f (sort keys %fdict) { + my $re = $fdict{$f}; + my $max; + for my $e (sort {$re->{$b} <=> $re->{$a}} keys %$re) { + my $efcount = $re->{$e}; + unless (defined $max) { $max = $efcount; } + my $m1 = $model1{$f}->{$e}; + unless (defined $m1) { next; } + $fc++; + my $dice = 2 * $efcount / ($ecounts{$e} + $fcounts{$f}); + my $feats = "F$fc=1"; + my $oe = $e; + my $len_e = length($oe); + my $of = $f; # normalized form + if ($IS_FRENCH_F) { + # see http://en.wikipedia.org/wiki/Use_of_the_circumflex_in_French + $of =~ s/â/as/g; + $of =~ s/ê/es/g; + $of =~ s/î/is/g; + $of =~ s/ô/os/g; + $of =~ s/û/us/g; + } elsif ($IS_ARABIC_F) { + if (length($of) > 1 && !($of =~ /\d/)) { + $of =~ s/\$/sh/g; + } + } + my $len_f = length($of); + $feats .= " Model1=$m1" if ($ADD_MODEL1); + $feats .= " Dice=$dice" if $ADD_DICE; + my $is_null = undef; + if ($ADD_NULL && $f eq '') { + $feats .= " IsNull=1"; + $is_null = 1; + } + if ($ADD_LEN) { + if (!$is_null) { + my $dlen = abs($len_e - $len_f); + $feats .= " DLen=$dlen"; + } + } + my $f_num = ($of =~ /^-?\d[0-9\.\,]+%?$/); # this matches *two digit* and more numbers + my $e_num = ($oe =~ /^-?\d[0-9\.\,]+%?$/); + my $both_non_numeric = (!$e_num && !$f_num); + if ($ADD_NUM_MM && (($f_num && !$e_num) || ($e_num && !$f_num))) { + $feats .= " NumMM=1"; + } + if ($ADD_PREFIX_ID) { + if ($len_e > 3 && $len_f > 3 && $both_non_numeric) { + my $pe = substr $oe, 0, 3; + my $pf = substr $of, 0, 3; + if ($pe eq $pf) { $feats .= " PfxIdentical=1"; } + } + } + if ($ADD_LD) { + my $ld = 0; + if ($is_null) { $ld = length($e); } else { + $ld = levenshtein($e, $f); + } + $feats .= " Leven=$ld"; + } + my $ident = ($e eq $f); + if ($ident && $ADD_ID) { $feats .= " Identical=1"; } + if ($ADD_111 && ($efcount == 1 && $ecounts{$e} == 1 && $fcounts{$f} == 1)) { + if ($ident && $ADD_ID) { + $feats .= " Id_OneOneOne=1"; + } + $feats .= " OneOneOne=1"; + } + if ($ADD_PUNC) { + if (($f =~ /^[0-9!\$%,\-\/"':;=+?.()«»]+$/ && $e =~ /[a-z]+/) || + ($e =~ /^[0-9!\$%,\-\/"':;=+?.()«»]+$/ && $f =~ /[a-z]+/)) { + $feats .= " PuncMiss=1"; + } + } + my $r = (0.5 - rand)/5; + print STDERR "F$fc $r\n"; + print "$f ||| $e ||| $feats\n"; + } +} + +sub levenshtein +{ + # $s1 and $s2 are the two strings + # $len1 and $len2 are their respective lengths + # + my ($s1, $s2) = @_; + my ($len1, $len2) = (length $s1, length $s2); + + # If one of the strings is empty, the distance is the length + # of the other string + # + return $len2 if ($len1 == 0); + return $len1 if ($len2 == 0); + + my %mat; + + # Init the distance matrix + # + # The first row to 0..$len1 + # The first column to 0..$len2 + # The rest to 0 + # + # The first row and column are initialized so to denote distance + # from the empty string + # + for (my $i = 0; $i <= $len1; ++$i) + { + for (my $j = 0; $j <= $len2; ++$j) + { + $mat{$i}{$j} = 0; + $mat{0}{$j} = $j; + } + + $mat{$i}{0} = $i; + } + + # Some char-by-char processing is ahead, so prepare + # array of chars from the strings + # + my @ar1 = split(//, $s1); + my @ar2 = split(//, $s2); + + for (my $i = 1; $i <= $len1; ++$i) + { + for (my $j = 1; $j <= $len2; ++$j) + { + # Set the cost to 1 iff the ith char of $s1 + # equals the jth of $s2 + # + # Denotes a substitution cost. When the char are equal + # there is no need to substitute, so the cost is 0 + # + my $cost = ($ar1[$i-1] eq $ar2[$j-1]) ? 0 : 1; + + # Cell $mat{$i}{$j} equals the minimum of: + # + # - The cell immediately above plus 1 + # - The cell immediately to the left plus 1 + # - The cell diagonally above and to the left plus the cost + # + # We can either insert a new char, delete a char or + # substitute an existing char (with an associated cost) + # + $mat{$i}{$j} = min([$mat{$i-1}{$j} + 1, + $mat{$i}{$j-1} + 1, + $mat{$i-1}{$j-1} + $cost]); + } + } + + # Finally, the Levenshtein distance equals the rightmost bottom cell + # of the matrix + # + # Note that $mat{$x}{$y} denotes the distance between the substrings + # 1..$x and 1..$y + # + return $mat{$len1}{$len2}; +} + + +# minimal element of a list +# +sub min +{ + my @list = @{$_[0]}; + my $min = $list[0]; + + foreach my $i (@list) + { + $min = $i if ($i < $min); + } + + return $min; +} + diff --git a/training/model1.cc b/training/model1.cc new file mode 100644 index 00000000..f571700f --- /dev/null +++ b/training/model1.cc @@ -0,0 +1,103 @@ +#include + +#include "lattice.h" +#include "stringlib.h" +#include "filelib.h" +#include "ttables.h" +#include "tdict.h" + +using namespace std; + +int main(int argc, char** argv) { + if (argc != 2) { + cerr << "Usage: " << argv[0] << " corpus.fr-en\n"; + return 1; + } + const int ITERATIONS = 5; + const prob_t BEAM_THRESHOLD(0.0001); + TTable tt; + const WordID kNULL = TD::Convert(""); + bool use_null = true; + TTable::Word2Word2Double was_viterbi; + for (int iter = 0; iter < ITERATIONS; ++iter) { + const bool final_iteration = (iter == (ITERATIONS - 1)); + cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; + ReadFile rf(argv[1]); + istream& in = *rf.stream(); + prob_t likelihood = prob_t::One(); + double denom = 0.0; + int lc = 0; + bool flag = false; + while(true) { + string line; + getline(in, line); + if (!in) break; + ++lc; + if (lc % 1000 == 0) { cerr << '.'; flag = true; } + if (lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; } + string ssrc, strg; + ParseTranslatorInput(line, &ssrc, &strg); + Lattice src, trg; + LatticeTools::ConvertTextToLattice(ssrc, &src); + LatticeTools::ConvertTextToLattice(strg, &trg); + assert(src.size() > 0); + assert(trg.size() > 0); + denom += 1.0; + vector probs(src.size() + 1); + for (int j = 0; j < trg.size(); ++j) { + const WordID& f_j = trg[j][0].label; + prob_t sum = prob_t::Zero(); + if (use_null) { + probs[0] = tt.prob(kNULL, f_j); + sum += probs[0]; + } + for (int i = 1; i <= src.size(); ++i) { + probs[i] = tt.prob(src[i-1][0].label, f_j); + sum += probs[i]; + } + if (final_iteration) { + WordID max_i = 0; + prob_t max_p = prob_t::Zero(); + if (use_null) { + max_i = kNULL; + max_p = probs[0]; + } + for (int i = 1; i <= src.size(); ++i) { + if (probs[i] > max_p) { + max_p = probs[i]; + max_i = src[i-1][0].label; + } + } + was_viterbi[max_i][f_j] = 1.0; + } else { + if (use_null) + tt.Increment(kNULL, f_j, probs[0] / sum); + for (int i = 1; i <= src.size(); ++i) + tt.Increment(src[i-1][0].label, f_j, probs[i] / sum); + } + likelihood *= sum; + } + } + if (flag) { cerr << endl; } + cerr << " log likelihood: " << log(likelihood) << endl; + cerr << " cross entopy: " << (-log(likelihood) / denom) << endl; + cerr << " perplexity: " << pow(2.0, -log(likelihood) / denom) << endl; + if (!final_iteration) tt.Normalize(); + } + for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) { + const TTable::Word2Double& cpd = ei->second; + const TTable::Word2Double& vit = was_viterbi[ei->first]; + const string& esym = TD::Convert(ei->first); + prob_t max_p = prob_t::Zero(); + for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) + if (fi->second > max_p) max_p = prob_t(fi->second); + const prob_t threshold = max_p * BEAM_THRESHOLD; + for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) { + if (fi->second > threshold || (vit.count(fi->first) > 0)) { + cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl; + } + } + } + return 0; +} + diff --git a/training/mr_em_train.cc b/training/mr_em_train.cc new file mode 100644 index 00000000..a15fbe4c --- /dev/null +++ b/training/mr_em_train.cc @@ -0,0 +1,270 @@ +#include +#include +#include +#include + +#include +#include + +#include "config.h" +#ifdef HAVE_BOOST_DIGAMMA +#include +using boost::math::digamma; +#endif + +#include "tdict.h" +#include "filelib.h" +#include "trule.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" + +using namespace std; +using boost::shared_ptr; +namespace po = boost::program_options; + +#ifndef HAVE_BOOST_DIGAMMA +#warning Using Mark Johnson's digamma() +double digamma(double x) { + double result = 0, xx, xx2, xx4; + assert(x > 0); + for ( ; x < 7; ++x) + result -= 1/x; + x -= 1.0/2.0; + xx = 1.0/x; + xx2 = xx*xx; + xx4 = xx2*xx2; + result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4; + return result; +} +#endif + +void SanityCheck(const vector& w) { + for (int i = 0; i < w.size(); ++i) { + assert(!isnan(w[i])); + } +} + +struct FComp { + const vector& w_; + FComp(const vector& w) : w_(w) {} + bool operator()(int a, int b) const { + return w_[a] > w_[b]; + } +}; + +void ShowLargestFeatures(const vector& w) { + vector fnums(w.size() - 1); + for (int i = 1; i < w.size(); ++i) + fnums[i-1] = i; + vector::iterator mid = fnums.begin(); + mid += (w.size() > 10 ? 10 : w.size()) - 1; + partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); + cerr << "MOST PROBABLE:"; + for (vector::iterator i = fnums.begin(); i != mid; ++i) { + cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; + } + cerr << endl; +} + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("output,o",po::value()->default_value("-"),"Output log probs file") + ("grammar,g",po::value >()->composing(),"SCFG grammar file(s)") + ("optimization_method,m", po::value()->default_value("em"), "Optimization method (em, vb)") + ("input_format,f",po::value()->default_value("b64"),"Encoding of the input (b64 or text)"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("grammar")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +// describes a multinomial or multinomial with a prior +// does not contain the parameters- just the list of events +// and any hyperparameters +struct MultinomialInfo { + MultinomialInfo() : alpha(1.0) {} + vector events; // the events that this multinomial generates + double alpha; // hyperparameter for (optional) Dirichlet prior +}; + +typedef map ModelDefinition; + +void LoadModelEvents(const po::variables_map& conf, ModelDefinition* pm) { + ModelDefinition& m = *pm; + m.clear(); + vector gfiles = conf["grammar"].as >(); + for (int i = 0; i < gfiles.size(); ++i) { + ReadFile rf(gfiles[i]); + istream& in = *rf.stream(); + int lc = 0; + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + ++lc; + TRule r(line, true); + const SparseVector& f = r.GetFeatureValues(); + if (f.num_active() == 0) { + cerr << "[WARNING] no feature found in " << gfiles[i] << ':' << lc << endl; + continue; + } + if (f.num_active() > 1) { + cerr << "[ERROR] more than one feature found in " << gfiles[i] << ':' << lc << endl; + exit(1); + } + SparseVector::const_iterator it = f.begin(); + if (it->second != 1.0) { + cerr << "[ERROR] feature with value != 1 found in " << gfiles[i] << ':' << lc << endl; + exit(1); + } + m[r.GetLHS()].events.push_back(it->first); + } + } + for (ModelDefinition::iterator it = m.begin(); it != m.end(); ++it) { + const vector& v = it->second.events; + cerr << "Multinomial [" << TD::Convert(it->first*-1) << "]\n"; + if (v.size() < 1000) { + cerr << " generates:"; + for (int i = 0; i < v.size(); ++i) { + cerr << " " << FD::Convert(v[i]); + } + cerr << endl; + } + } +} + +void Maximize(const ModelDefinition& m, const bool use_vb, vector* counts) { + for (ModelDefinition::const_iterator it = m.begin(); it != m.end(); ++it) { + const MultinomialInfo& mult_info = it->second; + const vector& events = mult_info.events; + cerr << "Multinomial [" << TD::Convert(it->first*-1) << "]"; + double tot = 0; + for (int i = 0; i < events.size(); ++i) + tot += (*counts)[events[i]]; + cerr << " = " << tot << endl; + assert(tot > 0.0); + double ltot = log(tot); + if (use_vb) + ltot = digamma(tot + events.size() * mult_info.alpha); + for (int i = 0; i < events.size(); ++i) { + if (use_vb) { + (*counts)[events[i]] = digamma((*counts)[events[i]] + mult_info.alpha) - ltot; + } else { + (*counts)[events[i]] = log((*counts)[events[i]]) - ltot; + } + } + if (events.size() < 50) { + for (int i = 0; i < events.size(); ++i) { + cerr << " p(" << FD::Convert(events[i]) << ")=" << exp((*counts)[events[i]]); + } + cerr << endl; + } + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + + const bool use_b64 = conf["input_format"].as() == "b64"; + const bool use_vb = conf["optimization_method"].as() == "vb"; + if (use_vb) + cerr << "Using variational Bayes, make sure alphas are set\n"; + + ModelDefinition model_def; + LoadModelEvents(conf, &model_def); + + const string s_obj = "**OBJ**"; + int num_feats = FD::NumFeats(); + cerr << "Number of features: " << num_feats << endl; + + vector counts(num_feats, 0); + double logprob = 0; + // 0**OBJ**=12.2;Feat1=2.3;Feat2=-0.2; + // 0**OBJ**=1.1;Feat1=1.0; + + // E-step + while(cin) { + string line; + getline(cin, line); + if (line.empty()) continue; + int feat; + double val; + size_t i = line.find("\t"); + assert(i != string::npos); + ++i; + if (use_b64) { + SparseVector g; + double obj; + if (!B64::Decode(&obj, &g, &line[i], line.size() - i)) { + cerr << "B64 decoder returned error, skipping!\n"; + continue; + } + logprob += obj; + const SparseVector& cg = g; + for (SparseVector::const_iterator it = cg.begin(); it != cg.end(); ++it) { + if (it->first >= num_feats) { + cerr << "Unexpected feature: " << FD::Convert(it->first) << endl; + abort(); + } + counts[it->first] += it->second; + } + } else { // text encoding - your counts will not be accurate! + while (i < line.size()) { + size_t start = i; + while (line[i] != '=' && i < line.size()) ++i; + if (i == line.size()) { cerr << "FORMAT ERROR\n"; break; } + string fname = line.substr(start, i - start); + if (fname == s_obj) { + feat = -1; + } else { + feat = FD::Convert(line.substr(start, i - start)); + if (feat >= num_feats) { + cerr << "Unexpected feature: " << line.substr(start, i - start) << endl; + abort(); + } + } + ++i; + start = i; + while (line[i] != ';' && i < line.size()) ++i; + if (i - start == 0) continue; + val = atof(line.substr(start, i - start).c_str()); + ++i; + if (feat == -1) { + logprob += val; + } else { + counts[feat] += val; + } + } + } + } + + cerr << "LOGPROB: " << logprob << endl; + // M-step + Maximize(model_def, use_vb, &counts); + + SanityCheck(counts); + ShowLargestFeatures(counts); + Weights weights; + weights.InitFromVector(counts); + weights.WriteToFile(conf["output"].as(), false); + + return 0; +} diff --git a/training/mr_optimize_reduce.cc b/training/mr_optimize_reduce.cc new file mode 100644 index 00000000..56b73c30 --- /dev/null +++ b/training/mr_optimize_reduce.cc @@ -0,0 +1,243 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "optimize.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" + +using namespace std; +using boost::shared_ptr; +namespace po = boost::program_options; + +void SanityCheck(const vector& w) { + for (int i = 0; i < w.size(); ++i) { + assert(!isnan(w[i])); + assert(!isinf(w[i])); + } +} + +struct FComp { + const vector& w_; + FComp(const vector& w) : w_(w) {} + bool operator()(int a, int b) const { + return fabs(w_[a]) > fabs(w_[b]); + } +}; + +void ShowLargestFeatures(const vector& w) { + vector fnums(w.size()); + for (int i = 0; i < w.size(); ++i) + fnums[i] = i; + vector::iterator mid = fnums.begin(); + mid += (w.size() > 10 ? 10 : w.size()); + partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); + cerr << "TOP FEATURES:"; + for (vector::iterator i = fnums.begin(); i != mid; ++i) { + cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; + } + cerr << endl; +} + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input_weights,i",po::value(),"Input feature weights file") + ("output_weights,o",po::value()->default_value("-"),"Output feature weights file") + ("optimization_method,m", po::value()->default_value("lbfgs"), "Optimization method (sgd, lbfgs, rprop)") + ("state,s",po::value(),"Read (and write if output_state is not set) optimizer state from this state file. In the first iteration, the file should not exist.") + ("input_format,f",po::value()->default_value("b64"),"Encoding of the input (b64 or text)") + ("output_state,S", po::value(), "Output state file (optional override)") + ("correction_buffers,M", po::value()->default_value(10), "Number of gradients for LBFGS to maintain in memory") + ("eta,e", po::value()->default_value(0.1), "Learning rate for SGD (eta)") + ("gaussian_prior,p","Use a Gaussian prior on the weights") + ("means,u", po::value(), "File containing the means for Gaussian prior") + ("sigma_squared", po::value()->default_value(1.0), "Sigma squared term for spherical Gaussian prior"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("input_weights") || !conf->count("state")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + + const bool use_b64 = conf["input_format"].as() == "b64"; + + Weights weights; + weights.InitFromFile(conf["input_weights"].as()); + const string s_obj = "**OBJ**"; + int num_feats = FD::NumFeats(); + cerr << "Number of features: " << num_feats << endl; + const bool gaussian_prior = conf.count("gaussian_prior"); + vector means(num_feats, 0); + if (conf.count("means")) { + if (!gaussian_prior) { + cerr << "Don't use --means without --gaussian_prior!\n"; + exit(1); + } + Weights wm; + wm.InitFromFile(conf["means"].as()); + if (num_feats != FD::NumFeats()) { + cerr << "[ERROR] Means file had unexpected features!\n"; + exit(1); + } + wm.InitVector(&means); + } + shared_ptr o; + const string omethod = conf["optimization_method"].as(); + if (omethod == "sgd") + o.reset(new SGDOptimizer(conf["eta"].as())); + else if (omethod == "rprop") + o.reset(new RPropOptimizer(num_feats)); // TODO add configuration + else + o.reset(new LBFGSOptimizer(num_feats, conf["correction_buffers"].as())); + cerr << "Optimizer: " << o->Name() << endl; + string state_file = conf["state"].as(); + { + ifstream in(state_file.c_str(), ios::binary); + if (in) + o->Load(&in); + else + cerr << "No state file found, assuming ITERATION 1\n"; + } + + vector lambdas(num_feats, 0); + weights.InitVector(&lambdas); + double objective = 0; + vector gradient(num_feats, 0); + // 0**OBJ**=12.2;Feat1=2.3;Feat2=-0.2; + // 0**OBJ**=1.1;Feat1=1.0; + int total_lines = 0; // TODO - this should be a count of the + // training instances!! + while(cin) { + string line; + getline(cin, line); + if (line.empty()) continue; + ++total_lines; + int feat; + double val; + size_t i = line.find("\t"); + assert(i != string::npos); + ++i; + if (use_b64) { + SparseVector g; + double obj; + if (!B64::Decode(&obj, &g, &line[i], line.size() - i)) { + cerr << "B64 decoder returned error, skipping gradient!\n"; + cerr << " START: " << line.substr(0,min(200ul, line.size())) << endl; + if (line.size() > 200) + cerr << " END: " << line.substr(line.size() - 200, 200) << endl; + cout << "-1\tRESTART\n"; + exit(99); + } + objective += obj; + const SparseVector& cg = g; + for (SparseVector::const_iterator it = cg.begin(); it != cg.end(); ++it) { + if (it->first >= num_feats) { + cerr << "Unexpected feature in gradient: " << FD::Convert(it->first) << endl; + abort(); + } + gradient[it->first] -= it->second; + } + } else { // text encoding - your gradients will not be accurate! + while (i < line.size()) { + size_t start = i; + while (line[i] != '=' && i < line.size()) ++i; + if (i == line.size()) { cerr << "FORMAT ERROR\n"; break; } + string fname = line.substr(start, i - start); + if (fname == s_obj) { + feat = -1; + } else { + feat = FD::Convert(line.substr(start, i - start)); + if (feat >= num_feats) { + cerr << "Unexpected feature in gradient: " << line.substr(start, i - start) << endl; + abort(); + } + } + ++i; + start = i; + while (line[i] != ';' && i < line.size()) ++i; + if (i - start == 0) continue; + val = atof(line.substr(start, i - start).c_str()); + ++i; + if (feat == -1) { + objective += val; + } else { + gradient[feat] -= val; + } + } + } + } + + if (gaussian_prior) { + const double sigsq = conf["sigma_squared"].as(); + double norm = 0; + for (int k = 1; k < lambdas.size(); ++k) { + const double& lambda_k = lambdas[k]; + if (lambda_k) { + const double param = (lambda_k - means[k]); + norm += param * param; + gradient[k] += param / sigsq; + } + } + const double reg = norm / (2.0 * sigsq); + cerr << "REGULARIZATION TERM: " << reg << endl; + objective += reg; + } + cerr << "EVALUATION #" << o->EvaluationCount() << " OBJECTIVE: " << objective << endl; + double gnorm = 0; + for (int i = 0; i < gradient.size(); ++i) + gnorm += gradient[i] * gradient[i]; + cerr << " GNORM=" << sqrt(gnorm) << endl; + vector old = lambdas; + int c = 0; + while (old == lambdas) { + ++c; + if (c > 1) { cerr << "Same lambdas, repeating optimization\n"; } + o->Optimize(objective, gradient, &lambdas); + assert(c < 5); + } + old.clear(); + SanityCheck(lambdas); + ShowLargestFeatures(lambdas); + weights.InitFromVector(lambdas); + weights.WriteToFile(conf["output_weights"].as(), false); + + const bool conv = o->HasConverged(); + if (conv) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } + + if (conf.count("output_state")) + state_file = conf["output_state"].as(); + ofstream out(state_file.c_str(), ios::binary); + cerr << "Writing state to: " << state_file << endl; + o->Save(&out); + out.close(); + + cout << o->EvaluationCount() << "\t" << conv << endl; + return 0; +} diff --git a/training/optimize.cc b/training/optimize.cc new file mode 100644 index 00000000..5194752e --- /dev/null +++ b/training/optimize.cc @@ -0,0 +1,114 @@ +#include "optimize.h" + +#include +#include + +#include "lbfgs.h" + +using namespace std; + +Optimizer::~Optimizer() {} + +void Optimizer::Save(ostream* out) const { + out->write((const char*)&eval_, sizeof(eval_)); + out->write((const char*)&has_converged_, sizeof(has_converged_)); + SaveImpl(out); + unsigned int magic = 0xABCDDCBA; // should be uint32_t + out->write((const char*)&magic, sizeof(magic)); +} + +void Optimizer::Load(istream* in) { + in->read((char*)&eval_, sizeof(eval_)); + ++eval_; + in->read((char*)&has_converged_, sizeof(has_converged_)); + LoadImpl(in); + unsigned int magic = 0; // should be uint32_t + in->read((char*)&magic, sizeof(magic)); + assert(magic == 0xABCDDCBA); + cerr << Name() << " EVALUATION #" << eval_ << endl; +} + +void Optimizer::SaveImpl(ostream* out) const { + (void)out; +} + +void Optimizer::LoadImpl(istream* in) { + (void)in; +} + +string RPropOptimizer::Name() const { + return "RPropOptimizer"; +} + +void RPropOptimizer::OptimizeImpl(const double& obj, + const vector& g, + vector* x) { + for (int i = 0; i < g.size(); ++i) { + const double g_i = g[i]; + const double sign_i = (signbit(g_i) ? -1.0 : 1.0); + const double prod = g_i * prev_g_[i]; + if (prod > 0.0) { + const double dij = min(delta_ij_[i] * eta_plus_, delta_max_); + (*x)[i] -= dij * sign_i; + delta_ij_[i] = dij; + prev_g_[i] = g_i; + } else if (prod < 0.0) { + delta_ij_[i] = max(delta_ij_[i] * eta_minus_, delta_min_); + prev_g_[i] = 0.0; + } else { + (*x)[i] -= delta_ij_[i] * sign_i; + prev_g_[i] = g_i; + } + } +} + +void RPropOptimizer::SaveImpl(ostream* out) const { + const size_t n = prev_g_.size(); + out->write((const char*)&n, sizeof(n)); + out->write((const char*)&prev_g_[0], sizeof(double) * n); + out->write((const char*)&delta_ij_[0], sizeof(double) * n); +} + +void RPropOptimizer::LoadImpl(istream* in) { + size_t n; + in->read((char*)&n, sizeof(n)); + assert(n == prev_g_.size()); + assert(n == delta_ij_.size()); + in->read((char*)&prev_g_[0], sizeof(double) * n); + in->read((char*)&delta_ij_[0], sizeof(double) * n); +} + +string SGDOptimizer::Name() const { + return "SGDOptimizer"; +} + +void SGDOptimizer::OptimizeImpl(const double& obj, + const vector& g, + vector* x) { + (void)obj; + for (int i = 0; i < g.size(); ++i) + (*x)[i] -= g[i] * eta_; +} + +string LBFGSOptimizer::Name() const { + return "LBFGSOptimizer"; +} + +LBFGSOptimizer::LBFGSOptimizer(int num_feats, int memory_buffers) : + opt_(num_feats, memory_buffers) {} + +void LBFGSOptimizer::SaveImpl(ostream* out) const { + opt_.serialize(out); +} + +void LBFGSOptimizer::LoadImpl(istream* in) { + opt_.deserialize(in); +} + +void LBFGSOptimizer::OptimizeImpl(const double& obj, + const vector& g, + vector* x) { + opt_.run(&(*x)[0], obj, &g[0]); + cerr << opt_ << endl; +} + diff --git a/training/optimize.h b/training/optimize.h new file mode 100644 index 00000000..eddceaad --- /dev/null +++ b/training/optimize.h @@ -0,0 +1,104 @@ +#ifndef _OPTIMIZE_H_ +#define _OPTIMIZE_H_ + +#include +#include +#include +#include + +#include "lbfgs.h" + +// abstract base class for first order optimizers +// order of invocation: new, Load(), Optimize(), Save(), delete +class Optimizer { + public: + Optimizer() : eval_(1), has_converged_(false) {} + virtual ~Optimizer(); + virtual std::string Name() const = 0; + int EvaluationCount() const { return eval_; } + bool HasConverged() const { return has_converged_; } + + void Optimize(const double& obj, + const std::vector& g, + std::vector* x) { + assert(g.size() == x->size()); + OptimizeImpl(obj, g, x); + scitbx::lbfgs::traditional_convergence_test converged(g.size()); + has_converged_ = converged(&(*x)[0], &g[0]); + } + + void Save(std::ostream* out) const; + void Load(std::istream* in); + protected: + virtual void SaveImpl(std::ostream* out) const; + virtual void LoadImpl(std::istream* in); + virtual void OptimizeImpl(const double& obj, + const std::vector& g, + std::vector* x) = 0; + + int eval_; + private: + bool has_converged_; +}; + +class RPropOptimizer : public Optimizer { + public: + explicit RPropOptimizer(int num_vars, + double eta_plus = 1.2, + double eta_minus = 0.5, + double delta_0 = 0.1, + double delta_max = 50.0, + double delta_min = 1e-6) : + prev_g_(num_vars, 0.0), + delta_ij_(num_vars, delta_0), + eta_plus_(eta_plus), + eta_minus_(eta_minus), + delta_max_(delta_max), + delta_min_(delta_min) { + assert(eta_plus > 1.0); + assert(eta_minus > 0.0 && eta_minus < 1.0); + assert(delta_max > 0.0); + assert(delta_min > 0.0); + } + std::string Name() const; + void OptimizeImpl(const double& obj, + const std::vector& g, + std::vector* x); + void SaveImpl(std::ostream* out) const; + void LoadImpl(std::istream* in); + private: + std::vector prev_g_; + std::vector delta_ij_; + const double eta_plus_; + const double eta_minus_; + const double delta_max_; + const double delta_min_; +}; + +class SGDOptimizer : public Optimizer { + public: + explicit SGDOptimizer(int num_vars, double eta = 0.1) : eta_(eta) { + (void) num_vars; + } + std::string Name() const; + void OptimizeImpl(const double& obj, + const std::vector& g, + std::vector* x); + private: + const double eta_; +}; + +class LBFGSOptimizer : public Optimizer { + public: + explicit LBFGSOptimizer(int num_vars, int memory_buffers = 10); + std::string Name() const; + void SaveImpl(std::ostream* out) const; + void LoadImpl(std::istream* in); + void OptimizeImpl(const double& obj, + const std::vector& g, + std::vector* x); + private: + scitbx::lbfgs::minimizer opt_; +}; + +#endif diff --git a/training/optimize_test.cc b/training/optimize_test.cc new file mode 100644 index 00000000..0ada7cbb --- /dev/null +++ b/training/optimize_test.cc @@ -0,0 +1,105 @@ +#include +#include +#include +#include +#include "optimize.h" +#include "sparse_vector.h" +#include "fdict.h" + +using namespace std; + +double TestOptimizer(Optimizer* opt) { + cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; + + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + vector x(3); + vector g(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + do { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + opt->Optimize(obj, g, &x); + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + } while (!opt->HasConverged()); + return obj; +} + +double TestPersistentOptimizer(Optimizer* opt) { + cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + vector x(3); + vector g(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + string state; + bool converged = false; + while (!converged) { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + + { + if (state.size() > 0) { + istringstream is(state, ios::binary); + opt->Load(&is); + } + opt->Optimize(obj, g, &x); + ostringstream os(ios::binary); opt->Save(&os); state = os.str(); + + } + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + converged = opt->HasConverged(); + if (!converged) { + // now screw up the state (should be undone by Load) + obj += 2.0; + g[1] = -g[2]; + vector x2 = x; + try { + opt->Optimize(obj, g, &x2); + } catch (...) { } + } + } + return obj; +} + +template +void TestOptimizerVariants(int num_vars) { + O oa(num_vars); + cerr << "-------------------------------------------------------------------------\n"; + cerr << "TESTING: " << oa.Name() << endl; + double o1 = TestOptimizer(&oa); + O ob(num_vars); + double o2 = TestPersistentOptimizer(&ob); + if (o1 != o2) { + cerr << oa.Name() << " VARIANTS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl; + exit(1); + } + cerr << oa.Name() << " SUCCESS\n"; +} + +int main() { + int n = 3; + TestOptimizerVariants(n); + TestOptimizerVariants(n); + TestOptimizerVariants(n); + return 0; +} + diff --git a/training/plftools.cc b/training/plftools.cc new file mode 100644 index 00000000..903ec54f --- /dev/null +++ b/training/plftools.cc @@ -0,0 +1,93 @@ +#include +#include +#include + +#include +#include + +#include "filelib.h" +#include "tdict.h" +#include "prob.h" +#include "hg.h" +#include "hg_io.h" +#include "viterbi.h" +#include "kbest.h" + +namespace po = boost::program_options; +using namespace std; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input,i", po::value(), "REQ. Lattice input file (PLF), - for STDIN") + ("prior_scale,p", po::value()->default_value(1.0), "Scale path probabilities by this amount < 1 flattens, > 1 sharpens") + ("weight,w", po::value >(), "Weight(s) for arc features") + ("output,o", po::value()->default_value("plf"), "Output format (text, plf)") + ("command,c", po::value()->default_value("push"), "Operation to perform: push, graphviz, 1best, 2best ...") + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("input") == 0) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char **argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + string infile = conf["input"].as(); + ReadFile rf(infile); + istream* in = rf.stream(); + assert(*in); + SparseVector wts; + vector wv; + if (conf.count("weight") > 0) wv = conf["weight"].as >(); + if (wv.empty()) wv.push_back(1.0); + for (int i = 0; i < wv.size(); ++i) { + const string fname = "Feature_" + boost::lexical_cast(i); + cerr << "[INFO] Arc weight " << (i+1) << " = " << wv[i] << endl; + wts.set_value(FD::Convert(fname), wv[i]); + } + const string cmd = conf["command"].as(); + const bool push_weights = cmd == "push"; + const bool output_plf = cmd == "plf"; + const bool graphviz = cmd == "graphviz"; + const bool kbest = cmd.rfind("best") == (cmd.size() - 4) && cmd.size() > 4; + int k = 1; + if (kbest) { + k = boost::lexical_cast(cmd.substr(0, cmd.size() - 4)); + cerr << "KBEST = " << k << endl; + } + const double scale = conf["prior_scale"].as(); + int lc = 0; + while(*in) { + ++lc; + string plf; + getline(*in, plf); + if (plf.empty()) continue; + Hypergraph hg; + HypergraphIO::ReadFromPLF(plf, &hg); + hg.Reweight(wts); + if (graphviz) hg.PrintGraphviz(); + if (push_weights) hg.PushWeightsToSource(scale); + if (output_plf) { + cout << HypergraphIO::AsPLF(hg) << endl; + } else { + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cout << lc << " ||| " << TD::GetString(d->yield) << " ||| " << d->score << endl; + } + } + } + return 0; +} + diff --git a/vest/Makefile.am b/vest/Makefile.am new file mode 100644 index 00000000..87c2383a --- /dev/null +++ b/vest/Makefile.am @@ -0,0 +1,32 @@ +bin_PROGRAMS = \ + mr_vest_map \ + mr_vest_reduce \ + scorer_test \ + lo_test \ + mr_vest_generate_mapper_input \ + fast_score \ + union_forests + +union_forests_SOURCES = union_forests.cc +union_forests_LDADD = $(top_srcdir)/src/libhg.a -lz + +fast_score_SOURCES = fast_score.cc ter.cc comb_scorer.cc scorer.cc viterbi_envelope.cc +fast_score_LDADD = $(top_srcdir)/src/libhg.a -lz + +mr_vest_generate_mapper_input_SOURCES = mr_vest_generate_mapper_input.cc line_optimizer.cc +mr_vest_generate_mapper_input_LDADD = $(top_srcdir)/src/libhg.a -lz + +mr_vest_map_SOURCES = viterbi_envelope.cc error_surface.cc mr_vest_map.cc scorer.cc ter.cc comb_scorer.cc line_optimizer.cc +mr_vest_map_LDADD = $(top_srcdir)/src/libhg.a -lz + +mr_vest_reduce_SOURCES = error_surface.cc mr_vest_reduce.cc scorer.cc ter.cc comb_scorer.cc line_optimizer.cc viterbi_envelope.cc +mr_vest_reduce_LDADD = $(top_srcdir)/src/libhg.a -lz + +scorer_test_SOURCES = scorer_test.cc scorer.cc ter.cc comb_scorer.cc viterbi_envelope.cc +scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(top_srcdir)/src/libhg.a -lz + +lo_test_SOURCES = lo_test.cc scorer.cc ter.cc comb_scorer.cc viterbi_envelope.cc error_surface.cc line_optimizer.cc +lo_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(top_srcdir)/src/libhg.a -lz + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(BOOST_CPPFLAGS) $(GTEST_CPPFLAGS) -I$(top_srcdir)/src +AM_LDFLAGS = $(BOOST_LDFLAGS) $(BOOST_PROGRAM_OPTIONS_LIB) diff --git a/vest/comb_scorer.cc b/vest/comb_scorer.cc new file mode 100644 index 00000000..7b2187f4 --- /dev/null +++ b/vest/comb_scorer.cc @@ -0,0 +1,81 @@ +#include "comb_scorer.h" + +#include + +using namespace std; + +class BLEUTERCombinationScore : public Score { + friend class BLEUTERCombinationScorer; + public: + ~BLEUTERCombinationScore(); + float ComputeScore() const { + return (bleu->ComputeScore() - ter->ComputeScore()) / 2.0f; + } + void ScoreDetails(string* details) const { + char buf[160]; + sprintf(buf, "Combi = %.2f, BLEU = %.2f, TER = %.2f", + ComputeScore()*100.0f, bleu->ComputeScore()*100.0f, ter->ComputeScore()*100.0f); + *details = buf; + } + void PlusEquals(const Score& delta) { + bleu->PlusEquals(*static_cast(delta).bleu); + ter->PlusEquals(*static_cast(delta).ter); + } + Score* GetZero() const { + BLEUTERCombinationScore* res = new BLEUTERCombinationScore; + res->bleu = bleu->GetZero(); + res->ter = ter->GetZero(); + return res; + } + void Subtract(const Score& rhs, Score* res) const { + bleu->Subtract(*static_cast(rhs).bleu, + static_cast(res)->bleu); + ter->Subtract(*static_cast(rhs).ter, + static_cast(res)->ter); + } + void Encode(std::string* out) const { + string bs, ts; + bleu->Encode(&bs); + ter->Encode(&ts); + out->clear(); + (*out) += static_cast(bs.size()); + (*out) += bs; + (*out) += ts; + } + bool IsAdditiveIdentity() const { + return bleu->IsAdditiveIdentity() && ter->IsAdditiveIdentity(); + } + private: + Score* bleu; + Score* ter; +}; + +BLEUTERCombinationScore::~BLEUTERCombinationScore() { + delete bleu; + delete ter; +} + +BLEUTERCombinationScorer::BLEUTERCombinationScorer(const vector >& refs) { + bleu_ = SentenceScorer::CreateSentenceScorer(IBM_BLEU, refs); + ter_ = SentenceScorer::CreateSentenceScorer(TER, refs); +} + +BLEUTERCombinationScorer::~BLEUTERCombinationScorer() { + delete bleu_; + delete ter_; +} + +Score* BLEUTERCombinationScorer::ScoreCandidate(const std::vector& hyp) const { + BLEUTERCombinationScore* res = new BLEUTERCombinationScore; + res->bleu = bleu_->ScoreCandidate(hyp); + res->ter = ter_->ScoreCandidate(hyp); + return res; +} + +Score* BLEUTERCombinationScorer::ScoreFromString(const std::string& in) { + int bss = in[0]; + BLEUTERCombinationScore* r = new BLEUTERCombinationScore; + r->bleu = SentenceScorer::CreateScoreFromString(IBM_BLEU, in.substr(1, bss)); + r->ter = SentenceScorer::CreateScoreFromString(TER, in.substr(1 + bss)); + return r; +} diff --git a/vest/dist-vest.pl b/vest/dist-vest.pl new file mode 100755 index 00000000..5528838c --- /dev/null +++ b/vest/dist-vest.pl @@ -0,0 +1,642 @@ +#!/usr/bin/env perl + +use Getopt::Long; +use IPC::Open2; +use strict; +use POSIX ":sys_wait_h"; + +my $mydir = `dirname $0`; +chomp $mydir; +# Default settings +my $srcFile = "/fs/cliplab/mteval/Evaluation/Chinese-English/mt03.src.txt"; +my $refFiles = "/fs/cliplab/mteval/Evaluation/Chinese-English/mt03.ref.txt.*"; +my $bin_dir = "/fs/clip-software/cdec/bin"; +$bin_dir = "/Users/redpony/cdyer-svn-root/cdec/vest/bin_dir"; +die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; +my $FAST_SCORE="$bin_dir/fast_score"; +die "Can't find $FAST_SCORE" unless -x $FAST_SCORE; +my $MAPINPUT = "$bin_dir/mr_vest_generate_mapper_input"; +my $MAPPER = "$bin_dir/mr_vest_map"; +my $REDUCER = "$bin_dir/mr_vest_reduce"; +my $SCORER = $FAST_SCORE; +die "Can't find $MAPPER" unless -x $MAPPER; +my $forestUnion = "$bin_dir/union_forests"; +die "Can't find $forestUnion" unless -x $forestUnion; +my $cdec = "$bin_dir/cdec"; +die "Can't find decoder in $cdec" unless -x $cdec; +my $decoder = $cdec; +my $lines_per_mapper = 440; +my $rand_directions = 10; +my $iteration = 1; +my $run_local = 0; +my $best_weights; +my $max_iterations = 40; +my $optimization_iters = 6; +my $num_rand_points = 20; +my $mert_nodes = join(" ", grep(/^c\d\d$/, split(/\n/, `pbsnodes -a`))); # "1 1 1 1 1" fails due to file staging conflicts +my $decode_nodes = "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"; # start 15 jobs +my $pmem = "3g"; +my $disable_clean = 0; +my %seen_weights; +my $normalize; +my $help = 0; +my $epsilon = 0.0001; +my $interval = 5; +my $dryrun = 0; +my $ranges; +my $restart = 0; +my $metric = "ibm_bleu"; +my $dir; +my $iniFile; +my $weights; +my $initialWeights; +my $decoderOpt; + +# Process command-line options +Getopt::Long::Configure("no_auto_abbrev"); +if (GetOptions( + "decoder=s" => \$decoderOpt, + "decode-nodes=s" => \$decode_nodes, + "dont-clean" => \$disable_clean, + "dry-run" => \$dryrun, + "epsilon" => \$epsilon, + "help" => \$help, + "interval" => \$interval, + "iteration=i" => \$iteration, + "local" => \$run_local, + "max-iterations=i" => \$max_iterations, + "mert-nodes=s" => \$mert_nodes, + "normalize=s" => \$normalize, + "pmem=s" => \$pmem, + "ranges=s" => \$ranges, + "rand-directions=i" => \$rand_directions, + "ref-files=s" => \$refFiles, + "metric=s" => \$metric, + "restart" => \$restart, + "source-file=s" => \$srcFile, + "weights=s" => \$initialWeights, + "workdir=s" => \$dir +) == 0 || @ARGV!=1 || $help) { + print_help(); + exit; +} + +if ($metric =~ /^(combi|ter)$/i) { + $lines_per_mapper = 40; +} + +($iniFile) = @ARGV; + +sub write_config; +sub enseg; +sub print_help; + +my $nodelist; +my $host =`hostname`; chomp $host; +my $bleu; +my $best_bleu = 0.0; +my $interval_count = 0; +my $epsilon_bleu = $best_bleu + $epsilon; +my $logfile; +my $projected_score; + +my $refs_comma_sep = get_comma_sep_refs($refFiles); + +unless ($dir){ + $dir = "vest"; +} +unless ($dir =~ /^\//){ # convert relative path to absolute path + my $basedir = `pwd`; + chomp $basedir; + $dir = "$basedir/$dir"; +} +if ($restart){ + $iniFile = `ls $dir/*.ini`; chomp $iniFile; + unless (-e $iniFile){ + die "ERROR: Could not find ini file in $dir to restart\n"; + } + $logfile = "$dir/mert.log"; + open(LOGFILE, ">>$logfile"); + print LOGFILE "RESTARTING STOPPED OPTIMIZATION\n\n"; + + # figure out best weights so far and iteration number + open(LOG, "$dir/mert.log"); + my $wi = 0; + while (my $line = ){ + chomp $line; + if ($line =~ /ITERATION (\d+)/) { + $iteration = $1; + } + } + + $iteration = $wi + 1; +} + +if ($decoderOpt){ $decoder = $decoderOpt; } + + +# Initializations and helper functions +srand; + +my @childpids = (); +my @cleanupcmds = (); + +sub cleanup { + print STDERR "Cleanup...\n"; + for my $pid (@childpids){ `kill $pid`; } + for my $cmd (@cleanupcmds){`$cmd`; } + exit 1; +}; +$SIG{INT} = "cleanup"; +$SIG{TERM} = "cleanup"; +$SIG{HUP} = "cleanup"; + +my $decoderBase = `basename $decoder`; chomp $decoderBase; +my $newIniFile = "$dir/$decoderBase.ini"; +my $parallelize = "$mydir/parallelize.pl"; +my $inputFileName = "$dir/input"; +my $user = $ENV{"USER"}; + +# process ini file +-e $iniFile || die "Error: could not open $iniFile for reading\n"; +open(INI, $iniFile); + +if ($dryrun){ + write_config(*STDERR); + exit 0; +} else { + if (-e $dir){ + unless($restart){ + die "ERROR: working dir $dir already exists\n\n"; + } + } else { + mkdir $dir; + mkdir "$dir/hgs"; + mkdir "$dir/hgs-current"; + unless (-e $initialWeights) { + print STDERR "Please specify an initial weights file with --initial-weights\n"; + print_help(); + exit; + } + `cp $initialWeights $dir/weights.0`; + die "Can't find weights.0" unless (-e "$dir/weights.0"); + } + unless($restart){ + $logfile = "$dir/mert.log"; + open(LOGFILE, ">$logfile"); + } + write_config(*LOGFILE); +} + + +# Generate initial files and values +unless ($restart){ `cp $iniFile $newIniFile`; } +$iniFile = $newIniFile; + +my $newsrc = "$dir/dev.input"; +unless($restart){ enseg($srcFile, $newsrc); } +$srcFile = $newsrc; +my $devSize = 0; +open F, "<$srcFile" or die "Can't read $srcFile: $!"; +while() { $devSize++; } +close F; + +unless($best_weights){ $best_weights = $weights; } +unless($projected_score){ $projected_score = 0.0; } +$seen_weights{$weights} = 1; + +my $random_seed = int(time / 1000); +my $lastWeightsFile; +my $lastPScore = 0; +# main optimization loop +while (1){ + print LOGFILE "\n\nITERATION $iteration\n==========\n"; + + # iteration-specific files + my $runFile="$dir/run.raw.$iteration"; + my $onebestFile="$dir/1best.$iteration"; + my $logdir="$dir/logs.$iteration"; + my $decoderLog="$logdir/decoder.sentserver.log.$iteration"; + my $scorerLog="$logdir/scorer.log.$iteration"; + `mkdir -p $logdir`; + + #decode + print LOGFILE "DECODE\n"; + print LOGFILE `date`; + my $im1 = $iteration - 1; + my $weightsFile="$dir/weights.$im1"; + my $decoder_cmd = "$decoder -c $iniFile -w $weightsFile -O $dir/hgs-current"; + my $pcmd = "cat $srcFile | $parallelize -p $pmem -e $logdir -n \"$decode_nodes\" -- "; + if ($run_local) { $pcmd = "cat $srcFile |"; } + my $cmd = $pcmd . "$decoder_cmd 2> $decoderLog 1> $runFile"; + print LOGFILE "COMMAND:\n$cmd\n"; + my $result = 0; + $result = system($cmd); + unless ($result == 0){ + cleanup(); + print LOGFILE "ERROR: Parallel decoder returned non-zero exit code $result\n"; + die; + } + my $dec_score = `cat $runFile | $SCORER $refs_comma_sep -l $metric`; + chomp $dec_score; + print LOGFILE "DECODER SCORE: $dec_score\n"; + + # save space + `gzip $runFile`; + `gzip $decoderLog`; + + if ($iteration > $max_iterations){ + print LOGFILE "\nREACHED STOPPING CRITERION: Maximum iterations\n"; + last; + } + + # run optimizer + print LOGFILE "\nUNION FORESTS\n"; + print LOGFILE `date`; + my $mergeLog="$logdir/prune-merge.log.$iteration"; + $cmd = "$forestUnion -r $dir/hgs -n $dir/hgs-current -s $devSize"; + print LOGFILE "COMMAND:\n$cmd\n"; + $result = system($cmd); + unless ($result == 0){ + cleanup(); + print LOGFILE "ERROR: merge command returned non-zero exit code $result\n"; + die; + } + `rm -f $dir/hgs-current/*.json.gz`; # clean up old HGs, they've been moved to the repository + + my $score = 0; + my $icc = 0; + my $inweights="$dir/weights.$im1"; + for (my $opt_iter=1; $opt_iter<$optimization_iters; $opt_iter++) { + print LOGFILE "\nGENERATE OPTIMIZATION STRATEGY (OPT-ITERATION $opt_iter/$optimization_iters)\n"; + print LOGFILE `date`; + $icc++; + $cmd="$MAPINPUT -w $inweights -r $dir/hgs -s $devSize -d $rand_directions > $dir/agenda.$im1-$opt_iter"; + print LOGFILE "COMMAND:\n$cmd\n"; + $result = system($cmd); + unless ($result == 0){ + cleanup(); + print LOGFILE "ERROR: mapinput command returned non-zero exit code $result\n"; + die; + } + + `mkdir $dir/splag.$im1`; + $cmd="split -a 3 -l $lines_per_mapper $dir/agenda.$im1-$opt_iter $dir/splag.$im1/mapinput."; + print LOGFILE "COMMAND:\n$cmd\n"; + $result = system($cmd); + unless ($result == 0){ + cleanup(); + print LOGFILE "ERROR: split command returned non-zero exit code $result\n"; + die; + } + opendir(DIR, "$dir/splag.$im1") or die "Can't open directory: $!"; + my @shards = grep { /^mapinput\./ } readdir(DIR); + closedir DIR; + die "No shards!" unless scalar @shards > 0; + my $joblist = ""; + my $nmappers = 0; + my @mapoutputs = (); + @cleanupcmds = (); + my %o2i = (); + my $first_shard = 1; + for my $shard (@shards) { + my $mapoutput = $shard; + my $client_name = $shard; + $client_name =~ s/mapinput.//; + $client_name = "fmert.$client_name"; + $mapoutput =~ s/mapinput/mapoutput/; + push @mapoutputs, "$dir/splag.$im1/$mapoutput"; + $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard"; + my $script = "$MAPPER -l $metric $refs_comma_sep < $dir/splag.$im1/$shard | sort -k1 > $dir/splag.$im1/$mapoutput"; + if ($run_local) { + print LOGFILE "COMMAND:\n$script\n"; + $result = system($script); + unless ($result == 0){ + cleanup(); + print LOGFILE "ERROR: mapper returned non-zero exit code $result\n"; + die; + } + } else { + my $todo = "qsub -q batch -l pmem=3000mb,walltime=5:00:00 -N $client_name -o /dev/null -e $logdir/$client_name.ER"; + local(*QOUT, *QIN); + open2(\*QOUT, \*QIN, $todo); + print QIN $script; + if ($first_shard) { print LOGFILE "$script\n"; $first_shard=0; } + close QIN; + $nmappers++; + while (my $jobid=){ + chomp $jobid; + push(@cleanupcmds, "`qdel $jobid 2> /dev/null`"); + $jobid =~ s/^(\d+)(.*?)$/\1/g; + print STDERR "short job id $jobid\n"; + if ($joblist == "") { $joblist = $jobid; } + else {$joblist = $joblist . "\|" . $jobid; } + } + close QOUT; + } + } + if ($run_local) { + } else { + print LOGFILE "Launched $nmappers mappers.\n"; + print LOGFILE "Waiting for mappers to complete...\n"; + while ($nmappers > 0) { + sleep 5; + my @livejobs = grep(/$joblist/, split(/\n/, `qstat`)); + $nmappers = scalar @livejobs; + } + print LOGFILE "All mappers complete.\n"; + } + my $tol = 0; + my $til = 0; + for my $mo (@mapoutputs) { + my $olines = get_lines($mo); + my $ilines = get_lines($o2i{$mo}); + $tol += $olines; + $til += $ilines; + die "$mo: output lines ($olines) doesn't match input lines ($ilines)" unless $olines==$ilines; + } + print LOGFILE "Results for $tol/$til lines\n"; + print LOGFILE "\nSORTING AND RUNNING FMERT REDUCER\n"; + print LOGFILE `date`; + $cmd="sort -k1 @mapoutputs | $REDUCER > $dir/redoutput.$im1"; + print LOGFILE "COMMAND:\n$cmd\n"; + $result = system($cmd); + unless ($result == 0){ + cleanup(); + print LOGFILE "ERROR: reducer command returned non-zero exit code $result\n"; + die; + } + $cmd="sort -rnk3 '-t|' $dir/redoutput.$im1 | head -1"; + my $best=`$cmd`; chomp $best; + print LOGFILE "$best\n"; + my ($oa, $x, $xscore) = split /\|/, $best; + $score = $xscore; + print LOGFILE "PROJECTED SCORE: $score\n"; + if (abs($x) < $epsilon) { + print LOGFILE "\nOPTIMIZER: no score improvement: abs($x) < $epsilon\n"; + last; + } + my ($origin, $axis) = split /\s+/, $oa; + + my %ori = convert($origin); + my %axi = convert($axis); + + my $finalFile="$dir/weights.$im1-$opt_iter"; + open W, ">$finalFile" or die "Can't write: $finalFile: $!"; + for my $k (sort keys %ori) { + my $v = $ori{$k} + $axi{$k} * $x; + print W "$k $v\n"; + } + + `rm -rf $dir/splag.$im1`; + $inweights = $finalFile; + } + $lastWeightsFile = "$dir/weights.$iteration"; + `cp $inweights $lastWeightsFile`; + if ($icc < 2) { + print LOGFILE "\nREACHED STOPPING CRITERION: score change too little\n"; + last; + } + $lastPScore = $score; + $iteration++; + print LOGFILE "\n==========\n"; +} + +print LOGFILE "\nFINAL WEIGHTS: $dir/$lastWeightsFile\n(Use -w with hiero)\n\n"; + +sub normalize_weights { + my ($rfn, $rpts, $feat) = @_; + my @feat_names = @$rfn; + my @pts = @$rpts; + my $z = 1.0; + for (my $i=0; $i < scalar @feat_names; $i++) { + if ($feat_names[$i] eq $feat) { + $z = $pts[$i]; + last; + } + } + for (my $i=0; $i < scalar @feat_names; $i++) { + $pts[$i] /= $z; + } + print LOGFILE " NORM WEIGHTS: @pts\n"; + return @pts; +} + +sub get_lines { + my $fn = shift @_; + open FL, "<$fn" or die "Couldn't read $fn: $!"; + my $lc = 0; + while() { $lc++; } + return $lc; +} + +sub get_comma_sep_refs { + my ($p) = @_; + my $o = `echo $p`; + chomp $o; + my @files = split /\s+/, $o; + return "-r " . join(' -r ', @files); +} + +sub read_weights_file { + my ($file) = @_; + open F, "<$file" or die "Couldn't read $file: $!"; + my @r = (); + my $pm = -1; + while() { + next if /^#/; + next if /^\s*$/; + chomp; + if (/^(.+)\s+(.+)$/) { + my $m = $1; + my $w = $2; + die "Weights out of order: $m <= $pm" unless $m > $pm; + push @r, $w; + } else { + warn "Unexpected feature name in weight file: $_"; + } + } + close F; + return join ' ', @r; +} + +# subs +sub write_config { + my $fh = shift; + my $cleanup = "yes"; + if ($disable_clean) {$cleanup = "no";} + + print $fh "\n"; + print $fh "DECODER: $decoder\n"; + print $fh "INI FILE: $iniFile\n"; + print $fh "WORKING DIR: $dir\n"; + print $fh "SOURCE (DEV): $srcFile\n"; + print $fh "REFS (DEV): $refFiles\n"; + print $fh "EVAL METRIC: $metric\n"; + print $fh "START ITERATION: $iteration\n"; + print $fh "MAX ITERATIONS: $max_iterations\n"; + print $fh "MERT NODES: $mert_nodes\n"; + print $fh "DECODE NODES: $decode_nodes\n"; + print $fh "HEAD NODE: $host\n"; + print $fh "PMEM (DECODING): $pmem\n"; + print $fh "CLEANUP: $cleanup\n"; + print $fh "INITIAL WEIGHTS: $initialWeights\n"; + + if ($restart){ + print $fh "PROJECTED BLEU: $projected_score\n"; + print $fh "BEST WEIGHTS: $best_weights\n"; + print $fh "BEST BLEU: $best_bleu\n"; + } +} + +sub update_weights_file { + my ($neww, $rfn, $rpts) = @_; + my @feats = @$rfn; + my @pts = @$rpts; + my $num_feats = scalar @feats; + my $num_pts = scalar @pts; + die "$num_feats (num_feats) != $num_pts (num_pts)" unless $num_feats == $num_pts; + open G, ">$neww" or die; + for (my $i = 0; $i < $num_feats; $i++) { + my $f = $feats[$i]; + my $lambda = $pts[$i]; + print G "$f $lambda\n"; + } + close G; +} + +sub enseg { + my $src = shift; + my $newsrc = shift; + open(SRC, $src); + open(NEWSRC, ">$newsrc"); + my $i=0; + while (my $line=){ + chomp $line; + print NEWSRC "$line\n"; + $i++; + } + close SRC; + close NEWSRC; +} + +sub print_help { + + my $executable = `basename $0`; chomp $executable; + print << "Help"; + +Usage: $executable [options] + $executable --restart + + $executable [options] + Runs a complete MERT optimization and test set decoding, using + the decoder configuration in ini file. Note that many of the + options have default values that are inferred automatically + based on certain conventions. For details, refer to descriptions + of the options --decoder, --weights, and --workdir. + + $executable --restart + Continues an optimization run that was stopped for some reason, + using configuration information found in the working directory + left behind by the stopped run. + +Options: + + --local + Run the decoder and optimizer locally. + + --decoder + Decoder binary to use. + + --decode-nodes + A list of nodes used for parallel decoding. If specific nodes + are not desired, use "1" for each node requested. Defaults to + "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1", which indicates a request for + 15 nodes. + + --dont-clean + If present, this flag prevents intermediate files, including + run files and cumulative files, from being automatically removed + after a successful optimization run (these files are left if the + run fails for any reason). If used, a makefile containing + cleanup commands is written to the directory. To clean up + the intermediate files, invoke make without any arguments. + + --dry-run + Prints out the settings and exits without doing anything. + + --epsilon + Require that the dev set BLEU score improve by at least + within iterations (controlled by parameter --interval). + If not specified, defaults to .002. + + --help + Print this message and exit. + + --interval + Require that the dev set BLEU score improve by at least + (controlled by parameter --epsilon) within iterations. + If not specified, defaults to 5. + + --iteration + Starting iteration number. If not specified, defaults to 1. + + --max-iterations + Maximum number of iterations to run. If not specified, defaults + to 10. + + --pmem + Amount of physical memory requested for parallel decoding jobs, + in the format expected by qsub. If not specified, defaults to + 2g. + + --ref-files + Dev set ref files. This option takes only a single string argument. + To use multiple files (including file globbing), this argument should + be quoted. If not specified, defaults to + /fs/cliplab/mteval/Evaluation/Chinese-English/mt03.ref.txt.* + + --metric + Metric to optimize. See fmert's --metric option for values. + Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi + + --normalize + After each iteration, rescale all feature weights such that feature- + name has a weight of 1.0. + + --rand-directions + MERT will attempt to optimize along all of the principle directions, + set this parameter to explore other directions. Defaults to 5. + + --source-file + Dev set source file. If not specified, defaults to + /fs/cliplab/mteval/Evaluation/Chinese-English/mt03.src.txt + + --weights + A file specifying initial feature weights. The format is + FeatureName_1 value1 + FeatureName_2 value2 + + --workdir + Directory for intermediate and output files. If not specified, the + name is derived from the ini filename. Assuming that the ini + filename begins with the decoder name and ends with ini, the default + name of the working directory is inferred from the middle part of + the filename. E.g. an ini file named decoder.foo.ini would have + a default working directory name foo. + +Help +} + +sub convert { + my ($str) = @_; + my @ps = split /;/, $str; + my %dict = (); + for my $p (@ps) { + my ($k, $v) = split /=/, $p; + $dict{$k} = $v; + } + return %dict; +} + + diff --git a/vest/error_surface.cc b/vest/error_surface.cc new file mode 100644 index 00000000..4e0af35c --- /dev/null +++ b/vest/error_surface.cc @@ -0,0 +1,46 @@ +#include "error_surface.h" + +#include +#include + +using namespace std; + +ErrorSurface::~ErrorSurface() { + for (ErrorSurface::iterator i = begin(); i != end(); ++i) + //delete i->delta; + ; +} + +void ErrorSurface::Serialize(std::string* out) const { + const int segments = this->size(); + ostringstream os(ios::binary); + os.write((const char*)&segments,sizeof(segments)); + for (int i = 0; i < segments; ++i) { + const ErrorSegment& cur = (*this)[i]; + string senc; + cur.delta->Encode(&senc); + assert(senc.size() < 256); + unsigned char len = senc.size(); + os.write((const char*)&cur.x, sizeof(cur.x)); + os.write((const char*)&len, sizeof(len)); + os.write((const char*)&senc[0], len); + } + *out = os.str(); +} + +void ErrorSurface::Deserialize(ScoreType type, const std::string& in) { + istringstream is(in, ios::binary); + int segments; + is.read((char*)&segments, sizeof(segments)); + this->resize(segments); + for (int i = 0; i < segments; ++i) { + ErrorSegment& cur = (*this)[i]; + unsigned char len; + is.read((char*)&cur.x, sizeof(cur.x)); + is.read((char*)&len, sizeof(len)); + string senc(len, '\0'); assert(senc.size() == len); + is.read((char*)&senc[0], len); + cur.delta = SentenceScorer::CreateScoreFromString(type, senc); + } +} + diff --git a/vest/fast_score.cc b/vest/fast_score.cc new file mode 100644 index 00000000..45b60d78 --- /dev/null +++ b/vest/fast_score.cc @@ -0,0 +1,74 @@ +#include +#include + +#include +#include + +#include "filelib.h" +#include "tdict.h" +#include "scorer.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("reference,r",po::value >(), "[REQD] Reference translation(s) (tokenized text file)") + ("loss_function,l",po::value()->default_value("ibm_bleu"), "Scoring metric (ibm_bleu, nist_bleu, koehn_bleu, ter, combi)") + ("in_file,i", po::value()->default_value("-"), "Input file") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (!conf->count("reference")) { + cerr << "Please specify one or more references using -r -r ...\n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string loss_function = conf["loss_function"].as(); + ScoreType type = ScoreTypeFromString(loss_function); + DocScorer ds(type, conf["reference"].as >()); + cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function << endl; + + ReadFile rf(conf["in_file"].as()); + Score* acc = NULL; + istream& in = *rf.stream(); + int lc = 0; + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + vector sent; + TD::ConvertSentence(line, &sent); + Score* sentscore = ds[lc]->ScoreCandidate(sent); + if (!acc) { acc = sentscore->GetZero(); } + acc->PlusEquals(*sentscore); + delete sentscore; + ++lc; + } + assert(lc > 0); + if (lc > ds.size()) { + cerr << "Too many (" << lc << ") translations in input, expected " << ds.size() << endl; + return 1; + } + if (lc != ds.size()) + cerr << "Fewer sentences in hyp (" << lc << ") than refs (" + << ds.size() << "): scoring partial set!\n"; + float score = acc->ComputeScore(); + string details; + acc->ScoreDetails(&details); + delete acc; + cerr << details << endl; + cout << score << endl; + return 0; +} diff --git a/vest/line_optimizer.cc b/vest/line_optimizer.cc new file mode 100644 index 00000000..98dcec34 --- /dev/null +++ b/vest/line_optimizer.cc @@ -0,0 +1,101 @@ +#include "line_optimizer.h" + +#include +#include + +#include "sparse_vector.h" +#include "scorer.h" + +using namespace std; + +typedef ErrorSurface::const_iterator ErrorIter; + +// sort by increasing x-ints +struct IntervalComp { + bool operator() (const ErrorIter& a, const ErrorIter& b) const { + return a->x < b->x; + } +}; + +double LineOptimizer::LineOptimize( + const vector& surfaces, + const LineOptimizer::ScoreType type, + float* best_score, + const double epsilon) { + vector all_ints; + for (vector::const_iterator i = surfaces.begin(); + i != surfaces.end(); ++i) { + const ErrorSurface& surface = *i; + for (ErrorIter j = surface.begin(); j != surface.end(); ++j) + all_ints.push_back(j); + } + sort(all_ints.begin(), all_ints.end(), IntervalComp()); + double last_boundary = all_ints.front()->x; + Score* acc = all_ints.front()->delta->GetZero(); + float& cur_best_score = *best_score; + cur_best_score = (type == MAXIMIZE_SCORE ? + -numeric_limits::max() : numeric_limits::max()); + bool left_edge = true; + double pos = numeric_limits::quiet_NaN(); + for (vector::iterator i = all_ints.begin(); + i != all_ints.end(); ++i) { + const ErrorSegment& seg = **i; + assert(seg.delta); + if (seg.x - last_boundary > epsilon) { + float sco = acc->ComputeScore(); + if ((type == MAXIMIZE_SCORE && sco > cur_best_score) || + (type == MINIMIZE_SCORE && sco < cur_best_score) ) { + cur_best_score = sco; + if (left_edge) { + pos = seg.x - 0.1; + left_edge = false; + } else { + pos = last_boundary + (seg.x - last_boundary) / 2; + } + // cerr << "NEW BEST: " << pos << " (score=" << cur_best_score << ")\n"; + } + // cerr << "---- s=" << sco << "\n"; + last_boundary = seg.x; + } + // cerr << "x-boundary=" << seg.x << "\n"; + acc->PlusEquals(*seg.delta); + } + float sco = acc->ComputeScore(); + if ((type == MAXIMIZE_SCORE && sco > cur_best_score) || + (type == MINIMIZE_SCORE && sco < cur_best_score) ) { + cur_best_score = sco; + if (left_edge) { + pos = 0; + } else { + pos = last_boundary + 1000.0; + } + } + delete acc; + return pos; +} + +void LineOptimizer::RandomUnitVector(const vector& features_to_optimize, + SparseVector* axis, + RandomNumberGenerator* rng) { + axis->clear(); + for (int i = 0; i < features_to_optimize.size(); ++i) + axis->set_value(features_to_optimize[i], rng->next() - 0.5); + (*axis) /= axis->l2norm(); +} + +void LineOptimizer::CreateOptimizationDirections( + const vector& features_to_optimize, + int additional_random_directions, + RandomNumberGenerator* rng, + vector >* dirs) { + const int num_directions = features_to_optimize.size() + additional_random_directions; + dirs->resize(num_directions); + for (int i = 0; i < num_directions; ++i) { + SparseVector& axis = (*dirs)[i]; + if (i < features_to_optimize.size()) + axis.set_value(features_to_optimize[i], 1.0); + else + RandomUnitVector(features_to_optimize, &axis, rng); + } + cerr << "Generated " << num_directions << " total axes to optimize along.\n"; +} diff --git a/vest/lo_test.cc b/vest/lo_test.cc new file mode 100644 index 00000000..0acae5e0 --- /dev/null +++ b/vest/lo_test.cc @@ -0,0 +1,201 @@ +#include +#include +#include + +#include +#include + +#include "fdict.h" +#include "hg.h" +#include "kbest.h" +#include "hg_io.h" +#include "filelib.h" +#include "inside_outside.h" +#include "viterbi.h" +#include "viterbi_envelope.h" +#include "line_optimizer.h" +#include "scorer.h" + +using namespace std; +using boost::shared_ptr; + +class OptTest : public testing::Test { + protected: + virtual void SetUp() { } + virtual void TearDown() { } +}; + +const char* ref11 = "australia reopens embassy in manila"; +const char* ref12 = "( afp , manila , january 2 ) australia reopened its embassy in the philippines today , which was shut down about seven weeks ago due to what was described as a specific threat of a terrorist attack ."; +const char* ref21 = "australia reopened manila embassy"; +const char* ref22 = "( agence france-presse , manila , 2nd ) - australia reopened its embassy in the philippines today . the embassy was closed seven weeks ago after what was described as a specific threat of a terrorist attack ."; +const char* ref31 = "australia to reopen embassy in manila"; +const char* ref32 = "( afp report from manila , january 2 ) australia reopened its embassy in the philippines today . seven weeks ago , the embassy was shut down due to so - called confirmed terrorist attack threats ."; +const char* ref41 = "australia to re - open its embassy to manila"; +const char* ref42 = "( afp , manila , thursday ) australia reopens its embassy to manila , which was closed for the so - called \" clear \" threat of terrorist attack 7 weeks ago ."; + +TEST_F(OptTest, TestCheckNaN) { + double x = 0; + double y = 0; + double z = x / y; + EXPECT_EQ(true, isnan(z)); +} + +TEST_F(OptTest,TestViterbiEnvelope) { + shared_ptr a1(new Segment(-1, 0)); + shared_ptr b1(new Segment(1, 0)); + shared_ptr a2(new Segment(-1, 1)); + shared_ptr b2(new Segment(1, -1)); + vector > sa; sa.push_back(a1); sa.push_back(b1); + vector > sb; sb.push_back(a2); sb.push_back(b2); + ViterbiEnvelope a(sa); + cerr << a << endl; + ViterbiEnvelope b(sb); + ViterbiEnvelope c = a; + c *= b; + cerr << a << " (*) " << b << " = " << c << endl; + EXPECT_EQ(3, c.size()); +} + +TEST_F(OptTest,TestViterbiEnvelopeInside) { + const string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}"; + Hypergraph hg; + istringstream instr(json); + HypergraphIO::ReadFromJSON(&instr, &hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 1.0); + hg.Reweight(wts); + vector, prob_t> > list; + std::vector > features; + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 10); + for (int i = 0; i < 10; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + SparseVector dir; dir.set_value(FD::Convert("f1"), 1.0); + ViterbiEnvelopeWeightFunction wf(wts, dir); + ViterbiEnvelope env = Inside(hg, NULL, wf); + cerr << env << endl; + const vector >& segs = env.GetSortedSegs(); + dir *= segs[1]->x; + wts += dir; + hg.Reweight(wts); + KBest::KBestDerivations, ESentenceTraversal> kbest2(hg, 10); + for (int i = 0; i < 10; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest2.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + for (int i = 0; i < segs.size(); ++i) { + cerr << "seg=" << i << endl; + vector trans; + segs[i]->ConstructTranslation(&trans); + cerr << TD::GetString(trans) << endl; + } +} + +TEST_F(OptTest, TestS1) { + int fPhraseModel_0 = FD::Convert("PhraseModel_0"); + int fPhraseModel_1 = FD::Convert("PhraseModel_1"); + int fPhraseModel_2 = FD::Convert("PhraseModel_2"); + int fLanguageModel = FD::Convert("LanguageModel"); + int fWordPenalty = FD::Convert("WordPenalty"); + int fPassThrough = FD::Convert("PassThrough"); + SparseVector wts; + wts.set_value(fWordPenalty, 4.25); + wts.set_value(fLanguageModel, -1.1165); + wts.set_value(fPhraseModel_0, -0.96); + wts.set_value(fPhraseModel_1, -0.65); + wts.set_value(fPhraseModel_2, -0.77); + wts.set_value(fPassThrough, -10.0); + + vector to_optimize; + to_optimize.push_back(fWordPenalty); + to_optimize.push_back(fLanguageModel); + to_optimize.push_back(fPhraseModel_0); + to_optimize.push_back(fPhraseModel_1); + to_optimize.push_back(fPhraseModel_2); + + Hypergraph hg; + ReadFile rf("./test_data/0.json.gz"); + HypergraphIO::ReadFromJSON(rf.stream(), &hg); + hg.Reweight(wts); + + Hypergraph hg2; + ReadFile rf2("./test_data/1.json.gz"); + HypergraphIO::ReadFromJSON(rf2.stream(), &hg2); + hg2.Reweight(wts); + + vector > refs1(4); + TD::ConvertSentence(ref11, &refs1[0]); + TD::ConvertSentence(ref21, &refs1[1]); + TD::ConvertSentence(ref31, &refs1[2]); + TD::ConvertSentence(ref41, &refs1[3]); + vector > refs2(4); + TD::ConvertSentence(ref12, &refs2[0]); + TD::ConvertSentence(ref22, &refs2[1]); + TD::ConvertSentence(ref32, &refs2[2]); + TD::ConvertSentence(ref42, &refs2[3]); + ScoreType type = ScoreTypeFromString("ibm_bleu"); + SentenceScorer* scorer1 = SentenceScorer::CreateSentenceScorer(type, refs1); + SentenceScorer* scorer2 = SentenceScorer::CreateSentenceScorer(type, refs2); + vector envs(2); + + RandomNumberGenerator rng; + + vector > axes; + LineOptimizer::CreateOptimizationDirections( + to_optimize, + 10, + &rng, + &axes); + assert(axes.size() == 10 + to_optimize.size()); + for (int i = 0; i < axes.size(); ++i) + cerr << axes[i] << endl; + const SparseVector& axis = axes[0]; + + cerr << "Computing Viterbi envelope using inside algorithm...\n"; + cerr << "axis: " << axis << endl; + clock_t t_start=clock(); + ViterbiEnvelopeWeightFunction wf(wts, axis); + envs[0] = Inside(hg, NULL, wf); + envs[1] = Inside(hg2, NULL, wf); + + vector es(2); + scorer1->ComputeErrorSurface(envs[0], &es[0]); + scorer2->ComputeErrorSurface(envs[1], &es[1]); + cerr << envs[0].size() << " " << envs[1].size() << endl; + cerr << es[0].size() << " " << es[1].size() << endl; + envs.clear(); + clock_t t_env=clock(); + float score; + double m = LineOptimizer::LineOptimize(es, LineOptimizer::MAXIMIZE_SCORE, &score); + clock_t t_opt=clock(); + cerr << "line optimizer returned: " << m << " (SCORE=" << score << ")\n"; + EXPECT_FLOAT_EQ(0.48719698, score); + SparseVector res = axis; + res *= m; + res += wts; + cerr << "res: " << res << endl; + cerr << "ENVELOPE PROCESSING=" << (static_cast(t_env - t_start) / 1000.0) << endl; + cerr << " LINE OPTIMIZATION=" << (static_cast(t_opt - t_env) / 1000.0) << endl; + hg.Reweight(res); + hg2.Reweight(res); + vector t1,t2; + ViterbiESentence(hg, &t1); + ViterbiESentence(hg2, &t2); + cerr << TD::GetString(t1) << endl; + cerr << TD::GetString(t2) << endl; + delete scorer1; + delete scorer2; +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc new file mode 100644 index 00000000..c96a61e4 --- /dev/null +++ b/vest/mr_vest_generate_mapper_input.cc @@ -0,0 +1,72 @@ +#include +#include + +#include +#include + +#include "filelib.h" +#include "weights.h" +#include "line_optimizer.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("dev_set_size,s",po::value(),"[REQD] Development set size (# of parallel sentences)") + ("forest_repository,r",po::value(),"[REQD] Path to forest repository") + ("weights,w",po::value(),"[REQD] Current feature weights file") + ("optimize_feature,o",po::value >(), "Feature to optimize (if none specified, all weights listed in the weights file will be optimized)") + ("random_directions,d",po::value()->default_value(20),"Number of random directions to run the line optimizer in") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (conf->count("dev_set_size") == 0) { + cerr << "Please specify the size of the development set using -d N\n"; + flag = true; + } + if (conf->count("weights") == 0) { + cerr << "Please specify the starting-point weights using -w \n"; + flag = true; + } + if (conf->count("forest_repository") == 0) { + cerr << "Please specify the forest repository location using -r \n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + RandomNumberGenerator rng; + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + Weights weights; + vector features; + weights.InitFromFile(conf["weights"].as(), &features); + const string forest_repository = conf["forest_repository"].as(); + assert(DirectoryExists(forest_repository)); + SparseVector origin; + weights.InitSparseVector(&origin); + if (conf.count("optimize_feature") > 0) + features=conf["optimize_feature"].as >(); + vector > axes; + vector fids(features.size()); + for (int i = 0; i < features.size(); ++i) + fids[i] = FD::Convert(features[i]); + LineOptimizer::CreateOptimizationDirections( + fids, + conf["random_directions"].as(), + &rng, + &axes); + int dev_set_size = conf["dev_set_size"].as(); + for (int i = 0; i < dev_set_size; ++i) + for (int j = 0; j < axes.size(); ++j) + cout << forest_repository << '/' << i << ".json.gz " << i << ' ' << origin << ' ' << axes[j] << endl; + return 0; +} diff --git a/vest/mr_vest_map.cc b/vest/mr_vest_map.cc new file mode 100644 index 00000000..80e84218 --- /dev/null +++ b/vest/mr_vest_map.cc @@ -0,0 +1,98 @@ +#include +#include +#include +#include + +#include +#include + +#include "filelib.h" +#include "stringlib.h" +#include "sparse_vector.h" +#include "scorer.h" +#include "viterbi_envelope.h" +#include "inside_outside.h" +#include "error_surface.h" +#include "hg.h" +#include "hg_io.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") + ("loss_function,l",po::value()->default_value("ibm_bleu"), "Loss function being optimized") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (!conf->count("reference")) { + cerr << "Please specify one or more references using -r \n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +bool ReadSparseVectorString(const string& s, SparseVector* v) { + vector fields; + Tokenize(s, ';', &fields); + if (fields.empty()) return false; + for (int i = 0; i < fields.size(); ++i) { + vector pair(2); + Tokenize(fields[i], '=', &pair); + if (pair.size() != 2) { + cerr << "Error parsing vector string: " << fields[i] << endl; + return false; + } + v->set_value(FD::Convert(pair[0]), atof(pair[1].c_str())); + } + return true; +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string loss_function = conf["loss_function"].as(); + ScoreType type = ScoreTypeFromString(loss_function); + DocScorer ds(type, conf["reference"].as >()); + cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function << endl; + Hypergraph hg; + string last_file; + while(cin) { + string line; + getline(cin, line); + if (line.empty()) continue; + istringstream is(line); + int sent_id; + string file, s_origin, s_axis; + is >> file >> sent_id >> s_origin >> s_axis; + SparseVector origin; + assert(ReadSparseVectorString(s_origin, &origin)); + SparseVector axis; + assert(ReadSparseVectorString(s_axis, &axis)); + // cerr << "File: " << file << "\nAxis: " << axis << "\n X: " << origin << endl; + if (last_file != file) { + last_file = file; + ReadFile rf(file); + HypergraphIO::ReadFromJSON(rf.stream(), &hg); + } + ViterbiEnvelopeWeightFunction wf(origin, axis); + ViterbiEnvelope ve = Inside(hg, NULL, wf); + ErrorSurface es; + ds[sent_id]->ComputeErrorSurface(ve, &es); + //cerr << "Viterbi envelope has " << ve.size() << " segments\n"; + cerr << "Error surface has " << es.size() << " segments\n"; + string val; + es.Serialize(&val); + cout << 'M' << ' ' << s_origin << ' ' << s_axis << '\t'; + B64::b64encode(val.c_str(), val.size(), &cout); + cout << endl; + } + return 0; +} diff --git a/vest/mr_vest_reduce.cc b/vest/mr_vest_reduce.cc new file mode 100644 index 00000000..c1347065 --- /dev/null +++ b/vest/mr_vest_reduce.cc @@ -0,0 +1,80 @@ +#include +#include +#include +#include + +#include +#include + +#include "sparse_vector.h" +#include "error_surface.h" +#include "line_optimizer.h" +#include "hg_io.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("loss_function,l",po::value()->default_value("ibm_bleu"), "Loss function being optimized") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string loss_function = conf["loss_function"].as(); + ScoreType type = ScoreTypeFromString(loss_function); + LineOptimizer::ScoreType opt_type = LineOptimizer::MAXIMIZE_SCORE; + if (type == TER) + opt_type = LineOptimizer::MINIMIZE_SCORE; + string last_key; + vector esv; + while(cin) { + string line; + getline(cin, line); + if (line.empty()) continue; + size_t ks = line.find("\t"); + assert(string::npos != ks); + assert(ks > 2); + string key = line.substr(2, ks - 2); + string val = line.substr(ks + 1); + if (key != last_key) { + if (!last_key.empty()) { + float score; + double x = LineOptimizer::LineOptimize(esv, opt_type, &score); + cout << last_key << "|" << x << "|" << score << endl; + } + last_key = key; + esv.clear(); + } + if (val.size() % 4 != 0) { + cerr << "B64 encoding error 1! Skipping.\n"; + continue; + } + string encoded(val.size() / 4 * 3, '\0'); + if (!B64::b64decode(reinterpret_cast(&val[0]), val.size(), &encoded[0], encoded.size())) { + cerr << "B64 encoding error 2! Skipping.\n"; + continue; + } + esv.push_back(ErrorSurface()); + esv.back().Deserialize(type, encoded); + } + if (!esv.empty()) { + cerr << "ESV=" << esv.size() << endl; + for (int i = 0; i < esv.size(); ++i) { cerr << esv[i].size() << endl; } + float score; + double x = LineOptimizer::LineOptimize(esv, opt_type, &score); + cout << last_key << "|" << x << "|" << score << endl; + } + return 0; +} diff --git a/vest/scorer.cc b/vest/scorer.cc new file mode 100644 index 00000000..e242bb46 --- /dev/null +++ b/vest/scorer.cc @@ -0,0 +1,485 @@ +#include "scorer.h" + +#include +#include +#include +#include +#include +#include + +#include + +#include "viterbi_envelope.h" +#include "error_surface.h" +#include "ter.h" +#include "comb_scorer.h" +#include "tdict.h" +#include "stringlib.h" + +using boost::shared_ptr; +using namespace std; + +const bool minimize_segments = true; // if adjacent segments have equal scores, merge them + +ScoreType ScoreTypeFromString(const std::string& st) { + const string sl = LowercaseString(st); + if (sl == "ser") + return SER; + if (sl == "ter") + return TER; + if (sl == "bleu" || sl == "ibm_bleu") + return IBM_BLEU; + if (sl == "nist_bleu") + return NIST_BLEU; + if (sl == "koehn_bleu") + return Koehn_BLEU; + if (sl == "combi") + return BLEU_minus_TER_over_2; + cerr << "Don't understand score type '" << sl << "', defaulting to ibm_bleu.\n"; + return IBM_BLEU; +} + +Score::~Score() {} +SentenceScorer::~SentenceScorer() {} + +class SERScore : public Score { + friend class SERScorer; + public: + SERScore() : correct(0), total(0) {} + float ComputeScore() const { + return static_cast(correct) / static_cast(total); + } + void ScoreDetails(string* details) const { + ostringstream os; + os << "SER= " << ComputeScore() << " (" << correct << '/' << total << ')'; + *details = os.str(); + } + void PlusEquals(const Score& delta) { + correct += static_cast(delta).correct; + total += static_cast(delta).total; + } + Score* GetZero() const { return new SERScore; } + void Subtract(const Score& rhs, Score* res) const { + SERScore* r = static_cast(res); + r->correct = correct - static_cast(rhs).correct; + r->total = total - static_cast(rhs).total; + } + void Encode(std::string* out) const { + assert(!"not implemented"); + } + bool IsAdditiveIdentity() const { + return (total == 0 && correct == 0); // correct is always 0 <= n <= total + } + private: + int correct, total; +}; + +class SERScorer : public SentenceScorer { + public: + SERScorer(const vector >& references) : refs_(references) {} + Score* ScoreCandidate(const std::vector& hyp) const { + SERScore* res = new SERScore; + res->total = 1; + for (int i = 0; i < refs_.size(); ++i) + if (refs_[i] == hyp) res->correct = 1; + return res; + } + static Score* ScoreFromString(const std::string& data) { + assert(!"Not implemented"); + } + private: + vector > refs_; +}; + +class BLEUScore : public Score { + friend class BLEUScorerBase; + public: + BLEUScore(int n) : correct_ngram_hit_counts(0,n), hyp_ngram_counts(0,n) { + ref_len = 0; + hyp_len = 0; } + float ComputeScore() const; + void ScoreDetails(string* details) const; + void PlusEquals(const Score& delta); + Score* GetZero() const; + void Subtract(const Score& rhs, Score* res) const; + void Encode(std::string* out) const; + bool IsAdditiveIdentity() const { + if (fabs(ref_len) > 0.1f || hyp_len != 0) return false; + for (int i = 0; i < correct_ngram_hit_counts.size(); ++i) + if (hyp_ngram_counts[i] != 0 || + correct_ngram_hit_counts[i] != 0) return false; + return true; + } + private: + float ComputeScore(vector* precs, float* bp) const; + valarray correct_ngram_hit_counts; + valarray hyp_ngram_counts; + float ref_len; + int hyp_len; +}; + +class BLEUScorerBase : public SentenceScorer { + public: + BLEUScorerBase(const std::vector >& references, + int n + ); + Score* ScoreCandidate(const std::vector& hyp) const; + static Score* ScoreFromString(const std::string& in); + + protected: + virtual float ComputeRefLength(const vector& hyp) const = 0; + private: + struct NGramCompare { + int operator() (const vector& a, const vector& b) { + size_t as = a.size(); + size_t bs = b.size(); + const size_t s = (as < bs ? as : bs); + for (size_t i = 0; i < s; ++i) { + int d = a[i] - b[i]; + if (d < 0) return true; + if (d > 0) return false; + } + return as < bs; + } + }; + typedef map, pair, NGramCompare> NGramCountMap; + void CountRef(const vector& ref) { + NGramCountMap tc; + vector ngram(n_); + int s = ref.size(); + for (int j=0; j& p = ngrams_[i->first]; + if (p.first < i->second.first) + p = i->second; + } + } + + void ComputeNgramStats(const vector& sent, + valarray* correct, + valarray* hyp) const { + assert(correct->size() == n_); + assert(hyp->size() == n_); + vector ngram(n_); + (*correct) *= 0; + (*hyp) *= 0; + int s = sent.size(); + for (int j=0; j& p = ngrams_[ngram]; + if (p.second < p.first) { + ++p.second; + (*correct)[i-1]++; + } + // if the 1 gram isn't found, don't try to match don't need to match any 2- 3- .. grams: + if (!p.first) { + for (; i<=k; ++i) + (*hyp)[i-1]++; + } else { + (*hyp)[i-1]++; + } + } + } + } + + mutable NGramCountMap ngrams_; + int n_; + vector lengths_; +}; + +Score* BLEUScorerBase::ScoreFromString(const std::string& in) { + istringstream is(in); + int n; + is >> n; + BLEUScore* r = new BLEUScore(n); + is >> r->ref_len >> r->hyp_len; + + for (int i = 0; i < n; ++i) { + is >> r->correct_ngram_hit_counts[i]; + is >> r->hyp_ngram_counts[i]; + } + return r; +} + +class IBM_BLEUScorer : public BLEUScorerBase { + public: + IBM_BLEUScorer(const std::vector >& references, + int n=4) : BLEUScorerBase(references, n), lengths_(references.size()) { + for (int i=0; i < references.size(); ++i) + lengths_[i] = references[i].size(); + } + protected: + float ComputeRefLength(const vector& hyp) const { + if (lengths_.size() == 1) return lengths_[0]; + int bestd = 2000000; + int hl = hyp.size(); + int bl = -1; + for (vector::const_iterator ci = lengths_.begin(); ci != lengths_.end(); ++ci) { + int cl = *ci; + if (abs(cl - hl) < bestd) { + bestd = abs(cl - hl); + bl = cl; + } + } + return bl; + } + private: + vector lengths_; +}; + +class NIST_BLEUScorer : public BLEUScorerBase { + public: + NIST_BLEUScorer(const std::vector >& references, + int n=4) : BLEUScorerBase(references, n), + shortest_(references[0].size()) { + for (int i=1; i < references.size(); ++i) + if (references[i].size() < shortest_) + shortest_ = references[i].size(); + } + protected: + float ComputeRefLength(const vector& hyp) const { + return shortest_; + } + private: + float shortest_; +}; + +class Koehn_BLEUScorer : public BLEUScorerBase { + public: + Koehn_BLEUScorer(const std::vector >& references, + int n=4) : BLEUScorerBase(references, n), + avg_(0) { + for (int i=0; i < references.size(); ++i) + avg_ += references[i].size(); + avg_ /= references.size(); + } + protected: + float ComputeRefLength(const vector& hyp) const { + return avg_; + } + private: + float avg_; +}; + +SentenceScorer* SentenceScorer::CreateSentenceScorer(const ScoreType type, + const std::vector >& refs) { + switch (type) { + case IBM_BLEU: return new IBM_BLEUScorer(refs, 4); + case NIST_BLEU: return new NIST_BLEUScorer(refs, 4); + case Koehn_BLEU: return new Koehn_BLEUScorer(refs, 4); + case TER: return new TERScorer(refs); + case SER: return new SERScorer(refs); + case BLEU_minus_TER_over_2: return new BLEUTERCombinationScorer(refs); + default: + assert(!"Not implemented!"); + } +} + +Score* SentenceScorer::CreateScoreFromString(const ScoreType type, const std::string& in) { + switch (type) { + case IBM_BLEU: + case NIST_BLEU: + case Koehn_BLEU: + return BLEUScorerBase::ScoreFromString(in); + case TER: + return TERScorer::ScoreFromString(in); + case SER: + return SERScorer::ScoreFromString(in); + case BLEU_minus_TER_over_2: + return BLEUTERCombinationScorer::ScoreFromString(in); + default: + assert(!"Not implemented!"); + } +} + +void SentenceScorer::ComputeErrorSurface(const ViterbiEnvelope& ve, ErrorSurface* env) const { + vector prev_trans; + const vector >& ienv = ve.GetSortedSegs(); + env->resize(ienv.size()); + Score* prev_score = NULL; + int j = 0; + for (int i = 0; i < ienv.size(); ++i) { + const Segment& seg = *ienv[i]; + vector trans; + seg.ConstructTranslation(&trans); + // cerr << "Scoring: " << TD::GetString(trans) << endl; + if (trans == prev_trans) { + if (!minimize_segments) { + assert(prev_score); // if this fails, it means + // the decoder can generate null translations + ErrorSegment& out = (*env)[j]; + out.delta = prev_score->GetZero(); + out.x = seg.x; + ++j; + } + // cerr << "Identical translation, skipping scoring\n"; + } else { + Score* score = ScoreCandidate(trans); + // cerr << "score= " << score->ComputeScore() << "\n"; + Score* cur_delta = score->GetZero(); + // just record the score diffs + if (!prev_score) + prev_score = score->GetZero(); + + score->Subtract(*prev_score, cur_delta); + delete prev_score; + prev_trans.swap(trans); + prev_score = score; + if ((!minimize_segments) || (!cur_delta->IsAdditiveIdentity())) { + ErrorSegment& out = (*env)[j]; + out.delta = cur_delta; + out.x = seg.x; + ++j; + } + } + } + delete prev_score; + // cerr << " In segments: " << ienv.size() << endl; + // cerr << "Out segments: " << j << endl; + assert(j > 0); + env->resize(j); +} + +void BLEUScore::ScoreDetails(std::string* details) const { + char buf[2000]; + vector precs(4); + float bp; + float bleu = ComputeScore(&precs, &bp); + sprintf(buf, "BLEU = %.2f, %.1f|%.1f|%.1f|%.1f (brev=%.3f)", + bleu*100.0, + precs[0]*100.0, + precs[1]*100.0, + precs[2]*100.0, + precs[3]*100.0, + bp); + *details = buf; +} + +float BLEUScore::ComputeScore(vector* precs, float* bp) const { + float log_bleu = 0; + if (precs) precs->clear(); + int count = 0; + for (int i = 0; i < hyp_ngram_counts.size(); ++i) { + if (hyp_ngram_counts[i] > 0) { + float lprec = log(correct_ngram_hit_counts[i]) - log(hyp_ngram_counts[i]); + if (precs) precs->push_back(exp(lprec)); + log_bleu += lprec; + ++count; + } + } + log_bleu /= static_cast(count); + float lbp = 0.0; + if (hyp_len < ref_len) + lbp = (hyp_len - ref_len) / hyp_len; + log_bleu += lbp; + if (bp) *bp = exp(lbp); + return exp(log_bleu); +} + +float BLEUScore::ComputeScore() const { + return ComputeScore(NULL, NULL); +} + +void BLEUScore::Subtract(const Score& rhs, Score* res) const { + const BLEUScore& d = static_cast(rhs); + BLEUScore* o = static_cast(res); + o->ref_len = ref_len - d.ref_len; + o->hyp_len = hyp_len - d.hyp_len; + o->correct_ngram_hit_counts = correct_ngram_hit_counts - d.correct_ngram_hit_counts; + o->hyp_ngram_counts = hyp_ngram_counts - d.hyp_ngram_counts; +} + +void BLEUScore::PlusEquals(const Score& delta) { + const BLEUScore& d = static_cast(delta); + correct_ngram_hit_counts += d.correct_ngram_hit_counts; + hyp_ngram_counts += d.hyp_ngram_counts; + ref_len += d.ref_len; + hyp_len += d.hyp_len; +} + +Score* BLEUScore::GetZero() const { + return new BLEUScore(hyp_ngram_counts.size()); +} + +void BLEUScore::Encode(std::string* out) const { + ostringstream os; + const int n = correct_ngram_hit_counts.size(); + os << n << ' ' << ref_len << ' ' << hyp_len; + for (int i = 0; i < n; ++i) + os << ' ' << correct_ngram_hit_counts[i] << ' ' << hyp_ngram_counts[i]; + *out = os.str(); +} + +BLEUScorerBase::BLEUScorerBase(const std::vector >& references, + int n) : n_(n) { + for (vector >::const_iterator ci = references.begin(); + ci != references.end(); ++ci) { + lengths_.push_back(ci->size()); + CountRef(*ci); + } +} + +Score* BLEUScorerBase::ScoreCandidate(const vector& hyp) const { + BLEUScore* bs = new BLEUScore(n_); + for (NGramCountMap::iterator i=ngrams_.begin(); i != ngrams_.end(); ++i) + i->second.second = 0; + ComputeNgramStats(hyp, &bs->correct_ngram_hit_counts, &bs->hyp_ngram_counts); + bs->ref_len = ComputeRefLength(hyp); + bs->hyp_len = hyp.size(); + return bs; +} + +DocScorer::~DocScorer() { + for (int i=0; i < scorers_.size(); ++i) + delete scorers_[i]; +} + +DocScorer::DocScorer( + const ScoreType type, + const std::vector& ref_files) { + // TODO stop using valarray, start using ReadFile + cerr << "Loading references (" << ref_files.size() << " files)\n"; + valarray ifs(ref_files.size()); + for (int i=0; i < ref_files.size(); ++i) { + ifs[i].open(ref_files[i].c_str()); + assert(ifs[i].good()); + } + char buf[64000]; + bool expect_eof = false; + while (!ifs[0].eof()) { + vector > refs(ref_files.size()); + for (int i=0; i < ref_files.size(); ++i) { + if (ifs[i].eof()) break; + ifs[i].getline(buf, 64000); + refs[i].clear(); + if (strlen(buf) == 0) { + if (ifs[i].eof()) { + if (!expect_eof) { + assert(i == 0); + expect_eof = true; + } + break; + } + } else { + TD::ConvertSentence(buf, &refs[i]); + assert(!refs[i].empty()); + } + assert(!expect_eof); + } + if (!expect_eof) scorers_.push_back(SentenceScorer::CreateSentenceScorer(type, refs)); + } + cerr << "Loaded reference translations for " << scorers_.size() << " sentences.\n"; +} + diff --git a/vest/scorer_test.cc b/vest/scorer_test.cc new file mode 100644 index 00000000..fca219cc --- /dev/null +++ b/vest/scorer_test.cc @@ -0,0 +1,178 @@ +#include +#include +#include +#include + +#include "tdict.h" +#include "scorer.h" + +using namespace std; + +class ScorerTest : public testing::Test { + protected: + virtual void SetUp() { + refs0.resize(4); + refs1.resize(4); + TD::ConvertSentence("export of high-tech products in guangdong in first two months this year reached 3.76 billion us dollars", &refs0[0]); + TD::ConvertSentence("guangdong's export of new high technology products amounts to us $ 3.76 billion in first two months of this year", &refs0[1]); + TD::ConvertSentence("guangdong exports us $ 3.76 billion worth of high technology products in the first two months of this year", &refs0[2]); + TD::ConvertSentence("in the first 2 months this year , the export volume of new hi-tech products in guangdong province reached 3.76 billion us dollars .", &refs0[3]); + TD::ConvertSentence("xinhua news agency , guangzhou , march 16 ( reporter chen ji ) the latest statistics show that from january through february this year , the export of high-tech products in guangdong province reached 3.76 billion us dollars , up 34.8 \% over the same period last year and accounted for 25.5 \% of the total export in the province .", &refs1[0]); + TD::ConvertSentence("xinhua news agency , guangzhou , march 16 ( reporter : chen ji ) -- latest statistic indicates that guangdong's export of new high technology products amounts to us $ 3.76 billion , up 34.8 \% over corresponding period and accounts for 25.5 \% of the total exports of the province .", &refs1[1]); + TD::ConvertSentence("xinhua news agency report of march 16 from guangzhou ( by staff reporter chen ji ) - latest statistics indicate guangdong province exported us $ 3.76 billion worth of high technology products , up 34.8 percent from the same period last year , which account for 25.5 percent of the total exports of the province .", &refs1[2]); + TD::ConvertSentence("guangdong , march 16 , ( xinhua ) -- ( chen ji reports ) as the newest statistics shows , in january and feberuary this year , the export volume of new hi-tech products in guangdong province reached 3.76 billion us dollars , up 34.8 \% than last year , making up 25.5 \% of the province's total .", &refs1[3]); + TD::ConvertSentence("one guangdong province will next export us $ 3.76 high-tech product two months first this year 3.76 billion us dollars", &hyp1); + TD::ConvertSentence("xinhua news agency , guangzhou , 16th of march ( reporter chen ) -- latest statistics suggest that guangdong exports new advanced technology product totals $ 3.76 million , 34.8 percent last corresponding period and accounts for 25.5 percent of the total export province .", &hyp2); + } + + virtual void TearDown() { } + + vector > refs0; + vector > refs1; + vector hyp1; + vector hyp2; +}; + +TEST_F(ScorerTest, TestCreateFromFiles) { + vector files; + files.push_back("test_data/re.txt.0"); + files.push_back("test_data/re.txt.1"); + files.push_back("test_data/re.txt.2"); + files.push_back("test_data/re.txt.3"); + DocScorer ds(IBM_BLEU, files); +} + +TEST_F(ScorerTest, TestBLEUScorer) { + SentenceScorer* s1 = SentenceScorer::CreateSentenceScorer(IBM_BLEU, refs0); + SentenceScorer* s2 = SentenceScorer::CreateSentenceScorer(IBM_BLEU, refs1); + Score* b1 = s1->ScoreCandidate(hyp1); + EXPECT_FLOAT_EQ(0.23185077, b1->ComputeScore()); + Score* b2 = s2->ScoreCandidate(hyp2); + EXPECT_FLOAT_EQ(0.38101241, b2->ComputeScore()); + b1->PlusEquals(*b2); + EXPECT_FLOAT_EQ(0.348854, b1->ComputeScore()); + EXPECT_FALSE(b1->IsAdditiveIdentity()); + string details; + b1->ScoreDetails(&details); + EXPECT_EQ("BLEU = 34.89, 81.5|50.8|29.5|18.6 (brev=0.898)", details); + cerr << details << endl; + string enc; + b1->Encode(&enc); + Score* b3 = SentenceScorer::CreateScoreFromString(IBM_BLEU, enc); + details.clear(); + cerr << "Encoded BLEU score size: " << enc.size() << endl; + b3->ScoreDetails(&details); + cerr << details << endl; + EXPECT_FALSE(b3->IsAdditiveIdentity()); + EXPECT_EQ("BLEU = 34.89, 81.5|50.8|29.5|18.6 (brev=0.898)", details); + Score* bz = b3->GetZero(); + EXPECT_TRUE(bz->IsAdditiveIdentity()); + delete bz; + delete b1; + delete s1; + delete b2; + delete s2; +} + +TEST_F(ScorerTest, TestTERScorer) { + SentenceScorer* s1 = SentenceScorer::CreateSentenceScorer(TER, refs0); + SentenceScorer* s2 = SentenceScorer::CreateSentenceScorer(TER, refs1); + string details; + Score* t1 = s1->ScoreCandidate(hyp1); + t1->ScoreDetails(&details); + cerr << "DETAILS: " << details << endl; + cerr << t1->ComputeScore() << endl; + Score* t2 = s2->ScoreCandidate(hyp2); + t2->ScoreDetails(&details); + cerr << "DETAILS: " << details << endl; + cerr << t2->ComputeScore() << endl; + t1->PlusEquals(*t2); + cerr << t1->ComputeScore() << endl; + t1->ScoreDetails(&details); + cerr << "DETAILS: " << details << endl; + EXPECT_EQ("TER = 44.16, 4| 8| 16| 6 (len=77)", details); + string enc; + t1->Encode(&enc); + Score* t3 = SentenceScorer::CreateScoreFromString(TER, enc); + details.clear(); + t3->ScoreDetails(&details); + EXPECT_EQ("TER = 44.16, 4| 8| 16| 6 (len=77)", details); + EXPECT_FALSE(t3->IsAdditiveIdentity()); + Score* tz = t3->GetZero(); + EXPECT_TRUE(tz->IsAdditiveIdentity()); + delete tz; + delete t3; + delete t1; + delete s1; + delete t2; + delete s2; +} + +TEST_F(ScorerTest, TestTERScorerSimple) { + vector > ref(1); + TD::ConvertSentence("1 2 3 A B", &ref[0]); + vector hyp; + TD::ConvertSentence("A B 1 2 3", &hyp); + SentenceScorer* s1 = SentenceScorer::CreateSentenceScorer(TER, ref); + string details; + Score* t1 = s1->ScoreCandidate(hyp); + t1->ScoreDetails(&details); + cerr << "DETAILS: " << details << endl; + delete t1; + delete s1; +} + +TEST_F(ScorerTest, TestSERScorerSimple) { + vector > ref(1); + TD::ConvertSentence("A B C D", &ref[0]); + vector hyp1; + TD::ConvertSentence("A B C", &hyp1); + vector hyp2; + TD::ConvertSentence("A B C D", &hyp2); + SentenceScorer* s1 = SentenceScorer::CreateSentenceScorer(SER, ref); + string details; + Score* t1 = s1->ScoreCandidate(hyp1); + t1->ScoreDetails(&details); + cerr << "DETAILS: " << details << endl; + Score* t2 = s1->ScoreCandidate(hyp2); + t2->ScoreDetails(&details); + cerr << "DETAILS: " << details << endl; + t2->PlusEquals(*t1); + t2->ScoreDetails(&details); + cerr << "DETAILS: " << details << endl; + delete t1; + delete t2; + delete s1; +} + +TEST_F(ScorerTest, TestCombiScorer) { + SentenceScorer* s1 = SentenceScorer::CreateSentenceScorer(BLEU_minus_TER_over_2, refs0); + string details; + Score* t1 = s1->ScoreCandidate(hyp1); + t1->ScoreDetails(&details); + cerr << "DETAILS: " << details << endl; + cerr << t1->ComputeScore() << endl; + string enc; + t1->Encode(&enc); + Score* t2 = SentenceScorer::CreateScoreFromString(BLEU_minus_TER_over_2, enc); + details.clear(); + t2->ScoreDetails(&details); + cerr << "DETAILS: " << details << endl; + Score* cz = t2->GetZero(); + EXPECT_FALSE(t2->IsAdditiveIdentity()); + EXPECT_TRUE(cz->IsAdditiveIdentity()); + cz->PlusEquals(*t2); + EXPECT_FALSE(cz->IsAdditiveIdentity()); + string d2; + cz->ScoreDetails(&d2); + EXPECT_EQ(d2, details); + delete cz; + delete t2; + delete t1; +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + diff --git a/vest/ter.cc b/vest/ter.cc new file mode 100644 index 00000000..ef66f3b7 --- /dev/null +++ b/vest/ter.cc @@ -0,0 +1,518 @@ +#include "ter.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tdict.h" + +const bool ter_use_average_ref_len = true; +const int ter_short_circuit_long_sentences = -1; + +using namespace std; +using namespace std::tr1; + +struct COSTS { + static const float substitution; + static const float deletion; + static const float insertion; + static const float shift; +}; +const float COSTS::substitution = 1.0f; +const float COSTS::deletion = 1.0f; +const float COSTS::insertion = 1.0f; +const float COSTS::shift = 1.0f; + +static const int MAX_SHIFT_SIZE = 10; +static const int MAX_SHIFT_DIST = 50; + +struct Shift { + unsigned int d_; + Shift() : d_() {} + Shift(int b, int e, int m) : d_() { + begin(b); + end(e); + moveto(m); + } + inline int begin() const { + return d_ & 0x3ff; + } + inline int end() const { + return (d_ >> 10) & 0x3ff; + } + inline int moveto() const { + int m = (d_ >> 20) & 0x7ff; + if (m > 1024) { m -= 1024; m *= -1; } + return m; + } + inline void begin(int b) { + d_ &= 0xfffffc00u; + d_ |= (b & 0x3ff); + } + inline void end(int e) { + d_ &= 0xfff003ffu; + d_ |= (e & 0x3ff) << 10; + } + inline void moveto(int m) { + bool neg = (m < 0); + if (neg) { m *= -1; m += 1024; } + d_ &= 0xfffff; + d_ |= (m & 0x7ff) << 20; + } +}; + +class TERScorerImpl { + + public: + enum TransType { MATCH, SUBSTITUTION, INSERTION, DELETION }; + + explicit TERScorerImpl(const vector& ref) : ref_(ref) { + for (int i = 0; i < ref.size(); ++i) + rwexists_.insert(ref[i]); + } + + float Calculate(const vector& hyp, int* subs, int* ins, int* dels, int* shifts) const { + return CalculateAllShifts(hyp, subs, ins, dels, shifts); + } + + inline int GetRefLength() const { + return ref_.size(); + } + + private: + vector ref_; + set rwexists_; + + typedef unordered_map, set, boost::hash > > NgramToIntsMap; + mutable NgramToIntsMap nmap_; + + static float MinimumEditDistance( + const vector& hyp, + const vector& ref, + vector* path) { + vector > bmat(hyp.size() + 1, vector(ref.size() + 1, MATCH)); + vector > cmat(hyp.size() + 1, vector(ref.size() + 1, 0)); + for (int i = 0; i <= hyp.size(); ++i) + cmat[i][0] = i; + for (int j = 0; j <= ref.size(); ++j) + cmat[0][j] = j; + for (int i = 1; i <= hyp.size(); ++i) { + const WordID& hw = hyp[i-1]; + for (int j = 1; j <= ref.size(); ++j) { + const WordID& rw = ref[j-1]; + float& cur_c = cmat[i][j]; + TransType& cur_b = bmat[i][j]; + + if (rw == hw) { + cur_c = cmat[i-1][j-1]; + cur_b = MATCH; + } else { + cur_c = cmat[i-1][j-1] + COSTS::substitution; + cur_b = SUBSTITUTION; + } + float cwoi = cmat[i-1][j]; + if (cur_c > cwoi + COSTS::insertion) { + cur_c = cwoi + COSTS::insertion; + cur_b = INSERTION; + } + float cwod = cmat[i][j-1]; + if (cur_c > cwod + COSTS::deletion) { + cur_c = cwod + COSTS::deletion; + cur_b = DELETION; + } + } + } + + // trace back along the best path and record the transition types + path->clear(); + int i = hyp.size(); + int j = ref.size(); + while (i > 0 || j > 0) { + if (j == 0) { + --i; + path->push_back(INSERTION); + } else if (i == 0) { + --j; + path->push_back(DELETION); + } else { + TransType t = bmat[i][j]; + path->push_back(t); + switch (t) { + case SUBSTITUTION: + case MATCH: + --i; --j; break; + case INSERTION: + --i; break; + case DELETION: + --j; break; + } + } + } + reverse(path->begin(), path->end()); + return cmat[hyp.size()][ref.size()]; + } + + void BuildWordMatches(const vector& hyp, NgramToIntsMap* nmap) const { + nmap->clear(); + set exists_both; + for (int i = 0; i < hyp.size(); ++i) + if (rwexists_.find(hyp[i]) != rwexists_.end()) + exists_both.insert(hyp[i]); + for (int start=0; start cp; + int mlen = min(MAX_SHIFT_SIZE, static_cast(ref_.size() - start)); + for (int len=0; len& in, + int start, int end, int moveto, vector* out) { + // cerr << "ps: " << start << " " << end << " " << moveto << endl; + out->clear(); + if (moveto == -1) { + for (int i = start; i <= end; ++i) + out->push_back(in[i]); + for (int i = 0; i < start; ++i) + out->push_back(in[i]); + for (int i = end+1; i < in.size(); ++i) + out->push_back(in[i]); + } else if (moveto < start) { + for (int i = 0; i <= moveto; ++i) + out->push_back(in[i]); + for (int i = start; i <= end; ++i) + out->push_back(in[i]); + for (int i = moveto+1; i < start; ++i) + out->push_back(in[i]); + for (int i = end+1; i < in.size(); ++i) + out->push_back(in[i]); + } else if (moveto > end) { + for (int i = 0; i < start; ++i) + out->push_back(in[i]); + for (int i = end+1; i <= moveto; ++i) + out->push_back(in[i]); + for (int i = start; i <= end; ++i) + out->push_back(in[i]); + for (int i = moveto+1; i < in.size(); ++i) + out->push_back(in[i]); + } else { + for (int i = 0; i < start; ++i) + out->push_back(in[i]); + for (int i = end+1; (i < in.size()) && (i <= end + (moveto - start)); ++i) + out->push_back(in[i]); + for (int i = start; i <= end; ++i) + out->push_back(in[i]); + for (int i = (end + (moveto - start))+1; i < in.size(); ++i) + out->push_back(in[i]); + } + if (out->size() != in.size()) { + cerr << "ps: " << start << " " << end << " " << moveto << endl; + cerr << "in=" << TD::GetString(in) << endl; + cerr << "out=" << TD::GetString(*out) << endl; + } + assert(out->size() == in.size()); + // cerr << "ps: " << TD::GetString(*out) << endl; + } + + void GetAllPossibleShifts(const vector& hyp, + const vector& ralign, + const vector& herr, + const vector& rerr, + const int min_size, + vector >* shifts) const { + for (int start = 0; start < hyp.size(); ++start) { + vector cp(1, hyp[start]); + NgramToIntsMap::iterator niter = nmap_.find(cp); + if (niter == nmap_.end()) continue; + bool ok = false; + int moveto; + for (set::iterator i = niter->second.begin(); i != niter->second.end(); ++i) { + moveto = *i; + int rm = ralign[moveto]; + ok = (start != rm && + (rm - start) < MAX_SHIFT_DIST && + (start - rm - 1) < MAX_SHIFT_DIST); + if (ok) break; + } + if (!ok) continue; + cp.clear(); + for (int end = start + min_size - 1; + ok && end < hyp.size() && end < (start + MAX_SHIFT_SIZE); ++end) { + cp.push_back(hyp[end]); + vector& sshifts = (*shifts)[end - start]; + ok = false; + NgramToIntsMap::iterator niter = nmap_.find(cp); + if (niter == nmap_.end()) break; + bool any_herr = false; + for (int i = start; i <= end && !any_herr; ++i) + any_herr = herr[i]; + if (!any_herr) { + ok = true; + continue; + } + for (set::iterator mi = niter->second.begin(); + mi != niter->second.end(); ++mi) { + int moveto = *mi; + int rm = ralign[moveto]; + if (! ((rm != start) && + ((rm < start) || (rm > end)) && + (rm - start <= MAX_SHIFT_DIST) && + ((start - rm - 1) <= MAX_SHIFT_DIST))) continue; + ok = true; + bool any_rerr = false; + for (int i = 0; (i <= end - start) && (!any_rerr); ++i) + any_rerr = rerr[moveto+i]; + if (!any_rerr) continue; + for (int roff = 0; roff <= (end - start); ++roff) { + int rmr = ralign[moveto+roff]; + if ((start != rmr) && ((roff == 0) || (rmr != ralign[moveto]))) + sshifts.push_back(Shift(start, end, moveto + roff)); + } + } + } + } + } + + bool CalculateBestShift(const vector& cur, + const vector& hyp, + float curerr, + const vector& path, + vector* new_hyp, + float* newerr, + vector* new_path) const { + vector herr, rerr; + vector ralign; + int hpos = -1; + for (int i = 0; i < path.size(); ++i) { + switch (path[i]) { + case MATCH: + ++hpos; + herr.push_back(false); + rerr.push_back(false); + ralign.push_back(hpos); + break; + case SUBSTITUTION: + ++hpos; + herr.push_back(true); + rerr.push_back(true); + ralign.push_back(hpos); + break; + case INSERTION: + ++hpos; + herr.push_back(true); + break; + case DELETION: + rerr.push_back(true); + ralign.push_back(hpos); + break; + } + } +#if 0 + cerr << "RALIGN: "; + for (int i = 0; i < rerr.size(); ++i) + cerr << ralign[i] << " "; + cerr << endl; + cerr << "RERR: "; + for (int i = 0; i < rerr.size(); ++i) + cerr << (bool)rerr[i] << " "; + cerr << endl; + cerr << "HERR: "; + for (int i = 0; i < herr.size(); ++i) + cerr << (bool)herr[i] << " "; + cerr << endl; +#endif + + vector > shifts(MAX_SHIFT_SIZE + 1); + GetAllPossibleShifts(cur, ralign, herr, rerr, 1, &shifts); + float cur_best_shift_cost = 0; + *newerr = curerr; + vector cur_best_path; + vector cur_best_hyp; + + bool res = false; + for (int i = shifts.size() - 1; i >=0; --i) { + float curfix = curerr - (cur_best_shift_cost + *newerr); + float maxfix = 2.0f * (1 + i) - COSTS::shift; + if ((curfix > maxfix) || ((cur_best_shift_cost == 0) && (curfix == maxfix))) break; + for (int j = 0; j < shifts[i].size(); ++j) { + const Shift& s = shifts[i][j]; + curfix = curerr - (cur_best_shift_cost + *newerr); + maxfix = 2.0f * (1 + i) - COSTS::shift; // TODO remove? + if ((curfix > maxfix) || ((cur_best_shift_cost == 0) && (curfix == maxfix))) continue; + vector shifted(cur.size()); + PerformShift(cur, s.begin(), s.end(), ralign[s.moveto()], &shifted); + vector try_path; + float try_cost = MinimumEditDistance(shifted, ref_, &try_path); + float gain = (*newerr + cur_best_shift_cost) - (try_cost + COSTS::shift); + if (gain > 0.0f || ((cur_best_shift_cost == 0.0f) && (gain == 0.0f))) { + *newerr = try_cost; + cur_best_shift_cost = COSTS::shift; + new_path->swap(try_path); + new_hyp->swap(shifted); + res = true; + // cerr << "Found better shift " << s.begin() << "..." << s.end() << " moveto " << s.moveto() << endl; + } + } + } + + return res; + } + + static void GetPathStats(const vector& path, int* subs, int* ins, int* dels) { + *subs = *ins = *dels = 0; + for (int i = 0; i < path.size(); ++i) { + switch (path[i]) { + case SUBSTITUTION: + ++(*subs); + case MATCH: + break; + case INSERTION: + ++(*ins); break; + case DELETION: + ++(*dels); break; + } + } + } + + float CalculateAllShifts(const vector& hyp, + int* subs, int* ins, int* dels, int* shifts) const { + BuildWordMatches(hyp, &nmap_); + vector path; + float med_cost = MinimumEditDistance(hyp, ref_, &path); + float edits = 0; + vector cur = hyp; + *shifts = 0; + if (ter_short_circuit_long_sentences < 0 || + ref_.size() < ter_short_circuit_long_sentences) { + while (true) { + vector new_hyp; + vector new_path; + float new_med_cost; + if (!CalculateBestShift(cur, hyp, med_cost, path, &new_hyp, &new_med_cost, &new_path)) + break; + edits += COSTS::shift; + ++(*shifts); + med_cost = new_med_cost; + path.swap(new_path); + cur.swap(new_hyp); + } + } + GetPathStats(path, subs, ins, dels); + return med_cost + edits; + } +}; + +class TERScore : public Score { + friend class TERScorer; + + public: + static const unsigned kINSERTIONS = 0; + static const unsigned kDELETIONS = 1; + static const unsigned kSUBSTITUTIONS = 2; + static const unsigned kSHIFTS = 3; + static const unsigned kREF_WORDCOUNT = 4; + static const unsigned kDUMMY_LAST_ENTRY = 5; + + TERScore() : stats(0,kDUMMY_LAST_ENTRY) {} + float ComputeScore() const { + float edits = static_cast(stats[kINSERTIONS] + stats[kDELETIONS] + stats[kSUBSTITUTIONS] + stats[kSHIFTS]); + return edits / static_cast(stats[kREF_WORDCOUNT]); + } + void ScoreDetails(string* details) const; + void PlusEquals(const Score& delta) { + stats += static_cast(delta).stats; + } + Score* GetZero() const { + return new TERScore; + } + void Subtract(const Score& rhs, Score* res) const { + static_cast(res)->stats = stats - static_cast(rhs).stats; + } + void Encode(std::string* out) const { + ostringstream os; + os << stats[kINSERTIONS] << ' ' + << stats[kDELETIONS] << ' ' + << stats[kSUBSTITUTIONS] << ' ' + << stats[kSHIFTS] << ' ' + << stats[kREF_WORDCOUNT]; + *out = os.str(); + } + bool IsAdditiveIdentity() const { + for (int i = 0; i < kDUMMY_LAST_ENTRY; ++i) + if (stats[i] != 0) return false; + return true; + } + private: + valarray stats; +}; + +Score* TERScorer::ScoreFromString(const std::string& data) { + istringstream is(data); + TERScore* r = new TERScore; + is >> r->stats[TERScore::kINSERTIONS] + >> r->stats[TERScore::kDELETIONS] + >> r->stats[TERScore::kSUBSTITUTIONS] + >> r->stats[TERScore::kSHIFTS] + >> r->stats[TERScore::kREF_WORDCOUNT]; + return r; +} + +void TERScore::ScoreDetails(std::string* details) const { + char buf[200]; + sprintf(buf, "TER = %.2f, %3d|%3d|%3d|%3d (len=%d)", + ComputeScore() * 100.0f, + stats[kINSERTIONS], + stats[kDELETIONS], + stats[kSUBSTITUTIONS], + stats[kSHIFTS], + stats[kREF_WORDCOUNT]); + *details = buf; +} + +TERScorer::~TERScorer() { + for (vector::iterator i = impl_.begin(); i != impl_.end(); ++i) + delete *i; +} + +TERScorer::TERScorer(const vector >& refs) : impl_(refs.size()) { + for (int i = 0; i < refs.size(); ++i) + impl_[i] = new TERScorerImpl(refs[i]); +} + +Score* TERScorer::ScoreCandidate(const std::vector& hyp) const { + float best_score = numeric_limits::max(); + TERScore* res = new TERScore; + int avg_len = 0; + for (int i = 0; i < impl_.size(); ++i) + avg_len += impl_[i]->GetRefLength(); + avg_len /= impl_.size(); + for (int i = 0; i < impl_.size(); ++i) { + int subs, ins, dels, shifts; + float score = impl_[i]->Calculate(hyp, &subs, &ins, &dels, &shifts); + // cerr << "Component TER cost: " << score << endl; + if (score < best_score) { + res->stats[TERScore::kINSERTIONS] = ins; + res->stats[TERScore::kDELETIONS] = dels; + res->stats[TERScore::kSUBSTITUTIONS] = subs; + res->stats[TERScore::kSHIFTS] = shifts; + if (ter_use_average_ref_len) { + res->stats[TERScore::kREF_WORDCOUNT] = avg_len; + } else { + res->stats[TERScore::kREF_WORDCOUNT] = impl_[i]->GetRefLength(); + } + + best_score = score; + } + } + return res; +} diff --git a/vest/test_data/0.json.gz b/vest/test_data/0.json.gz new file mode 100644 index 00000000..30f8dd77 Binary files /dev/null and b/vest/test_data/0.json.gz differ diff --git a/vest/test_data/1.json.gz b/vest/test_data/1.json.gz new file mode 100644 index 00000000..c82cc179 Binary files /dev/null and b/vest/test_data/1.json.gz differ diff --git a/vest/test_data/c2e.txt.0 b/vest/test_data/c2e.txt.0 new file mode 100644 index 00000000..12c4abe9 --- /dev/null +++ b/vest/test_data/c2e.txt.0 @@ -0,0 +1,2 @@ +australia reopens embassy in manila +( afp , manila , january 2 ) australia reopened its embassy in the philippines today , which was shut down about seven weeks ago due to what was described as a specific threat of a terrorist attack . diff --git a/vest/test_data/c2e.txt.1 b/vest/test_data/c2e.txt.1 new file mode 100644 index 00000000..4ac12df1 --- /dev/null +++ b/vest/test_data/c2e.txt.1 @@ -0,0 +1,2 @@ +australia reopened manila embassy +( agence france-presse , manila , 2nd ) - australia reopened its embassy in the philippines today . the embassy was closed seven weeks ago after what was described as a specific threat of a terrorist attack . diff --git a/vest/test_data/c2e.txt.2 b/vest/test_data/c2e.txt.2 new file mode 100644 index 00000000..2f67b72f --- /dev/null +++ b/vest/test_data/c2e.txt.2 @@ -0,0 +1,2 @@ +australia to reopen embassy in manila +( afp report from manila , january 2 ) australia reopened its embassy in the philippines today . seven weeks ago , the embassy was shut down due to so-called confirmed terrorist attack threats . diff --git a/vest/test_data/c2e.txt.3 b/vest/test_data/c2e.txt.3 new file mode 100644 index 00000000..5483cef6 --- /dev/null +++ b/vest/test_data/c2e.txt.3 @@ -0,0 +1,2 @@ +australia to re - open its embassy to manila +( afp , manila , thursday ) australia reopens its embassy to manila , which was closed for the so-called " clear " threat of terrorist attack 7 weeks ago . diff --git a/vest/test_data/re.txt.0 b/vest/test_data/re.txt.0 new file mode 100644 index 00000000..86eff087 --- /dev/null +++ b/vest/test_data/re.txt.0 @@ -0,0 +1,5 @@ +erdogan states turkey to reject any pressures to urge it to recognize cyprus +ankara 12 - 1 ( afp ) - turkish prime minister recep tayyip erdogan announced today , wednesday , that ankara will reject any pressure by the european union to urge it to recognize cyprus . this comes two weeks before the summit of european union state and government heads who will decide whether or nor membership negotiations with ankara should be opened . +erdogan told " ntv " television station that " the european union cannot address us by imposing new conditions on us with regard to cyprus . +we will discuss this dossier in the course of membership negotiations . " +he added " let me be clear , i cannot sidestep turkey , this is something we cannot accept . " diff --git a/vest/test_data/re.txt.1 b/vest/test_data/re.txt.1 new file mode 100644 index 00000000..2140f198 --- /dev/null +++ b/vest/test_data/re.txt.1 @@ -0,0 +1,5 @@ +erdogan confirms turkey will resist any pressure to recognize cyprus +ankara 12 - 1 ( afp ) - the turkish head of government , recep tayyip erdogan , announced today ( wednesday ) that ankara would resist any pressure the european union might exercise in order to force it into recognizing cyprus . this comes two weeks before a summit of european union heads of state and government , who will decide whether or not to open membership negotiations with ankara . +erdogan said to the ntv television channel : " the european union cannot engage with us through imposing new conditions on us with regard to cyprus . +we shall discuss this issue in the course of the membership negotiations . " +he added : " let me be clear - i cannot confine turkey . this is something we do not accept . " diff --git a/vest/test_data/re.txt.2 b/vest/test_data/re.txt.2 new file mode 100644 index 00000000..94e46286 --- /dev/null +++ b/vest/test_data/re.txt.2 @@ -0,0 +1,5 @@ +erdogan confirms that turkey will reject any pressures to encourage it to recognize cyprus +ankara , 12 / 1 ( afp ) - the turkish prime minister recep tayyip erdogan declared today , wednesday , that ankara will reject any pressures that the european union may apply on it to encourage to recognize cyprus . this comes two weeks before a summit of the heads of countries and governments of the european union , who will decide on whether or not to start negotiations on joining with ankara . +erdogan told the ntv television station that " it is not possible for the european union to talk to us by imposing new conditions on us regarding cyprus . +we shall discuss this dossier during the negotiations on joining . " +and he added , " let me be clear . turkey's arm should not be twisted ; this is something we cannot accept . " diff --git a/vest/test_data/re.txt.3 b/vest/test_data/re.txt.3 new file mode 100644 index 00000000..f87c3308 --- /dev/null +++ b/vest/test_data/re.txt.3 @@ -0,0 +1,5 @@ +erdogan stresses that turkey will reject all pressures to force it to recognize cyprus +ankara 12 - 1 ( afp ) - turkish prime minister recep tayyip erdogan announced today , wednesday , that ankara would refuse all pressures applied on it by the european union to force it to recognize cyprus . that came two weeks before the summit of the presidents and prime ministers of the european union , who would decide on whether to open negotiations on joining with ankara or not . +erdogan said to " ntv " tv station that the " european union can not communicate with us by imposing on us new conditions related to cyprus . +we will discuss this file during the negotiations on joining . " +he added , " let me be clear . turkey's arm should not be twisted . this is unacceptable to us . " diff --git a/vest/union_forests.cc b/vest/union_forests.cc new file mode 100644 index 00000000..207ecb5c --- /dev/null +++ b/vest/union_forests.cc @@ -0,0 +1,73 @@ +#include +#include +#include + +#include +#include + +#include "hg.h" +#include "hg_io.h" +#include "filelib.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("dev_set_size,s",po::value(),"[REQD] Development set size (# of parallel sentences)") + ("forest_repository,r",po::value(),"[REQD] Path to forest repository") + ("new_forest_repository,n",po::value(),"[REQD] Path to new forest repository") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (conf->count("dev_set_size") == 0) { + cerr << "Please specify the size of the development set using -d N\n"; + flag = true; + } + if (conf->count("new_forest_repository") == 0) { + cerr << "Please specify the starting-point weights using -n PATH\n"; + flag = true; + } + if (conf->count("forest_repository") == 0) { + cerr << "Please specify the forest repository location using -r PATH\n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const int size = conf["dev_set_size"].as(); + const string repo = conf["forest_repository"].as(); + const string new_repo = conf["new_forest_repository"].as(); + for (int i = 0; i < size; ++i) { + ostringstream sfin, sfout; + sfin << new_repo << '/' << i << ".json.gz"; + sfout << repo << '/' << i << ".json.gz"; + const string fin = sfin.str(); + const string fout = sfout.str(); + Hypergraph existing_hg; + cerr << "Processing " << fin << endl; + assert(FileExists(fin)); + if (FileExists(fout)) { + ReadFile rf(fout); + assert(HypergraphIO::ReadFromJSON(rf.stream(), &existing_hg)); + } + Hypergraph new_hg; + if (true) { + ReadFile rf(fin); + assert(HypergraphIO::ReadFromJSON(rf.stream(), &new_hg)); + } + existing_hg.Union(new_hg); + WriteFile wf(fout); + assert(HypergraphIO::WriteToJSON(existing_hg, false, wf.stream())); + } + return 0; +} diff --git a/vest/viterbi_envelope.cc b/vest/viterbi_envelope.cc new file mode 100644 index 00000000..1122030a --- /dev/null +++ b/vest/viterbi_envelope.cc @@ -0,0 +1,167 @@ +#include "viterbi_envelope.h" + +#include +#include + +using namespace std; +using boost::shared_ptr; + +ostream& operator<<(ostream& os, const ViterbiEnvelope& env) { + os << '<'; + const vector >& segs = env.GetSortedSegs(); + for (int i = 0; i < segs.size(); ++i) + os << (i==0 ? "" : "|") << "x=" << segs[i]->x << ",b=" << segs[i]->b << ",m=" << segs[i]->m << ",p1=" << segs[i]->p1 << ",p2=" << segs[i]->p2; + return os << '>'; +} + +ViterbiEnvelope::ViterbiEnvelope(int i) { + if (i == 0) { + // do nothing - <> + } else if (i == 1) { + segs.push_back(shared_ptr(new Segment(0, 0, 0, shared_ptr(), shared_ptr()))); + assert(this->IsMultiplicativeIdentity()); + } else { + cerr << "Only can create ViterbiEnvelope semiring 0 and 1 with this constructor!\n"; + abort(); + } +} + +struct SlopeCompare { + bool operator() (const shared_ptr& a, const shared_ptr& b) const { + return a->m < b->m; + } +}; + +const ViterbiEnvelope& ViterbiEnvelope::operator+=(const ViterbiEnvelope& other) { + if (!other.is_sorted) other.Sort(); + if (segs.empty()) { + segs = other.segs; + return *this; + } + is_sorted = false; + int j = segs.size(); + segs.resize(segs.size() + other.segs.size()); + for (int i = 0; i < other.segs.size(); ++i) + segs[j++] = other.segs[i]; + assert(j == segs.size()); + return *this; +} + +void ViterbiEnvelope::Sort() const { + sort(segs.begin(), segs.end(), SlopeCompare()); + const int k = segs.size(); + int j = 0; + for (int i = 0; i < k; ++i) { + Segment l = *segs[i]; + l.x = kMinusInfinity; + // cerr << "m=" << l.m << endl; + if (0 < j) { + if (segs[j-1]->m == l.m) { // lines are parallel + if (l.b <= segs[j-1]->b) continue; + --j; + } + while(0 < j) { + l.x = (l.b - segs[j-1]->b) / (segs[j-1]->m - l.m); + if (segs[j-1]->x < l.x) break; + --j; + } + if (0 == j) l.x = kMinusInfinity; + } + *segs[j++] = l; + } + segs.resize(j); + is_sorted = true; +} + +const ViterbiEnvelope& ViterbiEnvelope::operator*=(const ViterbiEnvelope& other) { + if (other.IsMultiplicativeIdentity()) { return *this; } + if (this->IsMultiplicativeIdentity()) { (*this) = other; return *this; } + + if (!is_sorted) Sort(); + if (!other.is_sorted) other.Sort(); + + if (this->IsEdgeEnvelope()) { +// if (other.size() > 1) +// cerr << *this << " (TIMES) " << other << endl; + shared_ptr edge_parent = segs[0]; + const double& edge_b = edge_parent->b; + const double& edge_m = edge_parent->m; + segs.clear(); + for (int i = 0; i < other.segs.size(); ++i) { + const Segment& seg = *other.segs[i]; + const double m = seg.m + edge_m; + const double b = seg.b + edge_b; + const double& x = seg.x; // x's don't change with * + segs.push_back(shared_ptr(new Segment(x, m, b, edge_parent, other.segs[i]))); + assert(segs.back()->p1->rule); + } +// if (other.size() > 1) +// cerr << " = " << *this << endl; + } else { + vector > new_segs; + int this_i = 0; + int other_i = 0; + const int this_size = segs.size(); + const int other_size = other.segs.size(); + double cur_x = kMinusInfinity; // moves from left to right across the + // real numbers, stopping for all inter- + // sections + double this_next_val = (1 < this_size ? segs[1]->x : kPlusInfinity); + double other_next_val = (1 < other_size ? other.segs[1]->x : kPlusInfinity); + while (this_i < this_size && other_i < other_size) { + const Segment& this_seg = *segs[this_i]; + const Segment& other_seg= *other.segs[other_i]; + const double m = this_seg.m + other_seg.m; + const double b = this_seg.b + other_seg.b; + + new_segs.push_back(shared_ptr(new Segment(cur_x, m, b, segs[this_i], other.segs[other_i]))); + int comp = 0; + if (this_next_val < other_next_val) comp = -1; else + if (this_next_val > other_next_val) comp = 1; + if (0 == comp) { // the next values are equal, advance both indices + ++this_i; + ++other_i; + cur_x = this_next_val; // could be other_next_val (they're equal!) + this_next_val = (this_i+1 < this_size ? segs[this_i+1]->x : kPlusInfinity); + other_next_val = (other_i+1 < other_size ? other.segs[other_i+1]->x : kPlusInfinity); + } else { // advance the i with the lower x, update cur_x + if (-1 == comp) { + ++this_i; + cur_x = this_next_val; + this_next_val = (this_i+1 < this_size ? segs[this_i+1]->x : kPlusInfinity); + } else { + ++other_i; + cur_x = other_next_val; + other_next_val = (other_i+1 < other_size ? other.segs[other_i+1]->x : kPlusInfinity); + } + } + } + segs.swap(new_segs); + } + //cerr << "Multiply: result=" << (*this) << endl; + return *this; +} + +// recursively construct translation +void Segment::ConstructTranslation(vector* trans) const { + const Segment* cur = this; + vector > ant_trans; + while(!cur->rule) { + ant_trans.resize(ant_trans.size() + 1); + cur->p2->ConstructTranslation(&ant_trans.back()); + cur = cur->p1.get(); + } + size_t ant_size = ant_trans.size(); + vector*> pants(ant_size); + --ant_size; + for (int i = 0; i < pants.size(); ++i) pants[ant_size - i] = &ant_trans[i]; + cur->rule->ESubstitute(pants, trans); +} + +ViterbiEnvelope ViterbiEnvelopeWeightFunction::operator()(const Hypergraph::Edge& e) const { + const double m = direction.dot(e.feature_values_); + const double b = origin.dot(e.feature_values_); + Segment* seg = new Segment(m, b, e.rule_); + return ViterbiEnvelope(1, seg); +} + -- cgit v1.2.3