summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2009-12-03 16:33:55 -0500
committerChris Dyer <redpony@gmail.com>2009-12-03 16:33:55 -0500
commit671c21451542e2dd20e45b4033d44d8e8735f87b (patch)
treeb1773b077dd65b826f067a423d26f7942ce4e043
initial check in
-rw-r--r--Makefile.am3
-rw-r--r--README24
-rw-r--r--configure.ac46
-rw-r--r--src/JSON_parser.h152
-rw-r--r--src/Makefile.am67
-rw-r--r--src/aligner.cc204
-rw-r--r--src/aligner.h23
-rw-r--r--src/apply_models.cc344
-rw-r--r--src/apply_models.h20
-rw-r--r--src/array2d.h171
-rw-r--r--src/bottom_up_parser.cc260
-rw-r--r--src/bottom_up_parser.h27
-rw-r--r--src/cdec.cc474
-rw-r--r--src/cdec_ff.cc18
-rw-r--r--src/collapse_weights.cc102
-rw-r--r--src/dict.h40
-rw-r--r--src/dict_test.cc30
-rw-r--r--src/earley_composer.cc726
-rw-r--r--src/earley_composer.h29
-rw-r--r--src/exp_semiring.h71
-rw-r--r--src/fdict.cc4
-rw-r--r--src/fdict.h21
-rw-r--r--src/ff.cc93
-rw-r--r--src/ff.h121
-rw-r--r--src/ff_factory.cc35
-rw-r--r--src/ff_factory.h39
-rw-r--r--src/ff_itg_span.h7
-rw-r--r--src/ff_test.cc134
-rw-r--r--src/ff_wordalign.cc221
-rw-r--r--src/ff_wordalign.h133
-rw-r--r--src/filelib.cc22
-rw-r--r--src/filelib.h66
-rw-r--r--src/forest_writer.cc23
-rw-r--r--src/forest_writer.h16
-rw-r--r--src/freqdict.cc23
-rw-r--r--src/freqdict.h19
-rw-r--r--src/fst_translator.cc91
-rw-r--r--src/grammar.cc163
-rw-r--r--src/grammar.h83
-rw-r--r--src/grammar_test.cc59
-rw-r--r--src/gzstream.cc165
-rw-r--r--src/gzstream.h121
-rw-r--r--src/hg.cc483
-rw-r--r--src/hg.h225
-rw-r--r--src/hg_intersect.cc121
-rw-r--r--src/hg_intersect.h13
-rw-r--r--src/hg_io.cc585
-rw-r--r--src/hg_io.h37
-rw-r--r--src/hg_test.cc441
-rw-r--r--src/ibm_model1.cc4
-rw-r--r--src/inside_outside.h111
-rw-r--r--src/json_parse.cc50
-rw-r--r--src/json_parse.h58
-rw-r--r--src/kbest.h207
-rw-r--r--src/lattice.cc27
-rw-r--r--src/lattice.h31
-rw-r--r--src/lexcrf.cc112
-rw-r--r--src/lexcrf.h18
-rw-r--r--src/lm_ff.cc328
-rw-r--r--src/lm_ff.h32
-rw-r--r--src/logval.h136
-rw-r--r--src/maxtrans_blunsom.cc287
-rw-r--r--src/parser_test.cc35
-rw-r--r--src/phrasebased_translator.cc206
-rw-r--r--src/phrasebased_translator.h18
-rw-r--r--src/phrasetable_fst.cc141
-rw-r--r--src/phrasetable_fst.h34
-rw-r--r--src/prob.h8
-rw-r--r--src/sampler.h136
-rw-r--r--src/scfg_translator.cc66
-rw-r--r--src/sentence_metadata.h42
-rw-r--r--src/small_vector.h187
-rw-r--r--src/small_vector_test.cc129
-rw-r--r--src/sparse_vector.cc98
-rw-r--r--src/sparse_vector.h264
-rw-r--r--src/stringlib.cc97
-rw-r--r--src/stringlib.h91
-rw-r--r--src/synparse.cc212
-rw-r--r--src/tdict.cc49
-rw-r--r--src/tdict.h19
-rw-r--r--src/timing_stats.cc24
-rw-r--r--src/timing_stats.h25
-rw-r--r--src/translator.h54
-rw-r--r--src/trule.cc237
-rw-r--r--src/trule.h122
-rw-r--r--src/trule_test.cc65
-rw-r--r--src/ttables.cc31
-rw-r--r--src/ttables.h87
-rw-r--r--src/viterbi.cc39
-rw-r--r--src/viterbi.h130
-rw-r--r--src/weights.cc73
-rw-r--r--src/weights.h21
-rw-r--r--src/weights_test.cc28
-rw-r--r--src/wordid.h6
-rw-r--r--training/Makefile.am39
-rw-r--r--training/atools.cc207
-rwxr-xr-xtraining/cluster-em.pl110
-rwxr-xr-xtraining/cluster-ptrain.pl144
-rw-r--r--training/grammar_convert.cc316
-rw-r--r--training/lbfgs.h1459
-rw-r--r--training/lbfgs_test.cc112
-rwxr-xr-xtraining/make-lexcrf-grammar.pl236
-rw-r--r--training/model1.cc103
-rw-r--r--training/mr_em_train.cc270
-rw-r--r--training/mr_optimize_reduce.cc243
-rw-r--r--training/optimize.cc114
-rw-r--r--training/optimize.h104
-rw-r--r--training/optimize_test.cc105
-rw-r--r--training/plftools.cc93
-rw-r--r--vest/Makefile.am32
-rw-r--r--vest/comb_scorer.cc81
-rwxr-xr-xvest/dist-vest.pl642
-rw-r--r--vest/error_surface.cc46
-rw-r--r--vest/fast_score.cc74
-rw-r--r--vest/line_optimizer.cc101
-rw-r--r--vest/lo_test.cc201
-rw-r--r--vest/mr_vest_generate_mapper_input.cc72
-rw-r--r--vest/mr_vest_map.cc98
-rw-r--r--vest/mr_vest_reduce.cc80
-rw-r--r--vest/scorer.cc485
-rw-r--r--vest/scorer_test.cc178
-rw-r--r--vest/ter.cc518
-rw-r--r--vest/test_data/0.json.gzbin0 -> 13709 bytes
-rw-r--r--vest/test_data/1.json.gzbin0 -> 204803 bytes
-rw-r--r--vest/test_data/c2e.txt.02
-rw-r--r--vest/test_data/c2e.txt.12
-rw-r--r--vest/test_data/c2e.txt.22
-rw-r--r--vest/test_data/c2e.txt.32
-rw-r--r--vest/test_data/re.txt.05
-rw-r--r--vest/test_data/re.txt.15
-rw-r--r--vest/test_data/re.txt.25
-rw-r--r--vest/test_data/re.txt.35
-rw-r--r--vest/union_forests.cc73
-rw-r--r--vest/viterbi_envelope.cc167
134 files changed, 17101 insertions, 0 deletions
diff --git a/Makefile.am b/Makefile.am
new file mode 100644
index 00000000..38e9e59a
--- /dev/null
+++ b/Makefile.am
@@ -0,0 +1,3 @@
+SUBDIRS = src training vest
+AUTOMAKE_OPTIONS = foreign
+
diff --git a/README b/README
new file mode 100644
index 00000000..63990ffb
--- /dev/null
+++ b/README
@@ -0,0 +1,24 @@
+cdec is a fast decoder.
+
+ .. more coming ...
+
+COPYRIGHT AND LICENSE
+------------------------------------------------------------------------------
+Copyright (c) 2009 by Chris Dyer <redpony@gmail.com>
+
+Licensed under the Apache License, Version 2.0 (the "License"); you may
+not use this file except in compliance with the License. You may obtain
+a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+The LBFGS implementation contains code from the Computational
+Crystallography Toolbox which is copyright (c) 2006 by The Regents of the
+University of California, through Lawrence Berkeley National Laboratory.
+For more information on their license, refer to http://cctbx.sourceforge.net/
diff --git a/configure.ac b/configure.ac
new file mode 100644
index 00000000..76307998
--- /dev/null
+++ b/configure.ac
@@ -0,0 +1,46 @@
+AC_INIT
+AM_INIT_AUTOMAKE(cdec,0.1)
+AC_CONFIG_HEADERS(config.h)
+AC_PROG_RANLIB
+AC_PROG_CC
+AC_PROG_CXX
+AC_LANG_CPLUSPLUS
+AX_BOOST_BASE
+AX_BOOST_PROGRAM_OPTIONS
+CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS"
+AC_CHECK_HEADER(boost/math/special_functions/digamma.hpp,
+ [AC_DEFINE([HAVE_BOOST_DIGAMMA], [], [flag for boost::math::digamma])])
+
+GTEST_LIB_CHECK
+AC_PROG_INSTALL
+
+AC_ARG_WITH(srilm,
+ [AC_HELP_STRING([--with-srilm=PATH], [(optional) path to SRI's LM toolkit])],
+ [with_srilm=$withval],
+ [with_srilm=no]
+ )
+
+AM_CONDITIONAL([SRI_LM], false)
+
+if test "x$with_srilm" != 'xno'
+then
+ SAVE_CPPFLAGS="$CPPFLAGS"
+ CPPFLAGS="$CPPFLAGS -I${with_srilm}/include"
+
+ AC_CHECK_HEADER(Ngram.h,
+ [AC_DEFINE([HAVE_SRILM], [], [flag for SRILM])],
+ [AC_MSG_ERROR([Cannot find SRILM!])])
+
+ LIB_SRILM="-loolm -ldstruct -lmisc"
+ # ROOT/lib/i686-m64/liboolm.a
+ # ROOT/lib/i686-m64/libdstruct.a
+ # ROOT/lib/i686-m64/libmisc.a
+ MY_ARCH=`${with_srilm}/sbin/machine-type`
+ LDFLAGS="$LDFLAGS -L${with_srilm}/lib/${MY_ARCH}"
+ LIBS="$LIBS $LIB_SRILM"
+ FMTLIBS="$FMTLIBS liboolm.a libdstruct.a libmisc.a"
+ AM_CONDITIONAL([SRI_LM], true)
+fi
+
+AC_OUTPUT(Makefile src/Makefile vest/Makefile)
+
diff --git a/src/JSON_parser.h b/src/JSON_parser.h
new file mode 100644
index 00000000..ceb5b24b
--- /dev/null
+++ b/src/JSON_parser.h
@@ -0,0 +1,152 @@
+#ifndef JSON_PARSER_H
+#define JSON_PARSER_H
+
+/* JSON_parser.h */
+
+
+#include <stddef.h>
+
+/* Windows DLL stuff */
+#ifdef _WIN32
+# ifdef JSON_PARSER_DLL_EXPORTS
+# define JSON_PARSER_DLL_API __declspec(dllexport)
+# else
+# define JSON_PARSER_DLL_API __declspec(dllimport)
+# endif
+#else
+# define JSON_PARSER_DLL_API
+#endif
+
+/* Determine the integer type use to parse non-floating point numbers */
+#if __STDC_VERSION__ >= 199901L || HAVE_LONG_LONG == 1
+typedef long long JSON_int_t;
+#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%lld"
+#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%lld"
+#else
+typedef long JSON_int_t;
+#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%ld"
+#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%ld"
+#endif
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef enum
+{
+ JSON_T_NONE = 0,
+ JSON_T_ARRAY_BEGIN, // 1
+ JSON_T_ARRAY_END, // 2
+ JSON_T_OBJECT_BEGIN, // 3
+ JSON_T_OBJECT_END, // 4
+ JSON_T_INTEGER, // 5
+ JSON_T_FLOAT, // 6
+ JSON_T_NULL, // 7
+ JSON_T_TRUE, // 8
+ JSON_T_FALSE, // 9
+ JSON_T_STRING, // 10
+ JSON_T_KEY, // 11
+ JSON_T_MAX // 12
+} JSON_type;
+
+typedef struct JSON_value_struct {
+ union {
+ JSON_int_t integer_value;
+
+ long double float_value;
+
+ struct {
+ const char* value;
+ size_t length;
+ } str;
+ } vu;
+} JSON_value;
+
+typedef struct JSON_parser_struct* JSON_parser;
+
+/*! \brief JSON parser callback
+
+ \param ctx The pointer passed to new_JSON_parser.
+ \param type An element of JSON_type but not JSON_T_NONE.
+ \param value A representation of the parsed value. This parameter is NULL for
+ JSON_T_ARRAY_BEGIN, JSON_T_ARRAY_END, JSON_T_OBJECT_BEGIN, JSON_T_OBJECT_END,
+ JSON_T_NULL, JSON_T_TRUE, and SON_T_FALSE. String values are always returned
+ as zero-terminated C strings.
+
+ \return Non-zero if parsing should continue, else zero.
+*/
+typedef int (*JSON_parser_callback)(void* ctx, int type, const struct JSON_value_struct* value);
+
+
+/*! \brief The structure used to configure a JSON parser object
+
+ \param depth If negative, the parser can parse arbitrary levels of JSON, otherwise
+ the depth is the limit
+ \param Pointer to a callback. This parameter may be NULL. In this case the input is merely checked for validity.
+ \param Callback context. This parameter may be NULL.
+ \param depth. Specifies the levels of nested JSON to allow. Negative numbers yield unlimited nesting.
+ \param allowComments. To allow C style comments in JSON, set to non-zero.
+ \param handleFloatsManually. To decode floating point numbers manually set this parameter to non-zero.
+
+ \return The parser object.
+*/
+typedef struct {
+ JSON_parser_callback callback;
+ void* callback_ctx;
+ int depth;
+ int allow_comments;
+ int handle_floats_manually;
+} JSON_config;
+
+
+/*! \brief Initializes the JSON parser configuration structure to default values.
+
+ The default configuration is
+ - 127 levels of nested JSON (depends on JSON_PARSER_STACK_SIZE, see json_parser.c)
+ - no parsing, just checking for JSON syntax
+ - no comments
+
+ \param config. Used to configure the parser.
+*/
+JSON_PARSER_DLL_API void init_JSON_config(JSON_config* config);
+
+/*! \brief Create a JSON parser object
+
+ \param config. Used to configure the parser. Set to NULL to use the default configuration.
+ See init_JSON_config
+
+ \return The parser object.
+*/
+JSON_PARSER_DLL_API extern JSON_parser new_JSON_parser(JSON_config* config);
+
+/*! \brief Destroy a previously created JSON parser object. */
+JSON_PARSER_DLL_API extern void delete_JSON_parser(JSON_parser jc);
+
+/*! \brief Parse a character.
+
+ \return Non-zero, if all characters passed to this function are part of are valid JSON.
+*/
+JSON_PARSER_DLL_API extern int JSON_parser_char(JSON_parser jc, int next_char);
+
+/*! \brief Finalize parsing.
+
+ Call this method once after all input characters have been consumed.
+
+ \return Non-zero, if all parsed characters are valid JSON, zero otherwise.
+*/
+JSON_PARSER_DLL_API extern int JSON_parser_done(JSON_parser jc);
+
+/*! \brief Determine if a given string is valid JSON white space
+
+ \return Non-zero if the string is valid, zero otherwise.
+*/
+JSON_PARSER_DLL_API extern int JSON_parser_is_legal_white_space_string(const char* s);
+
+
+#ifdef __cplusplus
+}
+#endif
+
+
+#endif /* JSON_PARSER_H */
diff --git a/src/Makefile.am b/src/Makefile.am
new file mode 100644
index 00000000..2af6cab7
--- /dev/null
+++ b/src/Makefile.am
@@ -0,0 +1,67 @@
+bin_PROGRAMS = \
+ dict_test \
+ weights_test \
+ trule_test \
+ hg_test \
+ ff_test \
+ parser_test \
+ grammar_test \
+ cdec \
+ small_vector_test
+
+cdec_SOURCES = cdec.cc forest_writer.cc maxtrans_blunsom.cc cdec_ff.cc ff_factory.cc timing_stats.cc
+small_vector_test_SOURCES = small_vector_test.cc
+small_vector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+parser_test_SOURCES = parser_test.cc
+parser_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+dict_test_SOURCES = dict_test.cc
+dict_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+ff_test_SOURCES = ff_test.cc
+ff_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+grammar_test_SOURCES = grammar_test.cc
+grammar_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+hg_test_SOURCES = hg_test.cc
+hg_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+trule_test_SOURCES = trule_test.cc
+trule_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+weights_test_SOURCES = weights_test.cc
+weights_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+
+LDADD = libhg.a
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS)
+AM_LDFLAGS = $(BOOST_LDFLAGS) $(BOOST_PROGRAM_OPTIONS_LIB) -lz
+
+noinst_LIBRARIES = libhg.a
+
+libhg_a_SOURCES = \
+ fst_translator.cc \
+ scfg_translator.cc \
+ hg.cc \
+ hg_io.cc \
+ viterbi.cc \
+ lattice.cc \
+ logging.cc \
+ aligner.cc \
+ gzstream.cc \
+ apply_models.cc \
+ earley_composer.cc \
+ phrasetable_fst.cc \
+ sparse_vector.cc \
+ trule.cc \
+ filelib.cc \
+ stringlib.cc \
+ fdict.cc \
+ tdict.cc \
+ weights.cc \
+ ttables.cc \
+ ff.cc \
+ lm_ff.cc \
+ ff_wordalign.cc \
+ hg_intersect.cc \
+ lexcrf.cc \
+ bottom_up_parser.cc \
+ phrasebased_translator.cc \
+ JSON_parser.c \
+ json_parse.cc \
+ grammar.cc
diff --git a/src/aligner.cc b/src/aligner.cc
new file mode 100644
index 00000000..d9d067e5
--- /dev/null
+++ b/src/aligner.cc
@@ -0,0 +1,204 @@
+#include "aligner.h"
+
+#include "array2d.h"
+#include "hg.h"
+#include "inside_outside.h"
+#include <set>
+
+using namespace std;
+
+struct EdgeCoverageInfo {
+ set<int> src_indices;
+ set<int> trg_indices;
+};
+
+static bool is_digit(char x) { return x >= '0' && x <= '9'; }
+
+boost::shared_ptr<Array2D<bool> > AlignerTools::ReadPharaohAlignmentGrid(const string& al) {
+ int max_x = 0;
+ int max_y = 0;
+ int i = 0;
+ while (i < al.size()) {
+ int x = 0;
+ while(i < al.size() && is_digit(al[i])) {
+ x *= 10;
+ x += al[i] - '0';
+ ++i;
+ }
+ if (x > max_x) max_x = x;
+ assert(i < al.size());
+ assert(al[i] == '-');
+ ++i;
+ int y = 0;
+ while(i < al.size() && is_digit(al[i])) {
+ y *= 10;
+ y += al[i] - '0';
+ ++i;
+ }
+ if (y > max_y) max_y = y;
+ while(i < al.size() && al[i] == ' ') { ++i; }
+ }
+
+ boost::shared_ptr<Array2D<bool> > grid(new Array2D<bool>(max_x + 1, max_y + 1));
+ i = 0;
+ while (i < al.size()) {
+ int x = 0;
+ while(i < al.size() && is_digit(al[i])) {
+ x *= 10;
+ x += al[i] - '0';
+ ++i;
+ }
+ assert(i < al.size());
+ assert(al[i] == '-');
+ ++i;
+ int y = 0;
+ while(i < al.size() && is_digit(al[i])) {
+ y *= 10;
+ y += al[i] - '0';
+ ++i;
+ }
+ (*grid)(x, y) = true;
+ while(i < al.size() && al[i] == ' ') { ++i; }
+ }
+ // cerr << *grid << endl;
+ return grid;
+}
+
+void AlignerTools::SerializePharaohFormat(const Array2D<bool>& alignment, ostream* out) {
+ bool need_space = false;
+ for (int i = 0; i < alignment.width(); ++i)
+ for (int j = 0; j < alignment.height(); ++j)
+ if (alignment(i,j)) {
+ if (need_space) (*out) << ' '; else need_space = true;
+ (*out) << i << '-' << j;
+ }
+ (*out) << endl;
+}
+
+// compute the coverage vectors of each edge
+// prereq: all derivations yield the same string pair
+void ComputeCoverages(const Hypergraph& g,
+ vector<EdgeCoverageInfo>* pcovs) {
+ for (int i = 0; i < g.edges_.size(); ++i) {
+ const Hypergraph::Edge& edge = g.edges_[i];
+ EdgeCoverageInfo& cov = (*pcovs)[i];
+ // no words
+ if (edge.rule_->EWords() == 0 || edge.rule_->FWords() == 0)
+ continue;
+ // aligned to NULL (crf ibm variant only)
+ if (edge.prev_i_ == -1 || edge.i_ == -1)
+ continue;
+ assert(edge.j_ >= 0);
+ assert(edge.prev_j_ >= 0);
+ if (edge.Arity() == 0) {
+ for (int k = edge.i_; k < edge.j_; ++k)
+ cov.trg_indices.insert(k);
+ for (int k = edge.prev_i_; k < edge.prev_j_; ++k)
+ cov.src_indices.insert(k);
+ } else {
+ // note: this code, which handles mixed NT and terminal
+ // rules assumes that nodes uniquely define a src and trg
+ // span.
+ int k = edge.prev_i_;
+ int j = 0;
+ const vector<WordID>& f = edge.rule_->e(); // rules are inverted
+ while (k < edge.prev_j_) {
+ if (f[j] > 0) {
+ cov.src_indices.insert(k);
+ // cerr << "src: " << k << endl;
+ ++k;
+ ++j;
+ } else {
+ const Hypergraph::Node& tailnode = g.nodes_[edge.tail_nodes_[-f[j]]];
+ assert(tailnode.in_edges_.size() > 0);
+ // any edge will do:
+ const Hypergraph::Edge& rep_edge = g.edges_[tailnode.in_edges_.front()];
+ //cerr << "skip " << (rep_edge.prev_j_ - rep_edge.prev_i_) << endl; // src span
+ k += (rep_edge.prev_j_ - rep_edge.prev_i_); // src span
+ ++j;
+ }
+ }
+ int tc = 0;
+ const vector<WordID>& e = edge.rule_->f(); // rules are inverted
+ k = edge.i_;
+ j = 0;
+ // cerr << edge.rule_->AsString() << endl;
+ // cerr << "i=" << k << " j=" << edge.j_ << endl;
+ while (k < edge.j_) {
+ //cerr << " k=" << k << endl;
+ if (e[j] > 0) {
+ cov.trg_indices.insert(k);
+ // cerr << "trg: " << k << endl;
+ ++k;
+ ++j;
+ } else {
+ assert(tc < edge.tail_nodes_.size());
+ const Hypergraph::Node& tailnode = g.nodes_[edge.tail_nodes_[tc]];
+ assert(tailnode.in_edges_.size() > 0);
+ // any edge will do:
+ const Hypergraph::Edge& rep_edge = g.edges_[tailnode.in_edges_.front()];
+ // cerr << "t skip " << (rep_edge.j_ - rep_edge.i_) << endl; // src span
+ k += (rep_edge.j_ - rep_edge.i_); // src span
+ ++j;
+ ++tc;
+ }
+ }
+ //abort();
+ }
+ }
+}
+
+void AlignerTools::WriteAlignment(const string& input,
+ const Lattice& ref,
+ const Hypergraph& g,
+ bool map_instead_of_viterbi) {
+ if (!map_instead_of_viterbi) {
+ assert(!"not implemented!");
+ }
+ vector<prob_t> edge_posteriors(g.edges_.size());
+ {
+ SparseVector<prob_t> posts;
+ InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, TransitionEventWeightFunction>(g, &posts);
+ for (int i = 0; i < edge_posteriors.size(); ++i)
+ edge_posteriors[i] = posts[i];
+ }
+ vector<EdgeCoverageInfo> edge2cov(g.edges_.size());
+ ComputeCoverages(g, &edge2cov);
+
+ Lattice src;
+ // currently only dealing with src text, even if the
+ // model supports lattice translation (which it probably does)
+ LatticeTools::ConvertTextToLattice(input, &src);
+ // TODO assert that src is a "real lattice"
+
+ Array2D<prob_t> align(src.size(), ref.size(), prob_t::Zero());
+ for (int c = 0; c < g.edges_.size(); ++c) {
+ const prob_t& p = edge_posteriors[c];
+ const EdgeCoverageInfo& eci = edge2cov[c];
+ for (set<int>::const_iterator si = eci.src_indices.begin();
+ si != eci.src_indices.end(); ++si) {
+ for (set<int>::const_iterator ti = eci.trg_indices.begin();
+ ti != eci.trg_indices.end(); ++ti) {
+ align(*si, *ti) += p;
+ }
+ }
+ }
+ prob_t threshold(0.9);
+ const bool use_soft_threshold = true; // TODO configure
+
+ Array2D<bool> grid(src.size(), ref.size(), false);
+ for (int j = 0; j < ref.size(); ++j) {
+ if (use_soft_threshold) {
+ threshold = prob_t::Zero();
+ for (int i = 0; i < src.size(); ++i)
+ if (align(i, j) > threshold) threshold = align(i, j);
+ //threshold *= prob_t(0.99);
+ }
+ for (int i = 0; i < src.size(); ++i)
+ grid(i, j) = align(i, j) >= threshold;
+ }
+ cerr << align << endl;
+ cerr << grid << endl;
+ SerializePharaohFormat(grid, &cout);
+};
+
diff --git a/src/aligner.h b/src/aligner.h
new file mode 100644
index 00000000..970c72f2
--- /dev/null
+++ b/src/aligner.h
@@ -0,0 +1,23 @@
+#ifndef _ALIGNER_H_
+
+#include <string>
+#include <iostream>
+#include <boost/shared_ptr.hpp>
+#include "array2d.h"
+#include "lattice.h"
+
+class Hypergraph;
+
+struct AlignerTools {
+ static boost::shared_ptr<Array2D<bool> > ReadPharaohAlignmentGrid(const std::string& al);
+ static void SerializePharaohFormat(const Array2D<bool>& alignment, std::ostream* out);
+
+ // assumption: g contains derivations of input/ref and
+ // ONLY input/ref.
+ static void WriteAlignment(const std::string& input,
+ const Lattice& ref,
+ const Hypergraph& g,
+ bool map_instead_of_viterbi = true);
+};
+
+#endif
diff --git a/src/apply_models.cc b/src/apply_models.cc
new file mode 100644
index 00000000..8efb331b
--- /dev/null
+++ b/src/apply_models.cc
@@ -0,0 +1,344 @@
+#include "apply_models.h"
+
+#include <vector>
+#include <algorithm>
+#include <tr1/unordered_map>
+#include <tr1/unordered_set>
+
+#include <boost/functional/hash.hpp>
+
+#include "hg.h"
+#include "ff.h"
+
+using namespace std;
+using namespace std::tr1;
+
+struct Candidate;
+typedef SmallVector JVector;
+typedef vector<Candidate*> CandidateHeap;
+typedef vector<Candidate*> CandidateList;
+
+// life cycle: candidates are created, placed on the heap
+// and retrieved by their estimated cost, when they're
+// retrieved, they're incorporated into the +LM hypergraph
+// where they also know the head node index they are
+// attached to. After they are added to the +LM hypergraph
+// vit_prob_ and est_prob_ fields may be updated as better
+// derivations are found (this happens since the successor's
+// of derivation d may have a better score- they are
+// explored lazily). However, the updates don't happen
+// when a candidate is in the heap so maintaining the heap
+// property is not an issue.
+struct Candidate {
+ int node_index_; // -1 until incorporated
+ // into the +LM forest
+ const Hypergraph::Edge* in_edge_; // in -LM forest
+ Hypergraph::Edge out_edge_;
+ string state_;
+ const JVector j_;
+ prob_t vit_prob_; // these are fixed until the cand
+ // is popped, then they may be updated
+ prob_t est_prob_;
+
+ Candidate(const Hypergraph::Edge& e,
+ const JVector& j,
+ const Hypergraph& out_hg,
+ const vector<CandidateList>& D,
+ const SentenceMetadata& smeta,
+ const ModelSet& models,
+ bool is_goal) :
+ node_index_(-1),
+ in_edge_(&e),
+ j_(j) {
+ InitializeCandidate(out_hg, smeta, D, models, is_goal);
+ }
+
+ // used to query uniqueness
+ Candidate(const Hypergraph::Edge& e,
+ const JVector& j) : in_edge_(&e), j_(j) {}
+
+ bool IsIncorporatedIntoHypergraph() const {
+ return node_index_ >= 0;
+ }
+
+ void InitializeCandidate(const Hypergraph& out_hg,
+ const SentenceMetadata& smeta,
+ const vector<vector<Candidate*> >& D,
+ const ModelSet& models,
+ const bool is_goal) {
+ const Hypergraph::Edge& in_edge = *in_edge_;
+ out_edge_.rule_ = in_edge.rule_;
+ out_edge_.feature_values_ = in_edge.feature_values_;
+ out_edge_.i_ = in_edge.i_;
+ out_edge_.j_ = in_edge.j_;
+ out_edge_.prev_i_ = in_edge.prev_i_;
+ out_edge_.prev_j_ = in_edge.prev_j_;
+ Hypergraph::TailNodeVector& tail = out_edge_.tail_nodes_;
+ tail.resize(j_.size());
+ prob_t p = prob_t::One();
+ // cerr << "\nEstimating application of " << in_edge.rule_->AsString() << endl;
+ for (int i = 0; i < tail.size(); ++i) {
+ const Candidate& ant = *D[in_edge.tail_nodes_[i]][j_[i]];
+ assert(ant.IsIncorporatedIntoHypergraph());
+ tail[i] = ant.node_index_;
+ p *= ant.vit_prob_;
+ }
+ prob_t edge_estimate = prob_t::One();
+ if (is_goal) {
+ assert(tail.size() == 1);
+ const string& ant_state = out_hg.nodes_[tail.front()].state_;
+ models.AddFinalFeatures(ant_state, &out_edge_);
+ } else {
+ models.AddFeaturesToEdge(smeta, out_hg, &out_edge_, &state_, &edge_estimate);
+ }
+ vit_prob_ = out_edge_.edge_prob_ * p;
+ est_prob_ = vit_prob_ * edge_estimate;
+ }
+};
+
+ostream& operator<<(ostream& os, const Candidate& cand) {
+ os << "CAND[";
+ if (!cand.IsIncorporatedIntoHypergraph()) { os << "PENDING "; }
+ else { os << "+LM_node=" << cand.node_index_; }
+ os << " edge=" << cand.in_edge_->id_;
+ os << " j=<";
+ for (int i = 0; i < cand.j_.size(); ++i)
+ os << (i==0 ? "" : " ") << cand.j_[i];
+ os << "> vit=" << log(cand.vit_prob_);
+ os << " est=" << log(cand.est_prob_);
+ return os << ']';
+}
+
+struct HeapCandCompare {
+ bool operator()(const Candidate* l, const Candidate* r) const {
+ return l->est_prob_ < r->est_prob_;
+ }
+};
+
+struct EstProbSorter {
+ bool operator()(const Candidate* l, const Candidate* r) const {
+ return l->est_prob_ > r->est_prob_;
+ }
+};
+
+// the same candidate <edge, j> can be added multiple times if
+// j is multidimensional (if you're going NW in Manhattan, you
+// can first go north, then west, or you can go west then north)
+// this is a hash function on the relevant variables from
+// Candidate to enforce this.
+struct CandidateUniquenessHash {
+ size_t operator()(const Candidate* c) const {
+ size_t x = 5381;
+ x = ((x << 5) + x) ^ c->in_edge_->id_;
+ for (int i = 0; i < c->j_.size(); ++i)
+ x = ((x << 5) + x) ^ c->j_[i];
+ return x;
+ }
+};
+
+struct CandidateUniquenessEquals {
+ bool operator()(const Candidate* a, const Candidate* b) const {
+ return (a->in_edge_ == b->in_edge_) && (a->j_ == b->j_);
+ }
+};
+
+typedef unordered_set<const Candidate*, CandidateUniquenessHash, CandidateUniquenessEquals> UniqueCandidateSet;
+typedef unordered_map<string, Candidate*, boost::hash<string> > State2Node;
+
+class CubePruningRescorer {
+
+public:
+ CubePruningRescorer(const ModelSet& m,
+ const SentenceMetadata& sm,
+ const Hypergraph& i,
+ int pop_limit,
+ Hypergraph* o) :
+ models(m),
+ smeta(sm),
+ in(i),
+ out(*o),
+ D(in.nodes_.size()),
+ pop_limit_(pop_limit) {
+ cerr << " Rescoring forest (cube pruning, pop_limit = " << pop_limit_ << ')' << endl;
+ }
+
+ void Apply() {
+ int num_nodes = in.nodes_.size();
+ int goal_id = num_nodes - 1;
+ int pregoal = goal_id - 1;
+ int every = 1;
+ if (num_nodes > 100) every = 10;
+ assert(in.nodes_[pregoal].out_edges_.size() == 1);
+ cerr << " ";
+ for (int i = 0; i < in.nodes_.size(); ++i) {
+ if (i % every == 0) cerr << '.';
+ KBest(i, i == goal_id);
+ }
+ cerr << endl;
+ cerr << " Best path: " << log(D[goal_id].front()->vit_prob_)
+ << "\t" << log(D[goal_id].front()->est_prob_) << endl;
+ out.PruneUnreachable(D[goal_id].front()->node_index_);
+ FreeAll();
+ }
+
+ private:
+ void FreeAll() {
+ for (int i = 0; i < D.size(); ++i) {
+ CandidateList& D_i = D[i];
+ for (int j = 0; j < D_i.size(); ++j)
+ delete D_i[j];
+ }
+ D.clear();
+ }
+
+ void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) {
+ Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_.rule_, item->out_edge_.tail_nodes_);
+ new_edge->feature_values_ = item->out_edge_.feature_values_;
+ new_edge->edge_prob_ = item->out_edge_.edge_prob_;
+ new_edge->i_ = item->out_edge_.i_;
+ new_edge->j_ = item->out_edge_.j_;
+ new_edge->prev_i_ = item->out_edge_.prev_i_;
+ new_edge->prev_j_ = item->out_edge_.prev_j_;
+ Candidate*& o_item = (*s2n)[item->state_];
+ if (!o_item) o_item = item;
+
+ int& node_id = o_item->node_index_;
+ if (node_id < 0) {
+ Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_, item->state_);
+ node_id = new_node->id_;
+ }
+ Hypergraph::Node* node = &out.nodes_[node_id];
+ out.ConnectEdgeToHeadNode(new_edge, node);
+
+ // update candidate if we have a better derivation
+ // note: the difference between the vit score and the estimated
+ // score is the same for all items with a common residual DP
+ // state
+ if (item->vit_prob_ > o_item->vit_prob_) {
+ assert(o_item->state_ == item->state_); // sanity check!
+ o_item->est_prob_ = item->est_prob_;
+ o_item->vit_prob_ = item->vit_prob_;
+ }
+ if (item != o_item) freelist->push_back(item);
+ }
+
+ void KBest(const int vert_index, const bool is_goal) {
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ UniqueCandidateSet unique_cands;
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, out, D, smeta, models, is_goal));
+ assert(unique_cands.insert(cand.back()).second); // these should all be unique!
+ }
+// cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ // cerr << "POPPED: " << *item << endl;
+ PushSucc(*item, is_goal, &cand, &unique_cands);
+ IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i)
+ D_v[c++] = i->second;
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
+ void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) {
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate query_unique(*item.in_edge_, j);
+ if (cs->count(&query_unique) == 0) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, smeta, models, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ assert(cs->insert(new_cand).second); // insert into uniqueness set, sanity check
+ }
+ }
+ }
+ }
+
+ const ModelSet& models;
+ const SentenceMetadata& smeta;
+ const Hypergraph& in;
+ Hypergraph& out;
+
+ vector<CandidateList> D; // maps nodes in in-HG to the
+ // equivalent nodes (many due to state
+ // splits) in the out-HG.
+ const int pop_limit_;
+};
+
+struct NoPruningRescorer {
+ NoPruningRescorer(const ModelSet& m, const Hypergraph& i, Hypergraph* o) :
+ models(m),
+ in(i),
+ out(*o) {
+ cerr << " Rescoring forest (full intersection)\n";
+ }
+
+ void RescoreNode(const int node_num, const bool is_goal) {
+ }
+
+ void Apply() {
+ int num_nodes = in.nodes_.size();
+ int goal_id = num_nodes - 1;
+ int pregoal = goal_id - 1;
+ int every = 1;
+ if (num_nodes > 100) every = 10;
+ assert(in.nodes_[pregoal].out_edges_.size() == 1);
+ cerr << " ";
+ for (int i = 0; i < in.nodes_.size(); ++i) {
+ if (i % every == 0) cerr << '.';
+ RescoreNode(i, i == goal_id);
+ }
+ cerr << endl;
+ }
+
+ private:
+ const ModelSet& models;
+ const Hypergraph& in;
+ Hypergraph& out;
+};
+
+// each node in the graph has one of these, it keeps track of
+void ApplyModelSet(const Hypergraph& in,
+ const SentenceMetadata& smeta,
+ const ModelSet& models,
+ const PruningConfiguration& config,
+ Hypergraph* out) {
+ int pl = config.pop_limit;
+ if (pl > 100 && in.nodes_.size() > 80000) {
+ cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n";
+ pl = 30;
+ }
+ CubePruningRescorer ma(models, smeta, in, pl, out);
+ ma.Apply();
+}
+
diff --git a/src/apply_models.h b/src/apply_models.h
new file mode 100644
index 00000000..08fce037
--- /dev/null
+++ b/src/apply_models.h
@@ -0,0 +1,20 @@
+#ifndef _APPLY_MODELS_H_
+#define _APPLY_MODELS_H_
+
+struct ModelSet;
+struct Hypergraph;
+struct SentenceMetadata;
+
+struct PruningConfiguration {
+ const int algorithm; // 0 = full intersection, 1 = cube pruning
+ const int pop_limit; // max number of pops off the heap at each node
+ explicit PruningConfiguration(int k) : algorithm(1), pop_limit(k) {}
+};
+
+void ApplyModelSet(const Hypergraph& in,
+ const SentenceMetadata& smeta,
+ const ModelSet& models,
+ const PruningConfiguration& config,
+ Hypergraph* out);
+
+#endif
diff --git a/src/array2d.h b/src/array2d.h
new file mode 100644
index 00000000..09d84d0b
--- /dev/null
+++ b/src/array2d.h
@@ -0,0 +1,171 @@
+#ifndef ARRAY2D_H_
+#define ARRAY2D_H_
+
+#include <iostream>
+#include <algorithm>
+#include <cassert>
+#include <vector>
+#include <string>
+
+template<typename T>
+class Array2D {
+ public:
+ typedef typename std::vector<T>::reference reference;
+ typedef typename std::vector<T>::const_reference const_reference;
+ typedef typename std::vector<T>::iterator iterator;
+ typedef typename std::vector<T>::const_iterator const_iterator;
+ Array2D() : width_(0), height_(0) {}
+ Array2D(int w, int h, const T& d = T()) :
+ width_(w), height_(h), data_(w*h, d) {}
+ Array2D(const Array2D& rhs) :
+ width_(rhs.width_), height_(rhs.height_), data_(rhs.data_) {}
+ void resize(int w, int h, const T& d = T()) {
+ data_.resize(w * h, d);
+ width_ = w;
+ height_ = h;
+ }
+ const Array2D& operator=(const Array2D& rhs) {
+ data_ = rhs.data_;
+ width_ = rhs.width_;
+ height_ = rhs.height_;
+ return *this;
+ }
+ void fill(const T& v) { data_.assign(data_.size(), v); }
+ int width() const { return width_; }
+ int height() const { return height_; }
+ reference operator()(int i, int j) {
+ return data_[offset(i, j)];
+ }
+ void clear() { data_.clear(); width_=0; height_=0; }
+ const_reference operator()(int i, int j) const {
+ return data_[offset(i, j)];
+ }
+ iterator begin_col(int j) {
+ return data_.begin() + offset(0,j);
+ }
+ const_iterator begin_col(int j) const {
+ return data_.begin() + offset(0,j);
+ }
+ iterator end_col(int j) {
+ return data_.begin() + offset(0,j) + width_;
+ }
+ const_iterator end_col(int j) const {
+ return data_.begin() + offset(0,j) + width_;
+ }
+ iterator end() { return data_.end(); }
+ const_iterator end() const { return data_.end(); }
+ const Array2D<T>& operator*=(const T& x) {
+ std::transform(data_.begin(), data_.end(), data_.begin(),
+ std::bind2nd(std::multiplies<T>(), x));
+ }
+ const Array2D<T>& operator/=(const T& x) {
+ std::transform(data_.begin(), data_.end(), data_.begin(),
+ std::bind2nd(std::divides<T>(), x));
+ }
+ const Array2D<T>& operator+=(const Array2D<T>& m) {
+ std::transform(m.data_.begin(), m.data_.end(), data_.begin(), data_.begin(), std::plus<T>());
+ }
+ const Array2D<T>& operator-=(const Array2D<T>& m) {
+ std::transform(m.data_.begin(), m.data_.end(), data_.begin(), data_.begin(), std::minus<T>());
+ }
+
+ private:
+ inline int offset(int i, int j) const {
+ assert(i<width_);
+ assert(j<height_);
+ return i + j * width_;
+ }
+
+ int width_;
+ int height_;
+
+ std::vector<T> data_;
+};
+
+template <typename T>
+Array2D<T> operator*(const Array2D<T>& l, const T& scalar) {
+ Array2D<T> res(l);
+ res *= scalar;
+ return res;
+}
+
+template <typename T>
+Array2D<T> operator*(const T& scalar, const Array2D<T>& l) {
+ Array2D<T> res(l);
+ res *= scalar;
+ return res;
+}
+
+template <typename T>
+Array2D<T> operator/(const Array2D<T>& l, const T& scalar) {
+ Array2D<T> res(l);
+ res /= scalar;
+ return res;
+}
+
+template <typename T>
+Array2D<T> operator+(const Array2D<T>& l, const Array2D<T>& r) {
+ Array2D<T> res(l);
+ res += r;
+ return res;
+}
+
+template <typename T>
+Array2D<T> operator-(const Array2D<T>& l, const Array2D<T>& r) {
+ Array2D<T> res(l);
+ res -= r;
+ return res;
+}
+
+template <typename T>
+inline std::ostream& operator<<(std::ostream& os, const Array2D<T>& m) {
+ for (int i=0; i<m.width(); ++i) {
+ for (int j=0; j<m.height(); ++j)
+ os << '\t' << m(i,j);
+ os << '\n';
+ }
+ return os;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const Array2D<bool>& m) {
+ os << ' ';
+ for (int j=0; j<m.height(); ++j)
+ os << (j%10);
+ os << "\n";
+ for (int i=0; i<m.width(); ++i) {
+ os << (i%10);
+ for (int j=0; j<m.height(); ++j)
+ os << (m(i,j) ? '*' : '.');
+ os << (i%10) << "\n";
+ }
+ os << ' ';
+ for (int j=0; j<m.height(); ++j)
+ os << (j%10);
+ os << "\n";
+ return os;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const Array2D<std::vector<bool> >& m) {
+ os << ' ';
+ for (int j=0; j<m.height(); ++j)
+ os << (j%10) << "\t";
+ os << "\n";
+ for (int i=0; i<m.width(); ++i) {
+ os << (i%10);
+ for (int j=0; j<m.height(); ++j) {
+ const std::vector<bool>& ar = m(i,j);
+ for (int k=0; k<ar.size(); ++k)
+ os << (ar[k] ? '*' : '.');
+ }
+ os << "\t";
+ os << (i%10) << "\n";
+ }
+ os << ' ';
+ for (int j=0; j<m.height(); ++j)
+ os << (j%10) << "\t";
+ os << "\n";
+ return os;
+}
+
+#endif
+
diff --git a/src/bottom_up_parser.cc b/src/bottom_up_parser.cc
new file mode 100644
index 00000000..349ed2de
--- /dev/null
+++ b/src/bottom_up_parser.cc
@@ -0,0 +1,260 @@
+#include "bottom_up_parser.h"
+
+#include <map>
+
+#include "hg.h"
+#include "array2d.h"
+#include "tdict.h"
+
+using namespace std;
+
+class ActiveChart;
+class PassiveChart {
+ public:
+ PassiveChart(const string& goal,
+ const vector<GrammarPtr>& grammars,
+ const Lattice& input,
+ Hypergraph* forest);
+ ~PassiveChart();
+
+ inline const vector<int>& operator()(int i, int j) const { return chart_(i,j); }
+ bool Parse();
+ inline int size() const { return chart_.width(); }
+ inline bool GoalFound() const { return goal_idx_ >= 0; }
+ inline int GetGoalIndex() const { return goal_idx_; }
+
+ private:
+ void ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail);
+ void ApplyRule(const int i, const int j, TRulePtr r, const Hypergraph::TailNodeVector& ant_nodes);
+ void ApplyUnaryRules(const int i, const int j);
+
+ const vector<GrammarPtr>& grammars_;
+ const Lattice& input_;
+ Hypergraph* forest_;
+ Array2D<vector<int> > chart_; // chart_(i,j) is the list of nodes derived spanning i,j
+ typedef map<int, int> Cat2NodeMap;
+ Array2D<Cat2NodeMap> nodemap_;
+ vector<ActiveChart*> act_chart_;
+ const WordID goal_cat_; // category that is being searched for at [0,n]
+ TRulePtr goal_rule_;
+ int goal_idx_; // index of goal node, if found
+
+ static WordID kGOAL; // [Goal]
+};
+
+WordID PassiveChart::kGOAL = 0;
+
+class ActiveChart {
+ public:
+ ActiveChart(const Hypergraph* hg, const PassiveChart& psv_chart) :
+ hg_(hg),
+ act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {}
+
+ struct ActiveItem {
+ ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, double lcost) :
+ gptr_(g), ant_nodes_(a), lattice_cost(lcost) {}
+ explicit ActiveItem(const GrammarIter* g) :
+ gptr_(g), ant_nodes_(), lattice_cost() {}
+
+ void ExtendTerminal(int symbol, double src_cost, vector<ActiveItem>* out_cell) const {
+ const GrammarIter* ni = gptr_->Extend(symbol);
+ if (ni) out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost));
+ }
+ void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector<ActiveItem>* out_cell) const {
+ int symbol = hg->nodes_[node_index].cat_;
+ const GrammarIter* ni = gptr_->Extend(symbol);
+ if (!ni) return;
+ Hypergraph::TailNodeVector na(ant_nodes_.size() + 1);
+ for (int i = 0; i < ant_nodes_.size(); ++i)
+ na[i] = ant_nodes_[i];
+ na[ant_nodes_.size()] = node_index;
+ out_cell->push_back(ActiveItem(ni, na, lattice_cost));
+ }
+
+ const GrammarIter* gptr_;
+ Hypergraph::TailNodeVector ant_nodes_;
+ double lattice_cost; // TODO? use SparseVector<double>
+ };
+
+ inline const vector<ActiveItem>& operator()(int i, int j) const { return act_chart_(i,j); }
+ void SeedActiveChart(const Grammar& g) {
+ int size = act_chart_.width();
+ for (int i = 0; i < size; ++i)
+ if (g.HasRuleForSpan(i,i))
+ act_chart_(i,i).push_back(ActiveItem(g.GetRoot()));
+ }
+
+ void ExtendActiveItems(int i, int k, int j) {
+ //cerr << " LOOK(" << i << "," << k << ") for completed items in (" << k << "," << j << ")\n";
+ vector<ActiveItem>& cell = act_chart_(i,j);
+ const vector<ActiveItem>& icell = act_chart_(i,k);
+ const vector<int>& idxs = psv_chart_(k, j);
+ //if (!idxs.empty()) { cerr << "FOUND IN (" << k << "," << j << ")\n"; }
+ for (vector<ActiveItem>::const_iterator di = icell.begin(); di != icell.end(); ++di) {
+ for (vector<int>::const_iterator ni = idxs.begin(); ni != idxs.end(); ++ni) {
+ di->ExtendNonTerminal(hg_, *ni, &cell);
+ }
+ }
+ }
+
+ void AdvanceDotsForAllItemsInCell(int i, int j, const vector<vector<LatticeArc> >& input) {
+ //cerr << "ADVANCE(" << i << "," << j << ")\n";
+ for (int k=i+1; k < j; ++k)
+ ExtendActiveItems(i, k, j);
+
+ const vector<LatticeArc>& out_arcs = input[j-1];
+ for (vector<LatticeArc>::const_iterator ai = out_arcs.begin();
+ ai != out_arcs.end(); ++ai) {
+ const WordID& f = ai->label;
+ const double& c = ai->cost;
+ const int& len = ai->dist2next;
+ //VLOG(1) << "F: " << TD::Convert(f) << endl;
+ const vector<ActiveItem>& ec = act_chart_(i, j-1);
+ for (vector<ActiveItem>::const_iterator di = ec.begin(); di != ec.end(); ++di)
+ di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1));
+ }
+ }
+
+ private:
+ const Hypergraph* hg_;
+ Array2D<vector<ActiveItem> > act_chart_;
+ const PassiveChart& psv_chart_;
+};
+
+PassiveChart::PassiveChart(const string& goal,
+ const vector<GrammarPtr>& grammars,
+ const Lattice& input,
+ Hypergraph* forest) :
+ grammars_(grammars),
+ input_(input),
+ forest_(forest),
+ chart_(input.size()+1, input.size()+1),
+ nodemap_(input.size()+1, input.size()+1),
+ goal_cat_(TD::Convert(goal) * -1),
+ goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")),
+ goal_idx_(-1) {
+ act_chart_.resize(grammars_.size());
+ for (int i = 0; i < grammars_.size(); ++i)
+ act_chart_[i] = new ActiveChart(forest, *this);
+ if (!kGOAL) kGOAL = TD::Convert("Goal") * -1;
+ cerr << " Goal category: [" << goal << ']' << endl;
+}
+
+void PassiveChart::ApplyRule(const int i, const int j, TRulePtr r, const Hypergraph::TailNodeVector& ant_nodes) {
+ Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
+ new_edge->prev_i_ = r->prev_i;
+ new_edge->prev_j_ = r->prev_j;
+ new_edge->i_ = i;
+ new_edge->j_ = j;
+ new_edge->feature_values_ = r->GetFeatureValues();
+ Cat2NodeMap& c2n = nodemap_(i,j);
+ const bool is_goal = (r->GetLHS() == kGOAL);
+ const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS());
+ Hypergraph::Node* node = NULL;
+ if (ni == c2n.end()) {
+ node = forest_->AddNode(r->GetLHS(), "");
+ c2n[r->GetLHS()] = node->id_;
+ if (is_goal) {
+ assert(goal_idx_ == -1);
+ goal_idx_ = node->id_;
+ } else {
+ chart_(i,j).push_back(node->id_);
+ }
+ } else {
+ node = &forest_->nodes_[ni->second];
+ }
+ forest_->ConnectEdgeToHeadNode(new_edge, node);
+}
+
+void PassiveChart::ApplyRules(const int i,
+ const int j,
+ const RuleBin* rules,
+ const Hypergraph::TailNodeVector& tail) {
+ const int n = rules->GetNumRules();
+ for (int k = 0; k < n; ++k)
+ ApplyRule(i, j, rules->GetIthRule(k), tail);
+}
+
+void PassiveChart::ApplyUnaryRules(const int i, const int j) {
+ const vector<int>& nodes = chart_(i,j); // reference is important!
+ for (int gi = 0; gi < grammars_.size(); ++gi) {
+ if (!grammars_[gi]->HasRuleForSpan(i,j)) continue;
+ for (int di = 0; di < nodes.size(); ++di) {
+ const WordID& cat = forest_->nodes_[nodes[di]].cat_;
+ const vector<TRulePtr>& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat);
+ for (int ri = 0; ri < unaries.size(); ++ri) {
+ // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl;
+ const Hypergraph::TailNodeVector ant(1, nodes[di]);
+ ApplyRule(i, j, unaries[ri], ant); // may update nodes
+ }
+ }
+ }
+}
+
+bool PassiveChart::Parse() {
+ forest_->nodes_.reserve(input_.size() * input_.size() * 2);
+ forest_->edges_.reserve(input_.size() * input_.size() * 1000); // TODO: reservation??
+ goal_idx_ = -1;
+ for (int gi = 0; gi < grammars_.size(); ++gi)
+ act_chart_[gi]->SeedActiveChart(*grammars_[gi]);
+
+ cerr << " ";
+ for (int l=1; l<input_.size()+1; ++l) {
+ cerr << '.';
+ for (int i=0; i<input_.size() + 1 - l; ++i) {
+ int j = i + l;
+ for (int gi = 0; gi < grammars_.size(); ++gi) {
+ const Grammar& g = *grammars_[gi];
+ if (g.HasRuleForSpan(i, j)) {
+ act_chart_[gi]->AdvanceDotsForAllItemsInCell(i, j, input_);
+
+ const vector<ActiveChart::ActiveItem>& cell = (*act_chart_[gi])(i,j);
+ for (vector<ActiveChart::ActiveItem>::const_iterator ai = cell.begin();
+ ai != cell.end(); ++ai) {
+ const RuleBin* rules = (ai->gptr_->GetRules());
+ if (!rules) continue;
+ ApplyRules(i, j, rules, ai->ant_nodes_);
+ }
+ }
+ }
+ ApplyUnaryRules(i,j);
+
+ for (int gi = 0; gi < grammars_.size(); ++gi) {
+ const Grammar& g = *grammars_[gi];
+ // deal with non-terminals that were just proved
+ if (g.HasRuleForSpan(i, j))
+ act_chart_[gi]->ExtendActiveItems(i, i, j);
+ }
+ }
+ const vector<int>& dh = chart_(0, input_.size());
+ for (int di = 0; di < dh.size(); ++di) {
+ const Hypergraph::Node& node = forest_->nodes_[dh[di]];
+ if (node.cat_ == goal_cat_) {
+ Hypergraph::TailNodeVector ant(1, node.id_);
+ ApplyRule(0, input_.size(), goal_rule_, ant);
+ }
+ }
+ }
+ cerr << endl;
+
+ if (GoalFound())
+ forest_->PruneUnreachable(forest_->nodes_.size() - 1);
+ return GoalFound();
+}
+
+PassiveChart::~PassiveChart() {
+ for (int i = 0; i < act_chart_.size(); ++i)
+ delete act_chart_[i];
+}
+
+ExhaustiveBottomUpParser::ExhaustiveBottomUpParser(
+ const string& goal_sym,
+ const vector<GrammarPtr>& grammars) :
+ goal_sym_(goal_sym),
+ grammars_(grammars) {}
+
+bool ExhaustiveBottomUpParser::Parse(const Lattice& input,
+ Hypergraph* forest) const {
+ PassiveChart chart(goal_sym_, grammars_, input, forest);
+ return chart.Parse();
+}
diff --git a/src/bottom_up_parser.h b/src/bottom_up_parser.h
new file mode 100644
index 00000000..546bfb54
--- /dev/null
+++ b/src/bottom_up_parser.h
@@ -0,0 +1,27 @@
+#ifndef _BOTTOM_UP_PARSER_H_
+#define _BOTTOM_UP_PARSER_H_
+
+#include <vector>
+#include <string>
+
+#include "lattice.h"
+#include "grammar.h"
+
+class Hypergraph;
+
+class ExhaustiveBottomUpParser {
+ public:
+ ExhaustiveBottomUpParser(const std::string& goal_sym,
+ const std::vector<GrammarPtr>& grammars);
+
+ // returns true if goal reached spanning the full input
+ // forest contains the full (i.e., unpruned) parse forest
+ bool Parse(const Lattice& input,
+ Hypergraph* forest) const;
+
+ private:
+ const std::string goal_sym_;
+ const std::vector<GrammarPtr> grammars_;
+};
+
+#endif
diff --git a/src/cdec.cc b/src/cdec.cc
new file mode 100644
index 00000000..c5780cef
--- /dev/null
+++ b/src/cdec.cc
@@ -0,0 +1,474 @@
+#include <iostream>
+#include <fstream>
+#include <tr1/unordered_map>
+#include <tr1/unordered_set>
+
+#include <boost/shared_ptr.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "timing_stats.h"
+#include "translator.h"
+#include "phrasebased_translator.h"
+#include "aligner.h"
+#include "stringlib.h"
+#include "forest_writer.h"
+#include "filelib.h"
+#include "sampler.h"
+#include "sparse_vector.h"
+#include "lexcrf.h"
+#include "weights.h"
+#include "tdict.h"
+#include "ff.h"
+#include "ff_factory.h"
+#include "hg_intersect.h"
+#include "apply_models.h"
+#include "viterbi.h"
+#include "kbest.h"
+#include "inside_outside.h"
+#include "exp_semiring.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+using namespace std::tr1;
+using boost::shared_ptr;
+namespace po = boost::program_options;
+
+// some globals ...
+boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;
+
+namespace Hack { void MaxTrans(const Hypergraph& in, int beam_size); }
+
+void ShowBanner() {
+ cerr << "cdec v1.0 (c) 2009 by Chris Dyer\n";
+}
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("formalism,f",po::value<string>()->default_value("scfg"),"Translation formalism; values include SCFG, FST, or PB. Specify LexicalCRF for experimental unsupervised CRF word alignment")
+ ("input,i",po::value<string>()->default_value("-"),"Source file")
+ ("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
+ ("weights,w",po::value<string>(),"Feature weights file")
+ ("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)")
+ ("list_feature_functions,L","List available feature functions")
+ ("add_pass_through_rules,P","Add rules to translate OOV words as themselves")
+ ("k_best,k",po::value<int>(),"Extract the k best derivations")
+ ("unique_k_best,r", "Unique k-best translation list")
+ ("aligner,a", "Run as a word/phrase aligner (src & ref required)")
+ ("cubepruning_pop_limit,K",po::value<int>()->default_value(200), "Max number of pops from the candidate heap at each node")
+ ("goal",po::value<string>()->default_value("S"),"Goal symbol (SCFG & FST)")
+ ("scfg_extra_glue_grammar", po::value<string>(), "Extra glue grammar file (Glue grammars apply when i=0 but have no other span restrictions)")
+ ("scfg_no_hiero_glue_grammar,n", "No Hiero glue grammar (nb. by default the SCFG decoder adds Hiero glue rules)")
+ ("scfg_default_nt,d",po::value<string>()->default_value("X"),"Default non-terminal symbol in SCFG")
+ ("scfg_max_span_limit,S",po::value<int>()->default_value(10),"Maximum non-terminal span limit (except \"glue\" grammar)")
+ ("show_tree_structure,T", "Show the Viterbi derivation structure")
+ ("show_expected_length", "Show the expected translation length under the model")
+ ("show_partition,z", "Compute and show the partition (inside score)")
+ ("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file")
+ ("graphviz","Show (constrained) translation forest in GraphViz format")
+ ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
+ ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
+ ("pb_max_distortion,D", po::value<int>()->default_value(4), "Phrase-based decoder: maximum distortion")
+ ("gradient,G","Compute d log p(e|f) / d lambda_i and write to STDOUT (src & ref required)")
+ ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
+ ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
+ ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
+ ("forest_output,O",po::value<string>(),"Directory to write forests to")
+ ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc.");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config,c", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ const string cfg = (*conf)["config"].as<string>();
+ cerr << "Configuration file: " << cfg << endl;
+ ifstream config(cfg.c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("list_feature_functions")) {
+ cerr << "Available feature functions (specify with -F):\n";
+ global_ff_registry->DisplayList();
+ cerr << endl;
+ exit(1);
+ }
+
+ if (conf->count("help") || conf->count("grammar") == 0) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+
+ const string formalism = LowercaseString((*conf)["formalism"].as<string>());
+ if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb") {
+ cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', or 'lexcrf'\n";
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+// TODO move out of cdec into some sampling decoder file
+void SampleRecurse(const Hypergraph& hg, const vector<SampleSet>& ss, int n, vector<WordID>* out) {
+ const SampleSet& s = ss[n];
+ int i = rng->SelectSample(s);
+ const Hypergraph::Edge& edge = hg.edges_[hg.nodes_[n].in_edges_[i]];
+ vector<vector<WordID> > ants(edge.tail_nodes_.size());
+ for (int j = 0; j < ants.size(); ++j)
+ SampleRecurse(hg, ss, edge.tail_nodes_[j], &ants[j]);
+
+ vector<const vector<WordID>*> pants(ants.size());
+ for (int j = 0; j < ants.size(); ++j) pants[j] = &ants[j];
+ edge.rule_->ESubstitute(pants, out);
+}
+
+struct SampleSort {
+ bool operator()(const pair<int,string>& a, const pair<int,string>& b) const {
+ return a.first > b.first;
+ }
+};
+
+// TODO move out of cdec into some sampling decoder file
+void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) {
+ unordered_map<string, int, boost::hash<string> > m;
+ hg->PushWeightsToGoal();
+ const int num_nodes = hg->nodes_.size();
+ vector<SampleSet> ss(num_nodes);
+ for (int i = 0; i < num_nodes; ++i) {
+ SampleSet& s = ss[i];
+ const vector<int>& in_edges = hg->nodes_[i].in_edges_;
+ for (int j = 0; j < in_edges.size(); ++j) {
+ s.add(hg->edges_[in_edges[j]].edge_prob_);
+ }
+ }
+ for (int i = 0; i < samples; ++i) {
+ vector<WordID> yield;
+ SampleRecurse(*hg, ss, hg->nodes_.size() - 1, &yield);
+ const string trans = TD::GetString(yield);
+ ++m[trans];
+ }
+ vector<pair<int, string> > dist;
+ for (unordered_map<string, int, boost::hash<string> >::iterator i = m.begin();
+ i != m.end(); ++i) {
+ dist.push_back(make_pair(i->second, i->first));
+ }
+ sort(dist.begin(), dist.end(), SampleSort());
+ if (k) {
+ for (int i = 0; i < k; ++i)
+ cout << dist[i].first << " ||| " << dist[i].second << endl;
+ } else {
+ cout << dist[0].second << endl;
+ }
+}
+
+// TODO decoder output should probably be moved to another file
+void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique) {
+ if (unique) {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);
+ for (int i = 0; i < k; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =
+ kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+ if (!d) break;
+ cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl;
+ }
+ } else {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k);
+ for (int i = 0; i < k; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+ if (!d) break;
+ cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl;
+ }
+ }
+}
+
+struct ELengthWeightFunction {
+ double operator()(const Hypergraph::Edge& e) const {
+ return e.rule_->ELength() - e.rule_->Arity();
+ }
+};
+
+
+struct TRPHash {
+ size_t operator()(const TRulePtr& o) const { return reinterpret_cast<size_t>(o.get()); }
+};
+static void ExtractRulesDedupe(const Hypergraph& hg, ostream* os) {
+ static unordered_set<TRulePtr, TRPHash> written;
+ for (int i = 0; i < hg.edges_.size(); ++i) {
+ const TRulePtr& rule = hg.edges_[i].rule_;
+ if (written.insert(rule).second) {
+ (*os) << rule->AsString() << endl;
+ }
+ }
+}
+
+void register_feature_functions();
+
+int main(int argc, char** argv) {
+ global_ff_registry.reset(new FFRegistry);
+ register_feature_functions();
+ ShowBanner();
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ const bool write_gradient = conf.count("gradient");
+ const bool feature_expectations = conf.count("feature_expectations");
+ if (write_gradient && feature_expectations) {
+ cerr << "You can only specify --gradient or --feature_expectations, not both!\n";
+ exit(1);
+ }
+ const bool output_training_vector = (write_gradient || feature_expectations);
+
+ boost::shared_ptr<Translator> translator;
+ const string formalism = LowercaseString(conf["formalism"].as<string>());
+ if (formalism == "scfg")
+ translator.reset(new SCFGTranslator(conf));
+ else if (formalism == "fst")
+ translator.reset(new FSTTranslator(conf));
+ else if (formalism == "pb")
+ translator.reset(new PhraseBasedTranslator(conf));
+ else if (formalism == "lexcrf")
+ translator.reset(new LexicalCRF(conf));
+ else
+ assert(!"error");
+
+ vector<double> wv;
+ Weights w;
+ if (conf.count("weights")) {
+ w.InitFromFile(conf["weights"].as<string>());
+ wv.resize(FD::NumFeats());
+ w.InitVector(&wv);
+ }
+
+ // set up additional scoring features
+ vector<shared_ptr<FeatureFunction> > pffs;
+ vector<const FeatureFunction*> late_ffs;
+ if (conf.count("feature_function") > 0) {
+ const vector<string>& add_ffs = conf["feature_function"].as<vector<string> >();
+ for (int i = 0; i < add_ffs.size(); ++i) {
+ string ff, param;
+ SplitCommandAndParam(add_ffs[i], &ff, &param);
+ if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n";
+ else cerr << " (no config parameters)\n";
+ shared_ptr<FeatureFunction> pff = global_ff_registry->Create(ff, param);
+ if (!pff) { exit(1); }
+ // TODO check that multiple features aren't trying to set the same fid
+ pffs.push_back(pff);
+ late_ffs.push_back(pff.get());
+ }
+ }
+ ModelSet late_models(wv, late_ffs);
+
+ const int sample_max_trans = conf.count("max_translation_sample") ?
+ conf["max_translation_sample"].as<int>() : 0;
+ if (sample_max_trans)
+ rng.reset(new RandomNumberGenerator<boost::mt19937>);
+ const bool aligner_mode = conf.count("aligner");
+ const bool minimal_forests = conf.count("minimal_forests");
+ const bool graphviz = conf.count("graphviz");
+ const bool encode_b64 = conf["vector_format"].as<string>() == "b64";
+ const bool kbest = conf.count("k_best");
+ const bool unique_kbest = conf.count("unique_k_best");
+ shared_ptr<WriteFile> extract_file;
+ if (conf.count("extract_rules"))
+ extract_file.reset(new WriteFile(conf["extract_rules"].as<string>()));
+
+ int combine_size = conf["combine_size"].as<int>();
+ if (combine_size < 1) combine_size = 1;
+ const string input = conf["input"].as<string>();
+ cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl;
+ ReadFile in_read(input);
+ istream *in = in_read.stream();
+ assert(*in);
+
+ SparseVector<double> acc_vec; // accumulate gradient
+ double acc_obj = 0; // accumulate objective
+ int g_count = 0; // number of gradient pieces computed
+ int sent_id = -1; // line counter
+
+ while(*in) {
+ Timer::Summarize();
+ ++sent_id;
+ string buf;
+ getline(*in, buf);
+ if (buf.empty()) continue;
+ map<string, string> sgml;
+ ProcessAndStripSGML(&buf, &sgml);
+ if (sgml.find("id") != sgml.end())
+ sent_id = atoi(sgml["id"].c_str());
+
+ cerr << "\nINPUT: ";
+ if (buf.size() < 100)
+ cerr << buf << endl;
+ else {
+ size_t x = buf.rfind(" ", 100);
+ if (x == string::npos) x = 100;
+ cerr << buf.substr(0, x) << " ..." << endl;
+ }
+ cerr << " id = " << sent_id << endl;
+ string to_translate;
+ Lattice ref;
+ ParseTranslatorInputLattice(buf, &to_translate, &ref);
+ const bool has_ref = ref.size() > 0;
+ SentenceMetadata smeta(sent_id, ref);
+ const bool hadoop_counters = (write_gradient);
+ Hypergraph forest; // -LM forest
+ Timer t("Translation");
+ if (!translator->Translate(to_translate, &smeta, wv, &forest)) {
+ cerr << " NO PARSE FOUND.\n";
+ if (hadoop_counters)
+ cerr << "reporter:counter:UserCounters,FParseFailed,1" << endl;
+ cout << endl << flush;
+ continue;
+ }
+ cerr << " -LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
+ cerr << " -LM forest (paths): " << forest.NumberOfPaths() << endl;
+ if (conf.count("show_expected_length")) {
+ const PRPair<double, double> res =
+ Inside<PRPair<double, double>,
+ PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest);
+ cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl;
+ }
+ if (conf.count("show_partition")) {
+ const prob_t z = Inside<prob_t, EdgeProb>(forest);
+ cerr << " -LM partition log(Z): " << log(z) << endl;
+ }
+ if (extract_file)
+ ExtractRulesDedupe(forest, extract_file->stream());
+ vector<WordID> trans;
+ const prob_t vs = ViterbiESentence(forest, &trans);
+ cerr << " -LM Viterbi: " << TD::GetString(trans) << endl;
+ if (conf.count("show_tree_structure"))
+ cerr << " -LM tree: " << ViterbiETree(forest) << endl;;
+ cerr << " -LM Viterbi: " << log(vs) << endl;
+
+ bool has_late_models = !late_models.empty();
+ if (has_late_models) {
+ forest.Reweight(wv);
+ forest.SortInEdgesByEdgeWeights();
+ Hypergraph lm_forest;
+ int cubepruning_pop_limit = conf["cubepruning_pop_limit"].as<int>();
+ ApplyModelSet(forest,
+ smeta,
+ late_models,
+ PruningConfiguration(cubepruning_pop_limit),
+ &lm_forest);
+ forest.swap(lm_forest);
+ forest.Reweight(wv);
+ trans.clear();
+ ViterbiESentence(forest, &trans);
+ cerr << " +LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
+ cerr << " +LM forest (paths): " << forest.NumberOfPaths() << endl;
+ cerr << " +LM Viterbi: " << TD::GetString(trans) << endl;
+ }
+ if (conf.count("forest_output") && !has_ref) {
+ ForestWriter writer(conf["forest_output"].as<string>(), sent_id);
+ assert(writer.Write(forest, minimal_forests));
+ }
+
+ if (sample_max_trans) {
+ MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0);
+ } else {
+ if (kbest) {
+ DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest);
+ } else {
+ if (!graphviz && !has_ref) {
+ cout << TD::GetString(trans) << endl << flush;
+ }
+ }
+ }
+
+ const int max_trans_beam_size = conf.count("max_translation_beam") ?
+ conf["max_translation_beam"].as<int>() : 0;
+ if (max_trans_beam_size) {
+ Hack::MaxTrans(forest, max_trans_beam_size);
+ continue;
+ }
+
+ if (graphviz && !has_ref) forest.PrintGraphviz();
+
+ // the following are only used if write_gradient is true!
+ SparseVector<double> full_exp, ref_exp, gradient;
+ double log_z = 0, log_ref_z = 0;
+ if (write_gradient)
+ log_z = log(
+ InsideOutside<prob_t, EdgeProb, SparseVector<double>, EdgeFeaturesWeightFunction>(forest, &full_exp));
+
+ if (has_ref) {
+ if (HG::Intersect(ref, &forest)) {
+ cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
+ cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl;
+ forest.Reweight(wv);
+ cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
+ if (hadoop_counters)
+ cerr << "reporter:counter:UserCounters,SentencePairsParsed,1" << endl;
+ if (conf.count("show_partition")) {
+ const prob_t z = Inside<prob_t, EdgeProb>(forest);
+ cerr << " Contst. partition log(Z): " << log(z) << endl;
+ }
+ //DumpKBest(sent_id, forest, 1000);
+ if (conf.count("forest_output")) {
+ ForestWriter writer(conf["forest_output"].as<string>(), sent_id);
+ assert(writer.Write(forest, minimal_forests));
+ }
+ if (aligner_mode && !output_training_vector)
+ AlignerTools::WriteAlignment(to_translate, ref, forest);
+ if (write_gradient) {
+ log_ref_z = log(
+ InsideOutside<prob_t, EdgeProb, SparseVector<double>, EdgeFeaturesWeightFunction>(forest, &ref_exp));
+ if (log_z < log_ref_z) {
+ cerr << "DIFF. ERR! log_z < log_ref_z: " << log_z << " " << log_ref_z << endl;
+ exit(1);
+ }
+ //cerr << "FULL: " << full_exp << endl;
+ //cerr << " REF: " << ref_exp << endl;
+ ref_exp -= full_exp;
+ acc_vec += ref_exp;
+ acc_obj += (log_z - log_ref_z);
+ }
+ if (feature_expectations) {
+ acc_obj += log(
+ InsideOutside<prob_t, EdgeProb, SparseVector<double>, EdgeFeaturesWeightFunction>(forest, &ref_exp));
+ acc_vec += ref_exp;
+ }
+
+ if (output_training_vector) {
+ ++g_count;
+ if (g_count % combine_size == 0) {
+ if (encode_b64) {
+ cout << "0\t";
+ B64::Encode(acc_obj, acc_vec, &cout);
+ cout << endl << flush;
+ } else {
+ cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush;
+ }
+ acc_vec.clear();
+ acc_obj = 0;
+ }
+ }
+ if (conf.count("graphviz")) forest.PrintGraphviz();
+ } else {
+ cerr << " REFERENCE UNREACHABLE.\n";
+ if (write_gradient) {
+ if (hadoop_counters)
+ cerr << "reporter:counter:UserCounters,EFParseFailed,1" << endl;
+ cout << endl << flush;
+ }
+ }
+ }
+ }
+ if (output_training_vector && !acc_vec.empty()) {
+ if (encode_b64) {
+ cout << "0\t";
+ B64::Encode(acc_obj, acc_vec, &cout);
+ cout << endl << flush;
+ } else {
+ cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush;
+ }
+ }
+}
+
diff --git a/src/cdec_ff.cc b/src/cdec_ff.cc
new file mode 100644
index 00000000..86abfb4a
--- /dev/null
+++ b/src/cdec_ff.cc
@@ -0,0 +1,18 @@
+#include <boost/shared_ptr.hpp>
+
+#include "ff.h"
+#include "lm_ff.h"
+#include "ff_factory.h"
+#include "ff_wordalign.h"
+
+boost::shared_ptr<FFRegistry> global_ff_registry;
+
+void register_feature_functions() {
+ global_ff_registry->Register("LanguageModel", new FFFactory<LanguageModel>);
+ global_ff_registry->Register("WordPenalty", new FFFactory<WordPenalty>);
+ global_ff_registry->Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>);
+ global_ff_registry->Register("MarkovJump", new FFFactory<MarkovJump>);
+ global_ff_registry->Register("BlunsomSynchronousParseHack", new FFFactory<BlunsomSynchronousParseHack>);
+ global_ff_registry->Register("AlignerResults", new FFFactory<AlignerResults>);
+};
+
diff --git a/src/collapse_weights.cc b/src/collapse_weights.cc
new file mode 100644
index 00000000..5e0f3f72
--- /dev/null
+++ b/src/collapse_weights.cc
@@ -0,0 +1,102 @@
+#include <iostream>
+#include <fstream>
+#include <tr1/unordered_map>
+
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+#include <boost/functional/hash.hpp>
+
+#include "prob.h"
+#include "filelib.h"
+#include "trule.h"
+#include "weights.h"
+
+namespace po = boost::program_options;
+using namespace std;
+
+typedef std::tr1::unordered_map<vector<WordID>, prob_t, boost::hash<vector<WordID> > > MarginalMap;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("grammar,g", po::value<string>(), "Grammar file")
+ ("weights,w", po::value<string>(), "Weights file");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config,c", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ const string cfg = (*conf)["config"].as<string>();
+ cerr << "Configuration file: " << cfg << endl;
+ ifstream config(cfg.c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help") || !conf->count("grammar") || !conf->count("weights")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ const string wfile = conf["weights"].as<string>();
+ const string gfile = conf["grammar"].as<string>();
+ Weights wm;
+ wm.InitFromFile(wfile);
+ vector<double> w;
+ wm.InitVector(&w);
+ MarginalMap e_tots;
+ MarginalMap f_tots;
+ prob_t tot;
+ {
+ ReadFile rf(gfile);
+ assert(*rf.stream());
+ istream& in = *rf.stream();
+ cerr << "Computing marginals...\n";
+ int lc = 0;
+ while(in) {
+ string line;
+ getline(in, line);
+ ++lc;
+ if (line.empty()) continue;
+ TRule tr(line, true);
+ if (tr.GetFeatureValues().empty())
+ cerr << "Line " << lc << ": empty features - may introduce bias\n";
+ prob_t prob;
+ prob.logeq(tr.GetFeatureValues().dot(w));
+ e_tots[tr.e_] += prob;
+ f_tots[tr.f_] += prob;
+ tot += prob;
+ }
+ }
+ bool normalized = (fabs(log(tot)) < 0.001);
+ cerr << "Total: " << tot << (normalized ? " [normalized]" : " [scaled]") << endl;
+ ReadFile rf(gfile);
+ istream&in = *rf.stream();
+ while(in) {
+ string line;
+ getline(in, line);
+ if (line.empty()) continue;
+ TRule tr(line, true);
+ const double lp = tr.GetFeatureValues().dot(w);
+ if (isinf(lp)) { continue; }
+ tr.scores_.clear();
+
+ cout << tr.AsString() << " ||| F_and_E=" << lp - log(tot);
+ if (!normalized) {
+ cout << ";ZF_and_E=" << lp;
+ }
+ cout << ";F_given_E=" << lp - log(e_tots[tr.e_])
+ << ";E_given_F=" << lp - log(f_tots[tr.f_]) << endl;
+ }
+ return 0;
+}
+
diff --git a/src/dict.h b/src/dict.h
new file mode 100644
index 00000000..bae9debe
--- /dev/null
+++ b/src/dict.h
@@ -0,0 +1,40 @@
+#ifndef DICT_H_
+#define DICT_H_
+
+#include <cassert>
+#include <cstring>
+#include <tr1/unordered_map>
+#include <string>
+#include <vector>
+
+#include <boost/functional/hash.hpp>
+
+#include "wordid.h"
+
+class Dict {
+ typedef std::tr1::unordered_map<std::string, WordID, boost::hash<std::string> > Map;
+ public:
+ Dict() : b0_("<bad0>") { words_.reserve(1000); }
+ inline int max() const { return words_.size(); }
+ inline WordID Convert(const std::string& word) {
+ Map::iterator i = d_.find(word);
+ if (i == d_.end()) {
+ words_.push_back(word);
+ d_[word] = words_.size();
+ return words_.size();
+ } else {
+ return i->second;
+ }
+ }
+ inline const std::string& Convert(const WordID& id) const {
+ if (id == 0) return b0_;
+ assert(id <= words_.size());
+ return words_[id-1];
+ }
+ private:
+ const std::string b0_;
+ std::vector<std::string> words_;
+ Map d_;
+};
+
+#endif
diff --git a/src/dict_test.cc b/src/dict_test.cc
new file mode 100644
index 00000000..5c5d84f0
--- /dev/null
+++ b/src/dict_test.cc
@@ -0,0 +1,30 @@
+#include "dict.h"
+
+#include <gtest/gtest.h>
+#include <cassert>
+
+class DTest : public testing::Test {
+ public:
+ DTest() {}
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+};
+
+TEST_F(DTest, Convert) {
+ Dict d;
+ WordID a = d.Convert("foo");
+ WordID b = d.Convert("bar");
+ std::string x = "foo";
+ WordID c = d.Convert(x);
+ EXPECT_NE(a, b);
+ EXPECT_EQ(a, c);
+ EXPECT_EQ(d.Convert(a), "foo");
+ EXPECT_EQ(d.Convert(b), "bar");
+}
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
diff --git a/src/earley_composer.cc b/src/earley_composer.cc
new file mode 100644
index 00000000..a59686e0
--- /dev/null
+++ b/src/earley_composer.cc
@@ -0,0 +1,726 @@
+#include "earley_composer.h"
+
+#include <iostream>
+#include <fstream>
+#include <map>
+#include <queue>
+#include <tr1/unordered_set>
+
+#include <boost/shared_ptr.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+#include <boost/lexical_cast.hpp>
+
+#include "phrasetable_fst.h"
+#include "sparse_vector.h"
+#include "tdict.h"
+#include "hg.h"
+
+using boost::shared_ptr;
+namespace po = boost::program_options;
+using namespace std;
+using namespace std::tr1;
+
+// Define the following macro if you want to see lots of debugging output
+// when you run the chart parser
+#undef DEBUG_CHART_PARSER
+
+// A few constants used by the chart parser ///////////////
+static const int kMAX_NODES = 2000000;
+static const string kPHRASE_STRING = "X";
+static bool constants_need_init = true;
+static WordID kUNIQUE_START;
+static WordID kPHRASE;
+static TRulePtr kX1X2;
+static TRulePtr kX1;
+static WordID kEPS;
+static TRulePtr kEPSRule;
+
+static void InitializeConstants() {
+ if (constants_need_init) {
+ kPHRASE = TD::Convert(kPHRASE_STRING) * -1;
+ kUNIQUE_START = TD::Convert("S") * -1;
+ kX1X2.reset(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]"));
+ kX1.reset(new TRule("[X] ||| [X,1] ||| [X,1]"));
+ kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>"));
+ kEPS = TD::Convert("<eps>");
+ constants_need_init = false;
+ }
+}
+////////////////////////////////////////////////////////////
+
+class EGrammarNode {
+ friend bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest);
+ friend void AddGrammarRule(const string& r, map<WordID, EGrammarNode>* g);
+ public:
+#ifdef DEBUG_CHART_PARSER
+ string hint;
+#endif
+ EGrammarNode() : is_some_rule_complete(false), is_root(false) {}
+ const map<WordID, EGrammarNode>& GetTerminals() const { return tptr; }
+ const map<WordID, EGrammarNode>& GetNonTerminals() const { return ntptr; }
+ bool HasNonTerminals() const { return (!ntptr.empty()); }
+ bool HasTerminals() const { return (!tptr.empty()); }
+ bool RuleCompletes() const {
+ return (is_some_rule_complete || (ntptr.empty() && tptr.empty()));
+ }
+ bool GrammarContinues() const {
+ return !(ntptr.empty() && tptr.empty());
+ }
+ bool IsRoot() const {
+ return is_root;
+ }
+ // these are the features associated with the rule from the start
+ // node up to this point. If you use these features, you must
+ // not Extend() this rule.
+ const SparseVector<double>& GetCFGProductionFeatures() const {
+ return input_features;
+ }
+
+ const EGrammarNode* Extend(const WordID& t) const {
+ if (t < 0) {
+ map<WordID, EGrammarNode>::const_iterator it = ntptr.find(t);
+ if (it == ntptr.end()) return NULL;
+ return &it->second;
+ } else {
+ map<WordID, EGrammarNode>::const_iterator it = tptr.find(t);
+ if (it == tptr.end()) return NULL;
+ return &it->second;
+ }
+ }
+
+ private:
+ map<WordID, EGrammarNode> tptr;
+ map<WordID, EGrammarNode> ntptr;
+ SparseVector<double> input_features;
+ bool is_some_rule_complete;
+ bool is_root;
+};
+typedef map<WordID, EGrammarNode> EGrammar; // indexed by the rule LHS
+
+// edges are immutable once created
+struct Edge {
+#ifdef DEBUG_CHART_PARSER
+ static int id_count;
+ const int id;
+#endif
+ const WordID cat; // lhs side of rule proved/being proved
+ const EGrammarNode* const dot; // dot position
+ const FSTNode* const q; // start of span
+ const FSTNode* const r; // end of span
+ const Edge* const active_parent; // back pointer, NULL for PREDICT items
+ const Edge* const passive_parent; // back pointer, NULL for SCAN and PREDICT items
+ const TargetPhraseSet* const tps; // translations
+ shared_ptr<SparseVector<double> > features; // features from CFG rule
+
+ bool IsPassive() const {
+ // when a rule is completed, this value will be set
+ return static_cast<bool>(features);
+ }
+ bool IsActive() const { return !IsPassive(); }
+ bool IsInitial() const {
+ return !(active_parent || passive_parent);
+ }
+ bool IsCreatedByScan() const {
+ return active_parent && !passive_parent && !dot->IsRoot();
+ }
+ bool IsCreatedByPredict() const {
+ return dot->IsRoot();
+ }
+ bool IsCreatedByComplete() const {
+ return active_parent && passive_parent;
+ }
+
+ // constructor for PREDICT
+ Edge(WordID c, const EGrammarNode* d, const FSTNode* q_and_r) :
+#ifdef DEBUG_CHART_PARSER
+ id(++id_count),
+#endif
+ cat(c), dot(d), q(q_and_r), r(q_and_r), active_parent(NULL), passive_parent(NULL), tps(NULL) {}
+ Edge(WordID c, const EGrammarNode* d, const FSTNode* q_and_r, const Edge* act_parent) :
+#ifdef DEBUG_CHART_PARSER
+ id(++id_count),
+#endif
+ cat(c), dot(d), q(q_and_r), r(q_and_r), active_parent(act_parent), passive_parent(NULL), tps(NULL) {}
+
+ // constructors for SCAN
+ Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j,
+ const Edge* act_par, const TargetPhraseSet* translations) :
+#ifdef DEBUG_CHART_PARSER
+ id(++id_count),
+#endif
+ cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(NULL), tps(translations) {}
+
+ Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j,
+ const Edge* act_par, const TargetPhraseSet* translations,
+ const SparseVector<double>& feats) :
+#ifdef DEBUG_CHART_PARSER
+ id(++id_count),
+#endif
+ cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(NULL), tps(translations),
+ features(new SparseVector<double>(feats)) {}
+
+ // constructors for COMPLETE
+ Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j,
+ const Edge* act_par, const Edge *pas_par) :
+#ifdef DEBUG_CHART_PARSER
+ id(++id_count),
+#endif
+ cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(pas_par), tps(NULL) {
+ assert(pas_par->IsPassive());
+ assert(act_par->IsActive());
+ }
+
+ Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j,
+ const Edge* act_par, const Edge *pas_par, const SparseVector<double>& feats) :
+#ifdef DEBUG_CHART_PARSER
+ id(++id_count),
+#endif
+ cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(pas_par), tps(NULL),
+ features(new SparseVector<double>(feats)) {
+ assert(pas_par->IsPassive());
+ assert(act_par->IsActive());
+ }
+
+ // constructor for COMPLETE query
+ Edge(const FSTNode* _r) :
+#ifdef DEBUG_CHART_PARSER
+ id(0),
+#endif
+ cat(0), dot(NULL), q(NULL),
+ r(_r), active_parent(NULL), passive_parent(NULL), tps(NULL) {}
+ // constructor for MERGE quere
+ Edge(const FSTNode* _q, int) :
+#ifdef DEBUG_CHART_PARSER
+ id(0),
+#endif
+ cat(0), dot(NULL), q(_q),
+ r(NULL), active_parent(NULL), passive_parent(NULL), tps(NULL) {}
+};
+#ifdef DEBUG_CHART_PARSER
+int Edge::id_count = 0;
+#endif
+
+ostream& operator<<(ostream& os, const Edge& e) {
+ string type = "PREDICT";
+ if (e.IsCreatedByScan())
+ type = "SCAN";
+ else if (e.IsCreatedByComplete())
+ type = "COMPLETE";
+ os << "["
+#ifdef DEBUG_CHART_PARSER
+ << '(' << e.id << ") "
+#else
+ << '(' << &e << ") "
+#endif
+ << "q=" << e.q << ", r=" << e.r
+ << ", cat="<< TD::Convert(e.cat*-1) << ", dot="
+ << e.dot
+#ifdef DEBUG_CHART_PARSER
+ << e.dot->hint
+#endif
+ << (e.IsActive() ? ", Active" : ", Passive")
+ << ", " << type;
+#ifdef DEBUG_CHART_PARSER
+ if (e.active_parent) { os << ", act.parent=(" << e.active_parent->id << ')'; }
+ if (e.passive_parent) { os << ", psv.parent=(" << e.passive_parent->id << ')'; }
+#endif
+ if (e.tps) { os << ", tps=" << e.tps; }
+ return os << ']';
+}
+
+struct Traversal {
+ const Edge* const edge; // result from the active / passive combination
+ const Edge* const active;
+ const Edge* const passive;
+ Traversal(const Edge* me, const Edge* a, const Edge* p) : edge(me), active(a), passive(p) {}
+};
+
+struct UniqueTraversalHash {
+ size_t operator()(const Traversal* t) const {
+ size_t x = 5381;
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(t->active);
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(t->passive);
+ x = ((x << 5) + x) ^ t->edge->IsActive();
+ return x;
+ }
+};
+
+struct UniqueTraversalEquals {
+ size_t operator()(const Traversal* a, const Traversal* b) const {
+ return (a->passive == b->passive && a->active == b->active && a->edge->IsActive() == b->edge->IsActive());
+ }
+};
+
+struct UniqueEdgeHash {
+ size_t operator()(const Edge* e) const {
+ size_t x = 5381;
+ if (e->IsActive()) {
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->dot);
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->q);
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->r);
+ x = ((x << 5) + x) ^ static_cast<size_t>(e->cat);
+ x += 13;
+ } else { // with passive edges, we don't care about the dot
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->q);
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->r);
+ x = ((x << 5) + x) ^ static_cast<size_t>(e->cat);
+ }
+ return x;
+ }
+};
+
+struct UniqueEdgeEquals {
+ bool operator()(const Edge* a, const Edge* b) const {
+ if (a->IsActive() != b->IsActive()) return false;
+ if (a->IsActive()) {
+ return (a->cat == b->cat) && (a->dot == b->dot) && (a->q == b->q) && (a->r == b->r);
+ } else {
+ return (a->cat == b->cat) && (a->q == b->q) && (a->r == b->r);
+ }
+ }
+};
+
+struct REdgeHash {
+ size_t operator()(const Edge* e) const {
+ size_t x = 5381;
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->r);
+ return x;
+ }
+};
+
+struct REdgeEquals {
+ bool operator()(const Edge* a, const Edge* b) const {
+ return (a->r == b->r);
+ }
+};
+
+struct QEdgeHash {
+ size_t operator()(const Edge* e) const {
+ size_t x = 5381;
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->q);
+ return x;
+ }
+};
+
+struct QEdgeEquals {
+ bool operator()(const Edge* a, const Edge* b) const {
+ return (a->q == b->q);
+ }
+};
+
+struct EdgeQueue {
+ queue<const Edge*> q;
+ EdgeQueue() {}
+ void clear() { while(!q.empty()) q.pop(); }
+ bool HasWork() const { return !q.empty(); }
+ const Edge* Next() { const Edge* res = q.front(); q.pop(); return res; }
+ void AddEdge(const Edge* s) { q.push(s); }
+};
+
+class EarleyComposerImpl {
+ public:
+ EarleyComposerImpl(WordID start_cat, const FSTNode& q_0) : start_cat_(start_cat), q_0_(&q_0) {}
+
+ // returns false if the intersection is empty
+ bool Compose(const EGrammar& g, Hypergraph* forest) {
+ goal_node = NULL;
+ EGrammar::const_iterator sit = g.find(start_cat_);
+ forest->ReserveNodes(kMAX_NODES);
+ assert(sit != g.end());
+ Edge* init = new Edge(start_cat_, &sit->second, q_0_);
+ assert(IncorporateNewEdge(init));
+ while (exp_agenda.HasWork() || agenda.HasWork()) {
+ while(exp_agenda.HasWork()) {
+ const Edge* edge = exp_agenda.Next();
+ FinishEdge(edge, forest);
+ }
+ if (agenda.HasWork()) {
+ const Edge* edge = agenda.Next();
+#ifdef DEBUG_CHART_PARSER
+ cerr << "processing (" << edge->id << ')' << endl;
+#endif
+ if (edge->IsActive()) {
+ if (edge->dot->HasTerminals())
+ DoScan(edge);
+ if (edge->dot->HasNonTerminals()) {
+ DoMergeWithPassives(edge);
+ DoPredict(edge, g);
+ }
+ } else {
+ DoComplete(edge);
+ }
+ }
+ }
+ if (goal_node) {
+ forest->PruneUnreachable(goal_node->id_);
+ forest->EpsilonRemove(kEPS);
+ }
+ FreeAll();
+ return goal_node;
+ }
+
+ void FreeAll() {
+ for (int i = 0; i < free_list_.size(); ++i)
+ delete free_list_[i];
+ free_list_.clear();
+ for (int i = 0; i < traversal_free_list_.size(); ++i)
+ delete traversal_free_list_[i];
+ traversal_free_list_.clear();
+ all_traversals.clear();
+ exp_agenda.clear();
+ agenda.clear();
+ tps2node.clear();
+ edge2node.clear();
+ all_edges.clear();
+ passive_edges.clear();
+ active_edges.clear();
+ }
+
+ ~EarleyComposerImpl() {
+ FreeAll();
+ }
+
+ // returns the total number of edges created during composition
+ int EdgesCreated() const {
+ return free_list_.size();
+ }
+
+ private:
+ void DoScan(const Edge* edge) {
+ // here, we assume that the FST will potentially have many more outgoing
+ // edges than the grammar, which will be just a couple. If you want to
+ // efficiently handle the case where both are relatively large, this code
+ // will need to change how the intersection is done. The best general
+ // solution would probably be the Baeza-Yates double binary search.
+
+ const EGrammarNode* dot = edge->dot;
+ const FSTNode* r = edge->r;
+ const map<WordID, EGrammarNode>& terms = dot->GetTerminals();
+ for (map<WordID, EGrammarNode>::const_iterator git = terms.begin();
+ git != terms.end(); ++git) {
+ const FSTNode* next_r = r->Extend(git->first);
+ if (!next_r) continue;
+ const EGrammarNode* next_dot = &git->second;
+ const bool grammar_continues = next_dot->GrammarContinues();
+ const bool rule_completes = next_dot->RuleCompletes();
+ assert(grammar_continues || rule_completes);
+ const SparseVector<double>& input_features = next_dot->GetCFGProductionFeatures();
+ // create up to 4 new edges!
+ if (next_r->HasOutgoingNonEpsilonEdges()) { // are there further symbols in the FST?
+ const TargetPhraseSet* translations = NULL;
+ if (rule_completes)
+ IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, next_r, edge, translations, input_features));
+ if (grammar_continues)
+ IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, next_r, edge, translations));
+ }
+ if (next_r->HasData()) { // indicates a loop back to q_0 in the FST
+ const TargetPhraseSet* translations = next_r->GetTranslations();
+ if (rule_completes)
+ IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, q_0_, edge, translations, input_features));
+ if (grammar_continues)
+ IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, q_0_, edge, translations));
+ }
+ }
+ }
+
+ void DoPredict(const Edge* edge, const EGrammar& g) {
+ const EGrammarNode* dot = edge->dot;
+ const map<WordID, EGrammarNode>& non_terms = dot->GetNonTerminals();
+ for (map<WordID, EGrammarNode>::const_iterator git = non_terms.begin();
+ git != non_terms.end(); ++git) {
+ const WordID nt_to_predict = git->first;
+ //cerr << edge->id << " -- " << TD::Convert(nt_to_predict*-1) << endl;
+ EGrammar::const_iterator egi = g.find(nt_to_predict);
+ if (egi == g.end()) {
+ cerr << "[ERROR] Can't find any grammar rules with a LHS of type "
+ << TD::Convert(-1*nt_to_predict) << '!' << endl;
+ continue;
+ }
+ assert(edge->IsActive());
+ const EGrammarNode* new_dot = &egi->second;
+ Edge* new_edge = new Edge(nt_to_predict, new_dot, edge->r, edge);
+ IncorporateNewEdge(new_edge);
+ }
+ }
+
+ void DoComplete(const Edge* passive) {
+#ifdef DEBUG_CHART_PARSER
+ cerr << " complete: " << *passive << endl;
+#endif
+ const WordID completed_nt = passive->cat;
+ const FSTNode* q = passive->q;
+ const FSTNode* next_r = passive->r;
+ const Edge query(q);
+ const pair<unordered_multiset<const Edge*, REdgeHash, REdgeEquals>::iterator,
+ unordered_multiset<const Edge*, REdgeHash, REdgeEquals>::iterator > p =
+ active_edges.equal_range(&query);
+ for (unordered_multiset<const Edge*, REdgeHash, REdgeEquals>::iterator it = p.first;
+ it != p.second; ++it) {
+ const Edge* active = *it;
+#ifdef DEBUG_CHART_PARSER
+ cerr << " pos: " << *active << endl;
+#endif
+ const EGrammarNode* next_dot = active->dot->Extend(completed_nt);
+ if (!next_dot) continue;
+ const SparseVector<double>& input_features = next_dot->GetCFGProductionFeatures();
+ // add up to 2 rules
+ if (next_dot->RuleCompletes())
+ IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive, input_features));
+ if (next_dot->GrammarContinues())
+ IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive));
+ }
+ }
+
+ void DoMergeWithPassives(const Edge* active) {
+ // edge is active, has non-terminals, we need to find the passives that can extend it
+ assert(active->IsActive());
+ assert(active->dot->HasNonTerminals());
+#ifdef DEBUG_CHART_PARSER
+ cerr << " merge active with passives: ACT=" << *active << endl;
+#endif
+ const Edge query(active->r, 1);
+ const pair<unordered_multiset<const Edge*, QEdgeHash, QEdgeEquals>::iterator,
+ unordered_multiset<const Edge*, QEdgeHash, QEdgeEquals>::iterator > p =
+ passive_edges.equal_range(&query);
+ for (unordered_multiset<const Edge*, QEdgeHash, QEdgeEquals>::iterator it = p.first;
+ it != p.second; ++it) {
+ const Edge* passive = *it;
+ const EGrammarNode* next_dot = active->dot->Extend(passive->cat);
+ if (!next_dot) continue;
+ const FSTNode* next_r = passive->r;
+ const SparseVector<double>& input_features = next_dot->GetCFGProductionFeatures();
+ if (next_dot->RuleCompletes())
+ IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive, input_features));
+ if (next_dot->GrammarContinues())
+ IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive));
+ }
+ }
+
+ // take ownership of edge memory, add to various indexes, etc
+ // returns true if this edge is new
+ bool IncorporateNewEdge(Edge* edge) {
+ free_list_.push_back(edge);
+ if (edge->passive_parent && edge->active_parent) {
+ Traversal* t = new Traversal(edge, edge->active_parent, edge->passive_parent);
+ traversal_free_list_.push_back(t);
+ if (all_traversals.find(t) != all_traversals.end()) {
+ return false;
+ } else {
+ all_traversals.insert(t);
+ }
+ }
+ exp_agenda.AddEdge(edge);
+ return true;
+ }
+
+ bool FinishEdge(const Edge* edge, Hypergraph* hg) {
+ bool is_new = false;
+ if (all_edges.find(edge) == all_edges.end()) {
+#ifdef DEBUG_CHART_PARSER
+ cerr << *edge << " is NEW\n";
+#endif
+ all_edges.insert(edge);
+ is_new = true;
+ if (edge->IsPassive()) passive_edges.insert(edge);
+ if (edge->IsActive()) active_edges.insert(edge);
+ agenda.AddEdge(edge);
+ } else {
+#ifdef DEBUG_CHART_PARSER
+ cerr << *edge << " is NOT NEW.\n";
+#endif
+ }
+ AddEdgeToTranslationForest(edge, hg);
+ return is_new;
+ }
+
+ // build the translation forest
+ void AddEdgeToTranslationForest(const Edge* edge, Hypergraph* hg) {
+ assert(hg->nodes_.size() < kMAX_NODES);
+ Hypergraph::Node* tps = NULL;
+ // first add any target language rules
+ if (edge->tps) {
+ Hypergraph::Node*& node = tps2node[(size_t)edge->tps];
+ if (!node) {
+ // cerr << "Creating phrases for " << edge->tps << endl;
+ const vector<TRulePtr>& rules = edge->tps->GetRules();
+ node = hg->AddNode(kPHRASE, "");
+ for (int i = 0; i < rules.size(); ++i) {
+ Hypergraph::Edge* hg_edge = hg->AddEdge(rules[i], Hypergraph::TailNodeVector());
+ hg_edge->feature_values_ += rules[i]->GetFeatureValues();
+ hg->ConnectEdgeToHeadNode(hg_edge, node);
+ }
+ }
+ tps = node;
+ }
+ Hypergraph::Node*& head_node = edge2node[edge];
+ if (!head_node)
+ head_node = hg->AddNode(kPHRASE, "");
+ if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_0_ && edge->IsPassive()) {
+ assert(goal_node == NULL || goal_node == head_node);
+ goal_node = head_node;
+ }
+ Hypergraph::TailNodeVector tail;
+ SparseVector<double> extra;
+ if (edge->IsCreatedByPredict()) {
+ // extra.set_value(FD::Convert("predict"), 1);
+ } else if (edge->IsCreatedByScan()) {
+ tail.push_back(edge2node[edge->active_parent]->id_);
+ if (tps) {
+ tail.push_back(tps->id_);
+ }
+ //extra.set_value(FD::Convert("scan"), 1);
+ } else if (edge->IsCreatedByComplete()) {
+ tail.push_back(edge2node[edge->active_parent]->id_);
+ tail.push_back(edge2node[edge->passive_parent]->id_);
+ //extra.set_value(FD::Convert("complete"), 1);
+ } else {
+ assert(!"unexpected edge type!");
+ }
+ //cerr << head_node->id_ << "<--" << *edge << endl;
+
+#ifdef DEBUG_CHART_PARSER
+ for (int i = 0; i < tail.size(); ++i)
+ if (tail[i] == head_node->id_) {
+ cerr << "ERROR: " << *edge << "\n i=" << i << endl;
+ if (i == 1) { cerr << "\tP: " << *edge->passive_parent << endl; }
+ if (i == 0) { cerr << "\tA: " << *edge->active_parent << endl; }
+ assert(!"self-loop found!");
+ }
+#endif
+ Hypergraph::Edge* hg_edge = NULL;
+ if (tail.size() == 0) {
+ hg_edge = hg->AddEdge(kEPSRule, tail);
+ } else if (tail.size() == 1) {
+ hg_edge = hg->AddEdge(kX1, tail);
+ } else if (tail.size() == 2) {
+ hg_edge = hg->AddEdge(kX1X2, tail);
+ }
+ if (edge->features)
+ hg_edge->feature_values_ += *edge->features;
+ hg_edge->feature_values_ += extra;
+ hg->ConnectEdgeToHeadNode(hg_edge, head_node);
+ }
+
+ Hypergraph::Node* goal_node;
+ EdgeQueue exp_agenda;
+ EdgeQueue agenda;
+ unordered_map<size_t, Hypergraph::Node*> tps2node;
+ unordered_map<const Edge*, Hypergraph::Node*, UniqueEdgeHash, UniqueEdgeEquals> edge2node;
+ unordered_set<const Traversal*, UniqueTraversalHash, UniqueTraversalEquals> all_traversals;
+ unordered_set<const Edge*, UniqueEdgeHash, UniqueEdgeEquals> all_edges;
+ unordered_multiset<const Edge*, QEdgeHash, QEdgeEquals> passive_edges;
+ unordered_multiset<const Edge*, REdgeHash, REdgeEquals> active_edges;
+ vector<Edge*> free_list_;
+ vector<Traversal*> traversal_free_list_;
+ const WordID start_cat_;
+ const FSTNode* const q_0_;
+};
+
+#ifdef DEBUG_CHART_PARSER
+static string TrimRule(const string& r) {
+ size_t start = r.find(" |||") + 5;
+ size_t end = r.rfind(" |||");
+ return r.substr(start, end - start);
+}
+#endif
+
+void AddGrammarRule(const string& r, EGrammar* g) {
+ const size_t pos = r.find(" ||| ");
+ if (pos == string::npos || r[0] != '[') {
+ cerr << "Bad rule: " << r << endl;
+ return;
+ }
+ const size_t rpos = r.rfind(" ||| ");
+ string feats;
+ string rs = r;
+ if (rpos != pos) {
+ feats = r.substr(rpos + 5);
+ rs = r.substr(0, rpos);
+ }
+ string rhs = rs.substr(pos + 5);
+ string trule = rs + " ||| " + rhs + " ||| " + feats;
+ TRule tr(trule);
+#ifdef DEBUG_CHART_PARSER
+ string hint_last_rule;
+#endif
+ EGrammarNode* cur = &(*g)[tr.GetLHS()];
+ cur->is_root = true;
+ for (int i = 0; i < tr.FLength(); ++i) {
+ WordID sym = tr.f()[i];
+#ifdef DEBUG_CHART_PARSER
+ hint_last_rule = TD::Convert(sym < 0 ? -sym : sym);
+ cur->hint += " <@@> (*" + hint_last_rule + ") " + TrimRule(tr.AsString());
+#endif
+ if (sym < 0)
+ cur = &cur->ntptr[sym];
+ else
+ cur = &cur->tptr[sym];
+ }
+#ifdef DEBUG_CHART_PARSER
+ cur->hint += " <@@> (" + hint_last_rule + "*) " + TrimRule(tr.AsString());
+#endif
+ cur->is_some_rule_complete = true;
+ cur->input_features = tr.GetFeatureValues();
+}
+
+EarleyComposer::~EarleyComposer() {
+ delete pimpl_;
+}
+
+EarleyComposer::EarleyComposer(const FSTNode* fst) {
+ InitializeConstants();
+ pimpl_ = new EarleyComposerImpl(kUNIQUE_START, *fst);
+}
+
+bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest) {
+ // first, convert the src forest into an EGrammar
+ EGrammar g;
+ const int nedges = src_forest.edges_.size();
+ const int nnodes = src_forest.nodes_.size();
+ vector<int> cats(nnodes);
+ bool assign_cats = false;
+ for (int i = 0; i < nnodes; ++i)
+ if (assign_cats) {
+ cats[i] = TD::Convert("CAT_" + boost::lexical_cast<string>(i)) * -1;
+ } else {
+ cats[i] = src_forest.nodes_[i].cat_;
+ }
+ // construct the grammar
+ for (int i = 0; i < nedges; ++i) {
+ const Hypergraph::Edge& edge = src_forest.edges_[i];
+ const vector<WordID>& src = edge.rule_->f();
+ EGrammarNode* cur = &g[cats[edge.head_node_]];
+ cur->is_root = true;
+ int ntc = 0;
+ for (int j = 0; j < src.size(); ++j) {
+ WordID sym = src[j];
+ if (sym <= 0) {
+ sym = cats[edge.tail_nodes_[ntc]];
+ ++ntc;
+ cur = &cur->ntptr[sym];
+ } else {
+ cur = &cur->tptr[sym];
+ }
+ }
+ cur->is_some_rule_complete = true;
+ cur->input_features = edge.feature_values_;
+ }
+ EGrammarNode& goal_rule = g[kUNIQUE_START];
+ assert((goal_rule.ntptr.size() == 1 && goal_rule.tptr.size() == 0) ||
+ (goal_rule.ntptr.size() == 0 && goal_rule.tptr.size() == 1));
+
+ return pimpl_->Compose(g, trg_forest);
+}
+
+bool EarleyComposer::Compose(istream* in, Hypergraph* trg_forest) {
+ EGrammar g;
+ while(*in) {
+ string line;
+ getline(*in, line);
+ if (line.empty()) continue;
+ AddGrammarRule(line, &g);
+ }
+
+ return pimpl_->Compose(g, trg_forest);
+}
diff --git a/src/earley_composer.h b/src/earley_composer.h
new file mode 100644
index 00000000..9f786bf6
--- /dev/null
+++ b/src/earley_composer.h
@@ -0,0 +1,29 @@
+#ifndef _EARLEY_COMPOSER_H_
+#define _EARLEY_COMPOSER_H_
+
+#include <iostream>
+
+class EarleyComposerImpl;
+class FSTNode;
+class Hypergraph;
+
+class EarleyComposer {
+ public:
+ ~EarleyComposer();
+ EarleyComposer(const FSTNode* phrasetable_root);
+ bool Compose(const Hypergraph& src_forest, Hypergraph* trg_forest);
+
+ // reads the grammar from a file. There must be a single top-level
+ // S -> X rule. Anything else is possible. Format is:
+ // [S] ||| [SS,1]
+ // [SS] ||| [NP,1] [VP,2] ||| Feature1=0.2 Feature2=-2.3
+ // [SS] ||| [VP,1] [NP,2] ||| Feature1=0.8
+ // [NP] ||| [DET,1] [N,2] ||| Feature3=2
+ // ...
+ bool Compose(std::istream* grammar_file, Hypergraph* trg_forest);
+
+ private:
+ EarleyComposerImpl* pimpl_;
+};
+
+#endif
diff --git a/src/exp_semiring.h b/src/exp_semiring.h
new file mode 100644
index 00000000..f91beee4
--- /dev/null
+++ b/src/exp_semiring.h
@@ -0,0 +1,71 @@
+#ifndef _EXP_SEMIRING_H_
+#define _EXP_SEMIRING_H_
+
+#include <iostream>
+
+// this file implements the first-order expectation semiring described
+// in Li & Eisner (EMNLP 2009)
+
+// requirements:
+// RType * RType ==> RType
+// PType * PType ==> PType
+// RType * PType ==> RType
+// good examples:
+// PType scalar, RType vector
+// BAD examples:
+// PType vector, RType scalar
+template <typename PType, typename RType>
+struct PRPair {
+ PRPair() : p(), r() {}
+ // Inside algorithm requires that T(0) and T(1)
+ // return the 0 and 1 values of the semiring
+ explicit PRPair(double x) : p(x), r() {}
+ PRPair(const PType& p, const RType& r) : p(p), r(r) {}
+ PRPair& operator+=(const PRPair& o) {
+ p += o.p;
+ r += o.r;
+ return *this;
+ }
+ PRPair& operator*=(const PRPair& o) {
+ r = (o.r * p) + (o.p * r);
+ p *= o.p;
+ return *this;
+ }
+ PType p;
+ RType r;
+};
+
+template <typename P, typename R>
+std::ostream& operator<<(std::ostream& o, const PRPair<P,R>& x) {
+ return o << '<' << x.p << ", " << x.r << '>';
+}
+
+template <typename P, typename R>
+const PRPair<P,R> operator+(const PRPair<P,R>& a, const PRPair<P,R>& b) {
+ PRPair<P,R> result = a;
+ result += b;
+ return result;
+}
+
+template <typename P, typename R>
+const PRPair<P,R> operator*(const PRPair<P,R>& a, const PRPair<P,R>& b) {
+ PRPair<P,R> result = a;
+ result *= b;
+ return result;
+}
+
+template <typename P, typename PWeightFunction, typename R, typename RWeightFunction>
+struct PRWeightFunction {
+ explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(),
+ const RWeightFunction& rwf = RWeightFunction()) :
+ pweight(pwf), rweight(rwf) {}
+ PRPair<P,R> operator()(const Hypergraph::Edge& e) const {
+ const P p = pweight(e);
+ const R r = rweight(e);
+ return PRPair<P,R>(p, r * p);
+ }
+ const PWeightFunction pweight;
+ const RWeightFunction rweight;
+};
+
+#endif
diff --git a/src/fdict.cc b/src/fdict.cc
new file mode 100644
index 00000000..83aa7cea
--- /dev/null
+++ b/src/fdict.cc
@@ -0,0 +1,4 @@
+#include "fdict.h"
+
+Dict FD::dict_;
+
diff --git a/src/fdict.h b/src/fdict.h
new file mode 100644
index 00000000..ff491cfb
--- /dev/null
+++ b/src/fdict.h
@@ -0,0 +1,21 @@
+#ifndef _FDICT_H_
+#define _FDICT_H_
+
+#include <string>
+#include <vector>
+#include "dict.h"
+
+struct FD {
+ static Dict dict_;
+ static inline int NumFeats() {
+ return dict_.max() + 1;
+ }
+ static inline WordID Convert(const std::string& s) {
+ return dict_.Convert(s);
+ }
+ static inline const std::string& Convert(const WordID& w) {
+ return dict_.Convert(w);
+ }
+};
+
+#endif
diff --git a/src/ff.cc b/src/ff.cc
new file mode 100644
index 00000000..488e6468
--- /dev/null
+++ b/src/ff.cc
@@ -0,0 +1,93 @@
+#include "ff.h"
+
+#include "tdict.h"
+#include "hg.h"
+
+using namespace std;
+
+FeatureFunction::~FeatureFunction() {}
+
+
+void FeatureFunction::FinalTraversalFeatures(const void* ant_state,
+ SparseVector<double>* features) const {
+ (void) ant_state;
+ (void) features;
+}
+
+// Hiero and Joshua use log_10(e) as the value, so I do to
+WordPenalty::WordPenalty(const string& param) :
+ fid_(FD::Convert("WordPenalty")),
+ value_(-1.0 / log(10)) {
+ if (!param.empty()) {
+ cerr << "Warning WordPenalty ignoring parameter: " << param << endl;
+ }
+}
+
+void WordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ (void) smeta;
+ (void) ant_states;
+ (void) state;
+ (void) estimated_features;
+ features->set_value(fid_, edge.rule_->EWords() * value_);
+}
+
+ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*>& models) :
+ models_(models),
+ weights_(w),
+ state_size_(0),
+ model_state_pos_(models.size()) {
+ for (int i = 0; i < models_.size(); ++i) {
+ model_state_pos_[i] = state_size_;
+ state_size_ += models_[i]->NumBytesContext();
+ }
+}
+
+void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta,
+ const Hypergraph& hg,
+ Hypergraph::Edge* edge,
+ string* context,
+ prob_t* combination_cost_estimate) const {
+ context->resize(state_size_);
+ memset(&(*context)[0], 0, state_size_);
+ SparseVector<double> est_vals; // only computed if combination_cost_estimate is non-NULL
+ if (combination_cost_estimate) *combination_cost_estimate = prob_t::One();
+ for (int i = 0; i < models_.size(); ++i) {
+ const FeatureFunction& ff = *models_[i];
+ void* cur_ff_context = NULL;
+ vector<const void*> ants(edge->tail_nodes_.size());
+ bool has_context = ff.NumBytesContext() > 0;
+ if (has_context) {
+ int spos = model_state_pos_[i];
+ cur_ff_context = &(*context)[spos];
+ for (int i = 0; i < ants.size(); ++i) {
+ ants[i] = &hg.nodes_[edge->tail_nodes_[i]].state_[spos];
+ }
+ }
+ ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context);
+ }
+ if (combination_cost_estimate)
+ combination_cost_estimate->logeq(est_vals.dot(weights_));
+ edge->edge_prob_.logeq(edge->feature_values_.dot(weights_));
+}
+
+void ModelSet::AddFinalFeatures(const std::string& state, Hypergraph::Edge* edge) const {
+ assert(1 == edge->rule_->Arity());
+
+ for (int i = 0; i < models_.size(); ++i) {
+ const FeatureFunction& ff = *models_[i];
+ const void* ant_state = NULL;
+ bool has_context = ff.NumBytesContext() > 0;
+ if (has_context) {
+ int spos = model_state_pos_[i];
+ ant_state = &state[spos];
+ }
+ ff.FinalTraversalFeatures(ant_state, &edge->feature_values_);
+ }
+ edge->edge_prob_.logeq(edge->feature_values_.dot(weights_));
+}
+
diff --git a/src/ff.h b/src/ff.h
new file mode 100644
index 00000000..c97e2fe2
--- /dev/null
+++ b/src/ff.h
@@ -0,0 +1,121 @@
+#ifndef _FF_H_
+#define _FF_H_
+
+#include <vector>
+
+#include "fdict.h"
+#include "hg.h"
+
+class SentenceMetadata;
+class FeatureFunction; // see definition below
+
+// if you want to develop a new feature, inherit from this class and
+// override TraversalFeaturesImpl(...). If it's a feature that returns /
+// depends on context, you may also need to implement
+// FinalTraversalFeatures(...)
+class FeatureFunction {
+ public:
+ FeatureFunction() : state_size_() {}
+ explicit FeatureFunction(int state_size) : state_size_(state_size) {}
+ virtual ~FeatureFunction();
+
+ // returns the number of bytes of context that this feature function will
+ // (maximally) use. By default, 0 ("stateless" models in Hiero/Joshua).
+ // NOTE: this value is fixed for the instance of your class, you cannot
+ // use different amounts of memory for different nodes in the forest.
+ inline int NumBytesContext() const { return state_size_; }
+
+ // Compute the feature values and (if this applies) the estimates of the
+ // feature values when this edge is used incorporated into a larger context
+ inline void TraversalFeatures(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_state) const {
+ TraversalFeaturesImpl(smeta, edge, ant_contexts,
+ features, estimated_features, out_state);
+ // TODO it's easy for careless feature function developers to overwrite
+ // the end of their state and clobber someone else's memory. These bugs
+ // will be horrendously painful to track down. There should be some
+ // optional strict mode that's enforced here that adds some kind of
+ // barrier between the blocks reserved for the residual contexts
+ }
+
+ // if there's some state left when you transition to the goal state, score
+ // it here. For example, the language model computes the cost of adding
+ // <s> and </s>.
+ virtual void FinalTraversalFeatures(const void* residual_state,
+ SparseVector<double>* final_features) const;
+
+ protected:
+ // context is a pointer to a buffer of size NumBytesContext() that the
+ // feature function can write its state to. It's up to the feature function
+ // to determine how much space it needs and to determine how to encode its
+ // residual contextual information since it is OPAQUE to all clients outside
+ // of the particular FeatureFunction class. There is one exception:
+ // equality of the contents (i.e., memcmp) is required to determine whether
+ // two states can be combined.
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const = 0;
+
+ // !!! ONLY call this from subclass *CONSTRUCTORS* !!!
+ void SetStateSize(size_t state_size) {
+ state_size_ = state_size;
+ }
+
+ private:
+ int state_size_;
+};
+
+// word penalty feature, for each word on the E side of a rule,
+// add value_
+class WordPenalty : public FeatureFunction {
+ public:
+ WordPenalty(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ private:
+ const int fid_;
+ const double value_;
+};
+
+// this class is a set of FeatureFunctions that can be used to score, rescore,
+// etc. a (translation?) forest
+class ModelSet {
+ public:
+ ModelSet() : state_size_(0) {}
+
+ ModelSet(const std::vector<double>& weights,
+ const std::vector<const FeatureFunction*>& models);
+
+ // sets edge->feature_values_ and edge->edge_prob_
+ // NOTE: edge must not necessarily be in hg.edges_ but its TAIL nodes
+ // must be.
+ void AddFeaturesToEdge(const SentenceMetadata& smeta,
+ const Hypergraph& hg,
+ Hypergraph::Edge* edge,
+ std::string* residual_context,
+ prob_t* combination_cost_estimate = NULL) const;
+
+ void AddFinalFeatures(const std::string& residual_context,
+ Hypergraph::Edge* edge) const;
+
+ bool empty() const { return models_.empty(); }
+ private:
+ std::vector<const FeatureFunction*> models_;
+ std::vector<double> weights_;
+ int state_size_;
+ std::vector<int> model_state_pos_;
+};
+
+#endif
diff --git a/src/ff_factory.cc b/src/ff_factory.cc
new file mode 100644
index 00000000..1854e0bb
--- /dev/null
+++ b/src/ff_factory.cc
@@ -0,0 +1,35 @@
+#include "ff_factory.h"
+
+#include "ff.h"
+
+using boost::shared_ptr;
+using namespace std;
+
+FFFactoryBase::~FFFactoryBase() {}
+
+void FFRegistry::DisplayList() const {
+ for (map<string, shared_ptr<FFFactoryBase> >::const_iterator it = reg_.begin();
+ it != reg_.end(); ++it) {
+ cerr << " " << it->first << endl;
+ }
+}
+
+shared_ptr<FeatureFunction> FFRegistry::Create(const string& ffname, const string& param) const {
+ map<string, shared_ptr<FFFactoryBase> >::const_iterator it = reg_.find(ffname);
+ shared_ptr<FeatureFunction> res;
+ if (it == reg_.end()) {
+ cerr << "I don't know how to create feature " << ffname << endl;
+ } else {
+ res = it->second->Create(param);
+ }
+ return res;
+}
+
+void FFRegistry::Register(const string& ffname, FFFactoryBase* factory) {
+ if (reg_.find(ffname) != reg_.end()) {
+ cerr << "Duplicate registration of FeatureFunction with name " << ffname << "!\n";
+ abort();
+ }
+ reg_[ffname].reset(factory);
+}
+
diff --git a/src/ff_factory.h b/src/ff_factory.h
new file mode 100644
index 00000000..bc586567
--- /dev/null
+++ b/src/ff_factory.h
@@ -0,0 +1,39 @@
+#ifndef _FF_FACTORY_H_
+#define _FF_FACTORY_H_
+
+#include <iostream>
+#include <string>
+#include <map>
+
+#include <boost/shared_ptr.hpp>
+
+class FeatureFunction;
+class FFRegistry;
+class FFFactoryBase;
+extern boost::shared_ptr<FFRegistry> global_ff_registry;
+
+class FFRegistry {
+ friend int main(int argc, char** argv);
+ friend class FFFactoryBase;
+ public:
+ boost::shared_ptr<FeatureFunction> Create(const std::string& ffname, const std::string& param) const;
+ void DisplayList() const;
+ void Register(const std::string& ffname, FFFactoryBase* factory);
+ private:
+ FFRegistry() {}
+ std::map<std::string, boost::shared_ptr<FFFactoryBase> > reg_;
+};
+
+struct FFFactoryBase {
+ virtual ~FFFactoryBase();
+ virtual boost::shared_ptr<FeatureFunction> Create(const std::string& param) const = 0;
+};
+
+template<class FF>
+class FFFactory : public FFFactoryBase {
+ boost::shared_ptr<FeatureFunction> Create(const std::string& param) const {
+ return boost::shared_ptr<FeatureFunction>(new FF(param));
+ }
+};
+
+#endif
diff --git a/src/ff_itg_span.h b/src/ff_itg_span.h
new file mode 100644
index 00000000..b990f86a
--- /dev/null
+++ b/src/ff_itg_span.h
@@ -0,0 +1,7 @@
+#ifndef _FF_ITG_SPAN_H_
+#define _FF_ITG_SPAN_H_
+
+class ITGSpanFeatures : public FeatureFunction {
+};
+
+#endif
diff --git a/src/ff_test.cc b/src/ff_test.cc
new file mode 100644
index 00000000..1c20f9ac
--- /dev/null
+++ b/src/ff_test.cc
@@ -0,0 +1,134 @@
+#include <cassert>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <gtest/gtest.h>
+#include "hg.h"
+#include "lm_ff.h"
+#include "ff.h"
+#include "trule.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+
+LanguageModel* lm_ = NULL;
+LanguageModel* lm3_ = NULL;
+
+class FFTest : public testing::Test {
+ public:
+ FFTest() : smeta(0,Lattice()) {
+ if (!lm_) {
+ static LanguageModel slm("-o 2 ./test_data/brown.lm.gz");
+ lm_ = &slm;
+ static LanguageModel slm3("./test_data/dummy.3gram.lm -o 3");
+ lm3_ = &slm3;
+ }
+ }
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+ SentenceMetadata smeta;
+};
+
+TEST_F(FFTest,LanguageModel) {
+ vector<const FeatureFunction*> ms(1, lm_);
+ TRulePtr tr1(new TRule("[X] ||| [X,1] said"));
+ TRulePtr tr2(new TRule("[X] ||| the man said"));
+ TRulePtr tr3(new TRule("[X] ||| the fat man"));
+ Hypergraph hg;
+ const int lm_fid = FD::Convert("LanguageModel");
+ vector<double> w(lm_fid + 1,1);
+ ModelSet models(w, ms);
+ string state;
+ Hypergraph::Edge edge;
+ edge.rule_ = tr2;
+ models.AddFeaturesToEdge(smeta, hg, &edge, &state);
+ double lm1 = edge.feature_values_.dot(w);
+ cerr << "lm=" << edge.feature_values_[lm_fid] << endl;
+
+ hg.nodes_.resize(1);
+ hg.edges_.resize(2);
+ hg.edges_[0].rule_ = tr3;
+ models.AddFeaturesToEdge(smeta, hg, &hg.edges_[0], &hg.nodes_[0].state_);
+ hg.edges_[1].tail_nodes_.push_back(0);
+ hg.edges_[1].rule_ = tr1;
+ string state2;
+ models.AddFeaturesToEdge(smeta, hg, &hg.edges_[1], &state2);
+ double tot = hg.edges_[1].feature_values_[lm_fid] + hg.edges_[0].feature_values_[lm_fid];
+ cerr << "lm=" << tot << endl;
+ EXPECT_TRUE(state2 == state);
+ EXPECT_FALSE(state == hg.nodes_[0].state_);
+}
+
+TEST_F(FFTest, Small) {
+ WordPenalty wp("");
+ vector<const FeatureFunction*> ms(2, lm_);
+ ms[1] = &wp;
+ TRulePtr tr1(new TRule("[X] ||| [X,1] said"));
+ TRulePtr tr2(new TRule("[X] ||| john said"));
+ TRulePtr tr3(new TRule("[X] ||| john"));
+ cerr << "RULE: " << tr1->AsString() << endl;
+ Hypergraph hg;
+ vector<double> w(2); w[0]=1.0; w[1]=-2.0;
+ ModelSet models(w, ms);
+ string state;
+ Hypergraph::Edge edge;
+ edge.rule_ = tr2;
+ cerr << tr2->AsString() << endl;
+ models.AddFeaturesToEdge(smeta, hg, &edge, &state);
+ double s1 = edge.feature_values_.dot(w);
+ cerr << "lm=" << edge.feature_values_[0] << endl;
+ cerr << "wp=" << edge.feature_values_[1] << endl;
+
+ hg.nodes_.resize(1);
+ hg.edges_.resize(2);
+ hg.edges_[0].rule_ = tr3;
+ models.AddFeaturesToEdge(smeta, hg, &hg.edges_[0], &hg.nodes_[0].state_);
+ double acc = hg.edges_[0].feature_values_.dot(w);
+ cerr << hg.edges_[0].feature_values_[0] << endl;
+ hg.edges_[1].tail_nodes_.push_back(0);
+ hg.edges_[1].rule_ = tr1;
+ string state2;
+ models.AddFeaturesToEdge(smeta, hg, &hg.edges_[1], &state2);
+ acc += hg.edges_[1].feature_values_.dot(w);
+ double tot = hg.edges_[1].feature_values_[0] + hg.edges_[0].feature_values_[0];
+ cerr << "lm=" << tot << endl;
+ cerr << "acc=" << acc << endl;
+ cerr << " s1=" << s1 << endl;
+ EXPECT_TRUE(state2 == state);
+ EXPECT_FALSE(state == hg.nodes_[0].state_);
+ EXPECT_FLOAT_EQ(acc, s1);
+}
+
+TEST_F(FFTest, LM3) {
+ int x = lm3_->NumBytesContext();
+ Hypergraph::Edge edge1;
+ edge1.rule_.reset(new TRule("[X] ||| x y ||| one ||| 1.0 -2.4 3.0"));
+ Hypergraph::Edge edge2;
+ edge2.rule_.reset(new TRule("[X] ||| [X,1] a ||| [X,1] two ||| 1.0 -2.4 3.0"));
+ Hypergraph::Edge edge3;
+ edge3.rule_.reset(new TRule("[X] ||| [X,1] a ||| zero [X,1] two ||| 1.0 -2.4 3.0"));
+ vector<const void*> ants1;
+ string state(x, '\0');
+ SparseVector<double> feats;
+ SparseVector<double> est;
+ lm3_->TraversalFeatures(smeta, edge1, ants1, &feats, &est, (void *)&state[0]);
+ cerr << "returned " << feats << endl;
+ cerr << edge1.feature_values_ << endl;
+ cerr << lm3_->DebugStateToString((const void*)&state[0]) << endl;
+ EXPECT_EQ("[ one ]", lm3_->DebugStateToString((const void*)&state[0]));
+ ants1.push_back((const void*)&state[0]);
+ string state2(x, '\0');
+ lm3_->TraversalFeatures(smeta, edge2, ants1, &feats, &est, (void *)&state2[0]);
+ cerr << lm3_->DebugStateToString((const void*)&state2[0]) << endl;
+ EXPECT_EQ("[ one two ]", lm3_->DebugStateToString((const void*)&state2[0]));
+ string state3(x, '\0');
+ lm3_->TraversalFeatures(smeta, edge3, ants1, &feats, &est, (void *)&state3[0]);
+ cerr << lm3_->DebugStateToString((const void*)&state3[0]) << endl;
+ EXPECT_EQ("[ zero one <{STAR}> one two ]", lm3_->DebugStateToString((const void*)&state3[0]));
+}
+
+int main(int argc, char **argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/ff_wordalign.cc b/src/ff_wordalign.cc
new file mode 100644
index 00000000..e605ac8d
--- /dev/null
+++ b/src/ff_wordalign.cc
@@ -0,0 +1,221 @@
+#include "ff_wordalign.h"
+
+#include <string>
+#include <cmath>
+
+#include "stringlib.h"
+#include "sentence_metadata.h"
+#include "hg.h"
+#include "fdict.h"
+#include "aligner.h"
+#include "tdict.h" // Blunsom hack
+#include "filelib.h" // Blunsom hack
+
+using namespace std;
+
+RelativeSentencePosition::RelativeSentencePosition(const string& param) :
+ fid_(FD::Convert("RelativeSentencePosition")) {}
+
+void RelativeSentencePosition::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ // if the source word is either null or the generated word
+ // has no position in the reference
+ if (edge.i_ == -1 || edge.prev_i_ == -1)
+ return;
+
+ assert(smeta.GetTargetLength() > 0);
+ const double val = fabs(static_cast<double>(edge.i_) / smeta.GetSourceLength() -
+ static_cast<double>(edge.prev_i_) / smeta.GetTargetLength());
+ features->set_value(fid_, val);
+// cerr << f_len_ << " " << e_len_ << " [" << edge.i_ << "," << edge.j_ << "|" << edge.prev_i_ << "," << edge.prev_j_ << "]\t" << edge.rule_->AsString() << "\tVAL=" << val << endl;
+}
+
+MarkovJump::MarkovJump(const string& param) :
+ FeatureFunction(1),
+ fid_(FD::Convert("MarkovJump")),
+ individual_params_per_jumpsize_(false),
+ condition_on_flen_(false) {
+ cerr << " MarkovJump: Blunsom&Cohn feature";
+ vector<string> argv;
+ int argc = SplitOnWhitespace(param, &argv);
+ if (argc > 0) {
+ if (argc != 1 || !(argv[0] == "-f" || argv[0] == "-i" || argv[0] == "-if")) {
+ cerr << "MarkovJump: expected parameters to be -f, -i, or -if\n";
+ exit(1);
+ }
+ individual_params_per_jumpsize_ = (argv[0][1] == 'i');
+ condition_on_flen_ = (argv[0][argv[0].size() - 1] == 'f');
+ if (individual_params_per_jumpsize_) {
+ template_ = "Jump:000";
+ cerr << ", individual jump parameters";
+ if (condition_on_flen_) {
+ template_ += ":F00";
+ cerr << " (split by f-length)";
+ }
+ }
+ }
+ cerr << endl;
+}
+
+void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ unsigned char& dpstate = *((unsigned char*)state);
+ if (edge.Arity() == 0) {
+ dpstate = static_cast<unsigned int>(edge.i_);
+ } else if (edge.Arity() == 1) {
+ dpstate = *((unsigned char*)ant_states[0]);
+ } else if (edge.Arity() == 2) {
+ int left_index = *((unsigned char*)ant_states[0]);
+ int right_index = *((unsigned char*)ant_states[1]);
+ if (right_index == -1)
+ dpstate = static_cast<unsigned int>(left_index);
+ else
+ dpstate = static_cast<unsigned int>(right_index);
+ const int jumpsize = right_index - left_index;
+ features->set_value(fid_, fabs(jumpsize - 1)); // Blunsom and Cohn def
+
+ if (individual_params_per_jumpsize_) {
+ string fname = template_;
+ int param = jumpsize;
+ if (jumpsize < 0) {
+ param *= -1;
+ fname[5]='L';
+ } else if (jumpsize > 0) {
+ fname[5]='R';
+ }
+ if (param) {
+ fname[6] = '0' + (param / 10);
+ fname[7] = '0' + (param % 10);
+ }
+ if (condition_on_flen_) {
+ const int flen = smeta.GetSourceLength();
+ fname[10] = '0' + (flen / 10);
+ fname[11] = '0' + (flen % 10);
+ }
+ features->set_value(FD::Convert(fname), 1.0);
+ }
+ } else {
+ assert(!"something really unexpected is happening");
+ }
+}
+
+AlignerResults::AlignerResults(const std::string& param) :
+ cur_sent_(-1),
+ cur_grid_(NULL) {
+ vector<string> argv;
+ int argc = SplitOnWhitespace(param, &argv);
+ if (argc != 2) {
+ cerr << "Required format: AlignerResults [FeatureName] [file.pharaoh]\n";
+ exit(1);
+ }
+ cerr << " feature: " << argv[0] << "\talignments: " << argv[1] << endl;
+ fid_ = FD::Convert(argv[0]);
+ ReadFile rf(argv[1]);
+ istream& in = *rf.stream(); int lc = 0;
+ while(in) {
+ string line;
+ getline(in, line);
+ if (!in) break;
+ ++lc;
+ is_aligned_.push_back(AlignerTools::ReadPharaohAlignmentGrid(line));
+ }
+ cerr << " Loaded " << lc << " refs\n";
+}
+
+void AlignerResults::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ if (edge.i_ == -1 || edge.prev_i_ == -1)
+ return;
+
+ if (cur_sent_ != smeta.GetSentenceID()) {
+ assert(smeta.HasReference());
+ cur_sent_ = smeta.GetSentenceID();
+ assert(cur_sent_ < is_aligned_.size());
+ cur_grid_ = is_aligned_[cur_sent_].get();
+ }
+
+ //cerr << edge.rule_->AsString() << endl;
+
+ int j = edge.i_; // source side (f)
+ int i = edge.prev_i_; // target side (e)
+ if (j < cur_grid_->height() && i < cur_grid_->width() && (*cur_grid_)(i, j)) {
+// if (edge.rule_->e_[0] == smeta.GetReference()[i][0].label) {
+ features->set_value(fid_, 1.0);
+// cerr << edge.rule_->AsString() << " (" << i << "," << j << ")\n";
+// }
+ }
+}
+
+BlunsomSynchronousParseHack::BlunsomSynchronousParseHack(const string& param) :
+ FeatureFunction((100 / 8) + 1), fid_(FD::Convert("NotRef")), cur_sent_(-1) {
+ ReadFile rf(param);
+ istream& in = *rf.stream(); int lc = 0;
+ while(in) {
+ string line;
+ getline(in, line);
+ if (!in) break;
+ ++lc;
+ refs_.push_back(vector<WordID>());
+ TD::ConvertSentence(line, &refs_.back());
+ }
+ cerr << " Loaded " << lc << " refs\n";
+}
+
+void BlunsomSynchronousParseHack::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ if (cur_sent_ != smeta.GetSentenceID()) {
+ // assert(smeta.HasReference());
+ cur_sent_ = smeta.GetSentenceID();
+ assert(cur_sent_ < refs_.size());
+ cur_ref_ = &refs_[cur_sent_];
+ cur_map_.clear();
+ for (int i = 0; i < cur_ref_->size(); ++i) {
+ vector<WordID> phrase;
+ for (int j = i; j < cur_ref_->size(); ++j) {
+ phrase.push_back((*cur_ref_)[j]);
+ cur_map_[phrase] = i;
+ }
+ }
+ }
+ //cerr << edge.rule_->AsString() << endl;
+ for (int i = 0; i < ant_states.size(); ++i) {
+ if (DoesNotBelong(ant_states[i])) {
+ //cerr << " ant " << i << " does not belong\n";
+ return;
+ }
+ }
+ vector<vector<WordID> > ants(ant_states.size());
+ vector<const vector<WordID>* > pants(ant_states.size());
+ for (int i = 0; i < ant_states.size(); ++i) {
+ AppendAntecedentString(ant_states[i], &ants[i]);
+ //cerr << " ant[" << i << "]: " << ((int)*(static_cast<const unsigned char*>(ant_states[i]))) << " " << TD::GetString(ants[i]) << endl;
+ pants[i] = &ants[i];
+ }
+ vector<WordID> yield;
+ edge.rule_->ESubstitute(pants, &yield);
+ //cerr << "YIELD: " << TD::GetString(yield) << endl;
+ Vec2Int::iterator it = cur_map_.find(yield);
+ if (it == cur_map_.end()) {
+ features->set_value(fid_, 1);
+ //cerr << " BAD!\n";
+ return;
+ }
+ SetStateMask(it->second, it->second + yield.size(), state);
+}
+
diff --git a/src/ff_wordalign.h b/src/ff_wordalign.h
new file mode 100644
index 00000000..1581641c
--- /dev/null
+++ b/src/ff_wordalign.h
@@ -0,0 +1,133 @@
+#ifndef _FF_WORD_ALIGN_H_
+#define _FF_WORD_ALIGN_H_
+
+#include "ff.h"
+#include "array2d.h"
+
+class RelativeSentencePosition : public FeatureFunction {
+ public:
+ RelativeSentencePosition(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ const int fid_;
+};
+
+class MarkovJump : public FeatureFunction {
+ public:
+ MarkovJump(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ const int fid_;
+ bool individual_params_per_jumpsize_;
+ bool condition_on_flen_;
+ std::string template_;
+};
+
+class AlignerResults : public FeatureFunction {
+ public:
+ AlignerResults(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ int fid_;
+ std::vector<boost::shared_ptr<Array2D<bool> > > is_aligned_;
+ mutable int cur_sent_;
+ const Array2D<bool> mutable* cur_grid_;
+};
+
+#include <tr1/unordered_map>
+#include <boost/functional/hash.hpp>
+#include <cassert>
+class BlunsomSynchronousParseHack : public FeatureFunction {
+ public:
+ BlunsomSynchronousParseHack(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ inline bool DoesNotBelong(const void* state) const {
+ for (int i = 0; i < NumBytesContext(); ++i) {
+ if (*(static_cast<const unsigned char*>(state) + i)) return false;
+ }
+ return true;
+ }
+
+ inline void AppendAntecedentString(const void* state, std::vector<WordID>* yield) const {
+ int i = 0;
+ int ind = 0;
+ while (i < NumBytesContext() && !(*(static_cast<const unsigned char*>(state) + i))) { ++i; ind += 8; }
+ // std::cerr << i << " " << NumBytesContext() << std::endl;
+ assert(i != NumBytesContext());
+ assert(ind < cur_ref_->size());
+ int cur = *(static_cast<const unsigned char*>(state) + i);
+ int comp = 1;
+ while (comp < 256 && (comp & cur) == 0) { comp <<= 1; ++ind; }
+ assert(ind < cur_ref_->size());
+ assert(comp < 256);
+ do {
+ assert(ind < cur_ref_->size());
+ yield->push_back((*cur_ref_)[ind]);
+ ++ind;
+ comp <<= 1;
+ if (comp == 256) {
+ comp = 1;
+ ++i;
+ cur = *(static_cast<const unsigned char*>(state) + i);
+ }
+ } while (comp & cur);
+ }
+
+ inline void SetStateMask(int start, int end, void* state) const {
+ assert((end / 8) < NumBytesContext());
+ int i = 0;
+ int comp = 1;
+ for (int j = 0; j < start; ++j) {
+ comp <<= 1;
+ if (comp == 256) {
+ ++i;
+ comp = 1;
+ }
+ }
+ //std::cerr << "SM: " << i << "\n";
+ for (int j = start; j < end; ++j) {
+ *(static_cast<unsigned char*>(state) + i) |= comp;
+ //std::cerr << " " << comp << "\n";
+ comp <<= 1;
+ if (comp == 256) {
+ ++i;
+ comp = 1;
+ }
+ }
+ //std::cerr << " MASK: " << ((int)*(static_cast<unsigned char*>(state))) << "\n";
+ }
+
+ const int fid_;
+ mutable int cur_sent_;
+ typedef std::tr1::unordered_map<std::vector<WordID>, int, boost::hash<std::vector<WordID> > > Vec2Int;
+ mutable Vec2Int cur_map_;
+ const std::vector<WordID> mutable * cur_ref_;
+ mutable std::vector<std::vector<WordID> > refs_;
+};
+
+#endif
diff --git a/src/filelib.cc b/src/filelib.cc
new file mode 100644
index 00000000..79ad2847
--- /dev/null
+++ b/src/filelib.cc
@@ -0,0 +1,22 @@
+#include "filelib.h"
+
+#include <unistd.h>
+#include <sys/stat.h>
+
+using namespace std;
+
+bool FileExists(const std::string& fn) {
+ struct stat info;
+ int s = stat(fn.c_str(), &info);
+ return (s==0);
+}
+
+bool DirectoryExists(const string& dir) {
+ if (access(dir.c_str(),0) == 0) {
+ struct stat status;
+ stat(dir.c_str(), &status);
+ if (status.st_mode & S_IFDIR) return true;
+ }
+ return false;
+}
+
diff --git a/src/filelib.h b/src/filelib.h
new file mode 100644
index 00000000..62cb9427
--- /dev/null
+++ b/src/filelib.h
@@ -0,0 +1,66 @@
+#ifndef _FILELIB_H_
+#define _FILELIB_H_
+
+#include <cassert>
+#include <string>
+#include <iostream>
+#include <cstdlib>
+#include "gzstream.h"
+
+// reads from standard in if filename is -
+// uncompresses if file ends with .gz
+// otherwise, reads from a normal file
+class ReadFile {
+ public:
+ ReadFile(const std::string& filename) :
+ no_delete_on_exit_(filename == "-"),
+ in_(no_delete_on_exit_ ? static_cast<std::istream*>(&std::cin) :
+ (EndsWith(filename, ".gz") ?
+ static_cast<std::istream*>(new igzstream(filename.c_str())) :
+ static_cast<std::istream*>(new std::ifstream(filename.c_str())))) {
+ if (!*in_) {
+ std::cerr << "Failed to open " << filename << std::endl;
+ abort();
+ }
+ }
+ ~ReadFile() {
+ if (!no_delete_on_exit_) delete in_;
+ }
+
+ inline std::istream* stream() { return in_; }
+
+ private:
+ static bool EndsWith(const std::string& f, const std::string& suf) {
+ return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size());
+ }
+ const bool no_delete_on_exit_;
+ std::istream* const in_;
+};
+
+class WriteFile {
+ public:
+ WriteFile(const std::string& filename) :
+ no_delete_on_exit_(filename == "-"),
+ out_(no_delete_on_exit_ ? static_cast<std::ostream*>(&std::cout) :
+ (EndsWith(filename, ".gz") ?
+ static_cast<std::ostream*>(new ogzstream(filename.c_str())) :
+ static_cast<std::ostream*>(new std::ofstream(filename.c_str())))) {}
+ ~WriteFile() {
+ (*out_) << std::flush;
+ if (!no_delete_on_exit_) delete out_;
+ }
+
+ inline std::ostream* stream() { return out_; }
+
+ private:
+ static bool EndsWith(const std::string& f, const std::string& suf) {
+ return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size());
+ }
+ const bool no_delete_on_exit_;
+ std::ostream* const out_;
+};
+
+bool FileExists(const std::string& file_name);
+bool DirectoryExists(const std::string& dir_name);
+
+#endif
diff --git a/src/forest_writer.cc b/src/forest_writer.cc
new file mode 100644
index 00000000..a9117d18
--- /dev/null
+++ b/src/forest_writer.cc
@@ -0,0 +1,23 @@
+#include "forest_writer.h"
+
+#include <iostream>
+
+#include <boost/lexical_cast.hpp>
+
+#include "filelib.h"
+#include "hg_io.h"
+#include "hg.h"
+
+using namespace std;
+
+ForestWriter::ForestWriter(const std::string& path, int num) :
+ fname_(path + '/' + boost::lexical_cast<string>(num) + ".json.gz"), used_(false) {}
+
+bool ForestWriter::Write(const Hypergraph& forest, bool minimal_rules) {
+ assert(!used_);
+ used_ = true;
+ cerr << " Writing forest to " << fname_ << endl;
+ WriteFile wf(fname_);
+ return HypergraphIO::WriteToJSON(forest, minimal_rules, wf.stream());
+}
+
diff --git a/src/forest_writer.h b/src/forest_writer.h
new file mode 100644
index 00000000..819a8940
--- /dev/null
+++ b/src/forest_writer.h
@@ -0,0 +1,16 @@
+#ifndef _FOREST_WRITER_H_
+#define _FOREST_WRITER_H_
+
+#include <string>
+
+class Hypergraph;
+
+struct ForestWriter {
+ ForestWriter(const std::string& path, int num);
+ bool Write(const Hypergraph& forest, bool minimal_rules);
+
+ const std::string fname_;
+ bool used_;
+};
+
+#endif
diff --git a/src/freqdict.cc b/src/freqdict.cc
new file mode 100644
index 00000000..4cfffe58
--- /dev/null
+++ b/src/freqdict.cc
@@ -0,0 +1,23 @@
+#include <iostream>
+#include <fstream>
+#include <cassert>
+#include "freqdict.h"
+
+void FreqDict::load(const std::string& fname) {
+ std::ifstream ifs(fname.c_str());
+ int cc=0;
+ while (!ifs.eof()) {
+ std::string word;
+ ifs >> word;
+ if (word.size() == 0) continue;
+ if (word[0] == '#') continue;
+ double count = 0;
+ ifs >> count;
+ assert(count > 0.0); // use -log(f)
+ counts_[word]=count;
+ ++cc;
+ if (cc % 10000 == 0) { std::cerr << "."; }
+ }
+ std::cerr << "\n";
+ std::cerr << "Loaded " << cc << " words\n";
+}
diff --git a/src/freqdict.h b/src/freqdict.h
new file mode 100644
index 00000000..c9bb4c42
--- /dev/null
+++ b/src/freqdict.h
@@ -0,0 +1,19 @@
+#ifndef _FREQDICT_H_
+#define _FREQDICT_H_
+
+#include <map>
+#include <string>
+
+class FreqDict {
+ public:
+ void load(const std::string& fname);
+ float frequency(const std::string& word) const {
+ std::map<std::string,float>::const_iterator i = counts_.find(word);
+ if (i == counts_.end()) return 0;
+ return i->second;
+ }
+ private:
+ std::map<std::string, float> counts_;
+};
+
+#endif
diff --git a/src/fst_translator.cc b/src/fst_translator.cc
new file mode 100644
index 00000000..57feb227
--- /dev/null
+++ b/src/fst_translator.cc
@@ -0,0 +1,91 @@
+#include "translator.h"
+
+#include <sstream>
+#include <boost/shared_ptr.hpp>
+
+#include "sentence_metadata.h"
+#include "filelib.h"
+#include "hg.h"
+#include "hg_io.h"
+#include "earley_composer.h"
+#include "phrasetable_fst.h"
+#include "tdict.h"
+
+using namespace std;
+
+struct FSTTranslatorImpl {
+ FSTTranslatorImpl(const boost::program_options::variables_map& conf) :
+ goal_sym(conf["goal"].as<string>()),
+ kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")),
+ kGOAL(TD::Convert("Goal") * -1),
+ add_pass_through_rules(conf.count("add_pass_through_rules")) {
+ fst.reset(LoadTextPhrasetable(conf["grammar"].as<vector<string> >()));
+ ec.reset(new EarleyComposer(fst.get()));
+ }
+
+ bool Translate(const string& input,
+ const vector<double>& weights,
+ Hypergraph* forest) {
+ bool composed = false;
+ if (input.find("{\"rules\"") == 0) {
+ istringstream is(input);
+ Hypergraph src_cfg_hg;
+ assert(HypergraphIO::ReadFromJSON(&is, &src_cfg_hg));
+ if (add_pass_through_rules) {
+ SparseVector<double> feats;
+ feats.set_value(FD::Convert("PassThrough"), 1);
+ for (int i = 0; i < src_cfg_hg.edges_.size(); ++i) {
+ const vector<WordID>& f = src_cfg_hg.edges_[i].rule_->f_;
+ for (int j = 0; j < f.size(); ++j) {
+ if (f[j] > 0) {
+ fst->AddPassThroughTranslation(f[j], feats);
+ }
+ }
+ }
+ }
+ composed = ec->Compose(src_cfg_hg, forest);
+ } else {
+ const string dummy_grammar("[" + goal_sym + "] ||| " + input + " ||| TOP=1");
+ cerr << " Dummy grammar: " << dummy_grammar << endl;
+ istringstream is(dummy_grammar);
+ if (add_pass_through_rules) {
+ vector<WordID> words;
+ TD::ConvertSentence(input, &words);
+ SparseVector<double> feats;
+ feats.set_value(FD::Convert("PassThrough"), 1);
+ for (int i = 0; i < words.size(); ++i)
+ fst->AddPassThroughTranslation(words[i], feats);
+ }
+ composed = ec->Compose(&is, forest);
+ }
+ if (composed) {
+ Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
+ Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1, "");
+ Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
+ forest->ConnectEdgeToHeadNode(hg_edge, goal);
+ forest->Reweight(weights);
+ }
+ if (add_pass_through_rules)
+ fst->ClearPassThroughTranslations();
+ return composed;
+ }
+
+ const string goal_sym;
+ const TRulePtr kGOAL_RULE;
+ const WordID kGOAL;
+ const bool add_pass_through_rules;
+ boost::shared_ptr<EarleyComposer> ec;
+ boost::shared_ptr<FSTNode> fst;
+};
+
+FSTTranslator::FSTTranslator(const boost::program_options::variables_map& conf) :
+ pimpl_(new FSTTranslatorImpl(conf)) {}
+
+bool FSTTranslator::Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ smeta->SetSourceLength(0); // don't know how to compute this
+ return pimpl_->Translate(input, weights, minus_lm_forest);
+}
+
diff --git a/src/grammar.cc b/src/grammar.cc
new file mode 100644
index 00000000..69e38320
--- /dev/null
+++ b/src/grammar.cc
@@ -0,0 +1,163 @@
+#include "grammar.h"
+
+#include <algorithm>
+#include <utility>
+#include <map>
+
+#include "filelib.h"
+#include "tdict.h"
+
+using namespace std;
+
+const vector<TRulePtr> Grammar::NO_RULES;
+
+RuleBin::~RuleBin() {}
+GrammarIter::~GrammarIter() {}
+Grammar::~Grammar() {}
+
+bool Grammar::HasRuleForSpan(int i, int j) const {
+ (void) i;
+ (void) j;
+ return true; // always true by default
+}
+
+struct TextRuleBin : public RuleBin {
+ int GetNumRules() const {
+ return rules_.size();
+ }
+ TRulePtr GetIthRule(int i) const {
+ return rules_[i];
+ }
+ void AddRule(TRulePtr t) {
+ rules_.push_back(t);
+ }
+ int Arity() const {
+ return rules_.front()->Arity();
+ }
+ void Dump() const {
+ for (int i = 0; i < rules_.size(); ++i)
+ VLOG(1) << rules_[i]->AsString() << endl;
+ }
+ private:
+ vector<TRulePtr> rules_;
+};
+
+struct TextGrammarNode : public GrammarIter {
+ TextGrammarNode() : rb_(NULL) {}
+ ~TextGrammarNode() {
+ delete rb_;
+ }
+ const GrammarIter* Extend(int symbol) const {
+ map<WordID, TextGrammarNode>::const_iterator i = tree_.find(symbol);
+ if (i == tree_.end()) return NULL;
+ return &i->second;
+ }
+
+ const RuleBin* GetRules() const {
+ if (rb_) {
+ //rb_->Dump();
+ }
+ return rb_;
+ }
+
+ map<WordID, TextGrammarNode> tree_;
+ TextRuleBin* rb_;
+};
+
+struct TGImpl {
+ TextGrammarNode root_;
+};
+
+TextGrammar::TextGrammar() : max_span_(10), pimpl_(new TGImpl) {}
+TextGrammar::TextGrammar(const string& file) :
+ max_span_(10),
+ pimpl_(new TGImpl) {
+ ReadFromFile(file);
+}
+
+const GrammarIter* TextGrammar::GetRoot() const {
+ return &pimpl_->root_;
+}
+
+void TextGrammar::AddRule(const TRulePtr& rule) {
+ if (rule->IsUnary()) {
+ rhs2unaries_[rule->f().front()].push_back(rule);
+ unaries_.push_back(rule);
+ } else {
+ TextGrammarNode* cur = &pimpl_->root_;
+ for (int i = 0; i < rule->f_.size(); ++i)
+ cur = &cur->tree_[rule->f_[i]];
+ if (cur->rb_ == NULL)
+ cur->rb_ = new TextRuleBin;
+ cur->rb_->AddRule(rule);
+ }
+}
+
+void TextGrammar::ReadFromFile(const string& filename) {
+ ReadFile in(filename);
+ istream& in_file = *in.stream();
+ assert(in_file);
+ long long int rule_count = 0;
+ bool fl = false;
+ while(in_file) {
+ string line;
+ getline(in_file, line);
+ if (line.empty()) continue;
+ ++rule_count;
+ if (rule_count % 50000 == 0) { cerr << '.' << flush; fl = true; }
+ if (rule_count % 2000000 == 0) { cerr << " [" << rule_count << "]\n"; fl = false; }
+ TRulePtr rule(TRule::CreateRuleSynchronous(line));
+ if (rule) {
+ AddRule(rule);
+ } else {
+ if (fl) { cerr << endl; }
+ cerr << "Skipping badly formatted rule in line " << rule_count << " of " << filename << endl;
+ fl = false;
+ }
+ }
+ if (fl) cerr << endl;
+ cerr << " " << rule_count << " rules read.\n";
+}
+
+bool TextGrammar::HasRuleForSpan(int i, int j) const {
+ return (max_span_ >= (j - i));
+}
+
+GlueGrammar::GlueGrammar(const string& file) : TextGrammar(file) {}
+
+GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt) {
+ TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]"));
+ TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["
+ + default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1"));
+
+ AddRule(stop_glue);
+ AddRule(glue);
+ //cerr << "GLUE: " << stop_glue->AsString() << endl;
+ //cerr << "GLUE: " << glue->AsString() << endl;
+}
+
+bool GlueGrammar::HasRuleForSpan(int i, int j) const {
+ (void) j;
+ return (i == 0);
+}
+
+PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat) :
+ has_rule_(input.size() + 1) {
+ for (int i = 0; i < input.size(); ++i) {
+ const vector<LatticeArc>& alts = input[i];
+ for (int k = 0; k < alts.size(); ++k) {
+ const int j = alts[k].dist2next + i;
+ has_rule_[i].insert(j);
+ const string& src = TD::Convert(alts[k].label);
+ TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1"));
+ AddRule(pt);
+// cerr << "PT: " << pt->AsString() << endl;
+ }
+ }
+}
+
+bool PassThroughGrammar::HasRuleForSpan(int i, int j) const {
+ const set<int>& hr = has_rule_[i];
+ if (i == j) { return !hr.empty(); }
+ return (hr.find(j) != hr.end());
+}
diff --git a/src/grammar.h b/src/grammar.h
new file mode 100644
index 00000000..4a03c505
--- /dev/null
+++ b/src/grammar.h
@@ -0,0 +1,83 @@
+#ifndef GRAMMAR_H_
+#define GRAMMAR_H_
+
+#include <vector>
+#include <map>
+#include <set>
+#include <boost/shared_ptr.hpp>
+
+#include "lattice.h"
+#include "trule.h"
+
+struct RuleBin {
+ virtual ~RuleBin();
+ virtual int GetNumRules() const = 0;
+ virtual TRulePtr GetIthRule(int i) const = 0;
+ virtual int Arity() const = 0;
+};
+
+struct GrammarIter {
+ virtual ~GrammarIter();
+ virtual const RuleBin* GetRules() const = 0;
+ virtual const GrammarIter* Extend(int symbol) const = 0;
+};
+
+struct Grammar {
+ typedef std::map<WordID, std::vector<TRulePtr> > Cat2Rules;
+ static const std::vector<TRulePtr> NO_RULES;
+
+ virtual ~Grammar();
+ virtual const GrammarIter* GetRoot() const = 0;
+ virtual bool HasRuleForSpan(int i, int j) const;
+
+ // cat is the category to be rewritten
+ inline const std::vector<TRulePtr>& GetAllUnaryRules() const {
+ return unaries_;
+ }
+
+ // get all the unary rules that rewrite category cat
+ inline const std::vector<TRulePtr>& GetUnaryRulesForRHS(const WordID& cat) const {
+ Cat2Rules::const_iterator found = rhs2unaries_.find(cat);
+ if (found == rhs2unaries_.end())
+ return NO_RULES;
+ else
+ return found->second;
+ }
+
+ protected:
+ Cat2Rules rhs2unaries_; // these must be filled in by subclasses!
+ std::vector<TRulePtr> unaries_;
+};
+
+typedef boost::shared_ptr<Grammar> GrammarPtr;
+
+class TGImpl;
+struct TextGrammar : public Grammar {
+ TextGrammar();
+ TextGrammar(const std::string& file);
+ void SetMaxSpan(int m) { max_span_ = m; }
+ virtual const GrammarIter* GetRoot() const;
+ void AddRule(const TRulePtr& rule);
+ void ReadFromFile(const std::string& filename);
+ virtual bool HasRuleForSpan(int i, int j) const;
+ const std::vector<TRulePtr>& GetUnaryRules(const WordID& cat) const;
+ private:
+ int max_span_;
+ boost::shared_ptr<TGImpl> pimpl_;
+};
+
+struct GlueGrammar : public TextGrammar {
+ // read glue grammar from file
+ explicit GlueGrammar(const std::string& file);
+ GlueGrammar(const std::string& goal_nt, const std::string& default_nt); // "S", "X"
+ virtual bool HasRuleForSpan(int i, int j) const;
+};
+
+struct PassThroughGrammar : public TextGrammar {
+ PassThroughGrammar(const Lattice& input, const std::string& cat);
+ virtual bool HasRuleForSpan(int i, int j) const;
+ private:
+ std::vector<std::set<int> > has_rule_; // index by [i][j]
+};
+
+#endif
diff --git a/src/grammar_test.cc b/src/grammar_test.cc
new file mode 100644
index 00000000..62b8f958
--- /dev/null
+++ b/src/grammar_test.cc
@@ -0,0 +1,59 @@
+#include <cassert>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <gtest/gtest.h>
+#include "trule.h"
+#include "tdict.h"
+#include "grammar.h"
+#include "bottom_up_parser.h"
+#include "ff.h"
+#include "weights.h"
+
+using namespace std;
+
+class GrammarTest : public testing::Test {
+ public:
+ GrammarTest() {
+ wts.InitFromFile("test_data/weights.gt");
+ }
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+ Weights wts;
+};
+
+TEST_F(GrammarTest,TestTextGrammar) {
+ vector<double> w;
+ vector<const FeatureFunction*> ms;
+ ModelSet models(w, ms);
+
+ TextGrammar g;
+ TRulePtr r1(new TRule("[X] ||| a b c ||| A B C ||| 0.1 0.2 0.3", true));
+ TRulePtr r2(new TRule("[X] ||| a b c ||| 1 2 3 ||| 0.2 0.3 0.4", true));
+ TRulePtr r3(new TRule("[X] ||| a b c d ||| A B C D ||| 0.1 0.2 0.3", true));
+ cerr << r1->AsString() << endl;
+ g.AddRule(r1);
+ g.AddRule(r2);
+ g.AddRule(r3);
+}
+
+TEST_F(GrammarTest,TestTextGrammarFile) {
+ GrammarPtr g(new TextGrammar("./test_data/grammar.prune"));
+ vector<GrammarPtr> grammars(1, g);
+
+ LatticeArc a(TD::Convert("ein"), 0.0, 1);
+ LatticeArc b(TD::Convert("haus"), 0.0, 1);
+ Lattice lattice(2);
+ lattice[0].push_back(a);
+ lattice[1].push_back(b);
+ Hypergraph forest;
+ ExhaustiveBottomUpParser parser("PHRASE", grammars);
+ parser.Parse(lattice, &forest);
+ forest.PrintGraphviz();
+}
+
+int main(int argc, char **argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/gzstream.cc b/src/gzstream.cc
new file mode 100644
index 00000000..9703e6ad
--- /dev/null
+++ b/src/gzstream.cc
@@ -0,0 +1,165 @@
+// ============================================================================
+// gzstream, C++ iostream classes wrapping the zlib compression library.
+// Copyright (C) 2001 Deepak Bandyopadhyay, Lutz Kettner
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+// ============================================================================
+//
+// File : gzstream.C
+// Revision : $Revision: 1.1 $
+// Revision_date : $Date: 2006/03/30 04:05:52 $
+// Author(s) : Deepak Bandyopadhyay, Lutz Kettner
+//
+// Standard streambuf implementation following Nicolai Josuttis, "The
+// Standard C++ Library".
+// ============================================================================
+
+#include "gzstream.h"
+#include <iostream>
+#include <cstring>
+
+#ifdef GZSTREAM_NAMESPACE
+namespace GZSTREAM_NAMESPACE {
+#endif
+
+// ----------------------------------------------------------------------------
+// Internal classes to implement gzstream. See header file for user classes.
+// ----------------------------------------------------------------------------
+
+// --------------------------------------
+// class gzstreambuf:
+// --------------------------------------
+
+gzstreambuf* gzstreambuf::open( const char* name, int open_mode) {
+ if ( is_open())
+ return (gzstreambuf*)0;
+ mode = open_mode;
+ // no append nor read/write mode
+ if ((mode & std::ios::ate) || (mode & std::ios::app)
+ || ((mode & std::ios::in) && (mode & std::ios::out)))
+ return (gzstreambuf*)0;
+ char fmode[10];
+ char* fmodeptr = fmode;
+ if ( mode & std::ios::in)
+ *fmodeptr++ = 'r';
+ else if ( mode & std::ios::out)
+ *fmodeptr++ = 'w';
+ *fmodeptr++ = 'b';
+ *fmodeptr = '\0';
+ file = gzopen( name, fmode);
+ if (file == 0)
+ return (gzstreambuf*)0;
+ opened = 1;
+ return this;
+}
+
+gzstreambuf * gzstreambuf::close() {
+ if ( is_open()) {
+ sync();
+ opened = 0;
+ if ( gzclose( file) == Z_OK)
+ return this;
+ }
+ return (gzstreambuf*)0;
+}
+
+int gzstreambuf::underflow() { // used for input buffer only
+ if ( gptr() && ( gptr() < egptr()))
+ return * reinterpret_cast<unsigned char *>( gptr());
+
+ if ( ! (mode & std::ios::in) || ! opened)
+ return EOF;
+ // Josuttis' implementation of inbuf
+ int n_putback = gptr() - eback();
+ if ( n_putback > 4)
+ n_putback = 4;
+ memcpy( buffer + (4 - n_putback), gptr() - n_putback, n_putback);
+
+ int num = gzread( file, buffer+4, bufferSize-4);
+ if (num <= 0) // ERROR or EOF
+ return EOF;
+
+ // reset buffer pointers
+ setg( buffer + (4 - n_putback), // beginning of putback area
+ buffer + 4, // read position
+ buffer + 4 + num); // end of buffer
+
+ // return next character
+ return * reinterpret_cast<unsigned char *>( gptr());
+}
+
+int gzstreambuf::flush_buffer() {
+ // Separate the writing of the buffer from overflow() and
+ // sync() operation.
+ int w = pptr() - pbase();
+ if ( gzwrite( file, pbase(), w) != w)
+ return EOF;
+ pbump( -w);
+ return w;
+}
+
+int gzstreambuf::overflow( int c) { // used for output buffer only
+ if ( ! ( mode & std::ios::out) || ! opened)
+ return EOF;
+ if (c != EOF) {
+ *pptr() = c;
+ pbump(1);
+ }
+ if ( flush_buffer() == EOF)
+ return EOF;
+ return c;
+}
+
+int gzstreambuf::sync() {
+ // Changed to use flush_buffer() instead of overflow( EOF)
+ // which caused improper behavior with std::endl and flush(),
+ // bug reported by Vincent Ricard.
+ if ( pptr() && pptr() > pbase()) {
+ if ( flush_buffer() == EOF)
+ return -1;
+ }
+ return 0;
+}
+
+// --------------------------------------
+// class gzstreambase:
+// --------------------------------------
+
+gzstreambase::gzstreambase( const char* name, int mode) {
+ init( &buf);
+ open( name, mode);
+}
+
+gzstreambase::~gzstreambase() {
+ buf.close();
+}
+
+void gzstreambase::open( const char* name, int open_mode) {
+ if ( ! buf.open( name, open_mode))
+ clear( rdstate() | std::ios::badbit);
+}
+
+void gzstreambase::close() {
+ if ( buf.is_open())
+ if ( ! buf.close())
+ clear( rdstate() | std::ios::badbit);
+}
+
+#ifdef GZSTREAM_NAMESPACE
+} // namespace GZSTREAM_NAMESPACE
+#endif
+
+// ============================================================================
+// EOF //
diff --git a/src/gzstream.h b/src/gzstream.h
new file mode 100644
index 00000000..ad9785fd
--- /dev/null
+++ b/src/gzstream.h
@@ -0,0 +1,121 @@
+// ============================================================================
+// gzstream, C++ iostream classes wrapping the zlib compression library.
+// Copyright (C) 2001 Deepak Bandyopadhyay, Lutz Kettner
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+// ============================================================================
+//
+// File : gzstream.h
+// Revision : $Revision: 1.1 $
+// Revision_date : $Date: 2006/03/30 04:05:52 $
+// Author(s) : Deepak Bandyopadhyay, Lutz Kettner
+//
+// Standard streambuf implementation following Nicolai Josuttis, "The
+// Standard C++ Library".
+// ============================================================================
+
+#ifndef GZSTREAM_H
+#define GZSTREAM_H 1
+
+// standard C++ with new header file names and std:: namespace
+#include <iostream>
+#include <fstream>
+#include <zlib.h>
+
+#ifdef GZSTREAM_NAMESPACE
+namespace GZSTREAM_NAMESPACE {
+#endif
+
+// ----------------------------------------------------------------------------
+// Internal classes to implement gzstream. See below for user classes.
+// ----------------------------------------------------------------------------
+
+class gzstreambuf : public std::streambuf {
+private:
+ static const int bufferSize = 47+256; // size of data buff
+ // totals 512 bytes under g++ for igzstream at the end.
+
+ gzFile file; // file handle for compressed file
+ char buffer[bufferSize]; // data buffer
+ char opened; // open/close state of stream
+ int mode; // I/O mode
+
+ int flush_buffer();
+public:
+ gzstreambuf() : opened(0) {
+ setp( buffer, buffer + (bufferSize-1));
+ setg( buffer + 4, // beginning of putback area
+ buffer + 4, // read position
+ buffer + 4); // end position
+ // ASSERT: both input & output capabilities will not be used together
+ }
+ int is_open() { return opened; }
+ gzstreambuf* open( const char* name, int open_mode);
+ gzstreambuf* close();
+ ~gzstreambuf() { close(); }
+
+ virtual int overflow( int c = EOF);
+ virtual int underflow();
+ virtual int sync();
+};
+
+class gzstreambase : virtual public std::ios {
+protected:
+ gzstreambuf buf;
+public:
+ gzstreambase() { init(&buf); }
+ gzstreambase( const char* name, int open_mode);
+ ~gzstreambase();
+ void open( const char* name, int open_mode);
+ void close();
+ gzstreambuf* rdbuf() { return &buf; }
+};
+
+// ----------------------------------------------------------------------------
+// User classes. Use igzstream and ogzstream analogously to ifstream and
+// ofstream respectively. They read and write files based on the gz*
+// function interface of the zlib. Files are compatible with gzip compression.
+// ----------------------------------------------------------------------------
+
+class igzstream : public gzstreambase, public std::istream {
+public:
+ igzstream() : std::istream( &buf) {}
+ igzstream( const char* name, int open_mode = std::ios::in)
+ : gzstreambase( name, open_mode), std::istream( &buf) {}
+ gzstreambuf* rdbuf() { return gzstreambase::rdbuf(); }
+ void open( const char* name, int open_mode = std::ios::in) {
+ gzstreambase::open( name, open_mode);
+ }
+};
+
+class ogzstream : public gzstreambase, public std::ostream {
+public:
+ ogzstream() : std::ostream( &buf) {}
+ ogzstream( const char* name, int mode = std::ios::out)
+ : gzstreambase( name, mode), std::ostream( &buf) {}
+ gzstreambuf* rdbuf() { return gzstreambase::rdbuf(); }
+ void open( const char* name, int open_mode = std::ios::out) {
+ gzstreambase::open( name, open_mode);
+ }
+};
+
+#ifdef GZSTREAM_NAMESPACE
+} // namespace GZSTREAM_NAMESPACE
+#endif
+
+#endif // GZSTREAM_H
+// ============================================================================
+// EOF //
+
diff --git a/src/hg.cc b/src/hg.cc
new file mode 100644
index 00000000..dd8f8eba
--- /dev/null
+++ b/src/hg.cc
@@ -0,0 +1,483 @@
+#include "hg.h"
+
+#include <cassert>
+#include <numeric>
+#include <set>
+#include <map>
+#include <iostream>
+
+#include "viterbi.h"
+#include "inside_outside.h"
+#include "tdict.h"
+
+using namespace std;
+
+double Hypergraph::NumberOfPaths() const {
+ return Inside<double, TransitionCountWeightFunction>(*this);
+}
+
+prob_t Hypergraph::ComputeEdgePosteriors(double scale, vector<prob_t>* posts) const {
+ const ScaledEdgeProb weight(scale);
+ SparseVector<double> pv;
+ const double inside = InsideOutside<prob_t,
+ ScaledEdgeProb,
+ SparseVector<double>,
+ EdgeFeaturesWeightFunction>(*this, &pv, weight);
+ posts->resize(edges_.size());
+ for (int i = 0; i < edges_.size(); ++i)
+ (*posts)[i] = prob_t(pv.value(i));
+ return prob_t(inside);
+}
+
+prob_t Hypergraph::ComputeBestPathThroughEdges(vector<prob_t>* post) const {
+ vector<prob_t> in(edges_.size());
+ vector<prob_t> out(edges_.size());
+ post->resize(edges_.size());
+
+ vector<prob_t> ins_node_best(nodes_.size());
+ for (int i = 0; i < nodes_.size(); ++i) {
+ const Node& node = nodes_[i];
+ prob_t& node_ins_best = ins_node_best[i];
+ if (node.in_edges_.empty()) node_ins_best = prob_t::One();
+ for (int j = 0; j < node.in_edges_.size(); ++j) {
+ const Edge& edge = edges_[node.in_edges_[j]];
+ prob_t& in_edge_sco = in[node.in_edges_[j]];
+ in_edge_sco = edge.edge_prob_;
+ for (int k = 0; k < edge.tail_nodes_.size(); ++k)
+ in_edge_sco *= ins_node_best[edge.tail_nodes_[k]];
+ if (in_edge_sco > node_ins_best) node_ins_best = in_edge_sco;
+ }
+ }
+ const prob_t ins_sco = ins_node_best[nodes_.size() - 1];
+
+ // sanity check
+ int tots = 0;
+ for (int i = 0; i < nodes_.size(); ++i) { if (nodes_[i].out_edges_.empty()) tots++; }
+ assert(tots == 1);
+
+ // compute outside scores, potentially using inside scores
+ vector<prob_t> out_node_best(nodes_.size());
+ for (int i = nodes_.size() - 1; i >= 0; --i) {
+ const Node& node = nodes_[i];
+ prob_t& node_out_best = out_node_best[node.id_];
+ if (node.out_edges_.empty()) node_out_best = prob_t::One();
+ for (int j = 0; j < node.out_edges_.size(); ++j) {
+ const Edge& edge = edges_[node.out_edges_[j]];
+ prob_t sco = edge.edge_prob_ * out_node_best[edge.head_node_];
+ for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
+ if (edge.tail_nodes_[k] != i)
+ sco *= ins_node_best[edge.tail_nodes_[k]];
+ }
+ if (sco > node_out_best) node_out_best = sco;
+ }
+ for (int j = 0; j < node.in_edges_.size(); ++j) {
+ out[node.in_edges_[j]] = node_out_best;
+ }
+ }
+
+ for (int i = 0; i < in.size(); ++i)
+ (*post)[i] = in[i] * out[i];
+
+ return ins_sco;
+}
+
+void Hypergraph::PushWeightsToSource(double scale) {
+ vector<prob_t> posts;
+ ComputeEdgePosteriors(scale, &posts);
+ for (int i = 0; i < nodes_.size(); ++i) {
+ const Hypergraph::Node& node = nodes_[i];
+ prob_t z = prob_t::Zero();
+ for (int j = 0; j < node.out_edges_.size(); ++j)
+ z += posts[node.out_edges_[j]];
+ for (int j = 0; j < node.out_edges_.size(); ++j) {
+ edges_[node.out_edges_[j]].edge_prob_ = posts[node.out_edges_[j]] / z;
+ }
+ }
+}
+
+void Hypergraph::PushWeightsToGoal(double scale) {
+ vector<prob_t> posts;
+ ComputeEdgePosteriors(scale, &posts);
+ for (int i = 0; i < nodes_.size(); ++i) {
+ const Hypergraph::Node& node = nodes_[i];
+ prob_t z = prob_t::Zero();
+ for (int j = 0; j < node.in_edges_.size(); ++j)
+ z += posts[node.in_edges_[j]];
+ for (int j = 0; j < node.in_edges_.size(); ++j) {
+ edges_[node.in_edges_[j]].edge_prob_ = posts[node.in_edges_[j]] / z;
+ }
+ }
+}
+
+void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge) {
+ assert(prune_edge.size() == edges_.size());
+ TopologicallySortNodesAndEdges(nodes_.size() - 1, &prune_edge);
+}
+
+void Hypergraph::DensityPruneInsideOutside(const double scale,
+ const bool use_sum_prod_semiring,
+ const double density,
+ const vector<bool>* preserve_mask) {
+ assert(density >= 1.0);
+ const int plen = ViterbiPathLength(*this);
+ vector<WordID> bp;
+ int rnum = min(static_cast<int>(edges_.size()), static_cast<int>(density * static_cast<double>(plen)));
+ if (rnum == edges_.size()) {
+ cerr << "No pruning required: denisty already sufficient";
+ return;
+ }
+ vector<prob_t> io(edges_.size());
+ if (use_sum_prod_semiring)
+ ComputeEdgePosteriors(scale, &io);
+ else
+ ComputeBestPathThroughEdges(&io);
+ assert(edges_.size() == io.size());
+ vector<prob_t> sorted = io;
+ nth_element(sorted.begin(), sorted.begin() + rnum, sorted.end(), greater<prob_t>());
+ const double cutoff = sorted[rnum];
+ vector<bool> prune(edges_.size());
+ for (int i = 0; i < edges_.size(); ++i) {
+ prune[i] = (io[i] < cutoff);
+ if (preserve_mask && (*preserve_mask)[i]) prune[i] = false;
+ }
+ PruneEdges(prune);
+}
+
+void Hypergraph::BeamPruneInsideOutside(
+ const double scale,
+ const bool use_sum_prod_semiring,
+ const double alpha,
+ const vector<bool>* preserve_mask) {
+ assert(alpha > 0.0);
+ assert(scale > 0.0);
+ vector<prob_t> io(edges_.size());
+ if (use_sum_prod_semiring)
+ ComputeEdgePosteriors(scale, &io);
+ else
+ ComputeBestPathThroughEdges(&io);
+ assert(edges_.size() == io.size());
+ prob_t best; // initializes to zero
+ for (int i = 0; i < io.size(); ++i)
+ if (io[i] > best) best = io[i];
+ const prob_t aprob(exp(-alpha));
+ const prob_t cutoff = best * aprob;
+ vector<bool> prune(edges_.size());
+ //cerr << preserve_mask.size() << " " << edges_.size() << endl;
+ int pc = 0;
+ for (int i = 0; i < io.size(); ++i) {
+ const bool prune_edge = (io[i] < cutoff);
+ if (prune_edge) ++pc;
+ prune[i] = (io[i] < cutoff);
+ if (preserve_mask && (*preserve_mask)[i]) prune[i] = false;
+ }
+ cerr << "Beam pruning " << pc << "/" << io.size() << " edges\n";
+ PruneEdges(prune);
+}
+
+void Hypergraph::PrintGraphviz() const {
+ int ei = 0;
+ cerr << "digraph G {\n rankdir=LR;\n nodesep=.05;\n";
+ for (vector<Edge>::const_iterator i = edges_.begin();
+ i != edges_.end(); ++i) {
+ const Edge& edge=*i;
+ ++ei;
+ static const string none = "<null>";
+ string rule = (edge.rule_ ? edge.rule_->AsString(false) : none);
+
+ cerr << " A_" << ei << " [label=\"" << rule << " p=" << edge.edge_prob_
+ << " F:" << edge.feature_values_
+ << "\" shape=\"rect\"];\n";
+ for (int i = 0; i < edge.tail_nodes_.size(); ++i) {
+ cerr << " " << edge.tail_nodes_[i] << " -> A_" << ei << ";\n";
+ }
+ cerr << " A_" << ei << " -> " << edge.head_node_ << ";\n";
+ }
+ for (vector<Node>::const_iterator ni = nodes_.begin();
+ ni != nodes_.end(); ++ni) {
+ cerr << " " << ni->id_ << "[label=\"" << (ni->cat_ < 0 ? TD::Convert(ni->cat_ * -1) : "")
+ //cerr << " " << ni->id_ << "[label=\"" << ni->cat_
+ << " n=" << ni->id_
+// << ",x=" << &*ni
+// << ",in=" << ni->in_edges_.size()
+// << ",out=" << ni->out_edges_.size()
+ << "\"];\n";
+ }
+ cerr << "}\n";
+}
+
+void Hypergraph::Union(const Hypergraph& other) {
+ if (&other == this) return;
+ if (nodes_.empty()) { nodes_ = other.nodes_; edges_ = other.edges_; return; }
+ int noff = nodes_.size();
+ int eoff = edges_.size();
+ int ogoal = other.nodes_.size() - 1;
+ int cgoal = noff - 1;
+ // keep a single goal node, so add nodes.size - 1
+ nodes_.resize(nodes_.size() + ogoal);
+ // add all edges
+ edges_.resize(edges_.size() + other.edges_.size());
+
+ for (int i = 0; i < ogoal; ++i) {
+ const Node& on = other.nodes_[i];
+ Node& cn = nodes_[i + noff];
+ cn.id_ = i + noff;
+ cn.in_edges_.resize(on.in_edges_.size());
+ for (int j = 0; j < on.in_edges_.size(); ++j)
+ cn.in_edges_[j] = on.in_edges_[j] + eoff;
+
+ cn.out_edges_.resize(on.out_edges_.size());
+ for (int j = 0; j < on.out_edges_.size(); ++j)
+ cn.out_edges_[j] = on.out_edges_[j] + eoff;
+ }
+
+ for (int i = 0; i < other.edges_.size(); ++i) {
+ const Edge& oe = other.edges_[i];
+ Edge& ce = edges_[i + eoff];
+ ce.id_ = i + eoff;
+ ce.rule_ = oe.rule_;
+ ce.feature_values_ = oe.feature_values_;
+ if (oe.head_node_ == ogoal) {
+ ce.head_node_ = cgoal;
+ nodes_[cgoal].in_edges_.push_back(ce.id_);
+ } else {
+ ce.head_node_ = oe.head_node_ + noff;
+ }
+ ce.tail_nodes_.resize(oe.tail_nodes_.size());
+ for (int j = 0; j < oe.tail_nodes_.size(); ++j)
+ ce.tail_nodes_[j] = oe.tail_nodes_[j] + noff;
+ }
+
+ TopologicallySortNodesAndEdges(cgoal);
+}
+
+int Hypergraph::MarkReachable(const Node& node,
+ vector<bool>* rmap,
+ const vector<bool>* prune_edges) const {
+ int total = 0;
+ if (!(*rmap)[node.id_]) {
+ total = 1;
+ (*rmap)[node.id_] = true;
+ for (int i = 0; i < node.in_edges_.size(); ++i) {
+ if (!(prune_edges && (*prune_edges)[node.in_edges_[i]])) {
+ for (int j = 0; j < edges_[node.in_edges_[i]].tail_nodes_.size(); ++j)
+ total += MarkReachable(nodes_[edges_[node.in_edges_[i]].tail_nodes_[j]], rmap, prune_edges);
+ }
+ }
+ }
+ return total;
+}
+
+void Hypergraph::PruneUnreachable(int goal_node_id) {
+ TopologicallySortNodesAndEdges(goal_node_id, NULL);
+}
+
+void Hypergraph::RemoveNoncoaccessibleStates(int goal_node_id) {
+ if (goal_node_id < 0) goal_node_id += nodes_.size();
+ assert(goal_node_id >= 0);
+ assert(goal_node_id < nodes_.size());
+
+ // TODO finish implementation
+ abort();
+}
+
+void Hypergraph::TopologicallySortNodesAndEdges(int goal_index,
+ const vector<bool>* prune_edges) {
+ vector<Edge> sedges(edges_.size());
+ // figure out which nodes are reachable from the goal
+ vector<bool> reachable(nodes_.size(), false);
+ int num_reachable = MarkReachable(nodes_[goal_index], &reachable, prune_edges);
+ vector<Node> snodes(num_reachable); snodes.clear();
+
+ // enumerate all reachable nodes in topologically sorted order
+ vector<int> old_node_to_new_id(nodes_.size(), -1);
+ vector<int> node_to_incount(nodes_.size(), -1);
+ vector<bool> node_processed(nodes_.size(), false);
+ typedef map<int, set<int> > PQueue;
+ PQueue pri_q;
+ for (int i = 0; i < nodes_.size(); ++i) {
+ if (!reachable[i])
+ continue;
+ const int inedges = nodes_[i].in_edges_.size();
+ int incount = inedges;
+ for (int j = 0; j < inedges; ++j)
+ if (edges_[nodes_[i].in_edges_[j]].tail_nodes_.size() == 0 ||
+ (prune_edges && (*prune_edges)[nodes_[i].in_edges_[j]]))
+ --incount;
+ // cerr << &nodes_[i] <<" : incount=" << incount << "\tout=" << nodes_[i].out_edges_.size() << "\t(in-edges=" << inedges << ")\n";
+ assert(node_to_incount[i] == -1);
+ node_to_incount[i] = incount;
+ pri_q[incount].insert(i);
+ }
+
+ int edge_count = 0;
+ while (!pri_q.empty()) {
+ PQueue::iterator iter = pri_q.find(0);
+ assert(iter != pri_q.end());
+ assert(!iter->second.empty());
+
+ // get first node with incount = 0
+ const int cur_index = *iter->second.begin();
+ const Node& node = nodes_[cur_index];
+ assert(reachable[cur_index]);
+ //cerr << "node: " << node << endl;
+ const int new_node_index = snodes.size();
+ old_node_to_new_id[cur_index] = new_node_index;
+ snodes.push_back(node);
+ Node& new_node = snodes.back();
+ new_node.id_ = new_node_index;
+ new_node.out_edges_.clear();
+
+ // fix up edges - we can now process the in edges and
+ // the out edges of their tails
+ int oi = 0;
+ for (int i = 0; i < node.in_edges_.size(); ++i, ++oi) {
+ if (prune_edges && (*prune_edges)[node.in_edges_[i]]) {
+ --oi;
+ continue;
+ }
+ new_node.in_edges_[oi] = edge_count;
+ Edge& edge = sedges[edge_count];
+ edge.id_ = edge_count;
+ ++edge_count;
+ const Edge& old_edge = edges_[node.in_edges_[i]];
+ edge.rule_ = old_edge.rule_;
+ edge.feature_values_ = old_edge.feature_values_;
+ edge.head_node_ = new_node_index;
+ edge.tail_nodes_.resize(old_edge.tail_nodes_.size());
+ edge.edge_prob_ = old_edge.edge_prob_;
+ edge.i_ = old_edge.i_;
+ edge.j_ = old_edge.j_;
+ edge.prev_i_ = old_edge.prev_i_;
+ edge.prev_j_ = old_edge.prev_j_;
+ for (int j = 0; j < old_edge.tail_nodes_.size(); ++j) {
+ const Node& old_tail_node = nodes_[old_edge.tail_nodes_[j]];
+ edge.tail_nodes_[j] = old_node_to_new_id[old_tail_node.id_];
+ snodes[edge.tail_nodes_[j]].out_edges_.push_back(edge_count-1);
+ assert(edge.tail_nodes_[j] != new_node_index);
+ }
+ }
+ assert(oi <= new_node.in_edges_.size());
+ new_node.in_edges_.resize(oi);
+
+ for (int i = 0; i < node.out_edges_.size(); ++i) {
+ const Edge& edge = edges_[node.out_edges_[i]];
+ const int next_index = edge.head_node_;
+ assert(cur_index != next_index);
+ if (!reachable[next_index]) continue;
+ if (prune_edges && (*prune_edges)[edge.id_]) continue;
+
+ bool dontReduce = false;
+ for (int j = 0; j < edge.tail_nodes_.size() && !dontReduce; ++j) {
+ int tail_index = edge.tail_nodes_[j];
+ dontReduce = (tail_index != cur_index) && !node_processed[tail_index];
+ }
+ if (dontReduce)
+ continue;
+
+ const int incount = node_to_incount[next_index];
+ if (incount <= 0) {
+ cerr << "incount = " << incount << ", should be > 0!\n";
+ cerr << "do you have a cycle in your hypergraph?\n";
+ abort();
+ }
+ PQueue::iterator it = pri_q.find(incount);
+ assert(it != pri_q.end());
+ it->second.erase(next_index);
+ if (it->second.empty()) pri_q.erase(it);
+
+ // reinsert node with reduced incount
+ pri_q[incount-1].insert(next_index);
+ --node_to_incount[next_index];
+ }
+
+ // remove node from set
+ iter->second.erase(cur_index);
+ if (iter->second.empty())
+ pri_q.erase(iter);
+ node_processed[cur_index] = true;
+ }
+
+ sedges.resize(edge_count);
+ nodes_.swap(snodes);
+ edges_.swap(sedges);
+ assert(nodes_.back().out_edges_.size() == 0);
+}
+
+TRulePtr Hypergraph::kEPSRule;
+TRulePtr Hypergraph::kUnaryRule;
+
+void Hypergraph::EpsilonRemove(WordID eps) {
+ if (!kEPSRule) {
+ kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>"));
+ kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]"));
+ }
+ vector<bool> kill(edges_.size(), false);
+ for (int i = 0; i < edges_.size(); ++i) {
+ const Edge& edge = edges_[i];
+ if (edge.tail_nodes_.empty() &&
+ edge.rule_->f_.size() == 1 &&
+ edge.rule_->f_[0] == eps) {
+ kill[i] = true;
+ if (!edge.feature_values_.empty()) {
+ Node& node = nodes_[edge.head_node_];
+ if (node.in_edges_.size() != 1) {
+ cerr << "[WARNING] <eps> edge with features going into non-empty node - can't promote\n";
+ // this *probably* means that there are multiple derivations of the
+ // same sequence via different paths through the input forest
+ // this needs to be investigated and fixed
+ } else {
+ for (int j = 0; j < node.out_edges_.size(); ++j)
+ edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_;
+ // cerr << "PROMOTED " << edge.feature_values_ << endl;
+ }
+ }
+ }
+ }
+ bool created_eps = false;
+ PruneEdges(kill);
+ for (int i = 0; i < nodes_.size(); ++i) {
+ const Node& node = nodes_[i];
+ if (node.in_edges_.empty()) {
+ for (int j = 0; j < node.out_edges_.size(); ++j) {
+ Edge& edge = edges_[node.out_edges_[j]];
+ if (edge.rule_->Arity() == 2) {
+ assert(edge.rule_->f_.size() == 2);
+ assert(edge.rule_->e_.size() == 2);
+ edge.rule_ = kUnaryRule;
+ int cur = node.id_;
+ int t = -1;
+ assert(edge.tail_nodes_.size() == 2);
+ for (int i = 0; i < 2; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; }
+ assert(t != -1);
+ edge.tail_nodes_.resize(1);
+ edge.tail_nodes_[0] = t;
+ } else {
+ edge.rule_ = kEPSRule;
+ edge.rule_->f_[0] = eps;
+ edge.rule_->e_[0] = eps;
+ edge.tail_nodes_.clear();
+ created_eps = true;
+ }
+ }
+ }
+ }
+ vector<bool> k2(edges_.size(), false);
+ PruneEdges(k2);
+ if (created_eps) EpsilonRemove(eps);
+}
+
+struct EdgeWeightSorter {
+ const Hypergraph& hg;
+ EdgeWeightSorter(const Hypergraph& h) : hg(h) {}
+ bool operator()(int a, int b) const {
+ return hg.edges_[a].edge_prob_ > hg.edges_[b].edge_prob_;
+ }
+};
+
+void Hypergraph::SortInEdgesByEdgeWeights() {
+ for (int i = 0; i < nodes_.size(); ++i) {
+ Node& node = nodes_[i];
+ sort(node.in_edges_.begin(), node.in_edges_.end(), EdgeWeightSorter(*this));
+ }
+}
+
diff --git a/src/hg.h b/src/hg.h
new file mode 100644
index 00000000..7a2658b8
--- /dev/null
+++ b/src/hg.h
@@ -0,0 +1,225 @@
+#ifndef _HG_H_
+#define _HG_H_
+
+#include <string>
+#include <vector>
+
+#include "small_vector.h"
+#include "sparse_vector.h"
+#include "wordid.h"
+#include "trule.h"
+#include "prob.h"
+
+// class representing an acyclic hypergraph
+// - edges have 1 head, 0..n tails
+class Hypergraph {
+ public:
+ Hypergraph() {}
+
+ // SmallVector is a fast, small vector<int> implementation for sizes <= 2
+ typedef SmallVector TailNodeVector;
+
+ // TODO get rid of state_ and cat_?
+ struct Node {
+ Node() : id_(), cat_() {}
+ int id_; // equal to this object's position in the nodes_ vector
+ WordID cat_; // non-terminal category if <0, 0 if not set
+ std::vector<int> in_edges_; // contents refer to positions in edges_
+ std::vector<int> out_edges_; // contents refer to positions in edges_
+ std::string state_; // opaque state
+ };
+
+ // TODO get rid of edge_prob_? (can be computed on the fly as the dot
+ // product of the weight vector and the feature values)
+ struct Edge {
+ Edge() : i_(-1), j_(-1), prev_i_(-1), prev_j_(-1) {}
+ inline int Arity() const { return tail_nodes_.size(); }
+ int head_node_; // refers to a position in nodes_
+ TailNodeVector tail_nodes_; // contents refer to positions in nodes_
+ TRulePtr rule_;
+ SparseVector<double> feature_values_;
+ prob_t edge_prob_; // dot product of weights and feat_values
+ int id_; // equal to this object's position in the edges_ vector
+
+ // span info. typically, i_ and j_ refer to indices in the source sentence
+ // if a synchronous parse has been executed i_ and j_ will refer to indices
+ // in the target sentence / lattice and prev_i_ prev_j_ will refer to
+ // positions in the source. Note: it is up to the translator implementation
+ // to properly set these values. For some models (like the Forest-input
+ // phrase based model) it may not be straightforward to do. if these values
+ // are not properly set, most things will work but alignment and any features
+ // that depend on them will be broken.
+ short int i_;
+ short int j_;
+ short int prev_i_;
+ short int prev_j_;
+ };
+
+ void swap(Hypergraph& other) {
+ other.nodes_.swap(nodes_);
+ other.edges_.swap(edges_);
+ }
+
+ void ResizeNodes(int size) {
+ nodes_.resize(size);
+ for (int i = 0; i < size; ++i) nodes_[i].id_ = i;
+ }
+
+ // reserves space in the nodes vector to prevent memory locations
+ // from changing
+ void ReserveNodes(size_t n, size_t e = 0) {
+ nodes_.reserve(n);
+ if (e) edges_.reserve(e);
+ }
+
+ Edge* AddEdge(const TRulePtr& rule, const TailNodeVector& tail) {
+ edges_.push_back(Edge());
+ Edge* edge = &edges_.back();
+ edge->rule_ = rule;
+ edge->tail_nodes_ = tail;
+ edge->id_ = edges_.size() - 1;
+ for (int i = 0; i < edge->tail_nodes_.size(); ++i)
+ nodes_[edge->tail_nodes_[i]].out_edges_.push_back(edge->id_);
+ return edge;
+ }
+
+ Node* AddNode(const WordID& cat, const std::string& state = "") {
+ nodes_.push_back(Node());
+ nodes_.back().cat_ = cat;
+ nodes_.back().state_ = state;
+ nodes_.back().id_ = nodes_.size() - 1;
+ return &nodes_.back();
+ }
+
+ void ConnectEdgeToHeadNode(const int edge_id, const int head_id) {
+ edges_[edge_id].head_node_ = head_id;
+ nodes_[head_id].in_edges_.push_back(edge_id);
+ }
+
+ // TODO remove this - use the version that takes indices
+ void ConnectEdgeToHeadNode(Edge* edge, Node* head) {
+ edge->head_node_ = head->id_;
+ head->in_edges_.push_back(edge->id_);
+ }
+
+ // merge the goal node from other with this goal node
+ void Union(const Hypergraph& other);
+
+ void PrintGraphviz() const;
+
+ // compute the total number of paths in the forest
+ double NumberOfPaths() const;
+
+ // BEWARE. this assumes that the source and target language
+ // strings are identical and that there are no loops.
+ // It assumes a bunch of other things about where the
+ // epsilons will be. It tries to assert failure if you
+ // break these assumptions, but it may not.
+ // TODO - make this work
+ void EpsilonRemove(WordID eps);
+
+ // multiple the weights vector by the edge feature vector
+ // (inner product) to set the edge probabilities
+ template <typename V>
+ void Reweight(const V& weights) {
+ for (int i = 0; i < edges_.size(); ++i) {
+ Edge& e = edges_[i];
+ e.edge_prob_.logeq(e.feature_values_.dot(weights));
+ }
+ }
+
+ // computes inside and outside scores for each
+ // edge in the hypergraph
+ // alpha->size = edges_.size = beta->size
+ // returns inside prob of goal node
+ prob_t ComputeEdgePosteriors(double scale,
+ std::vector<prob_t>* posts) const;
+
+ // find the score of the very best path passing through each edge
+ prob_t ComputeBestPathThroughEdges(std::vector<prob_t>* posts) const;
+
+ // move weights as near to the source as possible, resulting in a
+ // stochastic automaton. ONLY FUNCTIONAL FOR *LATTICES*.
+ // See M. Mohri and M. Riley. A Weight Pushing Algorithm for Large
+ // Vocabulary Speech Recognition. 2001.
+ // the log semiring (NOT tropical) is used
+ void PushWeightsToSource(double scale = 1.0);
+ // same, except weights are pushed to the goal, works for HGs,
+ // not just lattices
+ void PushWeightsToGoal(double scale = 1.0);
+
+ void SortInEdgesByEdgeWeights();
+
+ void PruneUnreachable(int goal_node_id); // DEPRECATED
+
+ void RemoveNoncoaccessibleStates(int goal_node_id = -1);
+
+ // remove edges from the hypergraph if prune_edge[edge_id] is true
+ void PruneEdges(const std::vector<bool>& prune_edge);
+
+ // if you don't know, use_sum_prod_semiring should be false
+ void DensityPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double density,
+ const std::vector<bool>* preserve_mask = NULL);
+
+ // prunes any edge whose score on the best path taking that edge is more than alpha away
+ // from the score of the global best past (or the highest edge posterior)
+ void BeamPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double alpha,
+ const std::vector<bool>* preserve_mask = NULL);
+
+ void clear() {
+ nodes_.clear();
+ edges_.clear();
+ }
+
+ inline size_t NumberOfEdges() const { return edges_.size(); }
+ inline size_t NumberOfNodes() const { return nodes_.size(); }
+ inline bool empty() const { return nodes_.empty(); }
+
+ // nodes_ is sorted in topological order
+ std::vector<Node> nodes_;
+ // edges_ is not guaranteed to be in any particular order
+ std::vector<Edge> edges_;
+
+ // reorder nodes_ so they are in topological order
+ // source nodes at 0 sink nodes at size-1
+ void TopologicallySortNodesAndEdges(int goal_idx,
+ const std::vector<bool>* prune_edges = NULL);
+ private:
+ // returns total nodes reachable
+ int MarkReachable(const Node& node,
+ std::vector<bool>* rmap,
+ const std::vector<bool>* prune_edges) const;
+
+ static TRulePtr kEPSRule;
+ static TRulePtr kUnaryRule;
+};
+
+// common WeightFunctions, map an edge -> WeightType
+// for generic Viterbi/Inside algorithms
+struct EdgeProb {
+ inline const prob_t& operator()(const Hypergraph::Edge& e) const { return e.edge_prob_; }
+};
+
+struct ScaledEdgeProb {
+ ScaledEdgeProb(const double& alpha) : alpha_(alpha) {}
+ inline prob_t operator()(const Hypergraph::Edge& e) const { return e.edge_prob_.pow(alpha_); }
+ const double alpha_;
+};
+
+struct EdgeFeaturesWeightFunction {
+ inline const SparseVector<double>& operator()(const Hypergraph::Edge& e) const { return e.feature_values_; }
+};
+
+struct TransitionEventWeightFunction {
+ inline SparseVector<prob_t> operator()(const Hypergraph::Edge& e) const {
+ SparseVector<prob_t> result;
+ result.set_value(e.id_, prob_t::One());
+ return result;
+ }
+};
+
+struct TransitionCountWeightFunction {
+ inline double operator()(const Hypergraph::Edge& e) const { (void)e; return 1.0; }
+};
+
+#endif
diff --git a/src/hg_intersect.cc b/src/hg_intersect.cc
new file mode 100644
index 00000000..a5e8913a
--- /dev/null
+++ b/src/hg_intersect.cc
@@ -0,0 +1,121 @@
+#include "hg_intersect.h"
+
+#include <vector>
+#include <tr1/unordered_map>
+#include <boost/lexical_cast.hpp>
+#include <boost/functional/hash.hpp>
+
+#include "tdict.h"
+#include "hg.h"
+#include "trule.h"
+#include "wordid.h"
+#include "bottom_up_parser.h"
+
+using boost::lexical_cast;
+using namespace std::tr1;
+using namespace std;
+
+struct RuleFilter {
+ unordered_map<vector<WordID>, bool, boost::hash<vector<WordID> > > exists_;
+ bool true_lattice;
+ RuleFilter(const Lattice& target, int max_phrase_size) {
+ true_lattice = false;
+ for (int i = 0; i < target.size(); ++i) {
+ vector<WordID> phrase;
+ int lim = min(static_cast<int>(target.size()), i + max_phrase_size);
+ for (int j = i; j < lim; ++j) {
+ if (target[j].size() > 1) { true_lattice = true; break; }
+ phrase.push_back(target[j][0].label);
+ exists_[phrase] = true;
+ }
+ }
+ vector<WordID> sos(1, TD::Convert("<s>"));
+ exists_[sos] = true;
+ }
+ bool operator()(const TRule& r) const {
+ // TODO do some smarter filtering for lattices
+ if (true_lattice) return false; // don't filter "true lattice" input
+ const vector<WordID>& e = r.e();
+ for (int i = 0; i < e.size(); ++i) {
+ if (e[i] <= 0) continue;
+ vector<WordID> phrase;
+ for (int j = i; j < e.size(); ++j) {
+ if (e[j] <= 0) break;
+ phrase.push_back(e[j]);
+ if (exists_.count(phrase) == 0) return true;
+ }
+ }
+ return false;
+ }
+};
+
+bool HG::Intersect(const Lattice& target, Hypergraph* hg) {
+ vector<bool> rem(hg->edges_.size(), false);
+ const RuleFilter filter(target, 15); // TODO make configurable
+ for (int i = 0; i < rem.size(); ++i)
+ rem[i] = filter(*hg->edges_[i].rule_);
+ hg->PruneEdges(rem);
+
+ const int nedges = hg->edges_.size();
+ const int nnodes = hg->nodes_.size();
+
+ TextGrammar* g = new TextGrammar;
+ GrammarPtr gp(g);
+ vector<int> cats(nnodes);
+ // each node in the translation forest becomes a "non-terminal" in the new
+ // grammar, create the labels here
+ for (int i = 0; i < nnodes; ++i)
+ cats[i] = TD::Convert("CAT_" + lexical_cast<string>(i)) * -1;
+
+ // construct the grammar
+ for (int i = 0; i < nedges; ++i) {
+ const Hypergraph::Edge& edge = hg->edges_[i];
+ const vector<WordID>& tgt = edge.rule_->e();
+ const vector<WordID>& src = edge.rule_->f();
+ TRulePtr rule(new TRule);
+ rule->prev_i = edge.i_;
+ rule->prev_j = edge.j_;
+ rule->lhs_ = cats[edge.head_node_];
+ vector<WordID>& f = rule->f_;
+ vector<WordID>& e = rule->e_;
+ f.resize(tgt.size()); // swap source and target, since the parser
+ e.resize(src.size()); // parses using the source side!
+ Hypergraph::TailNodeVector tn(edge.tail_nodes_.size());
+ int ntc = 0;
+ for (int j = 0; j < tgt.size(); ++j) {
+ const WordID& cur = tgt[j];
+ if (cur > 0) {
+ f[j] = cur;
+ } else {
+ tn[ntc++] = cur;
+ f[j] = cats[edge.tail_nodes_[-cur]];
+ }
+ }
+ ntc = 0;
+ for (int j = 0; j < src.size(); ++j) {
+ const WordID& cur = src[j];
+ if (cur > 0) {
+ e[j] = cur;
+ } else {
+ e[j] = tn[ntc++];
+ }
+ }
+ rule->scores_ = edge.feature_values_;
+ rule->parent_rule_ = edge.rule_;
+ rule->ComputeArity();
+ //cerr << "ADD: " << rule->AsString() << endl;
+
+ g->AddRule(rule);
+ }
+ g->SetMaxSpan(target.size() + 1);
+ const string& new_goal = TD::Convert(cats.back() * -1);
+ vector<GrammarPtr> grammars(1, gp);
+ Hypergraph tforest;
+ ExhaustiveBottomUpParser parser(new_goal, grammars);
+ if (!parser.Parse(target, &tforest))
+ return false;
+ else
+ hg->swap(tforest);
+ return true;
+}
+
diff --git a/src/hg_intersect.h b/src/hg_intersect.h
new file mode 100644
index 00000000..826bdaae
--- /dev/null
+++ b/src/hg_intersect.h
@@ -0,0 +1,13 @@
+#ifndef _HG_INTERSECT_H_
+#define _HG_INTERSECT_H_
+
+#include <vector>
+
+#include "lattice.h"
+
+class Hypergraph;
+struct HG {
+ static bool Intersect(const Lattice& target, Hypergraph* hg);
+};
+
+#endif
diff --git a/src/hg_io.cc b/src/hg_io.cc
new file mode 100644
index 00000000..629e65f1
--- /dev/null
+++ b/src/hg_io.cc
@@ -0,0 +1,585 @@
+#include "hg_io.h"
+
+#include <sstream>
+#include <iostream>
+
+#include <boost/lexical_cast.hpp>
+
+#include "tdict.h"
+#include "json_parse.h"
+#include "hg.h"
+
+using namespace std;
+
+struct HGReader : public JSONParser {
+ HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; }
+
+ void CreateNode(const string& cat, const vector<int>& in_edges) {
+ WordID c = TD::Convert("X") * -1;
+ if (!cat.empty()) c = TD::Convert(cat) * -1;
+ Hypergraph::Node* node = hg.AddNode(c, "");
+ for (int i = 0; i < in_edges.size(); ++i) {
+ if (in_edges[i] >= hg.edges_.size()) {
+ cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i]
+ << ", but hg only has " << hg.edges_.size() << " edges!\n";
+ abort();
+ }
+ hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node);
+ }
+ }
+ void CreateEdge(const TRulePtr& rule, SparseVector<double>* feats, const SmallVector& tail) {
+ Hypergraph::Edge* edge = hg.AddEdge(rule, tail);
+ feats->swap(edge->feature_values_);
+ }
+
+ bool HandleJSONEvent(int type, const JSON_value* value) {
+ switch(state) {
+ case -1:
+ assert(type == JSON_T_OBJECT_BEGIN);
+ state = 0;
+ break;
+ case 0:
+ if (type == JSON_T_OBJECT_END) {
+ //cerr << "HG created\n"; // TODO, signal some kind of callback
+ } else if (type == JSON_T_KEY) {
+ string val = value->vu.str.value;
+ if (val == "features") { assert(fdict.empty()); state = 1; }
+ else if (val == "is_sorted") { state = 3; }
+ else if (val == "rules") { assert(rules.empty()); state = 4; }
+ else if (val == "node") { state = 8; }
+ else if (val == "edges") { state = 13; }
+ else { cerr << "Unexpected key: " << val << endl; return false; }
+ }
+ break;
+
+ // features
+ case 1:
+ if(type == JSON_T_NULL) { state = 0; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 2;
+ break;
+ case 2:
+ if(type == JSON_T_ARRAY_END) { state = 0; break; }
+ assert(type == JSON_T_STRING);
+ fdict.push_back(FD::Convert(value->vu.str.value));
+ break;
+
+ // is_sorted
+ case 3:
+ assert(type == JSON_T_TRUE || type == JSON_T_FALSE);
+ is_sorted = (type == JSON_T_TRUE);
+ if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; }
+ state = 0;
+ break;
+
+ // rules
+ case 4:
+ if(type == JSON_T_NULL) { state = 0; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 5;
+ break;
+ case 5:
+ if(type == JSON_T_ARRAY_END) { state = 0; break; }
+ assert(type == JSON_T_INTEGER);
+ state = 6;
+ rule_id = value->vu.integer_value;
+ break;
+ case 6:
+ assert(type == JSON_T_STRING);
+ rules[rule_id] = TRulePtr(new TRule(value->vu.str.value));
+ state = 5;
+ break;
+
+ // Nodes
+ case 8:
+ assert(type == JSON_T_OBJECT_BEGIN);
+ ++nodes;
+ in_edges.clear();
+ cat.clear();
+ state = 9; break;
+ case 9:
+ if (type == JSON_T_OBJECT_END) {
+ //cerr << "Creating NODE\n";
+ CreateNode(cat, in_edges);
+ state = 0; break;
+ }
+ assert(type == JSON_T_KEY);
+ cur_key = value->vu.str.value;
+ if (cur_key == "cat") { assert(cat.empty()); state = 10; break; }
+ if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; }
+ cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n";
+ return false;
+ case 10:
+ assert(type == JSON_T_STRING || type == JSON_T_NULL);
+ cat = value->vu.str.value;
+ state = 9; break;
+ case 11:
+ if (type == JSON_T_NULL) { state = 9; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 12; break;
+ case 12:
+ if (type == JSON_T_ARRAY_END) { state = 9; break; }
+ assert(type == JSON_T_INTEGER);
+ //cerr << "in_edges: " << value->vu.integer_value << endl;
+ in_edges.push_back(value->vu.integer_value);
+ break;
+
+ // "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12},
+ // { "tail": null, "feats" : [0,0.87,1,0.02], "rule": 17},
+ // { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}]
+ case 13:
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 14;
+ break;
+ case 14:
+ if (type == JSON_T_ARRAY_END) { state = 0; break; }
+ assert(type == JSON_T_OBJECT_BEGIN);
+ //cerr << "New edge\n";
+ ++edges;
+ cur_rule.reset(); feats.clear(); tail.clear();
+ state = 15; break;
+ case 15:
+ if (type == JSON_T_OBJECT_END) {
+ CreateEdge(cur_rule, &feats, tail);
+ state = 14; break;
+ }
+ assert(type == JSON_T_KEY);
+ cur_key = value->vu.str.value;
+ //cerr << "edge key " << cur_key << endl;
+ if (cur_key == "rule") { assert(!cur_rule); state = 16; break; }
+ if (cur_key == "feats") { assert(feats.empty()); state = 17; break; }
+ if (cur_key == "tail") { assert(tail.empty()); state = 20; break; }
+ cerr << "Unexpected key " << cur_key << " in edge specification\n";
+ return false;
+ case 16: // edge.rule
+ if (type == JSON_T_INTEGER) {
+ int rule_id = value->vu.integer_value;
+ if (rules.find(rule_id) == rules.end()) {
+ // rules list must come before the edge definitions!
+ cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n";
+ return false;
+ }
+ cur_rule = rules[rule_id];
+ } else if (type == JSON_T_STRING) {
+ cur_rule.reset(new TRule(value->vu.str.value));
+ } else {
+ cerr << "Rule must be either a rule id or a rule string" << endl;
+ return false;
+ }
+ // cerr << "Edge: rule=" << cur_rule->AsString() << endl;
+ state = 15;
+ break;
+ case 17: // edge.feats
+ if (type == JSON_T_NULL) { state = 15; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 18; break;
+ case 18:
+ if (type == JSON_T_ARRAY_END) { state = 15; break; }
+ if (type != JSON_T_INTEGER && type != JSON_T_STRING) {
+ cerr << "Unexpected feature id type\n"; return false;
+ }
+ if (type == JSON_T_INTEGER) {
+ fid = value->vu.integer_value;
+ assert(fid < fdict.size());
+ fid = fdict[fid];
+ } else if (JSON_T_STRING) {
+ fid = FD::Convert(value->vu.str.value);
+ } else { abort(); }
+ state = 19;
+ break;
+ case 19:
+ {
+ assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT);
+ double val = (type == JSON_T_INTEGER ? static_cast<double>(value->vu.integer_value) :
+ strtod(value->vu.str.value, NULL));
+ feats.set_value(fid, val);
+ state = 18;
+ break;
+ }
+ case 20: // edge.tail
+ if (type == JSON_T_NULL) { state = 15; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 21; break;
+ case 21:
+ if (type == JSON_T_ARRAY_END) { state = 15; break; }
+ assert(type == JSON_T_INTEGER);
+ tail.push_back(value->vu.integer_value);
+ break;
+ }
+ return true;
+ }
+ string rp;
+ string cat;
+ SmallVector tail;
+ vector<int> in_edges;
+ TRulePtr cur_rule;
+ map<int, TRulePtr> rules;
+ vector<int> fdict;
+ SparseVector<double> feats;
+ int state;
+ int fid;
+ int nodes;
+ int edges;
+ string cur_key;
+ Hypergraph& hg;
+ int rule_id;
+ bool nodes_needed;
+ bool edges_needed;
+ bool is_sorted;
+};
+
+bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) {
+ hg->clear();
+ HGReader reader(hg);
+ return reader.Parse(in);
+}
+
+static void WriteRule(const TRule& r, ostream* out) {
+ if (!r.lhs_) { (*out) << "[X] ||| "; }
+ JSONParser::WriteEscapedString(r.AsString(), out);
+}
+
+bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) {
+ map<const TRule*, int> rid;
+ ostream& o = *out;
+ rid[NULL] = 0;
+ o << '{';
+ if (!remove_rules) {
+ o << "\"rules\":[";
+ for (int i = 0; i < hg.edges_.size(); ++i) {
+ const TRule* r = hg.edges_[i].rule_.get();
+ int &id = rid[r];
+ if (!id) {
+ id=rid.size() - 1;
+ if (id > 1) o << ',';
+ o << id << ',';
+ WriteRule(*r, &o);
+ };
+ }
+ o << "],";
+ }
+ const bool use_fdict = FD::NumFeats() < 1000;
+ if (use_fdict) {
+ o << "\"features\":[";
+ for (int i = 1; i < FD::NumFeats(); ++i) {
+ o << (i==1 ? "":",") << '"' << FD::Convert(i) << '"';
+ }
+ o << "],";
+ }
+ vector<int> edgemap(hg.edges_.size(), -1); // edges may be in non-topo order
+ int edge_count = 0;
+ for (int i = 0; i < hg.nodes_.size(); ++i) {
+ const Hypergraph::Node& node = hg.nodes_[i];
+ if (i > 0) { o << ","; }
+ o << "\"edges\":[";
+ for (int j = 0; j < node.in_edges_.size(); ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]];
+ edgemap[edge.id_] = edge_count;
+ ++edge_count;
+ o << (j == 0 ? "" : ",") << "{";
+
+ o << "\"tail\":[";
+ for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
+ o << (k > 0 ? "," : "") << edge.tail_nodes_[k];
+ }
+ o << "],";
+
+ o << "\"feats\":[";
+ bool first = true;
+ for (SparseVector<double>::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) {
+ if (!it->second) continue;
+ if (!first) o << ',';
+ if (use_fdict)
+ o << (it->first - 1);
+ else
+ o << '"' << FD::Convert(it->first) << '"';
+ o << ',' << it->second;
+ first = false;
+ }
+ o << "]";
+ if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; }
+ o << "}";
+ }
+ o << "],";
+
+ o << "\"node\":{\"in_edges\":[";
+ for (int j = 0; j < node.in_edges_.size(); ++j) {
+ int mapped_edge = edgemap[node.in_edges_[j]];
+ assert(mapped_edge >= 0);
+ o << (j == 0 ? "" : ",") << mapped_edge;
+ }
+ o << "]";
+ if (node.cat_ < 0) { o << ",\"cat\":\"" << TD::Convert(node.cat_ * -1) << '"'; }
+ o << "}";
+ }
+ o << "}\n";
+ return true;
+}
+
+bool needs_escape[128];
+void InitEscapes() {
+ memset(needs_escape, false, 128);
+ needs_escape[static_cast<size_t>('\'')] = true;
+ needs_escape[static_cast<size_t>('\\')] = true;
+}
+
+string HypergraphIO::Escape(const string& s) {
+ size_t len = s.size();
+ for (int i = 0; i < s.size(); ++i) {
+ unsigned char c = s[i];
+ if (c < 128 && needs_escape[c]) ++len;
+ }
+ if (len == s.size()) return s;
+ string res(len, ' ');
+ size_t o = 0;
+ for (int i = 0; i < s.size(); ++i) {
+ unsigned char c = s[i];
+ if (c < 128 && needs_escape[c])
+ res[o++] = '\\';
+ res[o++] = c;
+ }
+ assert(o == len);
+ return res;
+}
+
+string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses) {
+ static bool first = true;
+ if (first) { InitEscapes(); first = false; }
+ if (hg.nodes_.empty()) return "()";
+ ostringstream os;
+ if (include_global_parentheses) os << '(';
+ static const string EPS="*EPS*";
+ for (int i = 0; i < hg.nodes_.size()-1; ++i) {
+ os << '(';
+ if (hg.nodes_[i].out_edges_.empty()) abort();
+ for (int j = 0; j < hg.nodes_[i].out_edges_.size(); ++j) {
+ const Hypergraph::Edge& e = hg.edges_[hg.nodes_[i].out_edges_[j]];
+ const string output = e.rule_->e_.size() ==2 ? Escape(TD::Convert(e.rule_->e_[1])) : EPS;
+ double prob = log(e.edge_prob_);
+ if (isinf(prob)) { prob = -9e20; }
+ if (isnan(prob)) { prob = 0; }
+ os << "('" << output << "'," << prob << "," << e.head_node_ - i << "),";
+ }
+ os << "),";
+ }
+ if (include_global_parentheses) os << ')';
+ return os.str();
+}
+
+namespace PLF {
+
+const string chars = "'\\";
+const char& quote = chars[0];
+const char& slash = chars[1];
+
+// safe get
+inline char get(const std::string& in, int c) {
+ if (c < 0 || c >= (int)in.size()) return 0;
+ else return in[(size_t)c];
+}
+
+// consume whitespace
+inline void eatws(const std::string& in, int& c) {
+ while (get(in,c) == ' ') { c++; }
+}
+
+// from 'foo' return foo
+std::string getEscapedString(const std::string& in, int &c)
+{
+ eatws(in,c);
+ if (get(in,c++) != quote) return "ERROR";
+ std::string res;
+ char cur = 0;
+ do {
+ cur = get(in,c++);
+ if (cur == slash) { res += get(in,c++); }
+ else if (cur != quote) { res += cur; }
+ } while (get(in,c) != quote && (c < (int)in.size()));
+ c++;
+ eatws(in,c);
+ return res;
+}
+
+// basically atof
+float getFloat(const std::string& in, int &c)
+{
+ std::string tmp;
+ eatws(in,c);
+ while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') {
+ tmp += get(in,c++);
+ }
+ eatws(in,c);
+ return atof(tmp.c_str());
+}
+
+// basically atoi
+int getInt(const std::string& in, int &c)
+{
+ std::string tmp;
+ eatws(in,c);
+ while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') {
+ tmp += get(in,c++);
+ }
+ eatws(in,c);
+ return atoi(tmp.c_str());
+}
+
+// maximum number of nodes permitted
+#define MAX_NODES 100000000
+// parse ('foo', 0.23)
+void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) {
+ if (get(in,c++) != '(') { assert(!"PCN/PLF parse error: expected ( at start of cn alt block\n"); }
+ vector<WordID> ewords(2, 0);
+ ewords[1] = TD::Convert(getEscapedString(in,c));
+ TRulePtr r(new TRule(ewords));
+ //cerr << "RULE: " << r->AsString() << endl;
+ if (get(in,c++) != ',') { assert(!"PCN/PLF parse error: expected , after string\n"); }
+ size_t cnNext = 1;
+ std::vector<float> probs;
+ probs.push_back(getFloat(in,c));
+ while (get(in,c) == ',') {
+ c++;
+ float val = getFloat(in,c);
+ probs.push_back(val);
+ }
+ //if we read more than one prob, this was a lattice, last item was column increment
+ if (probs.size()>1) {
+ cnNext = static_cast<size_t>(probs.back());
+ probs.pop_back();
+ if (cnNext < 1) { assert(!"PCN/PLF parse error: bad link length at last element of cn alt block\n"); }
+ }
+ if (get(in,c++) != ')') { assert(!"PCN/PLF parse error: expected ) at end of cn alt block\n"); }
+ eatws(in,c);
+ Hypergraph::TailNodeVector tail(1, cur_node);
+ Hypergraph::Edge* edge = hg->AddEdge(r, tail);
+ //cerr << " <--" << cur_node << endl;
+ int head_node = cur_node + cnNext;
+ assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory
+ if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); }
+ hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]);
+ for (int i = 0; i < probs.size(); ++i)
+ edge->feature_values_.set_value(FD::Convert("Feature_" + boost::lexical_cast<string>(i)), probs[i]);
+}
+
+// parse (('foo', 0.23), ('bar', 0.77))
+void ReadPLFNode(const std::string& in, int &c, int cur_node, int line, Hypergraph* hg) {
+ //cerr << "PLF READING NODE " << cur_node << endl;
+ if (hg->nodes_.size() < (cur_node + 1)) { hg->ResizeNodes(cur_node + 1); }
+ if (get(in,c++) != '(') { cerr << line << ": Syntax error 1\n"; abort(); }
+ eatws(in,c);
+ while (1) {
+ if (c > (int)in.size()) { break; }
+ if (get(in,c) == ')') {
+ c++;
+ eatws(in,c);
+ break;
+ }
+ if (get(in,c) == ',' && get(in,c+1) == ')') {
+ c+=2;
+ eatws(in,c);
+ break;
+ }
+ if (get(in,c) == ',') { c++; eatws(in,c); }
+ ReadPLFEdge(in, c, cur_node, hg);
+ }
+}
+
+} // namespace PLF
+
+void HypergraphIO::ReadFromPLF(const std::string& in, Hypergraph* hg, int line) {
+ hg->clear();
+ int c = 0;
+ int cur_node = 0;
+ if (in[c++] != '(') { cerr << line << ": Syntax error!\n"; abort(); }
+ while (1) {
+ if (c > (int)in.size()) { break; }
+ if (PLF::get(in,c) == ')') {
+ c++;
+ PLF::eatws(in,c);
+ break;
+ }
+ if (PLF::get(in,c) == ',' && PLF::get(in,c+1) == ')') {
+ c+=2;
+ PLF::eatws(in,c);
+ break;
+ }
+ if (PLF::get(in,c) == ',') { c++; PLF::eatws(in,c); }
+ PLF::ReadPLFNode(in, c, cur_node, line, hg);
+ ++cur_node;
+ }
+ assert(cur_node == hg->nodes_.size() - 1);
+}
+
+void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) {
+ Lattice& l = *pl;
+ Hypergraph g;
+ ReadFromPLF(plf, &g, 0);
+ const int num_nodes = g.nodes_.size() - 1;
+ l.resize(num_nodes);
+ for (int i = 0; i < num_nodes; ++i) {
+ vector<LatticeArc>& alts = l[i];
+ const Hypergraph::Node& node = g.nodes_[i];
+ const int num_alts = node.out_edges_.size();
+ alts.resize(num_alts);
+ for (int j = 0; j < num_alts; ++j) {
+ const Hypergraph::Edge& edge = g.edges_[node.out_edges_[j]];
+ alts[j].label = edge.rule_->e_[1];
+ alts[j].cost = edge.feature_values_.value(FD::Convert("Feature_0"));
+ alts[j].dist2next = edge.head_node_ - node.id_;
+ }
+ }
+}
+
+namespace B64 {
+
+static const char cb64[]="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+static const char cd64[]="|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`abcdefghijklmnopq";
+
+static void encodeblock(const unsigned char* in, ostream* os, int len) {
+ char out[4];
+ out[0] = cb64[ in[0] >> 2 ];
+ out[1] = cb64[ ((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4) ];
+ out[2] = (len > 1 ? cb64[ ((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6) ] : '=');
+ out[3] = (len > 2 ? cb64[ in[2] & 0x3f ] : '=');
+ os->write(out, 4);
+}
+
+void b64encode(const char* data, const size_t size, ostream* out) {
+ size_t cur = 0;
+ while(cur < size) {
+ int len = min(static_cast<size_t>(3), size - cur);
+ encodeblock(reinterpret_cast<const unsigned char*>(&data[cur]), out, len);
+ cur += len;
+ }
+}
+
+static void decodeblock(const unsigned char* in, unsigned char* out) {
+ out[0] = (unsigned char ) (in[0] << 2 | in[1] >> 4);
+ out[1] = (unsigned char ) (in[1] << 4 | in[2] >> 2);
+ out[2] = (unsigned char ) (((in[2] << 6) & 0xc0) | in[3]);
+}
+
+bool b64decode(const unsigned char* data, const size_t insize, char* out, const size_t outsize) {
+ size_t cur = 0;
+ size_t ocur = 0;
+ unsigned char in[4];
+ while(cur < insize) {
+ assert(ocur < outsize);
+ for (int i = 0; i < 4; ++i) {
+ unsigned char v = data[cur];
+ v = (unsigned char) ((v < 43 || v > 122) ? '\0' : cd64[ v - 43 ]);
+ if (!v) {
+ cerr << "B64 decode error at offset " << cur << " offending character: " << (int)data[cur] << endl;
+ return false;
+ }
+ v = (unsigned char) ((v == '$') ? '\0' : v - 61);
+ if (v) in[i] = v - 1; else in[i] = 0;
+ ++cur;
+ }
+ decodeblock(in, reinterpret_cast<unsigned char*>(&out[ocur]));
+ ocur += 3;
+ }
+ return true;
+}
+}
+
diff --git a/src/hg_io.h b/src/hg_io.h
new file mode 100644
index 00000000..69a516c1
--- /dev/null
+++ b/src/hg_io.h
@@ -0,0 +1,37 @@
+#ifndef _HG_IO_H_
+#define _HG_IO_H_
+
+#include <iostream>
+
+#include "lattice.h"
+class Hypergraph;
+
+struct HypergraphIO {
+
+ // the format is basically a list of nodes and edges in topological order
+ // any edge you read, you must have already read its tail nodes
+ // any node you read, you must have already read its incoming edges
+ // this may make writing a bit more challenging if your forest is not
+ // topologically sorted (but that probably doesn't happen very often),
+ // but it makes reading much more memory efficient.
+ // see test_data/small.json.gz for an email encoding
+ static bool ReadFromJSON(std::istream* in, Hypergraph* out);
+
+ // if remove_rules is used, the hypergraph is serialized without rule information
+ // (so it only contains structure and feature information)
+ static bool WriteToJSON(const Hypergraph& hg, bool remove_rules, std::ostream* out);
+
+ // serialization utils
+ static void ReadFromPLF(const std::string& in, Hypergraph* out, int line = 0);
+ // return PLF string representation (undefined behavior on non-lattices)
+ static std::string AsPLF(const Hypergraph& hg, bool include_global_parentheses = true);
+ static void PLFtoLattice(const std::string& plf, Lattice* pl);
+ static std::string Escape(const std::string& s); // PLF helper
+};
+
+namespace B64 {
+ bool b64decode(const unsigned char* data, const size_t insize, char* out, const size_t outsize);
+ void b64encode(const char* data, const size_t size, std::ostream* out);
+}
+
+#endif
diff --git a/src/hg_test.cc b/src/hg_test.cc
new file mode 100644
index 00000000..ecd97508
--- /dev/null
+++ b/src/hg_test.cc
@@ -0,0 +1,441 @@
+#include <cassert>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <gtest/gtest.h>
+#include "tdict.h"
+
+#include "json_parse.h"
+#include "filelib.h"
+#include "hg.h"
+#include "hg_io.h"
+#include "hg_intersect.h"
+#include "viterbi.h"
+#include "kbest.h"
+#include "inside_outside.h"
+
+using namespace std;
+
+class HGTest : public testing::Test {
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+ void CreateHG(Hypergraph* hg) const;
+ void CreateHG_int(Hypergraph* hg) const;
+ void CreateHG_tiny(Hypergraph* hg) const;
+ void CreateHGBalanced(Hypergraph* hg) const;
+ void CreateLatticeHG(Hypergraph* hg) const;
+ void CreateTinyLatticeHG(Hypergraph* hg) const;
+};
+
+void HGTest::CreateTinyLatticeHG(Hypergraph* hg) const {
+ const string json = "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] b\",4,\"[X] ||| [1] B'\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.2],\"rule\":1},{\"tail\":[0],\"feats\":[0,-0.6],\"rule\":2}],\"node\":{\"in_edges\":[0,1]},\"edges\":[{\"tail\":[1],\"feats\":[0,-0.1],\"rule\":3},{\"tail\":[1],\"feats\":[0,-0.9],\"rule\":4}],\"node\":{\"in_edges\":[2,3]}}";
+ istringstream instr(json);
+ EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg));
+}
+
+void HGTest::CreateLatticeHG(Hypergraph* hg) const {
+ const string json = "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] A A\",4,\"[X] ||| [1] b\",5,\"[X] ||| [1] c\",6,\"[X] ||| [1] B C\",7,\"[X] ||| [1] A B C\",8,\"[X] ||| [1] CC\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[2,-0.3],\"rule\":1},{\"tail\":[0],\"feats\":[2,-0.6],\"rule\":2},{\"tail\":[0],\"feats\":[2,-1.7],\"rule\":3}],\"node\":{\"in_edges\":[0,1,2]},\"edges\":[{\"tail\":[1],\"feats\":[2,-0.5],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[2],\"feats\":[2,-0.6],\"rule\":5},{\"tail\":[1],\"feats\":[2,-0.8],\"rule\":6},{\"tail\":[0],\"feats\":[2,-0.01],\"rule\":7},{\"tail\":[2],\"feats\":[2,-0.8],\"rule\":8}],\"node\":{\"in_edges\":[4,5,6,7]}}";
+ istringstream instr(json);
+ EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg));
+}
+
+void HGTest::CreateHG_tiny(Hypergraph* hg) const {
+ const string json = "{\"rules\":[1,\"[X] ||| <s>\",2,\"[X] ||| X [1]\",3,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,-2,1,-99],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.5,1,-0.8],\"rule\":2},{\"tail\":[0],\"feats\":[0,-0.7,1,-0.9],\"rule\":3}],\"node\":{\"in_edges\":[1,2]}}";
+ istringstream instr(json);
+ EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg));
+}
+
+void HGTest::CreateHG_int(Hypergraph* hg) const {
+ const string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| b\",3,\"[X] ||| a [1]\",4,\"[X] ||| [1] b\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,0.1],\"rule\":1},{\"tail\":[],\"feats\":[0,0.1],\"rule\":2}],\"node\":{\"in_edges\":[0,1],\"cat\":\"X\"},\"edges\":[{\"tail\":[0],\"feats\":[0,0.3],\"rule\":3},{\"tail\":[0],\"feats\":[0,0.2],\"rule\":4}],\"node\":{\"in_edges\":[2,3],\"cat\":\"Goal\"}}";
+ istringstream instr(json);
+ EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg));
+}
+
+void HGTest::CreateHG(Hypergraph* hg) const {
+ string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}";
+ istringstream instr(json);
+ EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg));
+}
+
+void HGTest::CreateHGBalanced(Hypergraph* hg) const {
+ const string json = "{\"rules\":[1,\"[X] ||| i\",2,\"[X] ||| a\",3,\"[X] ||| b\",4,\"[X] ||| [1] [2]\",5,\"[X] ||| [1] [2]\",6,\"[X] ||| c\",7,\"[X] ||| d\",8,\"[X] ||| [1] [2]\",9,\"[X] ||| [1] [2]\",10,\"[X] ||| [1] [2]\",11,\"[X] ||| [1] [2]\",12,\"[X] ||| [1] [2]\",13,\"[X] ||| [1] [2]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[1,2],\"feats\":[],\"rule\":4},{\"tail\":[2,1],\"feats\":[],\"rule\":5}],\"node\":{\"in_edges\":[3,4]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":6}],\"node\":{\"in_edges\":[5]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":7}],\"node\":{\"in_edges\":[6]},\"edges\":[{\"tail\":[4,5],\"feats\":[],\"rule\":8},{\"tail\":[5,4],\"feats\":[],\"rule\":9}],\"node\":{\"in_edges\":[7,8]},\"edges\":[{\"tail\":[3,6],\"feats\":[],\"rule\":10},{\"tail\":[6,3],\"feats\":[],\"rule\":11}],\"node\":{\"in_edges\":[9,10]},\"edges\":[{\"tail\":[7,0],\"feats\":[],\"rule\":12},{\"tail\":[0,7],\"feats\":[],\"rule\":13}],\"node\":{\"in_edges\":[11,12]}}";
+ istringstream instr(json);
+ EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg));
+}
+
+TEST_F(HGTest,Controlled) {
+ Hypergraph hg;
+ CreateHG_tiny(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 0.8);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ prob_t prob = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "prob: " << prob << "\n";
+ EXPECT_FLOAT_EQ(-80.839996, log(prob));
+ EXPECT_EQ("X <s>", TD::GetString(trans));
+ vector<prob_t> post;
+ hg.PrintGraphviz();
+ prob_t c2 = Inside<prob_t, ScaledEdgeProb>(hg, NULL, ScaledEdgeProb(0.6));
+ EXPECT_FLOAT_EQ(-47.8577, log(c2));
+}
+
+TEST_F(HGTest,Union) {
+ Hypergraph hg1;
+ Hypergraph hg2;
+ CreateHG_tiny(&hg1);
+ CreateHG(&hg2);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 1.0);
+ hg1.Reweight(wts);
+ hg2.Reweight(wts);
+ prob_t c1,c2,c3,c4;
+ vector<WordID> t1,t2,t3,t4;
+ c1 = ViterbiESentence(hg1, &t1);
+ c2 = ViterbiESentence(hg2, &t2);
+ int l2 = ViterbiPathLength(hg2);
+ cerr << c1 << "\t" << TD::GetString(t1) << endl;
+ cerr << c2 << "\t" << TD::GetString(t2) << endl;
+ hg1.Union(hg2);
+ hg1.Reweight(wts);
+ c3 = ViterbiESentence(hg1, &t3);
+ int l3 = ViterbiPathLength(hg1);
+ cerr << c3 << "\t" << TD::GetString(t3) << endl;
+ EXPECT_FLOAT_EQ(c2, c3);
+ EXPECT_EQ(TD::GetString(t2), TD::GetString(t3));
+ EXPECT_EQ(l2, l3);
+
+ wts.set_value(FD::Convert("f2"), -1);
+ hg1.Reweight(wts);
+ c4 = ViterbiESentence(hg1, &t4);
+ cerr << c4 << "\t" << TD::GetString(t4) << endl;
+ EXPECT_EQ("Z <s>", TD::GetString(t4));
+ EXPECT_FLOAT_EQ(98.82, log(c4));
+
+ vector<pair<vector<WordID>, prob_t> > list;
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg1, 10);
+ for (int i = 0; i < 10; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg1.nodes_.size() - 1, i);
+ if (!d) break;
+ list.push_back(make_pair(d->yield, d->score));
+ }
+ EXPECT_TRUE(list[0].first == t4);
+ EXPECT_FLOAT_EQ(log(list[0].second), log(c4));
+ EXPECT_EQ(list.size(), 6);
+ EXPECT_FLOAT_EQ(log(list.back().second / list.front().second), -97.7);
+}
+
+TEST_F(HGTest,ControlledKBest) {
+ Hypergraph hg;
+ CreateHG(&hg);
+ vector<double> w(2); w[0]=0.4; w[1]=0.8;
+ hg.Reweight(w);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+
+ int best = 0;
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, 10);
+ for (int i = 0; i < 10; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cerr << TD::GetString(d->yield) << endl;
+ ++best;
+ }
+ EXPECT_EQ(4, best);
+}
+
+
+TEST_F(HGTest,InsideScore) {
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 1.0);
+ Hypergraph hg;
+ CreateTinyLatticeHG(&hg);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+ hg.PrintGraphviz();
+ prob_t inside = Inside<prob_t, EdgeProb>(hg);
+ EXPECT_FLOAT_EQ(1.7934048, inside); // computed by hand
+ vector<prob_t> post;
+ inside = hg.ComputeBestPathThroughEdges(&post);
+ EXPECT_FLOAT_EQ(-0.3, log(inside)); // computed by hand
+ EXPECT_EQ(post.size(), 4);
+ for (int i = 0; i < 4; ++i) {
+ cerr << "edge post: " << log(post[i]) << '\t' << hg.edges_[i].rule_->AsString() << endl;
+ }
+}
+
+
+TEST_F(HGTest,PruneInsideOutside) {
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("Feature_1"), 1.0);
+ Hypergraph hg;
+ CreateLatticeHG(&hg);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+ hg.PrintGraphviz();
+ //hg.DensityPruneInsideOutside(0.5, false, 2.0);
+ hg.BeamPruneInsideOutside(0.5, false, 0.5);
+ cost = ViterbiESentence(hg, &trans);
+ cerr << "Ncst: " << cost << endl;
+ cerr << TD::GetString(trans) << "\n";
+ hg.PrintGraphviz();
+}
+
+TEST_F(HGTest,TestPruneEdges) {
+ Hypergraph hg;
+ CreateLatticeHG(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 1.0);
+ hg.Reweight(wts);
+ hg.PrintGraphviz();
+ vector<bool> prune(hg.edges_.size(), true);
+ prune[6] = false;
+ hg.PruneEdges(prune);
+ cerr << "Pruned:\n";
+ hg.PrintGraphviz();
+}
+
+TEST_F(HGTest,TestIntersect) {
+ Hypergraph hg;
+ CreateHG_int(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 1.0);
+ hg.Reweight(wts);
+ hg.PrintGraphviz();
+
+ int best = 0;
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, 10);
+ for (int i = 0; i < 10; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cerr << TD::GetString(d->yield) << endl;
+ ++best;
+ }
+ EXPECT_EQ(4, best);
+
+ Lattice target(2);
+ target[0].push_back(LatticeArc(TD::Convert("a"), 0.0, 1));
+ target[1].push_back(LatticeArc(TD::Convert("b"), 0.0, 1));
+ HG::Intersect(target, &hg);
+ hg.PrintGraphviz();
+}
+
+TEST_F(HGTest,TestPrune2) {
+ Hypergraph hg;
+ CreateHG_int(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 1.0);
+ hg.Reweight(wts);
+ hg.PrintGraphviz();
+ vector<bool> rem(hg.edges_.size(), false);
+ rem[0] = true;
+ rem[1] = true;
+ hg.PruneEdges(rem);
+ hg.PrintGraphviz();
+ cerr << "TODO: fix this pruning behavior-- the resulting HG should be empty!\n";
+}
+
+TEST_F(HGTest,Sample) {
+ Hypergraph hg;
+ CreateLatticeHG(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("Feature_1"), 0.0);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+ hg.PrintGraphviz();
+}
+
+TEST_F(HGTest,PLF) {
+ Hypergraph hg;
+ string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)";
+ HypergraphIO::ReadFromPLF(inplf, &hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("Feature_0"), 1.0);
+ hg.Reweight(wts);
+ hg.PrintGraphviz();
+ string outplf = HypergraphIO::AsPLF(hg);
+ cerr << " IN: " << inplf << endl;
+ cerr << "OUT: " << outplf << endl;
+ assert(inplf == outplf);
+}
+
+TEST_F(HGTest,PushWeightsToGoal) {
+ Hypergraph hg;
+ CreateHG(&hg);
+ vector<double> w(2); w[0]=0.4; w[1]=0.8;
+ hg.Reweight(w);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+ hg.PrintGraphviz();
+ hg.PushWeightsToGoal();
+ hg.PrintGraphviz();
+}
+
+TEST_F(HGTest,TestSpecialKBest) {
+ Hypergraph hg;
+ CreateHGBalanced(&hg);
+ vector<double> w(1); w[0]=0;
+ hg.Reweight(w);
+ vector<pair<vector<WordID>, prob_t> > list;
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, 100000);
+ for (int i = 0; i < 100000; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cerr << TD::GetString(d->yield) << endl;
+ }
+ hg.PrintGraphviz();
+}
+
+TEST_F(HGTest, TestGenericViterbi) {
+ Hypergraph hg;
+ CreateHG_tiny(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 0.8);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ const prob_t prob = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "prob: " << prob << "\n";
+ EXPECT_FLOAT_EQ(-80.839996, log(prob));
+ EXPECT_EQ("X <s>", TD::GetString(trans));
+}
+
+TEST_F(HGTest, TestGenericInside) {
+ Hypergraph hg;
+ CreateTinyLatticeHG(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 1.0);
+ hg.Reweight(wts);
+ vector<prob_t> inside;
+ prob_t ins = Inside<prob_t, EdgeProb>(hg, &inside);
+ EXPECT_FLOAT_EQ(1.7934048, ins); // computed by hand
+ vector<prob_t> outside;
+ Outside<prob_t, EdgeProb>(hg, inside, &outside);
+ EXPECT_EQ(3, outside.size());
+ EXPECT_FLOAT_EQ(1.7934048, outside[0]);
+ EXPECT_FLOAT_EQ(1.3114071, outside[1]);
+ EXPECT_FLOAT_EQ(1.0, outside[2]);
+}
+
+TEST_F(HGTest,TestGenericInside2) {
+ Hypergraph hg;
+ CreateHG(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 0.8);
+ hg.Reweight(wts);
+ vector<prob_t> inside, outside;
+ prob_t ins = Inside<prob_t, EdgeProb>(hg, &inside);
+ Outside<prob_t, EdgeProb>(hg, inside, &outside);
+ for (int i = 0; i < hg.nodes_.size(); ++i)
+ cerr << i << "\t" << log(inside[i]) << "\t" << log(outside[i]) << endl;
+ EXPECT_FLOAT_EQ(0, log(inside[0]));
+ EXPECT_FLOAT_EQ(-1.7861683, log(outside[0]));
+ EXPECT_FLOAT_EQ(-0.4, log(inside[1]));
+ EXPECT_FLOAT_EQ(-1.3861683, log(outside[1]));
+ EXPECT_FLOAT_EQ(-0.8, log(inside[2]));
+ EXPECT_FLOAT_EQ(-0.986168, log(outside[2]));
+ EXPECT_FLOAT_EQ(-0.96, log(inside[3]));
+ EXPECT_FLOAT_EQ(-0.8261683, log(outside[3]));
+ EXPECT_FLOAT_EQ(-1.562512, log(inside[4]));
+ EXPECT_FLOAT_EQ(-0.22365622, log(outside[4]));
+ EXPECT_FLOAT_EQ(-1.7861683, log(inside[5]));
+ EXPECT_FLOAT_EQ(0, log(outside[5]));
+}
+
+TEST_F(HGTest,TestAddExpectations) {
+ Hypergraph hg;
+ CreateHG(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 0.8);
+ hg.Reweight(wts);
+ SparseVector<double> feat_exps;
+ InsideOutside<prob_t, EdgeProb, SparseVector<double>, EdgeFeaturesWeightFunction>(hg, &feat_exps);
+ EXPECT_FLOAT_EQ(-2.5439765, feat_exps[FD::Convert("f1")]);
+ EXPECT_FLOAT_EQ(-2.6357865, feat_exps[FD::Convert("f2")]);
+ cerr << feat_exps << endl;
+ SparseVector<prob_t> posts;
+ InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, TransitionEventWeightFunction>(hg, &posts);
+}
+
+TEST_F(HGTest, Small) {
+ ReadFile rf("test_data/small.json.gz");
+ Hypergraph hg;
+ assert(HypergraphIO::ReadFromJSON(rf.stream(), &hg));
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("Model_0"), -2.0);
+ wts.set_value(FD::Convert("Model_1"), -0.5);
+ wts.set_value(FD::Convert("Model_2"), -1.1);
+ wts.set_value(FD::Convert("Model_3"), -1.0);
+ wts.set_value(FD::Convert("Model_4"), -1.0);
+ wts.set_value(FD::Convert("Model_5"), 0.5);
+ wts.set_value(FD::Convert("Model_6"), 0.2);
+ wts.set_value(FD::Convert("Model_7"), -3.0);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+ vector<prob_t> post;
+ prob_t c2 = Inside<prob_t, ScaledEdgeProb>(hg, NULL, ScaledEdgeProb(0.6));
+ EXPECT_FLOAT_EQ(2.1431036, log(c2));
+}
+
+TEST_F(HGTest, JSONTest) {
+ ostringstream os;
+ JSONParser::WriteEscapedString("\"I don't know\", she said.", &os);
+ EXPECT_EQ("\"\\\"I don't know\\\", she said.\"", os.str());
+ ostringstream os2;
+ JSONParser::WriteEscapedString("yes", &os2);
+ EXPECT_EQ("\"yes\"", os2.str());
+}
+
+TEST_F(HGTest, TestGenericKBest) {
+ Hypergraph hg;
+ CreateHG(&hg);
+ //CreateHGBalanced(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 1.0);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, 1000);
+ for (int i = 0; i < 1000; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cerr << TD::GetString(d->yield) << " F:" << d->feature_values << endl;
+ }
+}
+
+int main(int argc, char **argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/ibm_model1.cc b/src/ibm_model1.cc
new file mode 100644
index 00000000..eb9c617c
--- /dev/null
+++ b/src/ibm_model1.cc
@@ -0,0 +1,4 @@
+#include <iostream>
+
+
+
diff --git a/src/inside_outside.h b/src/inside_outside.h
new file mode 100644
index 00000000..9114c9d7
--- /dev/null
+++ b/src/inside_outside.h
@@ -0,0 +1,111 @@
+#ifndef _INSIDE_H_
+#define _INSIDE_H_
+
+#include <vector>
+#include <algorithm>
+#include "hg.h"
+
+// run the inside algorithm and return the inside score
+// if result is non-NULL, result will contain the inside
+// score for each node
+// NOTE: WeightType(0) must construct the semiring's additive identity
+// WeightType(1) must construct the semiring's multiplicative identity
+template<typename WeightType, typename WeightFunction>
+WeightType Inside(const Hypergraph& hg,
+ std::vector<WeightType>* result = NULL,
+ const WeightFunction& weight = WeightFunction()) {
+ const int num_nodes = hg.nodes_.size();
+ std::vector<WeightType> dummy;
+ std::vector<WeightType>& inside_score = result ? *result : dummy;
+ inside_score.resize(num_nodes);
+ std::fill(inside_score.begin(), inside_score.end(), WeightType());
+ for (int i = 0; i < num_nodes; ++i) {
+ const Hypergraph::Node& cur_node = hg.nodes_[i];
+ WeightType* const cur_node_inside_score = &inside_score[i];
+ const int num_in_edges = cur_node.in_edges_.size();
+ if (num_in_edges == 0) {
+ *cur_node_inside_score = WeightType(1);
+ continue;
+ }
+ for (int j = 0; j < num_in_edges; ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]];
+ WeightType score = weight(edge);
+ for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
+ const int tail_node_index = edge.tail_nodes_[k];
+ score *= inside_score[tail_node_index];
+ }
+ *cur_node_inside_score += score;
+ }
+ }
+ return inside_score.back();
+}
+
+template<typename WeightType, typename WeightFunction>
+void Outside(const Hypergraph& hg,
+ std::vector<WeightType>& inside_score,
+ std::vector<WeightType>* result,
+ const WeightFunction& weight = WeightFunction()) {
+ assert(result);
+ const int num_nodes = hg.nodes_.size();
+ assert(inside_score.size() == num_nodes);
+ std::vector<WeightType>& outside_score = *result;
+ outside_score.resize(num_nodes);
+ std::fill(outside_score.begin(), outside_score.end(), WeightType(0));
+ outside_score.back() = WeightType(1);
+ for (int i = num_nodes - 1; i >= 0; --i) {
+ const Hypergraph::Node& cur_node = hg.nodes_[i];
+ const WeightType& head_node_outside_score = outside_score[i];
+ const int num_in_edges = cur_node.in_edges_.size();
+ for (int j = 0; j < num_in_edges; ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]];
+ const WeightType head_and_edge_weight = weight(edge) * head_node_outside_score;
+ const int num_tail_nodes = edge.tail_nodes_.size();
+ for (int k = 0; k < num_tail_nodes; ++k) {
+ const int update_tail_node_index = edge.tail_nodes_[k];
+ WeightType* const tail_outside_score = &outside_score[update_tail_node_index];
+ WeightType inside_contribution = WeightType(1);
+ for (int l = 0; l < num_tail_nodes; ++l) {
+ const int other_tail_node_index = edge.tail_nodes_[l];
+ if (update_tail_node_index != other_tail_node_index)
+ inside_contribution *= inside_score[other_tail_node_index];
+ }
+ *tail_outside_score += head_and_edge_weight * inside_contribution;
+ }
+ }
+ }
+}
+
+// this is the Inside-Outside optimization described in Li et al. (EMNLP 2009)
+// for computing the inside algorithm over expensive semirings
+// (such as expectations over features). See Figure 4. It is slightly different
+// in that x/k is returned not (k,x)
+// NOTE: RType * PType must be valid (and yield RType)
+template<typename PType, typename WeightFunction, typename RType, typename WeightFunction2>
+PType InsideOutside(const Hypergraph& hg,
+ RType* result_x,
+ const WeightFunction& weight1 = WeightFunction(),
+ const WeightFunction2& weight2 = WeightFunction2()) {
+ const int num_nodes = hg.nodes_.size();
+ std::vector<PType> inside, outside;
+ const PType z = Inside<PType,WeightFunction>(hg, &inside, weight1);
+ Outside<PType,WeightFunction>(hg, inside, &outside, weight1);
+ RType& x = *result_x;
+ x = RType();
+ for (int i = 0; i < num_nodes; ++i) {
+ const Hypergraph::Node& cur_node = hg.nodes_[i];
+ const int num_in_edges = cur_node.in_edges_.size();
+ for (int j = 0; j < num_in_edges; ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]];
+ PType prob = outside[i];
+ prob *= weight1(edge);
+ const int num_tail_nodes = edge.tail_nodes_.size();
+ for (int k = 0; k < num_tail_nodes; ++k)
+ prob *= inside[edge.tail_nodes_[k]];
+ prob /= z;
+ x += weight2(edge) * prob;
+ }
+ }
+ return z;
+}
+
+#endif
diff --git a/src/json_parse.cc b/src/json_parse.cc
new file mode 100644
index 00000000..f6fdfea8
--- /dev/null
+++ b/src/json_parse.cc
@@ -0,0 +1,50 @@
+#include "json_parse.h"
+
+#include <string>
+#include <iostream>
+
+using namespace std;
+
+static const char *json_hex_chars = "0123456789abcdef";
+
+void JSONParser::WriteEscapedString(const string& in, ostream* out) {
+ int pos = 0;
+ int start_offset = 0;
+ unsigned char c = 0;
+ (*out) << '"';
+ while(pos < in.size()) {
+ c = in[pos];
+ switch(c) {
+ case '\b':
+ case '\n':
+ case '\r':
+ case '\t':
+ case '"':
+ case '\\':
+ case '/':
+ if(pos - start_offset > 0)
+ (*out) << in.substr(start_offset, pos - start_offset);
+ if(c == '\b') (*out) << "\\b";
+ else if(c == '\n') (*out) << "\\n";
+ else if(c == '\r') (*out) << "\\r";
+ else if(c == '\t') (*out) << "\\t";
+ else if(c == '"') (*out) << "\\\"";
+ else if(c == '\\') (*out) << "\\\\";
+ else if(c == '/') (*out) << "\\/";
+ start_offset = ++pos;
+ break;
+ default:
+ if(c < ' ') {
+ cerr << "Warning, bad character (" << static_cast<int>(c) << ") in string\n";
+ if(pos - start_offset > 0)
+ (*out) << in.substr(start_offset, pos - start_offset);
+ (*out) << "\\u00" << json_hex_chars[c >> 4] << json_hex_chars[c & 0xf];
+ start_offset = ++pos;
+ } else pos++;
+ }
+ }
+ if(pos - start_offset > 0)
+ (*out) << in.substr(start_offset, pos - start_offset);
+ (*out) << '"';
+}
+
diff --git a/src/json_parse.h b/src/json_parse.h
new file mode 100644
index 00000000..c3cba954
--- /dev/null
+++ b/src/json_parse.h
@@ -0,0 +1,58 @@
+#ifndef _JSON_WRAPPER_H_
+#define _JSON_WRAPPER_H_
+
+#include <iostream>
+#include <cassert>
+#include "JSON_parser.h"
+
+class JSONParser {
+ public:
+ JSONParser() {
+ init_JSON_config(&config);
+ hack.mf = &JSONParser::Callback;
+ config.depth = 10;
+ config.callback_ctx = reinterpret_cast<void*>(this);
+ config.callback = hack.cb;
+ config.allow_comments = 1;
+ config.handle_floats_manually = 1;
+ jc = new_JSON_parser(&config);
+ }
+ virtual ~JSONParser() {
+ delete_JSON_parser(jc);
+ }
+ bool Parse(std::istream* in) {
+ int count = 0;
+ int lc = 1;
+ for (; in ; ++count) {
+ int next_char = in->get();
+ if (!in->good()) break;
+ if (lc == '\n') { ++lc; }
+ if (!JSON_parser_char(jc, next_char)) {
+ std::cerr << "JSON_parser_char: syntax error, line " << lc << " (byte " << count << ")" << std::endl;
+ return false;
+ }
+ }
+ if (!JSON_parser_done(jc)) {
+ std::cerr << "JSON_parser_done: syntax error\n";
+ return false;
+ }
+ return true;
+ }
+ static void WriteEscapedString(const std::string& in, std::ostream* out);
+ protected:
+ virtual bool HandleJSONEvent(int type, const JSON_value* value) = 0;
+ private:
+ int Callback(int type, const JSON_value* value) {
+ if (HandleJSONEvent(type, value)) return 1;
+ return 0;
+ }
+ JSON_parser_struct* jc;
+ JSON_config config;
+ typedef int (JSONParser::* MF)(int type, const struct JSON_value_struct* value);
+ union CBHack {
+ JSON_parser_callback cb;
+ MF mf;
+ } hack;
+};
+
+#endif
diff --git a/src/kbest.h b/src/kbest.h
new file mode 100644
index 00000000..cd9b6c2b
--- /dev/null
+++ b/src/kbest.h
@@ -0,0 +1,207 @@
+#ifndef _HG_KBEST_H_
+#define _HG_KBEST_H_
+
+#include <vector>
+#include <utility>
+#include <tr1/unordered_set>
+
+#include <boost/shared_ptr.hpp>
+
+#include "wordid.h"
+#include "hg.h"
+
+namespace KBest {
+ // default, don't filter any derivations from the k-best list
+ struct NoFilter {
+ bool operator()(const std::vector<WordID>& yield) {
+ (void) yield;
+ return false;
+ }
+ };
+
+ // optional, filter unique yield strings
+ struct FilterUnique {
+ std::tr1::unordered_set<std::vector<WordID>, boost::hash<std::vector<WordID> > > unique;
+
+ bool operator()(const std::vector<WordID>& yield) {
+ return !unique.insert(yield).second;
+ }
+ };
+
+ // utility class to lazily create the k-best derivations from a forest, uses
+ // the lazy k-best algorithm (Algorithm 3) from Huang and Chiang (IWPT 2005)
+ template<typename T, // yield type (returned by Traversal)
+ typename Traversal,
+ typename DerivationFilter = NoFilter,
+ typename WeightType = prob_t,
+ typename WeightFunction = EdgeProb>
+ struct KBestDerivations {
+ KBestDerivations(const Hypergraph& hg,
+ const size_t k,
+ const Traversal& tf = Traversal(),
+ const WeightFunction& wf = WeightFunction()) :
+ traverse(tf), w(wf), g(hg), nds(g.nodes_.size()), k_prime(k) {}
+
+ ~KBestDerivations() {
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
+ struct Derivation {
+ Derivation(const Hypergraph::Edge& e,
+ const SmallVector& jv,
+ const WeightType& w,
+ const SparseVector<double>& f) :
+ edge(&e),
+ j(jv),
+ score(w),
+ feature_values(f) {}
+
+ // dummy constructor, just for query
+ Derivation(const Hypergraph::Edge& e,
+ const SmallVector& jv) : edge(&e), j(jv) {}
+
+ T yield;
+ const Hypergraph::Edge* const edge;
+ const SmallVector j;
+ const WeightType score;
+ const SparseVector<double> feature_values;
+ };
+ struct HeapCompare {
+ bool operator()(const Derivation* a, const Derivation* b) const {
+ return a->score < b->score;
+ }
+ };
+ struct DerivationCompare {
+ bool operator()(const Derivation* a, const Derivation* b) const {
+ return a->score > b->score;
+ }
+ };
+ struct DerivationUniquenessHash {
+ size_t operator()(const Derivation* d) const {
+ size_t x = 5381;
+ x = ((x << 5) + x) ^ d->edge->id_;
+ for (int i = 0; i < d->j.size(); ++i)
+ x = ((x << 5) + x) ^ d->j[i];
+ return x;
+ }
+ };
+ struct DerivationUniquenessEquals {
+ bool operator()(const Derivation* a, const Derivation* b) const {
+ return (a->edge == b->edge) && (a->j == b->j);
+ }
+ };
+ typedef std::vector<Derivation*> CandidateHeap;
+ typedef std::vector<Derivation*> DerivationList;
+ typedef std::tr1::unordered_set<
+ const Derivation*, DerivationUniquenessHash, DerivationUniquenessEquals> UniqueDerivationSet;
+
+ struct NodeDerivationState {
+ CandidateHeap cand;
+ DerivationList D;
+ DerivationFilter filter;
+ UniqueDerivationSet ds;
+ explicit NodeDerivationState(const DerivationFilter& f = DerivationFilter()) : filter(f) {}
+ };
+
+ Derivation* LazyKthBest(int v, int k) {
+ NodeDerivationState& s = GetCandidates(v);
+ CandidateHeap& cand = s.cand;
+ DerivationList& D = s.D;
+ DerivationFilter& filter = s.filter;
+ bool add_next = true;
+ while (D.size() <= k) {
+ if (add_next && D.size() > 0) {
+ const Derivation* d = D.back();
+ LazyNext(d, &cand, &s.ds);
+ }
+ add_next = false;
+
+ if (cand.size() > 0) {
+ std::pop_heap(cand.begin(), cand.end(), HeapCompare());
+ Derivation* d = cand.back();
+ cand.pop_back();
+ std::vector<const T*> ants(d->edge->Arity());
+ for (int j = 0; j < ants.size(); ++j)
+ ants[j] = &LazyKthBest(d->edge->tail_nodes_[j], d->j[j])->yield;
+ traverse(*d->edge, ants, &d->yield);
+ if (!filter(d->yield)) {
+ D.push_back(d);
+ add_next = true;
+ }
+ } else {
+ break;
+ }
+ }
+ if (k < D.size()) return D[k]; else return NULL;
+ }
+
+ private:
+ // creates a derivation object with all fields set but the yield
+ // the yield is computed in LazyKthBest before the derivation is added to D
+ // returns NULL if j refers to derivation numbers larger than the
+ // antecedent structure define
+ Derivation* CreateDerivation(const Hypergraph::Edge& e, const SmallVector& j) {
+ WeightType score = w(e);
+ SparseVector<double> feats = e.feature_values_;
+ for (int i = 0; i < e.Arity(); ++i) {
+ const Derivation* ant = LazyKthBest(e.tail_nodes_[i], j[i]);
+ if (!ant) { return NULL; }
+ score *= ant->score;
+ feats += ant->feature_values;
+ }
+ freelist.push_back(new Derivation(e, j, score, feats));
+ return freelist.back();
+ }
+
+ NodeDerivationState& GetCandidates(int v) {
+ NodeDerivationState& s = nds[v];
+ if (!s.D.empty() || !s.cand.empty()) return s;
+
+ const Hypergraph::Node& node = g.nodes_[v];
+ for (int i = 0; i < node.in_edges_.size(); ++i) {
+ const Hypergraph::Edge& edge = g.edges_[node.in_edges_[i]];
+ SmallVector jv(edge.Arity(), 0);
+ Derivation* d = CreateDerivation(edge, jv);
+ assert(d);
+ s.cand.push_back(d);
+ }
+
+ const int effective_k = std::min(k_prime, s.cand.size());
+ const typename CandidateHeap::iterator kth = s.cand.begin() + effective_k;
+ std::nth_element(s.cand.begin(), kth, s.cand.end(), DerivationCompare());
+ s.cand.resize(effective_k);
+ std::make_heap(s.cand.begin(), s.cand.end(), HeapCompare());
+
+ return s;
+ }
+
+ void LazyNext(const Derivation* d, CandidateHeap* cand, UniqueDerivationSet* ds) {
+ for (int i = 0; i < d->j.size(); ++i) {
+ SmallVector j = d->j;
+ ++j[i];
+ const Derivation* ant = LazyKthBest(d->edge->tail_nodes_[i], j[i]);
+ if (ant) {
+ Derivation query_unique(*d->edge, j);
+ if (ds->count(&query_unique) == 0) {
+ Derivation* new_d = CreateDerivation(*d->edge, j);
+ if (new_d) {
+ cand->push_back(new_d);
+ std::push_heap(cand->begin(), cand->end(), HeapCompare());
+ assert(ds->insert(new_d).second); // insert into uniqueness set, sanity check
+ }
+ }
+ }
+ }
+ }
+
+ const Traversal traverse;
+ const WeightFunction w;
+ const Hypergraph& g;
+ std::vector<NodeDerivationState> nds;
+ std::vector<Derivation*> freelist;
+ const size_t k_prime;
+ };
+}
+
+#endif
diff --git a/src/lattice.cc b/src/lattice.cc
new file mode 100644
index 00000000..aa1df3db
--- /dev/null
+++ b/src/lattice.cc
@@ -0,0 +1,27 @@
+#include "lattice.h"
+
+#include "tdict.h"
+#include "hg_io.h"
+
+using namespace std;
+
+bool LatticeTools::LooksLikePLF(const string &line) {
+ return (line.size() > 5) && (line.substr(0,4) == "((('");
+}
+
+void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) {
+ Lattice& l = *pl;
+ vector<WordID> ids;
+ TD::ConvertSentence(text, &ids);
+ l.resize(ids.size());
+ for (int i = 0; i < l.size(); ++i)
+ l[i].push_back(LatticeArc(ids[i], 0.0, 1));
+}
+
+void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) {
+ if (LooksLikePLF(text_or_plf))
+ HypergraphIO::PLFtoLattice(text_or_plf, pl);
+ else
+ ConvertTextToLattice(text_or_plf, pl);
+}
+
diff --git a/src/lattice.h b/src/lattice.h
new file mode 100644
index 00000000..1177e768
--- /dev/null
+++ b/src/lattice.h
@@ -0,0 +1,31 @@
+#ifndef __LATTICE_H_
+#define __LATTICE_H_
+
+#include <string>
+#include <vector>
+#include "wordid.h"
+
+struct LatticeArc {
+ WordID label;
+ double cost;
+ int dist2next;
+ LatticeArc() : label(), cost(), dist2next() {}
+ LatticeArc(WordID w, double c, int i) : label(w), cost(c), dist2next(i) {}
+};
+
+class Lattice : public std::vector<std::vector<LatticeArc> > {
+ public:
+ Lattice() {}
+ explicit Lattice(size_t t, const std::vector<LatticeArc>& v = std::vector<LatticeArc>()) :
+ std::vector<std::vector<LatticeArc> >(t, v) {}
+
+ // TODO add distance functions
+};
+
+struct LatticeTools {
+ static bool LooksLikePLF(const std::string &line);
+ static void ConvertTextToLattice(const std::string& text, Lattice* pl);
+ static void ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl);
+};
+
+#endif
diff --git a/src/lexcrf.cc b/src/lexcrf.cc
new file mode 100644
index 00000000..33455a3d
--- /dev/null
+++ b/src/lexcrf.cc
@@ -0,0 +1,112 @@
+#include "lexcrf.h"
+
+#include <iostream>
+
+#include "filelib.h"
+#include "hg.h"
+#include "tdict.h"
+#include "grammar.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+
+struct LexicalCRFImpl {
+ LexicalCRFImpl(const boost::program_options::variables_map& conf) :
+ use_null(false),
+ kXCAT(TD::Convert("X")*-1),
+ kNULL(TD::Convert("<eps>")),
+ kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")),
+ kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) {
+ vector<string> gfiles = conf["grammar"].as<vector<string> >();
+ assert(gfiles.size() == 1);
+ ReadFile rf(gfiles.front());
+ TextGrammar *tg = new TextGrammar;
+ grammar.reset(tg);
+ istream* in = rf.stream();
+ int lc = 0;
+ bool flag = false;
+ while(*in) {
+ string line;
+ getline(*in, line);
+ if (line.empty()) continue;
+ ++lc;
+ TRulePtr r(TRule::CreateRulePhrasetable(line));
+ tg->AddRule(r);
+ if (lc % 50000 == 0) { cerr << '.'; flag = true; }
+ if (lc % 2000000 == 0) { cerr << " [" << lc << "]\n"; flag = false; }
+ }
+ if (flag) cerr << endl;
+ cerr << "Loaded " << lc << " rules\n";
+ }
+
+ void BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) {
+ const int e_len = smeta.GetTargetLength();
+ assert(e_len > 0);
+ const int f_len = lattice.size();
+ // hack to tell the feature function system how big the sentence pair is
+ const int f_start = (use_null ? -1 : 0);
+ int prev_node_id = -1;
+ for (int i = 0; i < e_len; ++i) { // for each word in the *ref*
+ Hypergraph::Node* node = forest->AddNode(kXCAT);
+ const int new_node_id = node->id_;
+ for (int j = f_start; j < f_len; ++j) { // for each word in the source
+ const WordID src_sym = (j < 0 ? kNULL : lattice[j][0].label);
+ const GrammarIter* gi = grammar->GetRoot()->Extend(src_sym);
+ if (!gi) {
+ cerr << "No translations found for: " << TD::Convert(src_sym) << "\n";
+ abort();
+ }
+ const RuleBin* rb = gi->GetRules();
+ assert(rb);
+ for (int k = 0; k < rb->GetNumRules(); ++k) {
+ TRulePtr rule = rb->GetIthRule(k);
+ Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());
+ edge->i_ = j;
+ edge->j_ = j+1;
+ edge->prev_i_ = i;
+ edge->prev_j_ = i+1;
+ edge->feature_values_ += edge->rule_->GetFeatureValues();
+ forest->ConnectEdgeToHeadNode(edge->id_, new_node_id);
+ }
+ }
+ if (prev_node_id >= 0) {
+ const int comb_node_id = forest->AddNode(kXCAT)->id_;
+ Hypergraph::TailNodeVector tail(2, prev_node_id);
+ tail[1] = new_node_id;
+ const int edge_id = forest->AddEdge(kBINARY, tail)->id_;
+ forest->ConnectEdgeToHeadNode(edge_id, comb_node_id);
+ prev_node_id = comb_node_id;
+ } else {
+ prev_node_id = new_node_id;
+ }
+ }
+ Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
+ Hypergraph::Node* goal = forest->AddNode(TD::Convert("[Goal]")*-1);
+ Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
+ forest->ConnectEdgeToHeadNode(hg_edge, goal);
+ }
+
+ private:
+ const bool use_null;
+ const WordID kXCAT;
+ const WordID kNULL;
+ const TRulePtr kBINARY;
+ const TRulePtr kGOAL_RULE;
+ GrammarPtr grammar;
+};
+
+LexicalCRF::LexicalCRF(const boost::program_options::variables_map& conf) :
+ pimpl_(new LexicalCRFImpl(conf)) {}
+
+bool LexicalCRF::Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* forest) {
+ Lattice lattice;
+ LatticeTools::ConvertTextToLattice(input, &lattice);
+ smeta->SetSourceLength(lattice.size());
+ pimpl_->BuildTrellis(lattice, *smeta, forest);
+ forest->Reweight(weights);
+ return true;
+}
+
diff --git a/src/lexcrf.h b/src/lexcrf.h
new file mode 100644
index 00000000..99362c81
--- /dev/null
+++ b/src/lexcrf.h
@@ -0,0 +1,18 @@
+#ifndef _LEXCRF_H_
+#define _LEXCRF_H_
+
+#include "translator.h"
+#include "lattice.h"
+
+struct LexicalCRFImpl;
+struct LexicalCRF : public Translator {
+ LexicalCRF(const boost::program_options::variables_map& conf);
+ bool Translate(const std::string& input,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* forest);
+ private:
+ boost::shared_ptr<LexicalCRFImpl> pimpl_;
+};
+
+#endif
diff --git a/src/lm_ff.cc b/src/lm_ff.cc
new file mode 100644
index 00000000..f95140de
--- /dev/null
+++ b/src/lm_ff.cc
@@ -0,0 +1,328 @@
+#include "lm_ff.h"
+
+#include <sstream>
+#include <unistd.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <netinet/in.h>
+#include <netdb.h>
+
+#include "tdict.h"
+#include "Vocab.h"
+#include "Ngram.h"
+#include "hg.h"
+#include "stringlib.h"
+
+using namespace std;
+
+struct LMClient {
+ struct Cache {
+ map<WordID, Cache> tree;
+ float prob;
+ Cache() : prob() {}
+ };
+
+ LMClient(const char* host) : port(6666) {
+ s = strchr(host, ':');
+ if (s != NULL) {
+ *s = '\0';
+ ++s;
+ port = atoi(s);
+ }
+ sock = socket(AF_INET, SOCK_STREAM, 0);
+ hp = gethostbyname(host);
+ if (hp == NULL) {
+ cerr << "unknown host " << host << endl;
+ abort();
+ }
+ bzero((char *)&server, sizeof(server));
+ bcopy(hp->h_addr, (char *)&server.sin_addr, hp->h_length);
+ server.sin_family = hp->h_addrtype;
+ server.sin_port = htons(port);
+
+ int errors = 0;
+ while (connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) {
+ cerr << "Error: connect()\n";
+ sleep(1);
+ errors++;
+ if (errors > 3) exit(1);
+ }
+ cerr << "Connected to LM on " << host << " on port " << port << endl;
+ }
+
+ float wordProb(int word, int* context) {
+ Cache* cur = &cache;
+ int i = 0;
+ while (context[i] > 0) {
+ cur = &cur->tree[context[i++]];
+ }
+ cur = &cur->tree[word];
+ if (cur->prob) { return cur->prob; }
+
+ i = 0;
+ ostringstream os;
+ os << "prob " << TD::Convert(word);
+ while (context[i] > 0) {
+ os << ' ' << TD::Convert(context[i++]);
+ }
+ os << endl;
+ string out = os.str();
+ write(sock, out.c_str(), out.size());
+ int r = read(sock, res, 6);
+ int errors = 0;
+ int cnt = 0;
+ while (1) {
+ if (r < 0) {
+ errors++; sleep(1);
+ cerr << "Error: read()\n";
+ if (errors > 5) exit(1);
+ } else if (r==0 || res[cnt] == '\n') { break; }
+ else {
+ cnt += r;
+ if (cnt==6) break;
+ read(sock, &res[cnt], 6-cnt);
+ }
+ }
+ cur->prob = *reinterpret_cast<float*>(res);
+ return cur->prob;
+ }
+
+ void clear() {
+ cache.tree.clear();
+ }
+
+ private:
+ Cache cache;
+ int sock, port;
+ char *s;
+ struct hostent *hp;
+ struct sockaddr_in server;
+ char res[8];
+};
+
+class LanguageModelImpl {
+ public:
+ LanguageModelImpl(int order, const string& f) :
+ ngram_(*TD::dict_), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1),
+ floor_(-100.0),
+ client_(NULL),
+ kSTART(TD::Convert("<s>")),
+ kSTOP(TD::Convert("</s>")),
+ kUNKNOWN(TD::Convert("<unk>")),
+ kNONE(-1),
+ kSTAR(TD::Convert("<{STAR}>")) {
+ if (f.find("lm://") == 0) {
+ client_ = new LMClient(f.substr(5).c_str());
+ } else {
+ File file(f.c_str(), "r", 0);
+ assert(file);
+ cerr << "Reading " << order_ << "-gram LM from " << f << endl;
+ ngram_.read(file, false);
+ }
+ }
+
+ ~LanguageModelImpl() {
+ delete client_;
+ }
+
+ inline int StateSize(const void* state) const {
+ return *(static_cast<const char*>(state) + state_size_);
+ }
+
+ inline void SetStateSize(int size, void* state) const {
+ *(static_cast<char*>(state) + state_size_) = size;
+ }
+
+ inline double LookupProbForBufferContents(int i) {
+ double p = client_ ?
+ client_->wordProb(buffer_[i], &buffer_[i+1])
+ : ngram_.wordProb(buffer_[i], (VocabIndex*)&buffer_[i+1]);
+ if (p < floor_) p = floor_;
+ return p;
+ }
+
+ string DebugStateToString(const void* state) const {
+ int len = StateSize(state);
+ const int* astate = reinterpret_cast<const int*>(state);
+ string res = "[";
+ for (int i = 0; i < len; ++i) {
+ res += " ";
+ res += TD::Convert(astate[i]);
+ }
+ res += " ]";
+ return res;
+ }
+
+ inline double ProbNoRemnant(int i, int len) {
+ int edge = len;
+ bool flag = true;
+ double sum = 0.0;
+ while (i >= 0) {
+ if (buffer_[i] == kSTAR) {
+ edge = i;
+ flag = false;
+ } else if (buffer_[i] <= 0) {
+ edge = i;
+ flag = true;
+ } else {
+ if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART)))
+ sum += LookupProbForBufferContents(i);
+ }
+ --i;
+ }
+ return sum;
+ }
+
+ double EstimateProb(const vector<WordID>& phrase) {
+ int len = phrase.size();
+ buffer_.resize(len + 1);
+ buffer_[len] = kNONE;
+ int i = len - 1;
+ for (int j = 0; j < len; ++j,--i)
+ buffer_[i] = phrase[j];
+ return ProbNoRemnant(len - 1, len);
+ }
+
+ double EstimateProb(const void* state) {
+ int len = StateSize(state);
+ // cerr << "residual len: " << len << endl;
+ buffer_.resize(len + 1);
+ buffer_[len] = kNONE;
+ const int* astate = reinterpret_cast<const int*>(state);
+ int i = len - 1;
+ for (int j = 0; j < len; ++j,--i)
+ buffer_[i] = astate[j];
+ return ProbNoRemnant(len - 1, len);
+ }
+
+ double FinalTraversalCost(const void* state) {
+ int slen = StateSize(state);
+ int len = slen + 2;
+ // cerr << "residual len: " << len << endl;
+ buffer_.resize(len + 1);
+ buffer_[len] = kNONE;
+ buffer_[len-1] = kSTART;
+ const int* astate = reinterpret_cast<const int*>(state);
+ int i = len - 2;
+ for (int j = 0; j < slen; ++j,--i)
+ buffer_[i] = astate[j];
+ buffer_[i] = kSTOP;
+ assert(i == 0);
+ return ProbNoRemnant(len - 1, len);
+ }
+
+ double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate) {
+ int len = rule.ELength() - rule.Arity();
+ for (int i = 0; i < ant_states.size(); ++i)
+ len += StateSize(ant_states[i]);
+ buffer_.resize(len + 1);
+ buffer_[len] = kNONE;
+ int i = len - 1;
+ const vector<WordID>& e = rule.e();
+ for (int j = 0; j < e.size(); ++j) {
+ if (e[j] < 1) {
+ const int* astate = reinterpret_cast<const int*>(ant_states[-e[j]]);
+ int slen = StateSize(astate);
+ for (int k = 0; k < slen; ++k)
+ buffer_[i--] = astate[k];
+ } else {
+ buffer_[i--] = e[j];
+ }
+ }
+
+ double sum = 0.0;
+ int* remnant = reinterpret_cast<int*>(vstate);
+ int j = 0;
+ i = len - 1;
+ int edge = len;
+
+ while (i >= 0) {
+ if (buffer_[i] == kSTAR) {
+ edge = i;
+ } else if (edge-i >= order_) {
+ sum += LookupProbForBufferContents(i);
+ } else if (edge == len && remnant) {
+ remnant[j++] = buffer_[i];
+ }
+ --i;
+ }
+ if (!remnant) return sum;
+
+ if (edge != len || len >= order_) {
+ remnant[j++] = kSTAR;
+ if (order_-1 < edge) edge = order_-1;
+ for (int i = edge-1; i >= 0; --i)
+ remnant[j++] = buffer_[i];
+ }
+
+ SetStateSize(j, vstate);
+ return sum;
+ }
+
+ static int OrderToStateSize(int order) {
+ return ((order-1) * 2 + 1) * sizeof(WordID) + 1;
+ }
+
+ private:
+ Ngram ngram_;
+ vector<WordID> buffer_;
+ const int order_;
+ const int state_size_;
+ const double floor_;
+ LMClient* client_;
+
+ public:
+ const WordID kSTART;
+ const WordID kSTOP;
+ const WordID kUNKNOWN;
+ const WordID kNONE;
+ const WordID kSTAR;
+};
+
+LanguageModel::LanguageModel(const string& param) :
+ fid_(FD::Convert("LanguageModel")) {
+ vector<string> argv;
+ int argc = SplitOnWhitespace(param, &argv);
+ int order = 3;
+ // TODO add support for -n FeatureName
+ string filename;
+ if (argc < 1) { cerr << "LanguageModel requires a filename, minimally!\n"; abort(); }
+ else if (argc == 1) { filename = argv[0]; }
+ else if (argc == 2 || argc > 3) { cerr << "Don't understand 'LanguageModel " << param << "'\n"; }
+ else if (argc == 3) {
+ if (argv[0] == "-o") {
+ order = atoi(argv[1].c_str());
+ filename = argv[2];
+ } else if (argv[1] == "-o") {
+ order = atoi(argv[2].c_str());
+ filename = argv[0];
+ }
+ }
+ SetStateSize(LanguageModelImpl::OrderToStateSize(order));
+ pimpl_ = new LanguageModelImpl(order, filename);
+}
+
+LanguageModel::~LanguageModel() {
+ delete pimpl_;
+}
+
+string LanguageModel::DebugStateToString(const void* state) const{
+ return pimpl_->DebugStateToString(state);
+}
+
+void LanguageModel::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ (void) smeta;
+ features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state));
+ estimated_features->set_value(fid_, pimpl_->EstimateProb(state));
+}
+
+void LanguageModel::FinalTraversalFeatures(const void* ant_state,
+ SparseVector<double>* features) const {
+ features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state));
+}
+
diff --git a/src/lm_ff.h b/src/lm_ff.h
new file mode 100644
index 00000000..cd717360
--- /dev/null
+++ b/src/lm_ff.h
@@ -0,0 +1,32 @@
+#ifndef _LM_FF_H_
+#define _LM_FF_H_
+
+#include <vector>
+#include <string>
+
+#include "hg.h"
+#include "ff.h"
+
+class LanguageModelImpl;
+
+class LanguageModel : public FeatureFunction {
+ public:
+ // param = "filename.lm [-o n]"
+ LanguageModel(const std::string& param);
+ ~LanguageModel();
+ virtual void FinalTraversalFeatures(const void* context,
+ SparseVector<double>* features) const;
+ std::string DebugStateToString(const void* state) const;
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ const int fid_;
+ mutable LanguageModelImpl* pimpl_;
+};
+
+#endif
diff --git a/src/logval.h b/src/logval.h
new file mode 100644
index 00000000..a8ca620c
--- /dev/null
+++ b/src/logval.h
@@ -0,0 +1,136 @@
+#ifndef LOGVAL_H_
+#define LOGVAL_H_
+
+#include <cmath>
+#include <limits>
+
+template <typename T>
+class LogVal {
+ public:
+ LogVal() : v_(-std::numeric_limits<T>::infinity()) {}
+ explicit LogVal(double x) : v_(std::log(x)) {}
+ LogVal<T>(const LogVal<T>& o) : v_(o.v_) {}
+ static LogVal<T> One() { return LogVal(1); }
+ static LogVal<T> Zero() { return LogVal(); }
+
+ void logeq(const T& v) { v_ = v; }
+
+ LogVal& operator+=(const LogVal& a) {
+ if (a.v_ == -std::numeric_limits<T>::infinity()) return *this;
+ if (a.v_ < v_) {
+ v_ = v_ + log1p(std::exp(a.v_ - v_));
+ } else {
+ v_ = a.v_ + log1p(std::exp(v_ - a.v_));
+ }
+ return *this;
+ }
+
+ LogVal& operator*=(const LogVal& a) {
+ v_ += a.v_;
+ return *this;
+ }
+
+ LogVal& operator*=(const T& a) {
+ v_ += log(a);
+ return *this;
+ }
+
+ LogVal& operator/=(const LogVal& a) {
+ v_ -= a.v_;
+ return *this;
+ }
+
+ LogVal& poweq(const T& power) {
+ if (power == 0) v_ = 0; else v_ *= power;
+ return *this;
+ }
+
+ LogVal pow(const T& power) const {
+ LogVal res = *this;
+ res.poweq(power);
+ return res;
+ }
+
+ operator T() const {
+ return std::exp(v_);
+ }
+
+ T v_;
+};
+
+template<typename T>
+LogVal<T> operator+(const LogVal<T>& o1, const LogVal<T>& o2) {
+ LogVal<T> res(o1);
+ res += o2;
+ return res;
+}
+
+template<typename T>
+LogVal<T> operator*(const LogVal<T>& o1, const LogVal<T>& o2) {
+ LogVal<T> res(o1);
+ res *= o2;
+ return res;
+}
+
+template<typename T>
+LogVal<T> operator*(const LogVal<T>& o1, const T& o2) {
+ LogVal<T> res(o1);
+ res *= o2;
+ return res;
+}
+
+template<typename T>
+LogVal<T> operator*(const T& o1, const LogVal<T>& o2) {
+ LogVal<T> res(o2);
+ res *= o1;
+ return res;
+}
+
+template<typename T>
+LogVal<T> operator/(const LogVal<T>& o1, const LogVal<T>& o2) {
+ LogVal<T> res(o1);
+ res /= o2;
+ return res;
+}
+
+template<typename T>
+T log(const LogVal<T>& o) {
+ return o.v_;
+}
+
+template <typename T>
+LogVal<T> pow(const LogVal<T>& b, const T& e) {
+ return b.pow(e);
+}
+
+template <typename T>
+bool operator<(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ < rhs.v_);
+}
+
+template <typename T>
+bool operator<=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ <= rhs.v_);
+}
+
+template <typename T>
+bool operator>(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ > rhs.v_);
+}
+
+template <typename T>
+bool operator>=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ >= rhs.v_);
+}
+
+template <typename T>
+bool operator==(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ == rhs.v_);
+}
+
+template <typename T>
+bool operator!=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ != rhs.v_);
+}
+
+#endif
diff --git a/src/maxtrans_blunsom.cc b/src/maxtrans_blunsom.cc
new file mode 100644
index 00000000..4a6680e0
--- /dev/null
+++ b/src/maxtrans_blunsom.cc
@@ -0,0 +1,287 @@
+#include "apply_models.h"
+
+#include <vector>
+#include <algorithm>
+#include <tr1/unordered_map>
+#include <tr1/unordered_set>
+
+#include <boost/tuple/tuple.hpp>
+#include <boost/functional/hash.hpp>
+
+#include "tdict.h"
+#include "hg.h"
+#include "ff.h"
+
+using boost::tuple;
+using namespace std;
+using namespace std::tr1;
+
+namespace Hack {
+
+struct Candidate;
+typedef SmallVector JVector;
+typedef vector<Candidate*> CandidateHeap;
+typedef vector<Candidate*> CandidateList;
+
+// life cycle: candidates are created, placed on the heap
+// and retrieved by their estimated cost, when they're
+// retrieved, they're incorporated into the +LM hypergraph
+// where they also know the head node index they are
+// attached to. After they are added to the +LM hypergraph
+// inside_prob_ and est_prob_ fields may be updated as better
+// derivations are found (this happens since the successor's
+// of derivation d may have a better score- they are
+// explored lazily). However, the updates don't happen
+// when a candidate is in the heap so maintaining the heap
+// property is not an issue.
+struct Candidate {
+ int node_index_; // -1 until incorporated
+ // into the +LM forest
+ const Hypergraph::Edge* in_edge_; // in -LM forest
+ Hypergraph::Edge out_edge_;
+ vector<WordID> state_;
+ const JVector j_;
+ prob_t inside_prob_; // these are fixed until the cand
+ // is popped, then they may be updated
+ prob_t est_prob_;
+
+ Candidate(const Hypergraph::Edge& e,
+ const JVector& j,
+ const vector<CandidateList>& D,
+ bool is_goal) :
+ node_index_(-1),
+ in_edge_(&e),
+ j_(j) {
+ InitializeCandidate(D, is_goal);
+ }
+
+ // used to query uniqueness
+ Candidate(const Hypergraph::Edge& e,
+ const JVector& j) : in_edge_(&e), j_(j) {}
+
+ bool IsIncorporatedIntoHypergraph() const {
+ return node_index_ >= 0;
+ }
+
+ void InitializeCandidate(const vector<vector<Candidate*> >& D,
+ const bool is_goal) {
+ const Hypergraph::Edge& in_edge = *in_edge_;
+ out_edge_.rule_ = in_edge.rule_;
+ out_edge_.feature_values_ = in_edge.feature_values_;
+ Hypergraph::TailNodeVector& tail = out_edge_.tail_nodes_;
+ tail.resize(j_.size());
+ prob_t p = prob_t::One();
+ // cerr << "\nEstimating application of " << in_edge.rule_->AsString() << endl;
+ vector<const vector<WordID>* > ants(tail.size());
+ for (int i = 0; i < tail.size(); ++i) {
+ const Candidate& ant = *D[in_edge.tail_nodes_[i]][j_[i]];
+ ants[i] = &ant.state_;
+ assert(ant.IsIncorporatedIntoHypergraph());
+ tail[i] = ant.node_index_;
+ p *= ant.inside_prob_;
+ }
+ prob_t edge_estimate = prob_t::One();
+ if (is_goal) {
+ assert(tail.size() == 1);
+ out_edge_.edge_prob_ = in_edge.edge_prob_;
+ } else {
+ in_edge.rule_->ESubstitute(ants, &state_);
+ out_edge_.edge_prob_ = in_edge.edge_prob_;
+ }
+ inside_prob_ = out_edge_.edge_prob_ * p;
+ est_prob_ = inside_prob_ * edge_estimate;
+ }
+};
+
+ostream& operator<<(ostream& os, const Candidate& cand) {
+ os << "CAND[";
+ if (!cand.IsIncorporatedIntoHypergraph()) { os << "PENDING "; }
+ else { os << "+LM_node=" << cand.node_index_; }
+ os << " edge=" << cand.in_edge_->id_;
+ os << " j=<";
+ for (int i = 0; i < cand.j_.size(); ++i)
+ os << (i==0 ? "" : " ") << cand.j_[i];
+ os << "> vit=" << log(cand.inside_prob_);
+ os << " est=" << log(cand.est_prob_);
+ return os << ']';
+}
+
+struct HeapCandCompare {
+ bool operator()(const Candidate* l, const Candidate* r) const {
+ return l->est_prob_ < r->est_prob_;
+ }
+};
+
+struct EstProbSorter {
+ bool operator()(const Candidate* l, const Candidate* r) const {
+ return l->est_prob_ > r->est_prob_;
+ }
+};
+
+// the same candidate <edge, j> can be added multiple times if
+// j is multidimensional (if you're going NW in Manhattan, you
+// can first go north, then west, or you can go west then north)
+// this is a hash function on the relevant variables from
+// Candidate to enforce this.
+struct CandidateUniquenessHash {
+ size_t operator()(const Candidate* c) const {
+ size_t x = 5381;
+ x = ((x << 5) + x) ^ c->in_edge_->id_;
+ for (int i = 0; i < c->j_.size(); ++i)
+ x = ((x << 5) + x) ^ c->j_[i];
+ return x;
+ }
+};
+
+struct CandidateUniquenessEquals {
+ bool operator()(const Candidate* a, const Candidate* b) const {
+ return (a->in_edge_ == b->in_edge_) && (a->j_ == b->j_);
+ }
+};
+
+typedef unordered_set<const Candidate*, CandidateUniquenessHash, CandidateUniquenessEquals> UniqueCandidateSet;
+typedef unordered_map<vector<WordID>, Candidate*, boost::hash<vector<WordID> > > State2Node;
+
+class MaxTransBeamSearch {
+
+public:
+ MaxTransBeamSearch(const Hypergraph& i, int pop_limit, Hypergraph* o) :
+ in(i),
+ out(*o),
+ D(in.nodes_.size()),
+ pop_limit_(pop_limit) {
+ cerr << " Finding max translation (cube pruning, pop_limit = " << pop_limit_ << ')' << endl;
+ }
+
+ void Apply() {
+ int num_nodes = in.nodes_.size();
+ int goal_id = num_nodes - 1;
+ int pregoal = goal_id - 1;
+ assert(in.nodes_[pregoal].out_edges_.size() == 1);
+ cerr << " ";
+ for (int i = 0; i < in.nodes_.size(); ++i) {
+ cerr << '.';
+ KBest(i, i == goal_id);
+ }
+ cerr << endl;
+ int best_node = D[goal_id].front()->in_edge_->tail_nodes_.front();
+ Candidate& best = *D[best_node].front();
+ cerr << " Best path: " << log(best.inside_prob_)
+ << "\t" << log(best.est_prob_) << endl;
+ cout << TD::GetString(D[best_node].front()->state_) << endl;
+ FreeAll();
+ }
+
+ private:
+ void FreeAll() {
+ for (int i = 0; i < D.size(); ++i) {
+ CandidateList& D_i = D[i];
+ for (int j = 0; j < D_i.size(); ++j)
+ delete D_i[j];
+ }
+ D.clear();
+ }
+
+ void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) {
+ Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_.rule_, item->out_edge_.tail_nodes_);
+ new_edge->feature_values_ = item->out_edge_.feature_values_;
+ new_edge->edge_prob_ = item->out_edge_.edge_prob_;
+ Candidate*& o_item = (*s2n)[item->state_];
+ if (!o_item) o_item = item;
+
+ int& node_id = o_item->node_index_;
+ if (node_id < 0) {
+ Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_, "");
+ node_id = new_node->id_;
+ }
+ Hypergraph::Node* node = &out.nodes_[node_id];
+ out.ConnectEdgeToHeadNode(new_edge, node);
+
+ if (item != o_item) {
+ assert(o_item->state_ == item->state_); // sanity check!
+ o_item->est_prob_ += item->est_prob_;
+ o_item->inside_prob_ += item->inside_prob_;
+ freelist->push_back(item);
+ }
+ }
+
+ void KBest(const int vert_index, const bool is_goal) {
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ UniqueCandidateSet unique_cands;
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, D, is_goal));
+ assert(unique_cands.insert(cand.back()).second); // these should all be unique!
+ }
+// cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ // cerr << "POPPED: " << *item << endl;
+ PushSucc(*item, is_goal, &cand, &unique_cands);
+ IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i)
+ D_v[c++] = i->second;
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
+ void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) {
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate query_unique(*item.in_edge_, j);
+ if (cs->count(&query_unique) == 0) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, D, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ assert(cs->insert(new_cand).second); // insert into uniqueness set, sanity check
+ }
+ }
+ }
+ }
+
+ const Hypergraph& in;
+ Hypergraph& out;
+
+ vector<CandidateList> D; // maps nodes in in-HG to the
+ // equivalent nodes (many due to state
+ // splits) in the out-HG.
+ const int pop_limit_;
+};
+
+// each node in the graph has one of these, it keeps track of
+void MaxTrans(const Hypergraph& in,
+ int beam_size) {
+ Hypergraph out;
+ MaxTransBeamSearch ma(in, beam_size, &out);
+ ma.Apply();
+}
+
+}
diff --git a/src/parser_test.cc b/src/parser_test.cc
new file mode 100644
index 00000000..da1fbd89
--- /dev/null
+++ b/src/parser_test.cc
@@ -0,0 +1,35 @@
+#include <cassert>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <gtest/gtest.h>
+#include "hg.h"
+#include "trule.h"
+#include "bottom_up_parser.h"
+#include "tdict.h"
+
+using namespace std;
+
+class ChartTest : public testing::Test {
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+};
+
+TEST_F(ChartTest,LanguageModel) {
+ LatticeArc a(TD::Convert("ein"), 0.0, 1);
+ LatticeArc b(TD::Convert("haus"), 0.0, 1);
+ Lattice lattice(2);
+ lattice[0].push_back(a);
+ lattice[1].push_back(b);
+ Hypergraph forest;
+ GrammarPtr g(new TextGrammar);
+ vector<GrammarPtr> grammars(1, g);
+ ExhaustiveBottomUpParser parser("PHRASE", grammars);
+ parser.Parse(lattice, &forest);
+}
+
+int main(int argc, char **argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/phrasebased_translator.cc b/src/phrasebased_translator.cc
new file mode 100644
index 00000000..5eb70876
--- /dev/null
+++ b/src/phrasebased_translator.cc
@@ -0,0 +1,206 @@
+#include "phrasebased_translator.h"
+
+#include <queue>
+#include <iostream>
+#include <tr1/unordered_map>
+#include <tr1/unordered_set>
+
+#include <boost/tuple/tuple.hpp>
+#include <boost/functional/hash.hpp>
+
+#include "sentence_metadata.h"
+#include "tdict.h"
+#include "hg.h"
+#include "filelib.h"
+#include "lattice.h"
+#include "phrasetable_fst.h"
+#include "array2d.h"
+
+using namespace std;
+using namespace std::tr1;
+using namespace boost::tuples;
+
+struct Coverage : public vector<bool> {
+ explicit Coverage(int n, bool v = false) : vector<bool>(n, v), first_gap() {}
+ void Cover(int i, int j) {
+ vector<bool>::iterator it = this->begin() + i;
+ vector<bool>::iterator end = this->begin() + j;
+ while (it != end)
+ *it++ = true;
+ if (first_gap == i) {
+ first_gap = j;
+ it = end;
+ while (*it && it != this->end()) {
+ ++it;
+ ++first_gap;
+ }
+ }
+ }
+ bool Collides(int i, int j) const {
+ vector<bool>::const_iterator it = this->begin() + i;
+ vector<bool>::const_iterator end = this->begin() + j;
+ while (it != end)
+ if (*it++) return true;
+ return false;
+ }
+ int GetFirstGap() const { return first_gap; }
+ private:
+ int first_gap;
+};
+struct CoverageHash {
+ size_t operator()(const Coverage& cov) const {
+ return hasher_(static_cast<const vector<bool>&>(cov));
+ }
+ private:
+ boost::hash<vector<bool> > hasher_;
+};
+ostream& operator<<(ostream& os, const Coverage& cov) {
+ os << '[';
+ for (int i = 0; i < cov.size(); ++i)
+ os << (cov[i] ? '*' : '.');
+ return os << " gap=" << cov.GetFirstGap() << ']';
+}
+
+typedef unordered_map<Coverage, int, CoverageHash> CoverageNodeMap;
+typedef unordered_set<Coverage, CoverageHash> UniqueCoverageSet;
+
+struct PhraseBasedTranslatorImpl {
+ PhraseBasedTranslatorImpl(const boost::program_options::variables_map& conf) :
+ add_pass_through_rules(conf.count("add_pass_through_rules")),
+ max_distortion(conf["pb_max_distortion"].as<int>()),
+ kSOURCE_RULE(new TRule("[X] ||| [X,1] ||| [X,1]", true)),
+ kCONCAT_RULE(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]", true)),
+ kNT_TYPE(TD::Convert("X") * -1) {
+ assert(max_distortion >= 0);
+ vector<string> gfiles = conf["grammar"].as<vector<string> >();
+ assert(gfiles.size() == 1);
+ cerr << "Reading phrasetable from " << gfiles.front() << endl;
+ ReadFile in(gfiles.front());
+ fst.reset(LoadTextPhrasetable(in.stream()));
+ }
+
+ struct State {
+ State(const Coverage& c, int _i, int _j, const FSTNode* q) :
+ coverage(c), i(_i), j(_j), fst(q) {}
+ Coverage coverage;
+ int i;
+ int j;
+ const FSTNode* fst;
+ };
+
+ // we keep track of unique coverages that have been extended since it's
+ // possible to "extend" the same coverage twice, e.g. translate "a b c"
+ // with phrases "a" "b" "a b" and "c". There are two ways to cover "a b"
+ void EnqueuePossibleContinuations(const Coverage& coverage, queue<State>* q, UniqueCoverageSet* ucs) {
+ if (ucs->insert(coverage).second) {
+ const int gap = coverage.GetFirstGap();
+ const int end = min(static_cast<int>(coverage.size()), gap + max_distortion + 1);
+ for (int i = gap; i < end; ++i)
+ if (!coverage[i]) q->push(State(coverage, i, i, fst.get()));
+ }
+ }
+
+ bool Translate(const std::string& input,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ Lattice lattice;
+ LatticeTools::ConvertTextOrPLF(input, &lattice);
+ smeta->SetSourceLength(lattice.size());
+ size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion);
+ minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100);
+ if (add_pass_through_rules) {
+ SparseVector<double> feats;
+ feats.set_value(FD::Convert("PassThrough"), 1);
+ for (int i = 0; i < lattice.size(); ++i) {
+ const vector<LatticeArc>& arcs = lattice[i];
+ for (int j = 0; j < arcs.size(); ++j) {
+ fst->AddPassThroughTranslation(arcs[j].label, feats);
+ // TODO handle lattice edge features
+ }
+ }
+ }
+ CoverageNodeMap c;
+ queue<State> q;
+ UniqueCoverageSet ucs;
+ const Coverage empty_cov(lattice.size(), false);
+ const Coverage goal_cov(lattice.size(), true);
+ EnqueuePossibleContinuations(empty_cov, &q, &ucs);
+ c[empty_cov] = 0; // have to handle the left edge specially
+ while(!q.empty()) {
+ const State s = q.front();
+ q.pop();
+ // cerr << "(" << s.i << "," << s.j << " ptr=" << s.fst << ") cov=" << s.coverage << endl;
+ const vector<LatticeArc>& arcs = lattice[s.j];
+ if (s.fst->HasData()) {
+ Coverage new_cov = s.coverage;
+ new_cov.Cover(s.i, s.j);
+ EnqueuePossibleContinuations(new_cov, &q, &ucs);
+ const vector<TRulePtr>& phrases = s.fst->GetTranslations()->GetRules();
+ const int phrase_head_index = minus_lm_forest->AddNode(kNT_TYPE)->id_;
+ for (int i = 0; i < phrases.size(); ++i) {
+ Hypergraph::Edge* edge = minus_lm_forest->AddEdge(phrases[i], Hypergraph::TailNodeVector());
+ edge->feature_values_ = edge->rule_->scores_;
+ minus_lm_forest->ConnectEdgeToHeadNode(edge->id_, phrase_head_index);
+ }
+ CoverageNodeMap::iterator cit = c.find(s.coverage);
+ assert(cit != c.end());
+ const int tail_node_plus1 = cit->second;
+ if (tail_node_plus1 == 0) { // left edge
+ c[new_cov] = phrase_head_index + 1;
+ } else { // not left edge
+ int& head_node_plus1 = c[new_cov];
+ if (!head_node_plus1)
+ head_node_plus1 = minus_lm_forest->AddNode(kNT_TYPE)->id_ + 1;
+ Hypergraph::TailNodeVector tail(2, tail_node_plus1 - 1);
+ tail[1] = phrase_head_index;
+ const int concat_edge = minus_lm_forest->AddEdge(kCONCAT_RULE, tail)->id_;
+ minus_lm_forest->ConnectEdgeToHeadNode(concat_edge, head_node_plus1 - 1);
+ }
+ }
+ if (s.j == lattice.size()) continue;
+ for (int l = 0; l < arcs.size(); ++l) {
+ const LatticeArc& arc = arcs[l];
+
+ const FSTNode* next_fst_state = s.fst->Extend(arc.label);
+ const int next_j = s.j + arc.dist2next;
+ if (next_fst_state &&
+ !s.coverage.Collides(s.i, next_j)) {
+ q.push(State(s.coverage, s.i, next_j, next_fst_state));
+ }
+ }
+ }
+ if (add_pass_through_rules)
+ fst->ClearPassThroughTranslations();
+ int pregoal_plus1 = c[goal_cov];
+ if (pregoal_plus1 > 0) {
+ TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [X,1]"));
+ int goal = minus_lm_forest->AddNode(TD::Convert("Goal") * -1)->id_;
+ int gedge = minus_lm_forest->AddEdge(kGOAL_RULE, Hypergraph::TailNodeVector(1, pregoal_plus1 - 1))->id_;
+ minus_lm_forest->ConnectEdgeToHeadNode(gedge, goal);
+ // they are almost topo, but not quite always
+ minus_lm_forest->TopologicallySortNodesAndEdges(goal);
+ minus_lm_forest->Reweight(weights);
+ return true;
+ } else {
+ return false; // composition failed
+ }
+ }
+
+ const bool add_pass_through_rules;
+ const int max_distortion;
+ TRulePtr kSOURCE_RULE;
+ const TRulePtr kCONCAT_RULE;
+ const WordID kNT_TYPE;
+ boost::shared_ptr<FSTNode> fst;
+};
+
+PhraseBasedTranslator::PhraseBasedTranslator(const boost::program_options::variables_map& conf) :
+ pimpl_(new PhraseBasedTranslatorImpl(conf)) {}
+
+bool PhraseBasedTranslator::Translate(const std::string& input,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ return pimpl_->Translate(input, smeta, weights, minus_lm_forest);
+}
diff --git a/src/phrasebased_translator.h b/src/phrasebased_translator.h
new file mode 100644
index 00000000..d42ce79c
--- /dev/null
+++ b/src/phrasebased_translator.h
@@ -0,0 +1,18 @@
+#ifndef _PHRASEBASED_TRANSLATOR_H_
+#define _PHRASEBASED_TRANSLATOR_H_
+
+#include "translator.h"
+
+class PhraseBasedTranslatorImpl;
+class PhraseBasedTranslator : public Translator {
+ public:
+ PhraseBasedTranslator(const boost::program_options::variables_map& conf);
+ bool Translate(const std::string& input,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest);
+ private:
+ boost::shared_ptr<PhraseBasedTranslatorImpl> pimpl_;
+};
+
+#endif
diff --git a/src/phrasetable_fst.cc b/src/phrasetable_fst.cc
new file mode 100644
index 00000000..f421e941
--- /dev/null
+++ b/src/phrasetable_fst.cc
@@ -0,0 +1,141 @@
+#include "phrasetable_fst.h"
+
+#include <cassert>
+#include <iostream>
+#include <map>
+
+#include <boost/shared_ptr.hpp>
+
+#include "filelib.h"
+#include "tdict.h"
+
+using boost::shared_ptr;
+using namespace std;
+
+TargetPhraseSet::~TargetPhraseSet() {}
+FSTNode::~FSTNode() {}
+
+class TextTargetPhraseSet : public TargetPhraseSet {
+ public:
+ void AddRule(TRulePtr rule) {
+ rules_.push_back(rule);
+ }
+ const vector<TRulePtr>& GetRules() const {
+ return rules_;
+ }
+
+ private:
+ // all rules must have arity 0
+ vector<TRulePtr> rules_;
+};
+
+class TextFSTNode : public FSTNode {
+ public:
+ const TargetPhraseSet* GetTranslations() const { return data.get(); }
+ bool HasData() const { return (bool)data; }
+ bool HasOutgoingNonEpsilonEdges() const { return !ptr.empty(); }
+ const FSTNode* Extend(const WordID& t) const {
+ map<WordID, TextFSTNode>::const_iterator it = ptr.find(t);
+ if (it == ptr.end()) return NULL;
+ return &it->second;
+ }
+
+ void AddPhrase(const string& phrase);
+
+ void AddPassThroughTranslation(const WordID& w, const SparseVector<double>& feats);
+ void ClearPassThroughTranslations();
+ private:
+ vector<WordID> passthroughs;
+ shared_ptr<TargetPhraseSet> data;
+ map<WordID, TextFSTNode> ptr;
+};
+
+#ifdef DEBUG_CHART_PARSER
+static string TrimRule(const string& r) {
+ size_t start = r.find(" |||") + 5;
+ size_t end = r.rfind(" |||");
+ return r.substr(start, end - start);
+}
+#endif
+
+void TextFSTNode::AddPhrase(const string& phrase) {
+ vector<WordID> words;
+ TRulePtr rule(TRule::CreateRulePhrasetable(phrase));
+ if (!rule) {
+ static int err = 0;
+ ++err;
+ if (err > 2) { cerr << "TOO MANY PHRASETABLE ERRORS\n"; exit(1); }
+ return;
+ }
+
+ TextFSTNode* fsa = this;
+ for (int i = 0; i < rule->FLength(); ++i)
+ fsa = &fsa->ptr[rule->f_[i]];
+
+ if (!fsa->data)
+ fsa->data.reset(new TextTargetPhraseSet);
+ static_cast<TextTargetPhraseSet*>(fsa->data.get())->AddRule(rule);
+}
+
+void TextFSTNode::AddPassThroughTranslation(const WordID& w, const SparseVector<double>& feats) {
+ TextFSTNode* next = &ptr[w];
+ // current, rules are only added if the symbol is completely missing as a
+ // word starting the phrase. As a result, it is possible that some sentences
+ // won't parse. If this becomes a problem, fix it here.
+ if (!next->data) {
+ TextTargetPhraseSet* tps = new TextTargetPhraseSet;
+ next->data.reset(tps);
+ TRule* rule = new TRule;
+ rule->e_.resize(1, w);
+ rule->f_.resize(1, w);
+ rule->lhs_ = TD::Convert("___PHRASE") * -1;
+ rule->scores_ = feats;
+ rule->arity_ = 0;
+ tps->AddRule(TRulePtr(rule));
+ passthroughs.push_back(w);
+ }
+}
+
+void TextFSTNode::ClearPassThroughTranslations() {
+ for (int i = 0; i < passthroughs.size(); ++i)
+ ptr.erase(passthroughs[i]);
+ passthroughs.clear();
+}
+
+static void AddPhrasetableToFST(istream* in, TextFSTNode* fst) {
+ int lc = 0;
+ bool flag = false;
+ while(*in) {
+ string line;
+ getline(*in, line);
+ if (line.empty()) continue;
+ ++lc;
+ fst->AddPhrase(line);
+ if (lc % 10000 == 0) { flag = true; cerr << '.' << flush; }
+ if (lc % 500000 == 0) { flag = false; cerr << " [" << lc << ']' << endl << flush; }
+ }
+ if (flag) cerr << endl;
+ cerr << "Loaded " << lc << " source phrases\n";
+}
+
+FSTNode* LoadTextPhrasetable(istream* in) {
+ TextFSTNode *fst = new TextFSTNode;
+ AddPhrasetableToFST(in, fst);
+ return fst;
+}
+
+FSTNode* LoadTextPhrasetable(const vector<string>& filenames) {
+ TextFSTNode* fst = new TextFSTNode;
+ for (int i = 0; i < filenames.size(); ++i) {
+ ReadFile rf(filenames[i]);
+ cerr << "Reading phrase from " << filenames[i] << endl;
+ AddPhrasetableToFST(rf.stream(), fst);
+ }
+ return fst;
+}
+
+FSTNode* LoadBinaryPhrasetable(const string& fname_prefix) {
+ (void) fname_prefix;
+ assert(!"not implemented yet");
+}
+
diff --git a/src/phrasetable_fst.h b/src/phrasetable_fst.h
new file mode 100644
index 00000000..477de1f7
--- /dev/null
+++ b/src/phrasetable_fst.h
@@ -0,0 +1,34 @@
+#ifndef _PHRASETABLE_FST_H_
+#define _PHRASETABLE_FST_H_
+
+#include <vector>
+#include <string>
+
+#include "sparse_vector.h"
+#include "trule.h"
+
+class TargetPhraseSet {
+ public:
+ virtual ~TargetPhraseSet();
+ virtual const std::vector<TRulePtr>& GetRules() const = 0;
+};
+
+class FSTNode {
+ public:
+ virtual ~FSTNode();
+ virtual const TargetPhraseSet* GetTranslations() const = 0;
+ virtual bool HasData() const = 0;
+ virtual bool HasOutgoingNonEpsilonEdges() const = 0;
+ virtual const FSTNode* Extend(const WordID& t) const = 0;
+
+ // these should only be called on q_0:
+ virtual void AddPassThroughTranslation(const WordID& w, const SparseVector<double>& feats) = 0;
+ virtual void ClearPassThroughTranslations() = 0;
+};
+
+// attn caller: you own the memory
+FSTNode* LoadTextPhrasetable(const std::vector<std::string>& filenames);
+FSTNode* LoadTextPhrasetable(std::istream* in);
+FSTNode* LoadBinaryPhrasetable(const std::string& fname_prefix);
+
+#endif
diff --git a/src/prob.h b/src/prob.h
new file mode 100644
index 00000000..bc297870
--- /dev/null
+++ b/src/prob.h
@@ -0,0 +1,8 @@
+#ifndef _PROB_H_
+#define _PROB_H_
+
+#include "logval.h"
+
+typedef LogVal<double> prob_t;
+
+#endif
diff --git a/src/sampler.h b/src/sampler.h
new file mode 100644
index 00000000..e5840f41
--- /dev/null
+++ b/src/sampler.h
@@ -0,0 +1,136 @@
+#ifndef SAMPLER_H_
+#define SAMPLER_H_
+
+#include <algorithm>
+#include <functional>
+#include <numeric>
+#include <iostream>
+#include <fstream>
+#include <vector>
+
+#include <boost/random/mersenne_twister.hpp>
+#include <boost/random/uniform_real.hpp>
+#include <boost/random/variate_generator.hpp>
+#include <boost/random/normal_distribution.hpp>
+#include <boost/random/poisson_distribution.hpp>
+
+#include "prob.h"
+
+struct SampleSet;
+
+template <typename RNG>
+struct RandomNumberGenerator {
+ static uint32_t GetTrulyRandomSeed() {
+ uint32_t seed;
+ std::ifstream r("/dev/urandom");
+ if (r) {
+ r.read((char*)&seed,sizeof(uint32_t));
+ }
+ if (r.fail() || !r) {
+ std::cerr << "Warning: could not read from /dev/urandom. Seeding from clock" << std::endl;
+ seed = time(NULL);
+ }
+ std::cerr << "Seeding random number sequence to " << seed << std::endl;
+ return seed;
+ }
+
+ RandomNumberGenerator() : m_dist(0,1), m_generator(), m_random(m_generator,m_dist) {
+ uint32_t seed = GetTrulyRandomSeed();
+ m_generator.seed(seed);
+ }
+ explicit RandomNumberGenerator(uint32_t seed) : m_dist(0,1), m_generator(), m_random(m_generator,m_dist) {
+ if (!seed) seed = GetTrulyRandomSeed();
+ m_generator.seed(seed);
+ }
+
+ size_t SelectSample(const prob_t& a, const prob_t& b, double T = 1.0) {
+ if (T == 1.0) {
+ if (this->next() > (a / (a + b))) return 1; else return 0;
+ } else {
+ assert(!"not implemented");
+ }
+ }
+
+ // T is the annealing temperature, if desired
+ size_t SelectSample(const SampleSet& ss, double T = 1.0);
+
+ // draw a value from U(0,1)
+ double next() {return m_random();}
+
+ // draw a value from N(mean,var)
+ double NextNormal(double mean, double var) {
+ return boost::normal_distribution<double>(mean, var)(m_random);
+ }
+
+ // draw a value from a Poisson distribution
+ // lambda must be greater than 0
+ int NextPoisson(int lambda) {
+ return boost::poisson_distribution<int>(lambda)(m_random);
+ }
+
+ bool AcceptMetropolisHastings(const prob_t& p_cur,
+ const prob_t& p_prev,
+ const prob_t& q_cur,
+ const prob_t& q_prev) {
+ const prob_t a = (p_cur / p_prev) * (q_prev / q_cur);
+ if (log(a) >= 0.0) return true;
+ return (prob_t(this->next()) < a);
+ }
+
+ private:
+ boost::uniform_real<> m_dist;
+ RNG m_generator;
+ boost::variate_generator<RNG&, boost::uniform_real<> > m_random;
+};
+
+typedef RandomNumberGenerator<boost::mt19937> MT19937;
+
+class SampleSet {
+ public:
+ const prob_t& operator[](int i) const { return m_scores[i]; }
+ bool empty() const { return m_scores.empty(); }
+ void add(const prob_t& s) { m_scores.push_back(s); }
+ void clear() { m_scores.clear(); }
+ size_t size() const { return m_scores.size(); }
+ std::vector<prob_t> m_scores;
+};
+
+template <typename RNG>
+size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet& ss, double T) {
+ assert(T > 0.0);
+ assert(ss.m_scores.size() > 0);
+ if (ss.m_scores.size() == 1) return 0;
+ const prob_t annealing_factor(1.0 / T);
+ const bool anneal = (annealing_factor != prob_t::One());
+ prob_t sum = prob_t::Zero();
+ if (anneal) {
+ for (int i = 0; i < ss.m_scores.size(); ++i)
+ sum += ss.m_scores[i].pow(annealing_factor); // p^(1/T)
+ } else {
+ sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), prob_t::Zero());
+ }
+ //for (size_t i = 0; i < ss.m_scores.size(); ++i) std::cerr << ss.m_scores[i] << ",";
+ //std::cerr << std::endl;
+
+ prob_t random(this->next()); // random number between 0 and 1
+ random *= sum; // scale with normalization factor
+ //std::cerr << "Random number " << random << std::endl;
+
+ //now figure out which sample
+ size_t position = 1;
+ sum = ss.m_scores[0];
+ if (anneal) {
+ sum.poweq(annealing_factor);
+ for (; position < ss.m_scores.size() && sum < random; ++position)
+ sum += ss.m_scores[position].pow(annealing_factor);
+ } else {
+ for (; position < ss.m_scores.size() && sum < random; ++position)
+ sum += ss.m_scores[position];
+ }
+ //std::cout << "random: " << random << " sample: " << position << std::endl;
+ //std::cerr << "Sample: " << position-1 << std::endl;
+ //exit(1);
+ return position-1;
+}
+
+#endif
diff --git a/src/scfg_translator.cc b/src/scfg_translator.cc
new file mode 100644
index 00000000..03602c6b
--- /dev/null
+++ b/src/scfg_translator.cc
@@ -0,0 +1,66 @@
+#include "translator.h"
+
+#include <vector>
+
+#include "hg.h"
+#include "grammar.h"
+#include "bottom_up_parser.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+
+Translator::~Translator() {}
+
+struct SCFGTranslatorImpl {
+ SCFGTranslatorImpl(const boost::program_options::variables_map& conf) :
+ max_span_limit(conf["scfg_max_span_limit"].as<int>()),
+ add_pass_through_rules(conf.count("add_pass_through_rules")),
+ goal(conf["goal"].as<string>()),
+ default_nt(conf["scfg_default_nt"].as<string>()) {
+ vector<string> gfiles = conf["grammar"].as<vector<string> >();
+ for (int i = 0; i < gfiles.size(); ++i) {
+ cerr << "Reading SCFG grammar from " << gfiles[i] << endl;
+ TextGrammar* g = new TextGrammar(gfiles[i]);
+ g->SetMaxSpan(max_span_limit);
+ grammars.push_back(GrammarPtr(g));
+ }
+ if (!conf.count("scfg_no_hiero_glue_grammar"))
+ grammars.push_back(GrammarPtr(new GlueGrammar(goal, default_nt)));
+ if (conf.count("scfg_extra_glue_grammar"))
+ grammars.push_back(GrammarPtr(new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>())));
+ }
+
+ const int max_span_limit;
+ const bool add_pass_through_rules;
+ const string goal;
+ const string default_nt;
+ vector<GrammarPtr> grammars;
+
+ bool Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* forest) {
+ vector<GrammarPtr> glist = grammars;
+ Lattice lattice;
+ LatticeTools::ConvertTextOrPLF(input, &lattice);
+ smeta->SetSourceLength(lattice.size());
+ if (add_pass_through_rules)
+ glist.push_back(GrammarPtr(new PassThroughGrammar(lattice, default_nt)));
+ ExhaustiveBottomUpParser parser(goal, glist);
+ if (!parser.Parse(lattice, forest))
+ return false;
+ forest->Reweight(weights);
+ return true;
+ }
+};
+
+SCFGTranslator::SCFGTranslator(const boost::program_options::variables_map& conf) :
+ pimpl_(new SCFGTranslatorImpl(conf)) {}
+
+bool SCFGTranslator::Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ return pimpl_->Translate(input, smeta, weights, minus_lm_forest);
+}
+
diff --git a/src/sentence_metadata.h b/src/sentence_metadata.h
new file mode 100644
index 00000000..0178f1f5
--- /dev/null
+++ b/src/sentence_metadata.h
@@ -0,0 +1,42 @@
+#ifndef _SENTENCE_METADATA_H_
+#define _SENTENCE_METADATA_H_
+
+#include <cassert>
+#include "lattice.h"
+
+struct SentenceMetadata {
+ SentenceMetadata(int id, const Lattice& ref) :
+ sent_id_(id),
+ src_len_(-1),
+ has_reference_(ref.size() > 0),
+ trg_len_(ref.size()),
+ ref_(has_reference_ ? &ref : NULL) {}
+
+ // this should be called by the Translator object after
+ // it has parsed the source
+ void SetSourceLength(int sl) { src_len_ = sl; }
+
+ // this should be called if a separate model needs to
+ // specify how long the target sentence should be
+ void SetTargetLength(int tl) {
+ assert(!has_reference_);
+ trg_len_ = tl;
+ }
+ bool HasReference() const { return has_reference_; }
+ const Lattice& GetReference() const { return *ref_; }
+ int GetSourceLength() const { return src_len_; }
+ int GetTargetLength() const { return trg_len_; }
+ int GetSentenceID() const { return sent_id_; }
+
+ private:
+ const int sent_id_;
+ int src_len_;
+
+ // you need to be very careful when depending on these values
+ // they will only be set during training / alignment contexts
+ const bool has_reference_;
+ int trg_len_;
+ const Lattice* const ref_;
+};
+
+#endif
diff --git a/src/small_vector.h b/src/small_vector.h
new file mode 100644
index 00000000..800c1df1
--- /dev/null
+++ b/src/small_vector.h
@@ -0,0 +1,187 @@
+#ifndef _SMALL_VECTOR_H_
+
+#include <streambuf> // std::max - where to get this?
+#include <cstring>
+#include <cassert>
+
+#define __SV_MAX_STATIC 2
+
+class SmallVector {
+
+ public:
+ SmallVector() : size_(0) {}
+
+ explicit SmallVector(size_t s, int v = 0) : size_(s) {
+ assert(s < 0x80);
+ if (s <= __SV_MAX_STATIC) {
+ for (int i = 0; i < s; ++i) data_.vals[i] = v;
+ } else {
+ capacity_ = s;
+ size_ = s;
+ data_.ptr = new int[s];
+ for (int i = 0; i < size_; ++i) data_.ptr[i] = v;
+ }
+ }
+
+ SmallVector(const SmallVector& o) : size_(o.size_) {
+ if (size_ <= __SV_MAX_STATIC) {
+ for (int i = 0; i < __SV_MAX_STATIC; ++i) data_.vals[i] = o.data_.vals[i];
+ } else {
+ capacity_ = size_ = o.size_;
+ data_.ptr = new int[capacity_];
+ std::memcpy(data_.ptr, o.data_.ptr, size_ * sizeof(int));
+ }
+ }
+
+ const SmallVector& operator=(const SmallVector& o) {
+ if (size_ <= __SV_MAX_STATIC) {
+ if (o.size_ <= __SV_MAX_STATIC) {
+ size_ = o.size_;
+ for (int i = 0; i < __SV_MAX_STATIC; ++i) data_.vals[i] = o.data_.vals[i];
+ } else {
+ capacity_ = size_ = o.size_;
+ data_.ptr = new int[capacity_];
+ std::memcpy(data_.ptr, o.data_.ptr, size_ * sizeof(int));
+ }
+ } else {
+ if (o.size_ <= __SV_MAX_STATIC) {
+ delete[] data_.ptr;
+ size_ = o.size_;
+ for (int i = 0; i < size_; ++i) data_.vals[i] = o.data_.vals[i];
+ } else {
+ if (capacity_ < o.size_) {
+ delete[] data_.ptr;
+ capacity_ = o.size_;
+ data_.ptr = new int[capacity_];
+ }
+ size_ = o.size_;
+ for (int i = 0; i < size_; ++i)
+ data_.ptr[i] = o.data_.ptr[i];
+ }
+ }
+ return *this;
+ }
+
+ ~SmallVector() {
+ if (size_ <= __SV_MAX_STATIC) return;
+ delete[] data_.ptr;
+ }
+
+ void clear() {
+ if (size_ > __SV_MAX_STATIC) {
+ delete[] data_.ptr;
+ }
+ size_ = 0;
+ }
+
+ bool empty() const { return size_ == 0; }
+ size_t size() const { return size_; }
+
+ inline void ensure_capacity(unsigned char min_size) {
+ assert(min_size > __SV_MAX_STATIC);
+ if (min_size < capacity_) return;
+ unsigned char new_cap = std::max(static_cast<unsigned char>(capacity_ << 1), min_size);
+ int* tmp = new int[new_cap];
+ std::memcpy(tmp, data_.ptr, capacity_ * sizeof(int));
+ delete[] data_.ptr;
+ data_.ptr = tmp;
+ capacity_ = new_cap;
+ }
+
+ inline void copy_vals_to_ptr() {
+ capacity_ = __SV_MAX_STATIC * 2;
+ int* tmp = new int[capacity_];
+ for (int i = 0; i < __SV_MAX_STATIC; ++i) tmp[i] = data_.vals[i];
+ data_.ptr = tmp;
+ }
+
+ inline void push_back(int v) {
+ if (size_ < __SV_MAX_STATIC) {
+ data_.vals[size_] = v;
+ ++size_;
+ return;
+ } else if (size_ == __SV_MAX_STATIC) {
+ copy_vals_to_ptr();
+ } else if (size_ == capacity_) {
+ ensure_capacity(size_ + 1);
+ }
+ data_.ptr[size_] = v;
+ ++size_;
+ }
+
+ int& back() { return this->operator[](size_ - 1); }
+ const int& back() const { return this->operator[](size_ - 1); }
+ int& front() { return this->operator[](0); }
+ const int& front() const { return this->operator[](0); }
+
+ void resize(size_t s, int v = 0) {
+ if (s <= __SV_MAX_STATIC) {
+ if (size_ > __SV_MAX_STATIC) {
+ int tmp[__SV_MAX_STATIC];
+ for (int i = 0; i < s; ++i) tmp[i] = data_.ptr[i];
+ delete[] data_.ptr;
+ for (int i = 0; i < s; ++i) data_.vals[i] = tmp[i];
+ size_ = s;
+ return;
+ }
+ if (s <= size_) {
+ size_ = s;
+ return;
+ } else {
+ for (int i = size_; i < s; ++i)
+ data_.vals[i] = v;
+ size_ = s;
+ return;
+ }
+ } else {
+ if (size_ <= __SV_MAX_STATIC)
+ copy_vals_to_ptr();
+ if (s > capacity_)
+ ensure_capacity(s);
+ if (s > size_) {
+ for (int i = size_; i < s; ++i)
+ data_.ptr[i] = v;
+ }
+ size_ = s;
+ }
+ }
+
+ int& operator[](size_t i) {
+ if (size_ <= __SV_MAX_STATIC) return data_.vals[i];
+ return data_.ptr[i];
+ }
+
+ const int& operator[](size_t i) const {
+ if (size_ <= __SV_MAX_STATIC) return data_.vals[i];
+ return data_.ptr[i];
+ }
+
+ bool operator==(const SmallVector& o) const {
+ if (size_ != o.size_) return false;
+ if (size_ <= __SV_MAX_STATIC) {
+ for (size_t i = 0; i < size_; ++i)
+ if (data_.vals[i] != o.data_.vals[i]) return false;
+ return true;
+ } else {
+ for (size_t i = 0; i < size_; ++i)
+ if (data_.ptr[i] != o.data_.ptr[i]) return false;
+ return true;
+ }
+ }
+
+ private:
+ unsigned char capacity_; // only defined when size_ >= __SV_MAX_STATIC
+ unsigned char size_;
+ union StorageType {
+ int vals[__SV_MAX_STATIC];
+ int* ptr;
+ };
+ StorageType data_;
+
+};
+
+inline bool operator!=(const SmallVector& a, const SmallVector& b) {
+ return !(a==b);
+}
+
+#endif
diff --git a/src/small_vector_test.cc b/src/small_vector_test.cc
new file mode 100644
index 00000000..84237791
--- /dev/null
+++ b/src/small_vector_test.cc
@@ -0,0 +1,129 @@
+#include "small_vector.h"
+
+#include <gtest/gtest.h>
+#include <iostream>
+#include <cassert>
+#include <vector>
+
+using namespace std;
+
+class SVTest : public testing::Test {
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+};
+
+TEST_F(SVTest, LargerThan2) {
+ SmallVector v;
+ SmallVector v2;
+ v.push_back(0);
+ v.push_back(1);
+ v.push_back(2);
+ assert(v.size() == 3);
+ assert(v[2] == 2);
+ assert(v[1] == 1);
+ assert(v[0] == 0);
+ v2 = v;
+ SmallVector copy(v);
+ assert(copy.size() == 3);
+ assert(copy[0] == 0);
+ assert(copy[1] == 1);
+ assert(copy[2] == 2);
+ assert(copy == v2);
+ copy[1] = 99;
+ assert(copy != v2);
+ assert(v2.size() == 3);
+ assert(v2[2] == 2);
+ assert(v2[1] == 1);
+ assert(v2[0] == 0);
+ v2[0] = -2;
+ v2[1] = -1;
+ v2[2] = 0;
+ assert(v2[2] == 0);
+ assert(v2[1] == -1);
+ assert(v2[0] == -2);
+ SmallVector v3(1,1);
+ assert(v3[0] == 1);
+ v2 = v3;
+ assert(v2.size() == 1);
+ assert(v2[0] == 1);
+ SmallVector v4(10, 1);
+ assert(v4.size() == 10);
+ assert(v4[5] == 1);
+ assert(v4[9] == 1);
+ v4 = v;
+ assert(v4.size() == 3);
+ assert(v4[2] == 2);
+ assert(v4[1] == 1);
+ assert(v4[0] == 0);
+ SmallVector v5(10, 2);
+ assert(v5.size() == 10);
+ assert(v5[7] == 2);
+ assert(v5[0] == 2);
+ assert(v.size() == 3);
+ v = v5;
+ assert(v.size() == 10);
+ assert(v[2] == 2);
+ assert(v[9] == 2);
+ SmallVector cc;
+ for (int i = 0; i < 33; ++i)
+ cc.push_back(i);
+ for (int i = 0; i < 33; ++i)
+ assert(cc[i] == i);
+ cc.resize(20);
+ assert(cc.size() == 20);
+ for (int i = 0; i < 20; ++i)
+ assert(cc[i] == i);
+ cc[0]=-1;
+ cc.resize(1, 999);
+ assert(cc.size() == 1);
+ assert(cc[0] == -1);
+ cc.resize(99, 99);
+ for (int i = 1; i < 99; ++i) {
+ cerr << i << " " << cc[i] << endl;
+ assert(cc[i] == 99);
+ }
+ cc.clear();
+ assert(cc.size() == 0);
+}
+
+TEST_F(SVTest, Small) {
+ SmallVector v;
+ SmallVector v1(1,0);
+ SmallVector v2(2,10);
+ SmallVector v1a(2,0);
+ EXPECT_TRUE(v1 != v1a);
+ EXPECT_TRUE(v1 == v1);
+ EXPECT_EQ(v1[0], 0);
+ EXPECT_EQ(v2[1], 10);
+ EXPECT_EQ(v2[0], 10);
+ ++v2[1];
+ --v2[0];
+ EXPECT_EQ(v2[0], 9);
+ EXPECT_EQ(v2[1], 11);
+ SmallVector v3(v2);
+ assert(v3[0] == 9);
+ assert(v3[1] == 11);
+ assert(!v3.empty());
+ assert(v3.size() == 2);
+ v3.clear();
+ assert(v3.empty());
+ assert(v3.size() == 0);
+ assert(v3 != v2);
+ assert(v2 != v3);
+ v3 = v2;
+ assert(v3 == v2);
+ assert(v2 == v3);
+ assert(v3[0] == 9);
+ assert(v3[1] == 11);
+ assert(!v3.empty());
+ assert(v3.size() == 2);
+ cerr << sizeof(SmallVector) << endl;
+ cerr << sizeof(vector<int>) << endl;
+}
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
diff --git a/src/sparse_vector.cc b/src/sparse_vector.cc
new file mode 100644
index 00000000..4035b9ef
--- /dev/null
+++ b/src/sparse_vector.cc
@@ -0,0 +1,98 @@
+#include "sparse_vector.h"
+
+#include <iostream>
+#include <cstring>
+
+#include "hg_io.h"
+
+using namespace std;
+
+namespace B64 {
+
+void Encode(double objective, const SparseVector<double>& v, ostream* out) {
+ const int num_feats = v.num_active();
+ size_t tot_size = 0;
+ const size_t off_objective = tot_size;
+ tot_size += sizeof(double); // objective
+ const size_t off_num_feats = tot_size;
+ tot_size += sizeof(int); // num_feats
+ const size_t off_data = tot_size;
+ tot_size += sizeof(unsigned char) * num_feats; // lengths of feature names;
+ typedef SparseVector<double>::const_iterator const_iterator;
+ for (const_iterator it = v.begin(); it != v.end(); ++it)
+ tot_size += FD::Convert(it->first).size(); // feature names;
+ tot_size += sizeof(double) * num_feats; // gradient
+ const size_t off_magic = tot_size;
+ tot_size += 4; // magic
+
+ // size_t b64_size = tot_size * 4 / 3;
+ // cerr << "Sparse vector binary size: " << tot_size << " (b64 size=" << b64_size << ")\n";
+ char* data = new char[tot_size];
+ *reinterpret_cast<double*>(&data[off_objective]) = objective;
+ *reinterpret_cast<int*>(&data[off_num_feats]) = num_feats;
+ char* cur = &data[off_data];
+ assert(cur - data == off_data);
+ for (const_iterator it = v.begin(); it != v.end(); ++it) {
+ const string& fname = FD::Convert(it->first);
+ *cur++ = static_cast<char>(fname.size()); // name len
+ memcpy(cur, &fname[0], fname.size());
+ cur += fname.size();
+ *reinterpret_cast<double*>(cur) = it->second;
+ cur += sizeof(double);
+ }
+ assert(cur - data == off_magic);
+ *reinterpret_cast<unsigned int*>(cur) = 0xBAABABBAu;
+ cur += sizeof(unsigned int);
+ assert(cur - data == tot_size);
+ b64encode(data, tot_size, out);
+ delete[] data;
+}
+
+bool Decode(double* objective, SparseVector<double>* v, const char* in, size_t size) {
+ v->clear();
+ if (size % 4 != 0) {
+ cerr << "B64 error - line % 4 != 0\n";
+ return false;
+ }
+ const size_t decoded_size = size * 3 / 4 - sizeof(unsigned int);
+ const size_t buf_size = decoded_size + sizeof(unsigned int);
+ if (decoded_size < 6) { cerr << "SparseVector decoding error: too short!\n"; return false; }
+ char* data = new char[buf_size];
+ if (!b64decode(reinterpret_cast<const unsigned char*>(in), size, data, buf_size)) {
+ delete[] data;
+ return false;
+ }
+ size_t cur = 0;
+ *objective = *reinterpret_cast<double*>(data);
+ cur += sizeof(double);
+ const int num_feats = *reinterpret_cast<int*>(&data[cur]);
+ cur += sizeof(int);
+ int fc = 0;
+ while(fc < num_feats && cur < decoded_size) {
+ ++fc;
+ const int fname_len = data[cur++];
+ assert(fname_len > 0);
+ assert(fname_len < 256);
+ string fname(fname_len, '\0');
+ memcpy(&fname[0], &data[cur], fname_len);
+ cur += fname_len;
+ const double val = *reinterpret_cast<double*>(&data[cur]);
+ cur += sizeof(double);
+ int fid = FD::Convert(fname);
+ v->set_value(fid, val);
+ }
+ if(num_feats != fc) {
+ cerr << "Expected " << num_feats << " but only decoded " << fc << "!\n";
+ delete[] data;
+ return false;
+ }
+ if (*reinterpret_cast<unsigned int*>(&data[cur]) != 0xBAABABBAu) {
+ cerr << "SparseVector decodeding error : magic does not match!\n";
+ delete[] data;
+ return false;
+ }
+ delete[] data;
+ return true;
+}
+
+}
diff --git a/src/sparse_vector.h b/src/sparse_vector.h
new file mode 100644
index 00000000..6a8c9bf4
--- /dev/null
+++ b/src/sparse_vector.h
@@ -0,0 +1,264 @@
+#ifndef _SPARSE_VECTOR_H_
+#define _SPARSE_VECTOR_H_
+
+// this is a modified version of code originally written
+// by Phil Blunsom
+
+#include <iostream>
+#include <map>
+#include <vector>
+#include <valarray>
+
+#include "fdict.h"
+
+template <typename T>
+class SparseVector {
+public:
+ SparseVector() {}
+
+ const T operator[](int index) const {
+ typename std::map<int, T>::const_iterator found = _values.find(index);
+ if (found == _values.end())
+ return T(0);
+ else
+ return found->second;
+ }
+
+ void set_value(int index, const T &value) {
+ _values[index] = value;
+ }
+
+ void add_value(int index, const T &value) {
+ _values[index] += value;
+ }
+
+ T value(int index) const {
+ typename std::map<int, T>::const_iterator found = _values.find(index);
+ if (found != _values.end())
+ return found->second;
+ else
+ return T(0);
+ }
+
+ void store(std::valarray<T>* target) const {
+ (*target) *= 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it) {
+ if (it->first >= target->size()) break;
+ (*target)[it->first] = it->second;
+ }
+ }
+
+ int max_index() const {
+ if (_values.empty()) return 0;
+ typename std::map<int, T>::const_iterator found =_values.end();
+ --found;
+ return found->first;
+ }
+
+ // dot product with a unit vector of the same length
+ // as the sparse vector
+ T dot() const {
+ T sum = 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ sum += it->second;
+ return sum;
+ }
+
+ template<typename S>
+ S dot(const SparseVector<S> &vec) const {
+ S sum = 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ {
+ typename std::map<int, T>::const_iterator
+ found = vec._values.find(it->first);
+ if (found != vec._values.end())
+ sum += it->second * found->second;
+ }
+ return sum;
+ }
+
+ template<typename S>
+ S dot(const std::vector<S> &vec) const {
+ S sum = 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ {
+ if (it->first < static_cast<int>(vec.size()))
+ sum += it->second * vec[it->first];
+ }
+ return sum;
+ }
+
+ template<typename S>
+ S dot(const S *vec) const {
+ // this is not range checked!
+ S sum = 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ sum += it->second * vec[it->first];
+ std::cout << "dot(*vec) " << sum << std::endl;
+ return sum;
+ }
+
+ T l1norm() const {
+ T sum = 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ sum += fabs(it->second);
+ return sum;
+ }
+
+ T l2norm() const {
+ T sum = 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ sum += it->second * it->second;
+ return sqrt(sum);
+ }
+
+ SparseVector<T> &operator+=(const SparseVector<T> &other) {
+ for (typename std::map<int, T>::const_iterator
+ it = other._values.begin(); it != other._values.end(); ++it)
+ {
+ T v = (_values[it->first] += it->second);
+ if (v == 0)
+ _values.erase(it->first);
+ }
+ return *this;
+ }
+
+ SparseVector<T> &operator-=(const SparseVector<T> &other) {
+ for (typename std::map<int, T>::const_iterator
+ it = other._values.begin(); it != other._values.end(); ++it)
+ {
+ T v = (_values[it->first] -= it->second);
+ if (v == 0)
+ _values.erase(it->first);
+ }
+ return *this;
+ }
+
+ SparseVector<T> &operator-=(const double &x) {
+ for (typename std::map<int, T>::iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ it->second -= x;
+ return *this;
+ }
+
+ SparseVector<T> &operator+=(const double &x) {
+ for (typename std::map<int, T>::iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ it->second += x;
+ return *this;
+ }
+
+ SparseVector<T> &operator/=(const double &x) {
+ for (typename std::map<int, T>::iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ it->second /= x;
+ return *this;
+ }
+
+ SparseVector<T> &operator*=(const T& x) {
+ for (typename std::map<int, T>::iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ it->second *= x;
+ return *this;
+ }
+
+ SparseVector<T> operator+(const double &x) const {
+ SparseVector<T> result = *this;
+ return result += x;
+ }
+
+ SparseVector<T> operator-(const double &x) const {
+ SparseVector<T> result = *this;
+ return result -= x;
+ }
+
+ SparseVector<T> operator/(const double &x) const {
+ SparseVector<T> result = *this;
+ return result /= x;
+ }
+
+ std::ostream &operator<<(std::ostream &out) const {
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ out << (it == _values.begin() ? "" : ";")
+ << FD::Convert(it->first) << '=' << it->second;
+ return out;
+ }
+
+ bool operator<(const SparseVector<T> &other) const {
+ typename std::map<int, T>::const_iterator it = _values.begin();
+ typename std::map<int, T>::const_iterator other_it = other._values.begin();
+
+ for (; it != _values.end() && other_it != other._values.end(); ++it, ++other_it)
+ {
+ if (it->first < other_it->first) return true;
+ if (it->first > other_it->first) return false;
+ if (it->second < other_it->second) return true;
+ if (it->second > other_it->second) return false;
+ }
+ return _values.size() < other._values.size();
+ }
+
+ int num_active() const { return _values.size(); }
+ bool empty() const { return _values.empty(); }
+
+ typedef typename std::map<int, T>::const_iterator const_iterator;
+ const_iterator begin() const { return _values.begin(); }
+ const_iterator end() const { return _values.end(); }
+
+ void clear() {
+ _values.clear();
+ }
+
+ void swap(SparseVector<T>& other) {
+ _values.swap(other._values);
+ }
+
+private:
+ std::map<int, T> _values;
+};
+
+template <typename T>
+SparseVector<T> operator+(const SparseVector<T>& a, const SparseVector<T>& b) {
+ SparseVector<T> result = a;
+ return result += b;
+}
+
+template <typename T>
+SparseVector<T> operator*(const SparseVector<T>& a, const double& b) {
+ SparseVector<T> result = a;
+ return result *= b;
+}
+
+template <typename T>
+SparseVector<T> operator*(const SparseVector<T>& a, const T& b) {
+ SparseVector<T> result = a;
+ return result *= b;
+}
+
+template <typename T>
+SparseVector<T> operator*(const double& a, const SparseVector<T>& b) {
+ SparseVector<T> result = b;
+ return result *= a;
+}
+
+template <typename T>
+std::ostream &operator<<(std::ostream &out, const SparseVector<T> &vec)
+{
+ return vec.operator<<(out);
+}
+
+namespace B64 {
+ void Encode(double objective, const SparseVector<double>& v, std::ostream* out);
+ // returns false if failed to decode
+ bool Decode(double* objective, SparseVector<double>* v, const char* data, size_t size);
+}
+
+#endif
diff --git a/src/stringlib.cc b/src/stringlib.cc
new file mode 100644
index 00000000..3ed74bef
--- /dev/null
+++ b/src/stringlib.cc
@@ -0,0 +1,97 @@
+#include "stringlib.h"
+
+#include <cstdlib>
+#include <cassert>
+#include <iostream>
+#include <map>
+
+#include "lattice.h"
+
+using namespace std;
+
+void ParseTranslatorInput(const string& line, string* input, string* ref) {
+ size_t hint = 0;
+ if (line.find("{\"rules\":") == 0) {
+ hint = line.find("}}");
+ if (hint == string::npos) {
+ cerr << "Syntax error: " << line << endl;
+ abort();
+ }
+ hint += 2;
+ }
+ size_t pos = line.find("|||", hint);
+ if (pos == string::npos) { *input = line; return; }
+ ref->clear();
+ *input = line.substr(0, pos - 1);
+ string rline = line.substr(pos + 4);
+ if (rline.size() > 0) {
+ assert(ref);
+ *ref = rline;
+ }
+}
+
+void ParseTranslatorInputLattice(const string& line, string* input, Lattice* ref) {
+ string sref;
+ ParseTranslatorInput(line, input, &sref);
+ if (sref.size() > 0) {
+ assert(ref);
+ LatticeTools::ConvertTextOrPLF(sref, ref);
+ }
+}
+
+void ProcessAndStripSGML(string* pline, map<string, string>* out) {
+ map<string, string>& meta = *out;
+ string& line = *pline;
+ string lline = LowercaseString(line);
+ if (lline.find("<seg")!=0) return;
+ size_t close = lline.find(">");
+ if (close == string::npos) return; // error
+ size_t end = lline.find("</seg>");
+ string seg = Trim(lline.substr(4, close-4));
+ string text = line.substr(close+1, end - close - 1);
+ for (size_t i = 1; i < seg.size(); i++) {
+ if (seg[i] == '=' && seg[i-1] == ' ') {
+ string less = seg.substr(0, i-1) + seg.substr(i);
+ seg = less; i = 0; continue;
+ }
+ if (seg[i] == '=' && seg[i+1] == ' ') {
+ string less = seg.substr(0, i+1);
+ if (i+2 < seg.size()) less += seg.substr(i+2);
+ seg = less; i = 0; continue;
+ }
+ }
+ line = Trim(text);
+ if (seg == "") return;
+ for (size_t i = 1; i < seg.size(); i++) {
+ if (seg[i] == '=') {
+ string label = seg.substr(0, i);
+ string val = seg.substr(i+1);
+ if (val[0] == '"') {
+ val = val.substr(1);
+ size_t close = val.find('"');
+ if (close == string::npos) {
+ cerr << "SGML parse error: missing \"\n";
+ seg = "";
+ i = 0;
+ } else {
+ seg = val.substr(close+1);
+ val = val.substr(0, close);
+ i = 0;
+ }
+ } else {
+ size_t close = val.find(' ');
+ if (close == string::npos) {
+ seg = "";
+ i = 0;
+ } else {
+ seg = val.substr(close+1);
+ val = val.substr(0, close);
+ }
+ }
+ label = Trim(label);
+ seg = Trim(seg);
+ meta[label] = val;
+ }
+ }
+}
+
diff --git a/src/stringlib.h b/src/stringlib.h
new file mode 100644
index 00000000..d26952c7
--- /dev/null
+++ b/src/stringlib.h
@@ -0,0 +1,91 @@
+#ifndef _STRINGLIB_H_
+
+#include <map>
+#include <vector>
+#include <cctype>
+#include <string>
+
+// read line in the form of either:
+// source
+// source ||| target
+// source will be returned as a string, target must be a sentence or
+// a lattice (in PLF format) and will be returned as a Lattice object
+void ParseTranslatorInput(const std::string& line, std::string* input, std::string* ref);
+struct Lattice;
+void ParseTranslatorInputLattice(const std::string& line, std::string* input, Lattice* ref);
+
+inline const std::string Trim(const std::string& str, const std::string& dropChars = " \t") {
+ std::string res = str;
+ res.erase(str.find_last_not_of(dropChars)+1);
+ return res.erase(0, res.find_first_not_of(dropChars));
+}
+
+inline void Tokenize(const std::string& str, char delimiter, std::vector<std::string>* res) {
+ std::string s = str;
+ int last = 0;
+ res->clear();
+ for (int i=0; i < s.size(); ++i)
+ if (s[i] == delimiter) {
+ s[i]=0;
+ if (last != i) {
+ res->push_back(&s[last]);
+ }
+ last = i + 1;
+ }
+ if (last != s.size())
+ res->push_back(&s[last]);
+}
+
+inline std::string LowercaseString(const std::string& in) {
+ std::string res(in.size(),' ');
+ for (int i = 0; i < in.size(); ++i)
+ res[i] = tolower(in[i]);
+ return res;
+}
+
+inline int CountSubstrings(const std::string& str, const std::string& sub) {
+ size_t p = 0;
+ int res = 0;
+ while (p < str.size()) {
+ p = str.find(sub, p);
+ if (p == std::string::npos) break;
+ ++res;
+ p += sub.size();
+ }
+ return res;
+}
+
+inline int SplitOnWhitespace(const std::string& in, std::vector<std::string>* out) {
+ out->clear();
+ int i = 0;
+ int start = 0;
+ std::string cur;
+ while(i < in.size()) {
+ if (in[i] == ' ' || in[i] == '\t') {
+ if (i - start > 0)
+ out->push_back(in.substr(start, i - start));
+ start = i + 1;
+ }
+ ++i;
+ }
+ if (i > start)
+ out->push_back(in.substr(start, i - start));
+ return out->size();
+}
+
+inline void SplitCommandAndParam(const std::string& in, std::string* cmd, std::string* param) {
+ cmd->clear();
+ param->clear();
+ std::vector<std::string> x;
+ SplitOnWhitespace(in, &x);
+ if (x.size() == 0) return;
+ *cmd = x[0];
+ for (int i = 1; i < x.size(); ++i) {
+ if (i > 1) { *param += " "; }
+ *param += x[i];
+ }
+}
+
+void ProcessAndStripSGML(std::string* line, std::map<std::string, std::string>* out);
+
+#endif
diff --git a/src/synparse.cc b/src/synparse.cc
new file mode 100644
index 00000000..96588f1e
--- /dev/null
+++ b/src/synparse.cc
@@ -0,0 +1,212 @@
+#include <iostream>
+#include <ext/hash_map>
+#include <ext/hash_set>
+#include <utility>
+
+#include <boost/multi_array.hpp>
+#include <boost/functional/hash.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "prob.h"
+#include "tdict.h"
+#include "filelib.h"
+
+using namespace std;
+using namespace __gnu_cxx;
+namespace po = boost::program_options;
+
+const prob_t kMONO(1.0); // 0.6
+const prob_t kINV(1.0); // 0.1
+const prob_t kLEX(1.0); // 0.3
+
+typedef hash_map<vector<WordID>, hash_map<vector<WordID>, prob_t, boost::hash<vector<WordID> > >, boost::hash<vector<WordID> > > PTable;
+typedef boost::multi_array<prob_t, 4> CChart;
+typedef pair<int,int> SpanType;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("phrasetable,p",po::value<string>(), "[REQD] Phrase pairs for ITG alignment")
+ ("input,i",po::value<string>()->default_value("-"), "Input file")
+ ("help,h", "Help");
+ po::options_description dcmdline_options;
+ dcmdline_options.add(opts);
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ bool flag = false;
+ if (!conf->count("phrasetable")) {
+ cerr << "Please specify a grammar file with -p <GRAMMAR.TXT>\n";
+ flag = true;
+ }
+ if (flag || conf->count("help")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+void LoadITGPhrasetable(const string& fname, PTable* ptable) {
+ const WordID sep = TD::Convert("|||");
+ ReadFile rf(fname);
+ istream& in = *rf.stream();
+ assert(in);
+ int lc = 0;
+ while(in) {
+ string line;
+ getline(in, line);
+ if (line.empty()) continue;
+ ++lc;
+ vector<WordID> full, f, e;
+ TD::ConvertSentence(line, &full);
+ int i = 0;
+ for (; i < full.size(); ++i) {
+ if (full[i] == sep) break;
+ f.push_back(full[i]);
+ }
+ ++i;
+ for (; i < full.size(); ++i) {
+ if (full[i] == sep) break;
+ e.push_back(full[i]);
+ }
+ ++i;
+ prob_t prob(0.000001);
+ if (i < full.size()) { prob = prob_t(atof(TD::Convert(full[i]))); ++i; }
+
+ if (i < full.size()) { cerr << "Warning line " << lc << " has extra stuff.\n"; }
+ assert(f.size() > 0);
+ assert(e.size() > 0);
+ (*ptable)[f][e] = prob;
+ }
+ cerr << "Read " << lc << " phrase pairs\n";
+}
+
+void FindPhrases(const vector<WordID>& e, const vector<WordID>& f, const PTable& pt, CChart* pcc) {
+ CChart& cc = *pcc;
+ const size_t n = f.size();
+ const size_t m = e.size();
+ typedef hash_map<vector<WordID>, vector<SpanType>, boost::hash<vector<WordID> > > PhraseToSpan;
+ PhraseToSpan e_locations;
+ for (int i = 0; i < m; ++i) {
+ const int mel = m - i;
+ vector<WordID> e_phrase;
+ for (int el = 0; el < mel; ++el) {
+ e_phrase.push_back(e[i + el]);
+ e_locations[e_phrase].push_back(make_pair(i, i + el + 1));
+ }
+ }
+ //cerr << "Cached the locations of " << e_locations.size() << " e-phrases\n";
+
+ for (int s = 0; s < n; ++s) {
+ const int mfl = n - s;
+ vector<WordID> f_phrase;
+ for (int fl = 0; fl < mfl; ++fl) {
+ f_phrase.push_back(f[s + fl]);
+ PTable::const_iterator it = pt.find(f_phrase);
+ if (it == pt.end()) continue;
+ const hash_map<vector<WordID>, prob_t, boost::hash<vector<WordID> > >& es = it->second;
+ for (hash_map<vector<WordID>, prob_t, boost::hash<vector<WordID> > >::const_iterator eit = es.begin(); eit != es.end(); ++eit) {
+ PhraseToSpan::iterator loc = e_locations.find(eit->first);
+ if (loc == e_locations.end()) continue;
+ const vector<SpanType>& espans = loc->second;
+ for (int j = 0; j < espans.size(); ++j) {
+ cc[s][s + fl + 1][espans[j].first][espans[j].second] = eit->second;
+ //cerr << '[' << s << ',' << (s + fl + 1) << ',' << espans[j].first << ',' << espans[j].second << "] is C\n";
+ }
+ }
+ }
+ }
+}
+
+long long int evals = 0;
+
+void ProcessSynchronousCell(const int s,
+ const int t,
+ const int u,
+ const int v,
+ const prob_t& lex,
+ const prob_t& mono,
+ const prob_t& inv,
+ const CChart& tc, CChart* ntc) {
+ prob_t& inside = (*ntc)[s][t][u][v];
+ // cerr << log(tc[s][t][u][v]) << " + " << log(lex) << endl;
+ inside = tc[s][t][u][v] * lex;
+ // cerr << " terminal span: " << log(inside) << endl;
+ if (t - s == 1) return;
+ if (v - u == 1) return;
+ for (int x = s+1; x < t; ++x) {
+ for (int y = u+1; y < v; ++y) {
+ const prob_t m = (*ntc)[s][x][u][y] * (*ntc)[x][t][y][v] * mono;
+ const prob_t i = (*ntc)[s][x][y][v] * (*ntc)[x][t][u][y] * inv;
+ // cerr << log(i) << "\t" << log(m) << endl;
+ inside += m;
+ inside += i;
+ evals++;
+ }
+ }
+ // cerr << " span: " << log(inside) << endl;
+}
+
+prob_t SynchronousParse(const int n, const int m, const prob_t& lex, const prob_t& mono, const prob_t& inv, const CChart& tc, CChart* ntc) {
+ for (int fl = 0; fl < n; ++fl) {
+ for (int el = 0; el < m; ++el) {
+ const int ms = n - fl;
+ for (int s = 0; s < ms; ++s) {
+ const int t = s + fl + 1;
+ const int mu = m - el;
+ for (int u = 0; u < mu; ++u) {
+ const int v = u + el + 1;
+ //cerr << "Processing cell [" << s << ',' << t << ',' << u << ',' << v << "]\n";
+ ProcessSynchronousCell(s, t, u, v, lex, mono, inv, tc, ntc);
+ }
+ }
+ }
+ }
+ return (*ntc)[0][n][0][m];
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ PTable ptable;
+ LoadITGPhrasetable(conf["phrasetable"].as<string>(), &ptable);
+ ReadFile rf(conf["input"].as<string>());
+ istream& in = *rf.stream();
+ int lc = 0;
+ const WordID sep = TD::Convert("|||");
+ while(in) {
+ string line;
+ getline(in, line);
+ if (line.empty()) continue;
+ ++lc;
+ vector<WordID> full, f, e;
+ TD::ConvertSentence(line, &full);
+ int i = 0;
+ for (; i < full.size(); ++i) {
+ if (full[i] == sep) break;
+ f.push_back(full[i]);
+ }
+ ++i;
+ for (; i < full.size(); ++i) {
+ if (full[i] == sep) break;
+ e.push_back(full[i]);
+ }
+ if (e.empty()) cerr << "E is empty!\n";
+ if (f.empty()) cerr << "F is empty!\n";
+ if (e.empty() || f.empty()) continue;
+ int n = f.size();
+ int m = e.size();
+ cerr << "Synchronous chart has " << (n * n * m * m) << " cells\n";
+ clock_t start = clock();
+ CChart cc(boost::extents[n+1][n+1][m+1][m+1]);
+ FindPhrases(e, f, ptable, &cc);
+ CChart ntc(boost::extents[n+1][n+1][m+1][m+1]);
+ prob_t likelihood = SynchronousParse(n, m, kLEX, kMONO, kINV, cc, &ntc);
+ clock_t end = clock();
+ cerr << "log Z: " << log(likelihood) << endl;
+ cerr << " Z: " << likelihood << endl;
+ double etime = (end - start) / 1000000.0;
+ cout << " time: " << etime << endl;
+ cout << "evals: " << evals << endl;
+ }
+ return 0;
+}
+
diff --git a/src/tdict.cc b/src/tdict.cc
new file mode 100644
index 00000000..c00d20b8
--- /dev/null
+++ b/src/tdict.cc
@@ -0,0 +1,49 @@
+#include "Ngram.h"
+#include "dict.h"
+#include "tdict.h"
+#include "Vocab.h"
+
+using namespace std;
+
+Vocab* TD::dict_ = new Vocab;
+
+static const string empty;
+static const string space = " ";
+
+WordID TD::Convert(const std::string& s) {
+ return dict_->addWord((VocabString)s.c_str());
+}
+
+const char* TD::Convert(const WordID& w) {
+ return dict_->getWord((VocabIndex)w);
+}
+
+void TD::GetWordIDs(const std::vector<std::string>& strings, std::vector<WordID>* ids) {
+ ids->clear();
+ for (vector<string>::const_iterator i = strings.begin(); i != strings.end(); ++i)
+ ids->push_back(TD::Convert(*i));
+}
+
+std::string TD::GetString(const std::vector<WordID>& str) {
+ string res;
+ for (vector<WordID>::const_iterator i = str.begin(); i != str.end(); ++i)
+ res += (i == str.begin() ? empty : space) + TD::Convert(*i);
+ return res;
+}
+
+void TD::ConvertSentence(const std::string& sent, std::vector<WordID>* ids) {
+ string s = sent;
+ int last = 0;
+ ids->clear();
+ for (int i=0; i < s.size(); ++i)
+ if (s[i] == 32 || s[i] == '\t') {
+ s[i]=0;
+ if (last != i) {
+ ids->push_back(Convert(&s[last]));
+ }
+ last = i + 1;
+ }
+ if (last != s.size())
+ ids->push_back(Convert(&s[last]));
+}
+
diff --git a/src/tdict.h b/src/tdict.h
new file mode 100644
index 00000000..9d4318fe
--- /dev/null
+++ b/src/tdict.h
@@ -0,0 +1,19 @@
+#ifndef _TDICT_H_
+#define _TDICT_H_
+
+#include <string>
+#include <vector>
+#include "wordid.h"
+
+class Vocab;
+
+struct TD {
+ static Vocab* dict_;
+ static void ConvertSentence(const std::string& sent, std::vector<WordID>* ids);
+ static void GetWordIDs(const std::vector<std::string>& strings, std::vector<WordID>* ids);
+ static std::string GetString(const std::vector<WordID>& str);
+ static WordID Convert(const std::string& s);
+ static const char* Convert(const WordID& w);
+};
+
+#endif
diff --git a/src/timing_stats.cc b/src/timing_stats.cc
new file mode 100644
index 00000000..85b95de5
--- /dev/null
+++ b/src/timing_stats.cc
@@ -0,0 +1,24 @@
+#include "timing_stats.h"
+
+#include <iostream>
+
+using namespace std;
+
+map<string, TimerInfo> Timer::stats;
+
+Timer::Timer(const string& timername) : start_t(clock()), cur(stats[timername]) {}
+
+Timer::~Timer() {
+ ++cur.calls;
+ const clock_t end_t = clock();
+ const double elapsed = (end_t - start_t) / 1000000.0;
+ cur.total_time += elapsed;
+}
+
+void Timer::Summarize() {
+ for (map<string, TimerInfo>::iterator it = stats.begin(); it != stats.end(); ++it) {
+ cerr << it->first << ": " << it->second.total_time << " secs (" << it->second.calls << " calls)\n";
+ }
+ stats.clear();
+}
+
diff --git a/src/timing_stats.h b/src/timing_stats.h
new file mode 100644
index 00000000..0a9f7656
--- /dev/null
+++ b/src/timing_stats.h
@@ -0,0 +1,25 @@
+#ifndef _TIMING_STATS_H_
+#define _TIMING_STATS_H_
+
+#include <string>
+#include <map>
+
+struct TimerInfo {
+ int calls;
+ double total_time;
+ TimerInfo() : calls(), total_time() {}
+};
+
+struct Timer {
+ Timer(const std::string& info);
+ ~Timer();
+ static void Summarize();
+ private:
+ static std::map<std::string, TimerInfo> stats;
+ clock_t start_t;
+ TimerInfo& cur;
+ Timer(const Timer& other);
+ const Timer& operator=(const Timer& other);
+};
+
+#endif
diff --git a/src/translator.h b/src/translator.h
new file mode 100644
index 00000000..194efbaa
--- /dev/null
+++ b/src/translator.h
@@ -0,0 +1,54 @@
+#ifndef _TRANSLATOR_H_
+#define _TRANSLATOR_H_
+
+#include <string>
+#include <vector>
+#include <boost/shared_ptr.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+class Hypergraph;
+class SentenceMetadata;
+
+class Translator {
+ public:
+ virtual ~Translator();
+ // returns true if goal reached, false otherwise
+ // minus_lm_forest will contain the unpruned forest. the
+ // feature values from the phrase table / grammar / etc
+ // should be in the forest already - the "late" features
+ // should not just copy values that are available without
+ // any context or computation.
+ // SentenceMetadata contains information about the sentence,
+ // but it is an input/output parameter since the Translator
+ // is also responsible for setting the value of src_len.
+ virtual bool Translate(const std::string& src,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest) = 0;
+};
+
+class SCFGTranslatorImpl;
+class SCFGTranslator : public Translator {
+ public:
+ SCFGTranslator(const boost::program_options::variables_map& conf);
+ bool Translate(const std::string& src,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest);
+ private:
+ boost::shared_ptr<SCFGTranslatorImpl> pimpl_;
+};
+
+class FSTTranslatorImpl;
+class FSTTranslator : public Translator {
+ public:
+ FSTTranslator(const boost::program_options::variables_map& conf);
+ bool Translate(const std::string& src,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest);
+ private:
+ boost::shared_ptr<FSTTranslatorImpl> pimpl_;
+};
+
+#endif
diff --git a/src/trule.cc b/src/trule.cc
new file mode 100644
index 00000000..b8f6995e
--- /dev/null
+++ b/src/trule.cc
@@ -0,0 +1,237 @@
+#include "trule.h"
+
+#include <sstream>
+
+#include "stringlib.h"
+#include "tdict.h"
+
+using namespace std;
+
+static WordID ConvertTrgString(const string& w) {
+ int len = w.size();
+ WordID id = 0;
+ // [X,0] or [0]
+ // for target rules, we ignore the category, just keep the index
+ if (len > 2 && w[0]=='[' && w[len-1]==']' && w[len-2] > '0' && w[len-2] <= '9' &&
+ (len == 3 || (len > 4 && w[len-3] == ','))) {
+ id = w[len-2] - '0';
+ id = 1 - id;
+ } else {
+ id = TD::Convert(w);
+ }
+ return id;
+}
+
+static WordID ConvertSrcString(const string& w, bool mono = false) {
+ int len = w.size();
+ // [X,0]
+ // for source rules, we keep the category and ignore the index (source rules are
+ // always numbered 1, 2, 3...
+ if (mono) {
+ if (len > 2 && w[0]=='[' && w[len-1]==']') {
+ if (len > 4 && w[len-3] == ',') {
+ cerr << "[ERROR] Monolingual rules mut not have non-terminal indices:\n "
+ << w << endl;
+ exit(1);
+ }
+ // TODO check that source indices go 1,2,3,etc.
+ return TD::Convert(w.substr(1, len-2)) * -1;
+ } else {
+ return TD::Convert(w);
+ }
+ } else {
+ if (len > 4 && w[0]=='[' && w[len-1]==']' && w[len-3] == ',' && w[len-2] > '0' && w[len-2] <= '9') {
+ return TD::Convert(w.substr(1, len-4)) * -1;
+ } else {
+ return TD::Convert(w);
+ }
+ }
+}
+
+static WordID ConvertLHS(const string& w) {
+ if (w[0] == '[') {
+ int len = w.size();
+ if (len < 3) { cerr << "Format error: " << w << endl; exit(1); }
+ return TD::Convert(w.substr(1, len-2)) * -1;
+ } else {
+ return TD::Convert(w) * -1;
+ }
+}
+
+TRule* TRule::CreateRuleSynchronous(const std::string& rule) {
+ TRule* res = new TRule;
+ if (res->ReadFromString(rule, true, false)) return res;
+ cerr << "[ERROR] Failed to creating rule from: " << rule << endl;
+ delete res;
+ return NULL;
+}
+
+TRule* TRule::CreateRulePhrasetable(const string& rule) {
+ // TODO make this faster
+ // TODO add configuration for default NT type
+ if (rule[0] == '[') {
+ cerr << "Phrasetable rules shouldn't have a LHS / non-terminals:\n " << rule << endl;
+ return NULL;
+ }
+ TRule* res = new TRule("[X] ||| " + rule, true, false);
+ if (res->Arity() != 0) {
+ cerr << "Phrasetable rules should have arity 0:\n " << rule << endl;
+ delete res;
+ return NULL;
+ }
+ return res;
+}
+
+TRule* TRule::CreateRuleMonolingual(const string& rule) {
+ return new TRule(rule, false, true);
+}
+
+bool TRule::ReadFromString(const string& line, bool strict, bool mono) {
+ e_.clear();
+ f_.clear();
+ scores_.clear();
+
+ string w;
+ istringstream is(line);
+ int format = CountSubstrings(line, "|||");
+ if (strict && format < 2) {
+ cerr << "Bad rule format in strict mode:\n" << line << endl;
+ return false;
+ }
+ if (format >= 2 || (mono && format == 1)) {
+ while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); }
+ while(is>>w && w!="|||") { f_.push_back(ConvertSrcString(w, mono)); }
+ if (!mono) {
+ while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); }
+ }
+ int fv = 0;
+ if (is) {
+ string ss;
+ getline(is, ss);
+ //cerr << "L: " << ss << endl;
+ int start = 0;
+ const int len = ss.size();
+ while (start < len) {
+ while(start < len && (ss[start] == ' ' || ss[start] == ';'))
+ ++start;
+ if (start == len) break;
+ int end = start + 1;
+ while(end < len && (ss[end] != '=' && ss[end] != ' ' && ss[end] != ';'))
+ ++end;
+ if (end == len || ss[end] == ' ' || ss[end] == ';') {
+ //cerr << "PROC: '" << ss.substr(start, end - start) << "'\n";
+ // non-named features
+ if (end != len) { ss[end] = 0; }
+ string fname = "PhraseModel_X";
+ if (fv > 9) { cerr << "Too many phrasetable scores - used named format\n"; abort(); }
+ fname[12]='0' + fv;
+ ++fv;
+ scores_.set_value(FD::Convert(fname), atof(&ss[start]));
+ //cerr << "F: " << fname << " VAL=" << scores_.value(FD::Convert(fname)) << endl;
+ } else {
+ const int fid = FD::Convert(ss.substr(start, end - start));
+ start = end + 1;
+ end = start + 1;
+ while(end < len && (ss[end] != ' ' && ss[end] != ';'))
+ ++end;
+ if (end < len) { ss[end] = 0; }
+ assert(start < len);
+ scores_.set_value(fid, atof(&ss[start]));
+ //cerr << "F: " << FD::Convert(fid) << " VAL=" << scores_.value(fid) << endl;
+ }
+ start = end + 1;
+ }
+ }
+ } else if (format == 1) {
+ while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); }
+ while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); }
+ f_ = e_;
+ int x = ConvertLHS("[X]");
+ for (int i = 0; i < f_.size(); ++i)
+ if (f_[i] <= 0) { f_[i] = x; }
+ } else {
+ cerr << "F: " << format << endl;
+ cerr << "[ERROR] Don't know how to read:\n" << line << endl;
+ }
+ if (mono) {
+ e_ = f_;
+ int ci = 0;
+ for (int i = 0; i < e_.size(); ++i)
+ if (e_[i] < 0)
+ e_[i] = ci--;
+ }
+ ComputeArity();
+ return SanityCheck();
+}
+
+bool TRule::SanityCheck() const {
+ vector<int> used(f_.size(), 0);
+ int ac = 0;
+ for (int i = 0; i < e_.size(); ++i) {
+ int ind = e_[i];
+ if (ind > 0) continue;
+ ind = -ind;
+ if ((++used[ind]) != 1) {
+ cerr << "[ERROR] e-side variable index " << (ind+1) << " used more than once!\n";
+ return false;
+ }
+ ac++;
+ }
+ if (ac != Arity()) {
+ cerr << "[ERROR] e-side arity mismatches f-side\n";
+ return false;
+ }
+ return true;
+}
+
+void TRule::ComputeArity() {
+ int min = 1;
+ for (vector<WordID>::const_iterator i = e_.begin(); i != e_.end(); ++i)
+ if (*i < min) min = *i;
+ arity_ = 1 - min;
+}
+
+static string AnonymousStrVar(int i) {
+ string res("[v]");
+ if(!(i <= 0 && i >= -8)) {
+ cerr << "Can't handle more than 9 non-terminals: index=" << (-i) << endl;
+ abort();
+ }
+ res[1] = '1' - i;
+ return res;
+}
+
+string TRule::AsString(bool verbose) const {
+ ostringstream os;
+ int idx = 0;
+ if (lhs_ && verbose) {
+ os << '[' << TD::Convert(lhs_ * -1) << "] |||";
+ for (int i = 0; i < f_.size(); ++i) {
+ const WordID& w = f_[i];
+ if (w < 0) {
+ int wi = w * -1;
+ ++idx;
+ os << " [" << TD::Convert(wi) << ',' << idx << ']';
+ } else {
+ os << ' ' << TD::Convert(w);
+ }
+ }
+ os << " ||| ";
+ }
+ if (idx > 9) {
+ cerr << "Too many non-terminals!\n partial: " << os.str() << endl;
+ exit(1);
+ }
+ for (int i =0; i<e_.size(); ++i) {
+ if (i) os << ' ';
+ const WordID& w = e_[i];
+ if (w < 1)
+ os << AnonymousStrVar(w);
+ else
+ os << TD::Convert(w);
+ }
+ if (!scores_.empty() && verbose) {
+ os << " ||| " << scores_;
+ }
+ return os.str();
+}
diff --git a/src/trule.h b/src/trule.h
new file mode 100644
index 00000000..d2b1babe
--- /dev/null
+++ b/src/trule.h
@@ -0,0 +1,122 @@
+#ifndef _RULE_H_
+#define _RULE_H_
+
+#include <algorithm>
+#include <vector>
+#include <cassert>
+#include <boost/shared_ptr.hpp>
+
+#include "sparse_vector.h"
+#include "wordid.h"
+
+class TRule;
+typedef boost::shared_ptr<TRule> TRulePtr;
+struct SpanInfo;
+
+// Translation rule
+class TRule {
+ public:
+ TRule() : lhs_(0), prev_i(-1), prev_j(-1) { }
+ explicit TRule(const std::vector<WordID>& e) : e_(e), lhs_(0), prev_i(-1), prev_j(-1) {}
+ TRule(const std::vector<WordID>& e, const std::vector<WordID>& f, const WordID& lhs) :
+ e_(e), f_(f), lhs_(lhs), prev_i(-1), prev_j(-1) {}
+
+ // deprecated - this will be private soon
+ explicit TRule(const std::string& text, bool strict = false, bool mono = false) {
+ ReadFromString(text, strict, mono);
+ }
+
+ // make a rule from a hiero-like rule table, e.g.
+ // [X] ||| [X,1] DE [X,2] ||| [X,2] of the [X,1]
+ // if misformatted, returns NULL
+ static TRule* CreateRuleSynchronous(const std::string& rule);
+
+ // make a rule from a phrasetable entry (i.e., one that has no LHS type), e.g:
+ // el gato ||| the cat ||| Feature_2=0.34
+ static TRule* CreateRulePhrasetable(const std::string& rule);
+
+ // make a rule from a non-synchrnous CFG representation, e.g.:
+ // [LHS] ||| term1 [NT] term2 [OTHER_NT] [YET_ANOTHER_NT]
+ static TRule* CreateRuleMonolingual(const std::string& rule);
+
+ void ESubstitute(const std::vector<const std::vector<WordID>* >& var_values,
+ std::vector<WordID>* result) const {
+ int vc = 0;
+ result->clear();
+ for (std::vector<WordID>::const_iterator i = e_.begin(); i != e_.end(); ++i) {
+ const WordID& c = *i;
+ if (c < 1) {
+ ++vc;
+ const std::vector<WordID>& var_value = *var_values[-c];
+ std::copy(var_value.begin(),
+ var_value.end(),
+ std::back_inserter(*result));
+ } else {
+ result->push_back(c);
+ }
+ }
+ assert(vc == var_values.size());
+ }
+
+ void FSubstitute(const std::vector<const std::vector<WordID>* >& var_values,
+ std::vector<WordID>* result) const {
+ int vc = 0;
+ result->clear();
+ for (std::vector<WordID>::const_iterator i = f_.begin(); i != f_.end(); ++i) {
+ const WordID& c = *i;
+ if (c < 1) {
+ const std::vector<WordID>& var_value = *var_values[vc++];
+ std::copy(var_value.begin(),
+ var_value.end(),
+ std::back_inserter(*result));
+ } else {
+ result->push_back(c);
+ }
+ }
+ assert(vc == var_values.size());
+ }
+
+ bool ReadFromString(const std::string& line, bool strict = false, bool monolingual = false);
+
+ bool Initialized() const { return e_.size(); }
+
+ std::string AsString(bool verbose = true) const;
+
+ static TRule DummyRule() {
+ TRule res;
+ res.e_.resize(1, 0);
+ return res;
+ }
+
+ const std::vector<WordID>& f() const { return f_; }
+ const std::vector<WordID>& e() const { return e_; }
+
+ int EWords() const { return ELength() - Arity(); }
+ int FWords() const { return FLength() - Arity(); }
+ int FLength() const { return f_.size(); }
+ int ELength() const { return e_.size(); }
+ int Arity() const { return arity_; }
+ bool IsUnary() const { return (Arity() == 1) && (f_.size() == 1); }
+ const SparseVector<double>& GetFeatureValues() const { return scores_; }
+ double Score(int i) const { return scores_[i]; }
+ WordID GetLHS() const { return lhs_; }
+ void ComputeArity();
+
+ // 0 = first variable, -1 = second variable, -2 = third ...
+ std::vector<WordID> e_;
+ // < 0: *-1 = encoding of category of variable
+ std::vector<WordID> f_;
+ WordID lhs_;
+ SparseVector<double> scores_;
+ char arity_;
+ TRulePtr parent_rule_; // usually NULL, except when doing constrained decoding
+
+ // this is only used when doing synchronous parsing
+ short int prev_i;
+ short int prev_j;
+
+ private:
+ bool SanityCheck() const;
+};
+
+#endif
diff --git a/src/trule_test.cc b/src/trule_test.cc
new file mode 100644
index 00000000..02a70764
--- /dev/null
+++ b/src/trule_test.cc
@@ -0,0 +1,65 @@
+#include "trule.h"
+
+#include <gtest/gtest.h>
+#include <cassert>
+#include <iostream>
+#include "tdict.h"
+
+using namespace std;
+
+class TRuleTest : public testing::Test {
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+};
+
+TEST_F(TRuleTest,TestFSubstitute) {
+ TRule r1("[X] ||| ob [X,1] [X,2] sah . ||| whether [X,1] saw [X,2] . ||| 0.99");
+ TRule r2("[X] ||| ich ||| i ||| 1.0");
+ TRule r3("[X] ||| ihn ||| him ||| 1.0");
+ vector<const vector<WordID>*> ants;
+ vector<WordID> res2;
+ r2.FSubstitute(ants, &res2);
+ assert(TD::GetString(res2) == "ich");
+ vector<WordID> res3;
+ r3.FSubstitute(ants, &res3);
+ assert(TD::GetString(res3) == "ihn");
+ ants.push_back(&res2);
+ ants.push_back(&res3);
+ vector<WordID> res;
+ r1.FSubstitute(ants, &res);
+ cerr << TD::GetString(res) << endl;
+ assert(TD::GetString(res) == "ob ich ihn sah .");
+}
+
+TEST_F(TRuleTest,TestPhrasetableRule) {
+ TRulePtr t(TRule::CreateRulePhrasetable("gato ||| cat ||| PhraseModel_0=-23.2;Foo=1;Bar=12"));
+ cerr << t->AsString() << endl;
+ assert(t->scores_.num_active() == 3);
+};
+
+
+TEST_F(TRuleTest,TestMonoRule) {
+ TRulePtr m(TRule::CreateRuleMonolingual("[LHS] ||| term1 [NT] term2 [NT2] [NT3]"));
+ assert(m->Arity() == 3);
+ cerr << m->AsString() << endl;
+ TRulePtr m2(TRule::CreateRuleMonolingual("[LHS] ||| term1 [NT] term2 [NT2] [NT3] ||| Feature1=0.23"));
+ assert(m2->Arity() == 3);
+ cerr << m2->AsString() << endl;
+ EXPECT_FLOAT_EQ(m2->scores_.value(FD::Convert("Feature1")), 0.23);
+}
+
+TEST_F(TRuleTest,TestRuleR) {
+ TRule t6;
+ t6.ReadFromString("[X] ||| den [X,1] sah [X,2] . ||| [X,2] saw the [X,1] . ||| 0.12321 0.23232 0.121");
+ cerr << "TEXT: " << t6.AsString() << endl;
+ EXPECT_EQ(t6.Arity(), 2);
+ EXPECT_EQ(t6.e_[0], -1);
+ EXPECT_EQ(t6.e_[3], 0);
+}
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
diff --git a/src/ttables.cc b/src/ttables.cc
new file mode 100644
index 00000000..2ea960f0
--- /dev/null
+++ b/src/ttables.cc
@@ -0,0 +1,31 @@
+#include "ttables.h"
+
+#include <cassert>
+
+#include "dict.h"
+
+using namespace std;
+using namespace std::tr1;
+
+void TTable::DeserializeProbsFromText(std::istream* in) {
+ int c = 0;
+ while(*in) {
+ string e;
+ string f;
+ double p;
+ (*in) >> e >> f >> p;
+ if (e.empty()) break;
+ ++c;
+ ttable[TD::Convert(e)][TD::Convert(f)] = prob_t(p);
+ }
+ cerr << "Loaded " << c << " translation parameters.\n";
+}
+
+void TTable::SerializeHelper(string* out, const Word2Word2Double& o) {
+ assert(!"not implemented");
+}
+
+void TTable::DeserializeHelper(const string& in, Word2Word2Double* o) {
+ assert(!"not implemented");
+}
+
diff --git a/src/ttables.h b/src/ttables.h
new file mode 100644
index 00000000..3ffc238a
--- /dev/null
+++ b/src/ttables.h
@@ -0,0 +1,87 @@
+#ifndef _TTABLES_H_
+#define _TTABLES_H_
+
+#include <iostream>
+#include <map>
+
+#include "wordid.h"
+#include "prob.h"
+#include "tdict.h"
+
+class TTable {
+ public:
+ TTable() {}
+ typedef std::map<WordID, double> Word2Double;
+ typedef std::map<WordID, Word2Double> Word2Word2Double;
+ inline const prob_t prob(const int& e, const int& f) const {
+ const Word2Word2Double::const_iterator cit = ttable.find(e);
+ if (cit != ttable.end()) {
+ const Word2Double& cpd = cit->second;
+ const Word2Double::const_iterator it = cpd.find(f);
+ if (it == cpd.end()) return prob_t(0.00001);
+ return prob_t(it->second);
+ } else {
+ return prob_t(0.00001);
+ }
+ }
+ inline void Increment(const int& e, const int& f) {
+ counts[e][f] += 1.0;
+ }
+ inline void Increment(const int& e, const int& f, double x) {
+ counts[e][f] += x;
+ }
+ void Normalize() {
+ ttable.swap(counts);
+ for (Word2Word2Double::iterator cit = ttable.begin();
+ cit != ttable.end(); ++cit) {
+ double tot = 0;
+ Word2Double& cpd = cit->second;
+ for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
+ tot += it->second;
+ for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
+ it->second /= tot;
+ }
+ counts.clear();
+ }
+ // adds counts from another TTable - probabilities remain unchanged
+ TTable& operator+=(const TTable& rhs) {
+ for (Word2Word2Double::const_iterator it = rhs.counts.begin();
+ it != rhs.counts.end(); ++it) {
+ const Word2Double& cpd = it->second;
+ Word2Double& tgt = counts[it->first];
+ for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) {
+ tgt[j->first] += j->second;
+ }
+ }
+ return *this;
+ }
+ void ShowTTable() {
+ for (Word2Word2Double::iterator it = ttable.begin(); it != ttable.end(); ++it) {
+ Word2Double& cpd = it->second;
+ for (Word2Double::iterator j = cpd.begin(); j != cpd.end(); ++j) {
+ std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl;
+ }
+ }
+ }
+ void ShowCounts() {
+ for (Word2Word2Double::iterator it = counts.begin(); it != counts.end(); ++it) {
+ Word2Double& cpd = it->second;
+ for (Word2Double::iterator j = cpd.begin(); j != cpd.end(); ++j) {
+ std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl;
+ }
+ }
+ }
+ void DeserializeProbsFromText(std::istream* in);
+ void SerializeCounts(std::string* out) const { SerializeHelper(out, counts); }
+ void DeserializeCounts(const std::string& in) { DeserializeHelper(in, &counts); }
+ void SerializeProbs(std::string* out) const { SerializeHelper(out, ttable); }
+ void DeserializeProbs(const std::string& in) { DeserializeHelper(in, &ttable); }
+ private:
+ static void SerializeHelper(std::string*, const Word2Word2Double& o);
+ static void DeserializeHelper(const std::string&, Word2Word2Double* o);
+ public:
+ Word2Word2Double ttable;
+ Word2Word2Double counts;
+};
+
+#endif
diff --git a/src/viterbi.cc b/src/viterbi.cc
new file mode 100644
index 00000000..82b2ce6d
--- /dev/null
+++ b/src/viterbi.cc
@@ -0,0 +1,39 @@
+#include "viterbi.h"
+
+#include <vector>
+#include "hg.h"
+
+using namespace std;
+
+string ViterbiETree(const Hypergraph& hg) {
+ vector<WordID> tmp;
+ const prob_t p = Viterbi<vector<WordID>, ETreeTraversal, prob_t, EdgeProb>(hg, &tmp);
+ return TD::GetString(tmp);
+}
+
+string ViterbiFTree(const Hypergraph& hg) {
+ vector<WordID> tmp;
+ const prob_t p = Viterbi<vector<WordID>, FTreeTraversal, prob_t, EdgeProb>(hg, &tmp);
+ return TD::GetString(tmp);
+}
+
+prob_t ViterbiESentence(const Hypergraph& hg, vector<WordID>* result) {
+ return Viterbi<vector<WordID>, ESentenceTraversal, prob_t, EdgeProb>(hg, result);
+}
+
+prob_t ViterbiFSentence(const Hypergraph& hg, vector<WordID>* result) {
+ return Viterbi<vector<WordID>, FSentenceTraversal, prob_t, EdgeProb>(hg, result);
+}
+
+int ViterbiELength(const Hypergraph& hg) {
+ int len = -1;
+ Viterbi<int, ELengthTraversal, prob_t, EdgeProb>(hg, &len);
+ return len;
+}
+
+int ViterbiPathLength(const Hypergraph& hg) {
+ int len = -1;
+ Viterbi<int, PathLengthTraversal, prob_t, EdgeProb>(hg, &len);
+ return len;
+}
+
diff --git a/src/viterbi.h b/src/viterbi.h
new file mode 100644
index 00000000..46a4f528
--- /dev/null
+++ b/src/viterbi.h
@@ -0,0 +1,130 @@
+#ifndef _VITERBI_H_
+#define _VITERBI_H_
+
+#include <vector>
+#include "prob.h"
+#include "hg.h"
+#include "tdict.h"
+
+// V must implement:
+// void operator()(const vector<const T*>& ants, T* result);
+template<typename T, typename Traversal, typename WeightType, typename WeightFunction>
+WeightType Viterbi(const Hypergraph& hg,
+ T* result,
+ const Traversal& traverse = Traversal(),
+ const WeightFunction& weight = WeightFunction()) {
+ const int num_nodes = hg.nodes_.size();
+ std::vector<T> vit_result(num_nodes);
+ std::vector<WeightType> vit_weight(num_nodes, WeightType::Zero());
+
+ for (int i = 0; i < num_nodes; ++i) {
+ const Hypergraph::Node& cur_node = hg.nodes_[i];
+ WeightType* const cur_node_best_weight = &vit_weight[i];
+ T* const cur_node_best_result = &vit_result[i];
+
+ const int num_in_edges = cur_node.in_edges_.size();
+ if (num_in_edges == 0) {
+ *cur_node_best_weight = WeightType(1);
+ continue;
+ }
+ for (int j = 0; j < num_in_edges; ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]];
+ WeightType score = weight(edge);
+ std::vector<const T*> ants(edge.tail_nodes_.size());
+ for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
+ const int tail_node_index = edge.tail_nodes_[k];
+ score *= vit_weight[tail_node_index];
+ ants[k] = &vit_result[tail_node_index];
+ }
+ if (*cur_node_best_weight < score) {
+ *cur_node_best_weight = score;
+ traverse(edge, ants, cur_node_best_result);
+ }
+ }
+ }
+ std::swap(*result, vit_result.back());
+ return vit_weight.back();
+}
+
+struct PathLengthTraversal {
+ void operator()(const Hypergraph::Edge& edge,
+ const std::vector<const int*>& ants,
+ int* result) const {
+ (void) edge;
+ *result = 1;
+ for (int i = 0; i < ants.size(); ++i) *result += *ants[i];
+ }
+};
+
+struct ESentenceTraversal {
+ void operator()(const Hypergraph::Edge& edge,
+ const std::vector<const std::vector<WordID>*>& ants,
+ std::vector<WordID>* result) const {
+ edge.rule_->ESubstitute(ants, result);
+ }
+};
+
+struct ELengthTraversal {
+ void operator()(const Hypergraph::Edge& edge,
+ const std::vector<const int*>& ants,
+ int* result) const {
+ *result = edge.rule_->ELength() - edge.rule_->Arity();
+ for (int i = 0; i < ants.size(); ++i) *result += *ants[i];
+ }
+};
+
+struct FSentenceTraversal {
+ void operator()(const Hypergraph::Edge& edge,
+ const std::vector<const std::vector<WordID>*>& ants,
+ std::vector<WordID>* result) const {
+ edge.rule_->FSubstitute(ants, result);
+ }
+};
+
+// create a strings of the form (S (X the man) (X said (X he (X would (X go)))))
+struct ETreeTraversal {
+ ETreeTraversal() : left("("), space(" "), right(")") {}
+ const std::string left;
+ const std::string space;
+ const std::string right;
+ void operator()(const Hypergraph::Edge& edge,
+ const std::vector<const std::vector<WordID>*>& ants,
+ std::vector<WordID>* result) const {
+ std::vector<WordID> tmp;
+ edge.rule_->ESubstitute(ants, &tmp);
+ const std::string cat = TD::Convert(edge.rule_->GetLHS() * -1);
+ if (cat == "Goal")
+ result->swap(tmp);
+ else
+ TD::ConvertSentence(left + cat + space + TD::GetString(tmp) + right,
+ result);
+ }
+};
+
+struct FTreeTraversal {
+ FTreeTraversal() : left("("), space(" "), right(")") {}
+ const std::string left;
+ const std::string space;
+ const std::string right;
+ void operator()(const Hypergraph::Edge& edge,
+ const std::vector<const std::vector<WordID>*>& ants,
+ std::vector<WordID>* result) const {
+ std::vector<WordID> tmp;
+ edge.rule_->FSubstitute(ants, &tmp);
+ const std::string cat = TD::Convert(edge.rule_->GetLHS() * -1);
+ if (cat == "Goal")
+ result->swap(tmp);
+ else
+ TD::ConvertSentence(left + cat + space + TD::GetString(tmp) + right,
+ result);
+ }
+};
+
+prob_t ViterbiESentence(const Hypergraph& hg, std::vector<WordID>* result);
+std::string ViterbiETree(const Hypergraph& hg);
+prob_t ViterbiFSentence(const Hypergraph& hg, std::vector<WordID>* result);
+std::string ViterbiFTree(const Hypergraph& hg);
+int ViterbiELength(const Hypergraph& hg);
+int ViterbiPathLength(const Hypergraph& hg);
+
+#endif
diff --git a/src/weights.cc b/src/weights.cc
new file mode 100644
index 00000000..bb0a878f
--- /dev/null
+++ b/src/weights.cc
@@ -0,0 +1,73 @@
+#include "weights.h"
+
+#include <sstream>
+
+#include "fdict.h"
+#include "filelib.h"
+
+using namespace std;
+
+void Weights::InitFromFile(const std::string& filename, vector<string>* feature_list) {
+ cerr << "Reading weights from " << filename << endl;
+ ReadFile in_file(filename);
+ istream& in = *in_file.stream();
+ assert(in);
+ int weight_count = 0;
+ bool fl = false;
+ while (in) {
+ double val = 0;
+ string buf;
+ getline(in, buf);
+ if (buf.size() == 0) continue;
+ if (buf[0] == '#') continue;
+ for (int i = 0; i < buf.size(); ++i)
+ if (buf[i] == '=') buf[i] = ' ';
+ int start = 0;
+ while(start < buf.size() && buf[start] == ' ') ++start;
+ int end = 0;
+ while(end < buf.size() && buf[end] != ' ') ++end;
+ int fid = FD::Convert(buf.substr(start, end - start));
+ while(end < buf.size() && buf[end] == ' ') ++end;
+ val = strtod(&buf.c_str()[end], NULL);
+ if (wv_.size() <= fid)
+ wv_.resize(fid + 1);
+ wv_[fid] = val;
+ if (feature_list) { feature_list->push_back(FD::Convert(fid)); }
+ ++weight_count;
+ if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; }
+ if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; }
+ }
+ if (fl) { cerr << endl; }
+ cerr << "Loaded " << weight_count << " feature weights\n";
+}
+
+void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_features) const {
+ WriteFile out(fname);
+ ostream& o = *out.stream();
+ assert(o);
+ o.precision(17);
+ const int num_feats = FD::NumFeats();
+ for (int i = 1; i < num_feats; ++i) {
+ const double val = (i < wv_.size() ? wv_[i] : 0.0);
+ if (hide_zero_value_features && val == 0.0) continue;
+ o << FD::Convert(i) << ' ' << val << endl;
+ }
+}
+
+void Weights::InitVector(std::vector<double>* w) const {
+ *w = wv_;
+}
+
+void Weights::InitSparseVector(SparseVector<double>* w) const {
+ for (int i = 1; i < wv_.size(); ++i) {
+ const double& weight = wv_[i];
+ if (weight) w->set_value(i, weight);
+ }
+}
+
+void Weights::InitFromVector(const std::vector<double>& w) {
+ wv_ = w;
+ if (wv_.size() > FD::NumFeats())
+ cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n";
+ wv_.resize(FD::NumFeats(), 0);
+}
diff --git a/src/weights.h b/src/weights.h
new file mode 100644
index 00000000..f19aa3ce
--- /dev/null
+++ b/src/weights.h
@@ -0,0 +1,21 @@
+#ifndef _WEIGHTS_H_
+#define _WEIGHTS_H_
+
+#include <string>
+#include <map>
+#include <vector>
+#include "sparse_vector.h"
+
+class Weights {
+ public:
+ Weights() {}
+ void InitFromFile(const std::string& fname, std::vector<std::string>* feature_list = NULL);
+ void WriteToFile(const std::string& fname, bool hide_zero_value_features = true) const;
+ void InitVector(std::vector<double>* w) const;
+ void InitSparseVector(SparseVector<double>* w) const;
+ void InitFromVector(const std::vector<double>& w);
+ private:
+ std::vector<double> wv_;
+};
+
+#endif
diff --git a/src/weights_test.cc b/src/weights_test.cc
new file mode 100644
index 00000000..aa6b3db2
--- /dev/null
+++ b/src/weights_test.cc
@@ -0,0 +1,28 @@
+#include <cassert>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <gtest/gtest.h>
+#include "weights.h"
+#include "tdict.h"
+#include "hg.h"
+
+using namespace std;
+
+class WeightsTest : public testing::Test {
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+};
+
+
+TEST_F(WeightsTest,Load) {
+ Weights w;
+ w.InitFromFile("test_data/weights");
+ w.WriteToFile("-");
+}
+
+int main(int argc, char **argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/wordid.h b/src/wordid.h
new file mode 100644
index 00000000..fb50bcc1
--- /dev/null
+++ b/src/wordid.h
@@ -0,0 +1,6 @@
+#ifndef _WORD_ID_H_
+#define _WORD_ID_H_
+
+typedef int WordID;
+
+#endif
diff --git a/training/Makefile.am b/training/Makefile.am
new file mode 100644
index 00000000..a2888f7a
--- /dev/null
+++ b/training/Makefile.am
@@ -0,0 +1,39 @@
+bin_PROGRAMS = \
+ model1 \
+ mr_optimize_reduce \
+ grammar_convert \
+ atools \
+ plftools \
+ lbfgs_test \
+ mr_em_train \
+ collapse_weights \
+ optimize_test
+
+atools_SOURCES = atools.cc
+
+model1_SOURCES = model1.cc
+model1_LDADD = libhg.a
+
+grammar_convert_SOURCES = grammar_convert.cc
+
+optimize_test_SOURCES = optimize_test.cc
+
+collapse_weights_SOURCES = collapse_weights.cc
+
+lbfgs_test_SOURCES = lbfgs_test.cc
+
+mr_optimize_reduce_SOURCES = mr_optimize_reduce.cc
+mr_optimize_reduce_LDADD = libhg.a
+
+mr_em_train_SOURCES = mr_em_train.cc
+mr_em_train_LDADD = libhg.a
+
+plftools_SOURCES = plftools.cc
+plftools_LDADD = libhg.a
+
+LDADD = libhg.a
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS)
+AM_LDFLAGS = $(BOOST_LDFLAGS) $(BOOST_PROGRAM_OPTIONS_LIB) -lz
+
+noinst_LIBRARIES = libhg.a
diff --git a/training/atools.cc b/training/atools.cc
new file mode 100644
index 00000000..bac73859
--- /dev/null
+++ b/training/atools.cc
@@ -0,0 +1,207 @@
+#include <iostream>
+#include <sstream>
+#include <vector>
+
+#include <map>
+#include <boost/program_options.hpp>
+#include <boost/shared_ptr.hpp>
+
+#include "filelib.h"
+#include "aligner.h"
+
+namespace po = boost::program_options;
+using namespace std;
+using boost::shared_ptr;
+
+struct Command {
+ virtual ~Command() {}
+ virtual string Name() const = 0;
+
+ // returns 1 for alignment grid output [default]
+ // returns 2 if Summary() should be called [for AER, etc]
+ virtual int Result() const { return 1; }
+
+ virtual bool RequiresTwoOperands() const { return true; }
+ virtual void Apply(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) = 0;
+ void EnsureSize(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) {
+ x->resize(max(a.width(), b.width()), max(a.height(), b.width()));
+ }
+ bool Safe(const Array2D<bool>& a, int i, int j) const {
+ if (i < a.width() && j < a.height())
+ return a(i,j);
+ else
+ return false;
+ }
+ virtual void Summary() { assert(!"Summary should have been overridden"); }
+};
+
+// compute fmeasure, second alignment is reference, first is hyp
+struct FMeasureCommand : public Command {
+ FMeasureCommand() : matches(), num_predicted(), num_in_ref() {}
+ int Result() const { return 2; }
+ string Name() const { return "f"; }
+ bool RequiresTwoOperands() const { return true; }
+ void Apply(const Array2D<bool>& hyp, const Array2D<bool>& ref, Array2D<bool>* x) {
+ int i_len = ref.width();
+ int j_len = ref.height();
+ for (int i = 0; i < i_len; ++i) {
+ for (int j = 0; j < j_len; ++j) {
+ if (ref(i,j)) {
+ ++num_in_ref;
+ if (Safe(hyp, i, j)) ++matches;
+ }
+ }
+ }
+ for (int i = 0; i < hyp.width(); ++i)
+ for (int j = 0; j < hyp.height(); ++j)
+ if (hyp(i,j)) ++num_predicted;
+ }
+ void Summary() {
+ if (num_predicted == 0 || num_in_ref == 0) {
+ cerr << "Insufficient statistics to compute f-measure!\n";
+ abort();
+ }
+ const double prec = static_cast<double>(matches) / num_predicted;
+ const double rec = static_cast<double>(matches) / num_in_ref;
+ cout << "P: " << prec << endl;
+ cout << "R: " << rec << endl;
+ const double f = (2.0 * prec * rec) / (rec + prec);
+ cout << "F: " << f << endl;
+ }
+ int matches;
+ int num_predicted;
+ int num_in_ref;
+};
+
+struct ConvertCommand : public Command {
+ string Name() const { return "convert"; }
+ bool RequiresTwoOperands() const { return false; }
+ void Apply(const Array2D<bool>& in, const Array2D<bool>&not_used, Array2D<bool>* x) {
+ *x = in;
+ }
+};
+
+struct InvertCommand : public Command {
+ string Name() const { return "invert"; }
+ bool RequiresTwoOperands() const { return false; }
+ void Apply(const Array2D<bool>& in, const Array2D<bool>&not_used, Array2D<bool>* x) {
+ Array2D<bool>& res = *x;
+ res.resize(in.height(), in.width());
+ for (int i = 0; i < in.height(); ++i)
+ for (int j = 0; j < in.width(); ++j)
+ res(i, j) = in(j, i);
+ }
+};
+
+struct IntersectCommand : public Command {
+ string Name() const { return "intersect"; }
+ bool RequiresTwoOperands() const { return true; }
+ void Apply(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) {
+ EnsureSize(a, b, x);
+ Array2D<bool>& res = *x;
+ for (int i = 0; i < a.width(); ++i)
+ for (int j = 0; j < a.height(); ++j)
+ res(i, j) = Safe(a, i, j) && Safe(b, i, j);
+ }
+};
+
+map<string, boost::shared_ptr<Command> > commands;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ ostringstream os;
+ os << "[REQ] Operation to perform:";
+ for (map<string, boost::shared_ptr<Command> >::iterator it = commands.begin();
+ it != commands.end(); ++it) {
+ os << ' ' << it->first;
+ }
+ string cstr = os.str();
+ opts.add_options()
+ ("input_1,i", po::value<string>(), "[REQ] Alignment 1 file, - for STDIN")
+ ("input_2,j", po::value<string>(), "[OPT] Alignment 2 file, - for STDIN")
+ ("command,c", po::value<string>()->default_value("convert"), cstr.c_str())
+ ("help,h", "Print this help message and exit");
+ po::options_description clo("Command line options");
+ po::options_description dcmdline_options;
+ dcmdline_options.add(opts);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ po::notify(*conf);
+
+ if (conf->count("help") || conf->count("input_1") == 0 || conf->count("command") == 0) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+ const string cmd = (*conf)["command"].as<string>();
+ if (commands.count(cmd) == 0) {
+ cerr << "Don't understand command: " << cmd << endl;
+ exit(1);
+ }
+ if (commands[cmd]->RequiresTwoOperands()) {
+ if (conf->count("input_2") == 0) {
+ cerr << "Command '" << cmd << "' requires two alignment files\n";
+ exit(1);
+ }
+ if ((*conf)["input_1"].as<string>() == "-" && (*conf)["input_2"].as<string>() == "-") {
+ cerr << "Both inputs cannot be STDIN\n";
+ exit(1);
+ }
+ } else {
+ if (conf->count("input_2") != 0) {
+ cerr << "Command '" << cmd << "' requires only one alignment file\n";
+ exit(1);
+ }
+ }
+}
+
+template<class C> static void AddCommand() {
+ C* c = new C;
+ commands[c->Name()].reset(c);
+}
+
+int main(int argc, char **argv) {
+ AddCommand<ConvertCommand>();
+ AddCommand<InvertCommand>();
+ AddCommand<IntersectCommand>();
+ AddCommand<FMeasureCommand>();
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ Command& cmd = *commands[conf["command"].as<string>()];
+ boost::shared_ptr<ReadFile> rf1(new ReadFile(conf["input_1"].as<string>()));
+ boost::shared_ptr<ReadFile> rf2;
+ if (cmd.RequiresTwoOperands())
+ rf2.reset(new ReadFile(conf["input_2"].as<string>()));
+ istream* in1 = rf1->stream();
+ istream* in2 = NULL;
+ if (rf2) in2 = rf2->stream();
+ while(*in1) {
+ string line1;
+ string line2;
+ getline(*in1, line1);
+ if (in2) {
+ getline(*in2, line2);
+ if ((*in1 && !*in2) || (*in2 && !*in1)) {
+ cerr << "Mismatched number of lines!\n";
+ exit(1);
+ }
+ }
+ if (line1.empty() && !*in1) break;
+ shared_ptr<Array2D<bool> > out(new Array2D<bool>);
+ shared_ptr<Array2D<bool> > a1 = AlignerTools::ReadPharaohAlignmentGrid(line1);
+ if (in2) {
+ shared_ptr<Array2D<bool> > a2 = AlignerTools::ReadPharaohAlignmentGrid(line2);
+ cmd.Apply(*a1, *a2, out.get());
+ } else {
+ Array2D<bool> dummy;
+ cmd.Apply(*a1, dummy, out.get());
+ }
+
+ if (cmd.Result() == 1) {
+ AlignerTools::SerializePharaohFormat(*out, &cout);
+ }
+ }
+ if (cmd.Result() == 2)
+ cmd.Summary();
+ return 0;
+}
+
diff --git a/training/cluster-em.pl b/training/cluster-em.pl
new file mode 100755
index 00000000..175870da
--- /dev/null
+++ b/training/cluster-em.pl
@@ -0,0 +1,110 @@
+#!/usr/bin/perl -w
+
+use strict;
+my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR; }
+use Getopt::Long;
+my $parallel = 1;
+
+my $CWD=`pwd`; chomp $CWD;
+my $BIN_DIR = "/chomes/redpony/cdyer-svn-repo/cdec/src";
+my $OPTIMIZER = "$BIN_DIR/mr_em_train";
+my $DECODER = "$BIN_DIR/cdec";
+my $COMBINER_CACHE_SIZE = 150;
+my $PARALLEL = "/chomes/redpony/svn-trunk/sa-utils/parallelize.pl";
+die "Can't find $OPTIMIZER" unless -f $OPTIMIZER;
+die "Can't execute $OPTIMIZER" unless -x $OPTIMIZER;
+die "Can't find $DECODER" unless -f $DECODER;
+die "Can't execute $DECODER" unless -x $DECODER;
+die "Can't find $PARALLEL" unless -f $PARALLEL;
+die "Can't execute $PARALLEL" unless -x $PARALLEL;
+my $restart = '';
+if ($ARGV[0] && $ARGV[0] eq '--restart') { shift @ARGV; $restart = 1; }
+
+die "Usage: $0 [--restart] training.corpus weights.init grammar.file [grammar2.file] ...\n" unless (scalar @ARGV >= 3);
+
+my $training_corpus = shift @ARGV;
+my $initial_weights = shift @ARGV;
+my @in_grammar_files = @ARGV;
+my $pmem="2500mb";
+my $nodes = 40;
+my $max_iteration = 1000;
+my $CFLAG = "-C 1";
+unless ($parallel) { $CFLAG = "-C 500"; }
+my @grammar_files;
+for my $g (@in_grammar_files) {
+ unless ($g =~ /^\//) { $g = $CWD . '/' . $g; }
+ die "Can't find $g" unless -f $g;
+ push @grammar_files, $g;
+}
+
+print STDERR <<EOT;
+EM TRAIN CONFIGURATION INFORMATION
+
+ Grammar file(s): @grammar_files
+ Training corpus: $training_corpus
+ Initial weights: $initial_weights
+ Decoder memory: $pmem
+ Nodes requested: $nodes
+ Max iterations: $max_iteration
+ restart: $restart
+EOT
+
+my $nodelist="1";
+for (my $i=1; $i<$nodes; $i++) { $nodelist .= " 1"; }
+my $iter = 1;
+
+my $dir = "$CWD/emtrain";
+if ($restart) {
+ die "$dir doesn't exist, but --restart specified!\n" unless -d $dir;
+ my $o = `ls -t $dir/weights.*`;
+ my ($a, @x) = split /\n/, $o;
+ if ($a =~ /weights.(\d+)\.gz$/) {
+ $iter = $1;
+ } else {
+ die "Unexpected file: $a!\n";
+ }
+ print STDERR "Restarting at iteration $iter\n";
+} else {
+ die "$dir already exists!\n" if -e $dir;
+ mkdir $dir or die "Can't create $dir: $!";
+
+ unless ($initial_weights =~ /\.gz$/) {
+ `cp $initial_weights $dir/weights.1`;
+ `gzip -9 $dir/weights.1`;
+ } else {
+ `cp $initial_weights $dir/weights.1.gz`;
+ }
+}
+
+while ($iter < $max_iteration) {
+ my $cur_time = `date`; chomp $cur_time;
+ print STDERR "\nStarting iteration $iter...\n";
+ print STDERR " time: $cur_time\n";
+ my $start = time;
+ my $next_iter = $iter + 1;
+ my $gfile = '-g' . (join ' -g ', @grammar_files);
+ my $dec_cmd="$DECODER --feature_expectations -S 999 $CFLAG $gfile -n -w $dir/weights.$iter.gz < $training_corpus 2> $dir/deco.log.$iter";
+ my $opt_cmd = "$OPTIMIZER $gfile -o $dir/weights.$next_iter.gz";
+ my $pcmd = "$PARALLEL -e $dir/err -p $pmem --nodelist \"$nodelist\" -- ";
+ my $cmd = "";
+ if ($parallel) { $cmd = $pcmd; }
+ $cmd .= "$dec_cmd | $opt_cmd";
+
+ print STDERR "EXECUTING: $cmd\n";
+ my $result = `$cmd`;
+ if ($? != 0) {
+ die "Error running iteration $iter: $!";
+ }
+ chomp $result;
+ my $end = time;
+ my $diff = ($end - $start);
+ print STDERR " ITERATION $iter TOOK $diff SECONDS\n";
+ $iter = $next_iter;
+ if ($result =~ /1$/) {
+ print STDERR "Training converged.\n";
+ last;
+ }
+}
+
+print "FINAL WEIGHTS: $dir/weights.$iter\n";
+
diff --git a/training/cluster-ptrain.pl b/training/cluster-ptrain.pl
new file mode 100755
index 00000000..99369cdc
--- /dev/null
+++ b/training/cluster-ptrain.pl
@@ -0,0 +1,144 @@
+#!/usr/bin/perl -w
+
+use strict;
+my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR; }
+use Getopt::Long;
+
+my $MAX_ITER_ATTEMPTS = 5; # number of times to retry a failed function evaluation
+my $CWD=`pwd`; chomp $CWD;
+my $BIN_DIR = $SCRIPT_DIR;
+my $OPTIMIZER = "$BIN_DIR/mr_optimize_reduce";
+my $DECODER = "$BIN_DIR/cdec";
+my $COMBINER_CACHE_SIZE = 150;
+my $PARALLEL = "/chomes/redpony/svn-trunk/sa-utils/parallelize.pl";
+die "Can't find $OPTIMIZER" unless -f $OPTIMIZER;
+die "Can't execute $OPTIMIZER" unless -x $OPTIMIZER;
+my $restart = '';
+if ($ARGV[0] && $ARGV[0] eq '--restart') { shift @ARGV; $restart = 1; }
+
+my $pmem="2500mb";
+my $nodes = 36;
+my $max_iteration = 1000;
+my $PRIOR_FLAG = "";
+my $parallel = 1;
+my $CFLAG = "-C 1";
+my $LOCAL;
+my $PRIOR;
+my $OALG = "lbfgs";
+my $sigsq = 1;
+my $means_file;
+GetOptions("decoder=s" => \$DECODER,
+ "run_locally" => \$LOCAL,
+ "gaussian_prior" => \$PRIOR,
+ "sigma_squared=f" => \$sigsq,
+ "means=s" => \$means_file,
+ "optimizer=s" => \$OALG,
+ "pmem=s" => \$pmem
+ ) or usage();
+usage() unless scalar @ARGV==3;
+my $config_file = shift @ARGV;
+my $training_corpus = shift @ARGV;
+my $initial_weights = shift @ARGV;
+die "Can't find $config_file" unless -f $config_file;
+die "Can't find $DECODER" unless -f $DECODER;
+die "Can't execute $DECODER" unless -x $DECODER;
+if ($LOCAL) { print STDERR "Will running LOCALLY.\n"; $parallel = 0; }
+if ($PRIOR) {
+ $PRIOR_FLAG="-p --sigma_squared $sigsq";
+ if ($means_file) { $PRIOR_FLAG .= " -u $means_file"; }
+}
+
+if ($parallel) {
+ die "Can't find $PARALLEL" unless -f $PARALLEL;
+ die "Can't execute $PARALLEL" unless -x $PARALLEL;
+}
+unless ($parallel) { $CFLAG = "-C 500"; }
+unless ($config_file =~ /^\//) { $config_file = $CWD . '/' . $config_file; }
+
+print STDERR <<EOT;
+PTRAIN CONFIGURATION INFORMATION
+
+ Config file: $config_file
+ Training corpus: $training_corpus
+ Initial weights: $initial_weights
+ Decoder memory: $pmem
+ Nodes requested: $nodes
+ Max iterations: $max_iteration
+ Optimizer: $OALG
+ PRIOR: $PRIOR_FLAG
+ restart: $restart
+EOT
+if ($OALG) { $OALG="-m $OALG"; }
+
+my $nodelist="1";
+for (my $i=1; $i<$nodes; $i++) { $nodelist .= " 1"; }
+my $iter = 1;
+
+my $dir = "$CWD/ptrain";
+if ($restart) {
+ die "$dir doesn't exist, but --restart specified!\n" unless -d $dir;
+ my $o = `ls -t $dir/weights.*`;
+ my ($a, @x) = split /\n/, $o;
+ if ($a =~ /weights.(\d+)\.gz$/) {
+ $iter = $1;
+ } else {
+ die "Unexpected file: $a!\n";
+ }
+ print STDERR "Restarting at iteration $iter\n";
+} else {
+ die "$dir already exists!\n" if -e $dir;
+ mkdir $dir or die "Can't create $dir: $!";
+
+ unless ($initial_weights =~ /\.gz$/) {
+ `cp $initial_weights $dir/weights.1`;
+ `gzip -9 $dir/weights.1`;
+ } else {
+ `cp $initial_weights $dir/weights.1.gz`;
+ }
+}
+
+my $iter_attempts = 1;
+while ($iter < $max_iteration) {
+ my $cur_time = `date`; chomp $cur_time;
+ print STDERR "\nStarting iteration $iter...\n";
+ print STDERR " time: $cur_time\n";
+ my $start = time;
+ my $next_iter = $iter + 1;
+ my $dec_cmd="$DECODER -G $CFLAG -c $config_file -w $dir/weights.$iter.gz < $training_corpus 2> $dir/deco.log.$iter";
+ my $opt_cmd = "$OPTIMIZER $PRIOR_FLAG -M 50 $OALG -s $dir/opt.state -i $dir/weights.$iter.gz -o $dir/weights.$next_iter.gz";
+ my $pcmd = "$PARALLEL -e $dir/err -p $pmem --nodelist \"$nodelist\" -- ";
+ my $cmd = "";
+ if ($parallel) { $cmd = $pcmd; }
+ $cmd .= "$dec_cmd | $opt_cmd";
+
+ print STDERR "EXECUTING: $cmd\n";
+ my $result = `$cmd`;
+ my $exit_code = $? >> 8;
+ if ($exit_code == 99) {
+ $iter_attempts++;
+ if ($iter_attempts > $MAX_ITER_ATTEMPTS) {
+ die "Received restart request $iter_attempts times from optimizer, giving up\n";
+ }
+ print STDERR "Function evaluation failed, retrying (attempt $iter_attempts)\n";
+ next;
+ }
+ if ($? != 0) {
+ die "Error running iteration $iter: $!";
+ }
+ chomp $result;
+ my $end = time;
+ my $diff = ($end - $start);
+ print STDERR " ITERATION $iter TOOK $diff SECONDS\n";
+ $iter = $next_iter;
+ if ($result =~ /1$/) {
+ print STDERR "Training converged.\n";
+ last;
+ }
+ $iter_attempts = 1;
+}
+
+print "FINAL WEIGHTS: $dir/weights.$iter\n";
+
+sub usage {
+ die "Usage: $0 [OPTIONS] cdec.ini training.corpus weights.init\n";
+}
diff --git a/training/grammar_convert.cc b/training/grammar_convert.cc
new file mode 100644
index 00000000..22ba0f46
--- /dev/null
+++ b/training/grammar_convert.cc
@@ -0,0 +1,316 @@
+#include <iostream>
+#include <algorithm>
+#include <sstream>
+
+#include <boost/lexical_cast.hpp>
+#include <boost/program_options.hpp>
+
+#include "tdict.h"
+#include "filelib.h"
+#include "hg.h"
+#include "hg_io.h"
+#include "kbest.h"
+#include "viterbi.h"
+#include "weights.h"
+
+namespace po = boost::program_options;
+using namespace std;
+
+WordID kSTART;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("input,i", po::value<string>()->default_value("-"), "Input file")
+ ("format,f", po::value<string>()->default_value("cfg"), "Input format. Values: cfg, json, split")
+ ("output,o", po::value<string>()->default_value("json"), "Output command. Values: json, 1best")
+ ("reorder,r", "Add Yamada & Knight (2002) reorderings")
+ ("weights,w", po::value<string>(), "Feature weights for k-best derivations [optional]")
+ ("collapse_weights,C", "Collapse order features into a single feature whose value is all of the locally applying feature weights")
+ ("k_derivations,k", po::value<int>(), "Show k derivations and their features")
+ ("max_reorder,m", po::value<int>()->default_value(999), "Move a constituent at most this far")
+ ("help,h", "Print this help message and exit");
+ po::options_description clo("Command line options");
+ po::options_description dcmdline_options;
+ dcmdline_options.add(opts);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ po::notify(*conf);
+
+ if (conf->count("help") || conf->count("input") == 0) {
+ cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into JSON hypergraph.\n";
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int GetOrCreateNode(const WordID& lhs, map<WordID, int>* lhs2node, Hypergraph* hg) {
+ int& node_id = (*lhs2node)[lhs];
+ if (!node_id)
+ node_id = hg->AddNode(lhs)->id_ + 1;
+ return node_id - 1;
+}
+
+void FilterAndCheckCorrectness(int goal, Hypergraph* hg) {
+ if (goal < 0) {
+ cerr << "Error! [S] not found in grammar!\n";
+ exit(1);
+ }
+ if (hg->nodes_[goal].in_edges_.size() != 1) {
+ cerr << "Error! [S] has more than one rewrite!\n";
+ exit(1);
+ }
+ int old_size = hg->nodes_.size();
+ hg->TopologicallySortNodesAndEdges(goal);
+ if (hg->nodes_.size() != old_size) {
+ cerr << "Warning! During sorting " << (old_size - hg->nodes_.size()) << " disappeared!\n";
+ }
+}
+
+void CreateEdge(const TRulePtr& r, const Hypergraph::TailNodeVector& tail, Hypergraph::Node* head_node, Hypergraph* hg) {
+ Hypergraph::Edge* new_edge = hg->AddEdge(r, tail);
+ hg->ConnectEdgeToHeadNode(new_edge, head_node);
+ new_edge->feature_values_ = r->scores_;
+}
+
+// from a category label like "NP_2", return "NP"
+string PureCategory(WordID cat) {
+ assert(cat < 0);
+ string c = TD::Convert(cat*-1);
+ size_t p = c.find("_");
+ if (p == string::npos) return c;
+ return c.substr(0, p);
+};
+
+string ConstituentOrderFeature(const TRule& rule, const vector<int>& pi) {
+ const static string kTERM_VAR = "x";
+ const vector<WordID>& f = rule.f();
+ map<string, int> used;
+ vector<string> terms(f.size());
+ for (int i = 0; i < f.size(); ++i) {
+ const string term = (f[i] < 0 ? PureCategory(f[i]) : kTERM_VAR);
+ int& count = used[term];
+ if (!count) {
+ terms[i] = term;
+ } else {
+ ostringstream os;
+ os << term << count;
+ terms[i] = os.str();
+ }
+ ++count;
+ }
+ ostringstream os;
+ os << PureCategory(rule.GetLHS()) << ':';
+ for (int i = 0; i < f.size(); ++i) {
+ if (i > 0) os << '_';
+ os << terms[pi[i]];
+ }
+ return os.str();
+}
+
+bool CheckPermutationMask(const vector<int>& mask, const vector<int>& pi) {
+ assert(mask.size() == pi.size());
+
+ int req_min = -1;
+ int cur_max = 0;
+ int cur_mask = -1;
+ for (int i = 0; i < mask.size(); ++i) {
+ if (mask[i] != cur_mask) {
+ cur_mask = mask[i];
+ req_min = cur_max - 1;
+ }
+ if (pi[i] > req_min) {
+ if (pi[i] > cur_max) cur_max = pi[i];
+ } else {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+void PermuteYKRecursive(int nodeid, const WordID& parent, const int max_reorder, Hypergraph* hg) {
+ Hypergraph::Node* node = &hg->nodes_[nodeid];
+ if (node->in_edges_.size() != 1) {
+ cerr << "Multiple rewrites of [" << TD::Convert(node->cat_ * -1) << "] (parent is [" << TD::Convert(parent*-1) << "])\n";
+ cerr << " not recursing!\n";
+ return;
+ }
+ const int oe_index = node->in_edges_.front();
+ const TRule& rule = *hg->edges_[oe_index].rule_;
+ const Hypergraph::TailNodeVector orig_tail = hg->edges_[oe_index].tail_nodes_;
+ const int tail_size = orig_tail.size();
+ for (int i = 0; i < tail_size; ++i) {
+ PermuteYKRecursive(hg->edges_[oe_index].tail_nodes_[i], node->cat_, max_reorder, hg);
+ }
+ const vector<WordID>& of = rule.f_;
+ if (of.size() == 1) return;
+// cerr << "Permuting [" << TD::Convert(node->cat_ * -1) << "]\n";
+// cerr << "ORIG: " << rule.AsString() << endl;
+ vector<WordID> pi(of.size(), 0);
+ for (int i = 0; i < pi.size(); ++i) pi[i] = i;
+
+ vector<int> permutation_mask(of.size(), 0);
+ const bool dont_reorder_across_PU = true; // TODO add configuration
+ if (dont_reorder_across_PU) {
+ int cur = 0;
+ for (int i = 0; i < pi.size(); ++i) {
+ if (of[i] >= 0) continue;
+ const string cat = PureCategory(of[i]);
+ if (cat == "PU" || cat == "PU!H" || cat == "PUNC" || cat == "PUNC!H" || cat == "CC") {
+ ++cur;
+ permutation_mask[i] = cur;
+ ++cur;
+ } else {
+ permutation_mask[i] = cur;
+ }
+ }
+ }
+ int fid = FD::Convert(ConstituentOrderFeature(rule, pi));
+ hg->edges_[oe_index].feature_values_.set_value(fid, 1.0);
+ while (next_permutation(pi.begin(), pi.end())) {
+ if (!CheckPermutationMask(permutation_mask, pi))
+ continue;
+ vector<WordID> nf(pi.size(), 0);
+ Hypergraph::TailNodeVector tail(pi.size(), 0);
+ bool skip = false;
+ for (int i = 0; i < pi.size(); ++i) {
+ int dist = pi[i] - i; if (dist < 0) dist *= -1;
+ if (dist > max_reorder) { skip = true; break; }
+ nf[i] = of[pi[i]];
+ tail[i] = orig_tail[pi[i]];
+ }
+ if (skip) continue;
+ TRulePtr nr(new TRule(rule));
+ nr->f_ = nf;
+ int fid = FD::Convert(ConstituentOrderFeature(rule, pi));
+ nr->scores_.set_value(fid, 1.0);
+// cerr << "PERM: " << nr->AsString() << endl;
+ CreateEdge(nr, tail, node, hg);
+ }
+}
+
+void PermuteYamadaAndKnight(Hypergraph* hg, int max_reorder) {
+ assert(hg->nodes_.back().cat_ == kSTART);
+ assert(hg->nodes_.back().in_edges_.size() == 1);
+ PermuteYKRecursive(hg->nodes_.size() - 1, kSTART, max_reorder, hg);
+}
+
+void CollapseWeights(Hypergraph* hg) {
+ int fid = FD::Convert("Reordering");
+ for (int i = 0; i < hg->edges_.size(); ++i) {
+ Hypergraph::Edge& edge = hg->edges_[i];
+ edge.feature_values_.clear();
+ if (edge.edge_prob_ != prob_t::Zero()) {
+ edge.feature_values_.set_value(fid, log(edge.edge_prob_));
+ }
+ }
+}
+
+void ProcessHypergraph(const vector<double>& w, const po::variables_map& conf, const string& ref, Hypergraph* hg) {
+ if (conf.count("reorder"))
+ PermuteYamadaAndKnight(hg, conf["max_reorder"].as<int>());
+ if (w.size() > 0) { hg->Reweight(w); }
+ if (conf.count("collapse_weights")) CollapseWeights(hg);
+ if (conf["output"].as<string>() == "json") {
+ HypergraphIO::WriteToJSON(*hg, false, &cout);
+ if (!ref.empty()) { cerr << "REF: " << ref << endl; }
+ } else {
+ vector<WordID> onebest;
+ ViterbiESentence(*hg, &onebest);
+ if (ref.empty()) {
+ cout << TD::GetString(onebest) << endl;
+ } else {
+ cout << TD::GetString(onebest) << " ||| " << ref << endl;
+ }
+ }
+ if (conf.count("k_derivations")) {
+ const int k = conf["k_derivations"].as<int>();
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(*hg, k);
+ for (int i = 0; i < k; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg->nodes_.size() - 1, i);
+ if (!d) break;
+ cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl;
+ }
+ }
+}
+
+int main(int argc, char **argv) {
+ kSTART = TD::Convert("S") * -1;
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ string infile = conf["input"].as<string>();
+ const bool is_split_input = (conf["format"].as<string>() == "split");
+ const bool is_json_input = is_split_input || (conf["format"].as<string>() == "json");
+ const bool collapse_weights = conf.count("collapse_weights");
+ Weights wts;
+ vector<double> w;
+ if (conf.count("weights")) {
+ wts.InitFromFile(conf["weights"].as<string>());
+ wts.InitVector(&w);
+ }
+ if (collapse_weights && !w.size()) {
+ cerr << "--collapse_weights requires a weights file to be specified!\n";
+ exit(1);
+ }
+ ReadFile rf(infile);
+ istream* in = rf.stream();
+ assert(*in);
+ int lc = 0;
+ Hypergraph hg;
+ map<WordID, int> lhs2node;
+ while(*in) {
+ string line;
+ ++lc;
+ getline(*in, line);
+ if (is_json_input) {
+ if (line.empty() || line[0] == '#') continue;
+ string ref;
+ if (is_split_input) {
+ size_t pos = line.rfind("}}");
+ assert(pos != string::npos);
+ size_t rstart = line.find("||| ", pos);
+ assert(rstart != string::npos);
+ ref = line.substr(rstart + 4);
+ line = line.substr(0, pos + 2);
+ }
+ istringstream is(line);
+ if (HypergraphIO::ReadFromJSON(&is, &hg)) {
+ ProcessHypergraph(w, conf, ref, &hg);
+ hg.clear();
+ } else {
+ cerr << "Error reading grammar from JSON: line " << lc << endl;
+ exit(1);
+ }
+ } else {
+ if (line.empty()) {
+ int goal = lhs2node[kSTART] - 1;
+ FilterAndCheckCorrectness(goal, &hg);
+ ProcessHypergraph(w, conf, "", &hg);
+ hg.clear();
+ lhs2node.clear();
+ continue;
+ }
+ if (line[0] == '#') continue;
+ if (line[0] != '[') {
+ cerr << "Line " << lc << ": bad format\n";
+ exit(1);
+ }
+ TRulePtr tr(TRule::CreateRuleMonolingual(line));
+ Hypergraph::TailNodeVector tail;
+ for (int i = 0; i < tr->f_.size(); ++i) {
+ WordID var_cat = tr->f_[i];
+ if (var_cat < 0)
+ tail.push_back(GetOrCreateNode(var_cat, &lhs2node, &hg));
+ }
+ const WordID lhs = tr->GetLHS();
+ int head = GetOrCreateNode(lhs, &lhs2node, &hg);
+ Hypergraph::Edge* edge = hg.AddEdge(tr, tail);
+ edge->feature_values_ = tr->scores_;
+ Hypergraph::Node* node = &hg.nodes_[head];
+ hg.ConnectEdgeToHeadNode(edge, node);
+ }
+ }
+}
+
diff --git a/training/lbfgs.h b/training/lbfgs.h
new file mode 100644
index 00000000..e8baecab
--- /dev/null
+++ b/training/lbfgs.h
@@ -0,0 +1,1459 @@
+#ifndef SCITBX_LBFGS_H
+#define SCITBX_LBFGS_H
+
+#include <cstdio>
+#include <cstddef>
+#include <cmath>
+#include <stdexcept>
+#include <algorithm>
+#include <vector>
+#include <string>
+#include <iostream>
+#include <sstream>
+
+namespace scitbx {
+
+//! Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) %minimizer.
+/*! Implementation of the
+ Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS)
+ algorithm for large-scale multidimensional minimization
+ problems.
+
+ This code was manually derived from Java code which was
+ in turn derived from the Fortran program
+ <code>lbfgs.f</code>. The Java translation was
+ effected mostly mechanically, with some manual
+ clean-up; in particular, array indices start at 0
+ instead of 1. Most of the comments from the Fortran
+ code have been pasted in.
+
+ Information on the original LBFGS Fortran source code is
+ available at
+ http://www.netlib.org/opt/lbfgs_um.shar . The following
+ information is taken verbatim from the Netlib documentation
+ for the Fortran source.
+
+ <pre>
+ file opt/lbfgs_um.shar
+ for unconstrained optimization problems
+ alg limited memory BFGS method
+ by J. Nocedal
+ contact nocedal@eecs.nwu.edu
+ ref D. C. Liu and J. Nocedal, ``On the limited memory BFGS method for
+ , large scale optimization methods'' Mathematical Programming 45
+ , (1989), pp. 503-528.
+ , (Postscript file of this paper is available via anonymous ftp
+ , to eecs.nwu.edu in the directory pub/%lbfgs/lbfgs_um.)
+ </pre>
+
+ @author Jorge Nocedal: original Fortran version, including comments
+ (July 1990).<br>
+ Robert Dodier: Java translation, August 1997.<br>
+ Ralf W. Grosse-Kunstleve: C++ port, March 2002.<br>
+ Chris Dyer: serialize/deserialize functionality
+ */
+namespace lbfgs {
+
+ //! Generic exception class for %lbfgs %error messages.
+ /*! All exceptions thrown by the minimizer are derived from this class.
+ */
+ class error : public std::exception {
+ public:
+ //! Constructor.
+ error(std::string const& msg) throw()
+ : msg_("lbfgs error: " + msg)
+ {}
+ //! Access to error message.
+ virtual const char* what() const throw() { return msg_.c_str(); }
+ protected:
+ virtual ~error() throw() {}
+ std::string msg_;
+ public:
+ static std::string itoa(unsigned long i) {
+ std::ostringstream os;
+ os << i;
+ return os.str();
+ }
+ };
+
+ //! Specific exception class.
+ class error_internal_error : public error {
+ public:
+ //! Constructor.
+ error_internal_error(const char* file, unsigned long line) throw()
+ : error(
+ "Internal Error: " + std::string(file) + "(" + itoa(line) + ")")
+ {}
+ };
+
+ //! Specific exception class.
+ class error_improper_input_parameter : public error {
+ public:
+ //! Constructor.
+ error_improper_input_parameter(std::string const& msg) throw()
+ : error("Improper input parameter: " + msg)
+ {}
+ };
+
+ //! Specific exception class.
+ class error_improper_input_data : public error {
+ public:
+ //! Constructor.
+ error_improper_input_data(std::string const& msg) throw()
+ : error("Improper input data: " + msg)
+ {}
+ };
+
+ //! Specific exception class.
+ class error_search_direction_not_descent : public error {
+ public:
+ //! Constructor.
+ error_search_direction_not_descent() throw()
+ : error("The search direction is not a descent direction.")
+ {}
+ };
+
+ //! Specific exception class.
+ class error_line_search_failed : public error {
+ public:
+ //! Constructor.
+ error_line_search_failed(std::string const& msg) throw()
+ : error("Line search failed: " + msg)
+ {}
+ };
+
+ //! Specific exception class.
+ class error_line_search_failed_rounding_errors
+ : public error_line_search_failed {
+ public:
+ //! Constructor.
+ error_line_search_failed_rounding_errors(std::string const& msg) throw()
+ : error_line_search_failed(msg)
+ {}
+ };
+
+ namespace detail {
+
+ template <typename NumType>
+ inline
+ NumType
+ pow2(NumType const& x) { return x * x; }
+
+ template <typename NumType>
+ inline
+ NumType
+ abs(NumType const& x) {
+ if (x < NumType(0)) return -x;
+ return x;
+ }
+
+ // This class implements an algorithm for multi-dimensional line search.
+ template <typename FloatType, typename SizeType = std::size_t>
+ class mcsrch
+ {
+ protected:
+ int infoc;
+ FloatType dginit;
+ bool brackt;
+ bool stage1;
+ FloatType finit;
+ FloatType dgtest;
+ FloatType width;
+ FloatType width1;
+ FloatType stx;
+ FloatType fx;
+ FloatType dgx;
+ FloatType sty;
+ FloatType fy;
+ FloatType dgy;
+ FloatType stmin;
+ FloatType stmax;
+
+ static FloatType const& max3(
+ FloatType const& x,
+ FloatType const& y,
+ FloatType const& z)
+ {
+ return x < y ? (y < z ? z : y ) : (x < z ? z : x );
+ }
+
+ public:
+ /* Minimize a function along a search direction. This code is
+ a Java translation of the function <code>MCSRCH</code> from
+ <code>lbfgs.f</code>, which in turn is a slight modification
+ of the subroutine <code>CSRCH</code> of More' and Thuente.
+ The changes are to allow reverse communication, and do not
+ affect the performance of the routine. This function, in turn,
+ calls <code>mcstep</code>.<p>
+
+ The Java translation was effected mostly mechanically, with
+ some manual clean-up; in particular, array indices start at 0
+ instead of 1. Most of the comments from the Fortran code have
+ been pasted in here as well.<p>
+
+ The purpose of <code>mcsrch</code> is to find a step which
+ satisfies a sufficient decrease condition and a curvature
+ condition.<p>
+
+ At each stage this function updates an interval of uncertainty
+ with endpoints <code>stx</code> and <code>sty</code>. The
+ interval of uncertainty is initially chosen so that it
+ contains a minimizer of the modified function
+ <pre>
+ f(x+stp*s) - f(x) - ftol*stp*(gradf(x)'s).
+ </pre>
+ If a step is obtained for which the modified function has a
+ nonpositive function value and nonnegative derivative, then
+ the interval of uncertainty is chosen so that it contains a
+ minimizer of <code>f(x+stp*s)</code>.<p>
+
+ The algorithm is designed to find a step which satisfies
+ the sufficient decrease condition
+ <pre>
+ f(x+stp*s) &lt;= f(X) + ftol*stp*(gradf(x)'s),
+ </pre>
+ and the curvature condition
+ <pre>
+ abs(gradf(x+stp*s)'s)) &lt;= gtol*abs(gradf(x)'s).
+ </pre>
+ If <code>ftol</code> is less than <code>gtol</code> and if,
+ for example, the function is bounded below, then there is
+ always a step which satisfies both conditions. If no step can
+ be found which satisfies both conditions, then the algorithm
+ usually stops when rounding errors prevent further progress.
+ In this case <code>stp</code> only satisfies the sufficient
+ decrease condition.<p>
+
+ @author Original Fortran version by Jorge J. More' and
+ David J. Thuente as part of the Minpack project, June 1983,
+ Argonne National Laboratory. Java translation by Robert
+ Dodier, August 1997.
+
+ @param n The number of variables.
+
+ @param x On entry this contains the base point for the line
+ search. On exit it contains <code>x + stp*s</code>.
+
+ @param f On entry this contains the value of the objective
+ function at <code>x</code>. On exit it contains the value
+ of the objective function at <code>x + stp*s</code>.
+
+ @param g On entry this contains the gradient of the objective
+ function at <code>x</code>. On exit it contains the gradient
+ at <code>x + stp*s</code>.
+
+ @param s The search direction.
+
+ @param stp On entry this contains an initial estimate of a
+ satifactory step length. On exit <code>stp</code> contains
+ the final estimate.
+
+ @param ftol Tolerance for the sufficient decrease condition.
+
+ @param xtol Termination occurs when the relative width of the
+ interval of uncertainty is at most <code>xtol</code>.
+
+ @param maxfev Termination occurs when the number of evaluations
+ of the objective function is at least <code>maxfev</code> by
+ the end of an iteration.
+
+ @param info This is an output variable, which can have these
+ values:
+ <ul>
+ <li><code>info = -1</code> A return is made to compute
+ the function and gradient.
+ <li><code>info = 1</code> The sufficient decrease condition
+ and the directional derivative condition hold.
+ </ul>
+
+ @param nfev On exit, this is set to the number of function
+ evaluations.
+
+ @param wa Temporary storage array, of length <code>n</code>.
+ */
+ void run(
+ FloatType const& gtol,
+ FloatType const& stpmin,
+ FloatType const& stpmax,
+ SizeType n,
+ FloatType* x,
+ FloatType f,
+ const FloatType* g,
+ FloatType* s,
+ SizeType is0,
+ FloatType& stp,
+ FloatType ftol,
+ FloatType xtol,
+ SizeType maxfev,
+ int& info,
+ SizeType& nfev,
+ FloatType* wa);
+
+ /* The purpose of this function is to compute a safeguarded step
+ for a linesearch and to update an interval of uncertainty for
+ a minimizer of the function.<p>
+
+ The parameter <code>stx</code> contains the step with the
+ least function value. The parameter <code>stp</code> contains
+ the current step. It is assumed that the derivative at
+ <code>stx</code> is negative in the direction of the step. If
+ <code>brackt</code> is <code>true</code> when
+ <code>mcstep</code> returns then a minimizer has been
+ bracketed in an interval of uncertainty with endpoints
+ <code>stx</code> and <code>sty</code>.<p>
+
+ Variables that must be modified by <code>mcstep</code> are
+ implemented as 1-element arrays.
+
+ @param stx Step at the best step obtained so far.
+ This variable is modified by <code>mcstep</code>.
+ @param fx Function value at the best step obtained so far.
+ This variable is modified by <code>mcstep</code>.
+ @param dx Derivative at the best step obtained so far.
+ The derivative must be negative in the direction of the
+ step, that is, <code>dx</code> and <code>stp-stx</code> must
+ have opposite signs. This variable is modified by
+ <code>mcstep</code>.
+
+ @param sty Step at the other endpoint of the interval of
+ uncertainty. This variable is modified by <code>mcstep</code>.
+ @param fy Function value at the other endpoint of the interval
+ of uncertainty. This variable is modified by
+ <code>mcstep</code>.
+
+ @param dy Derivative at the other endpoint of the interval of
+ uncertainty. This variable is modified by <code>mcstep</code>.
+
+ @param stp Step at the current step. If <code>brackt</code> is set
+ then on input <code>stp</code> must be between <code>stx</code>
+ and <code>sty</code>. On output <code>stp</code> is set to the
+ new step.
+ @param fp Function value at the current step.
+ @param dp Derivative at the current step.
+
+ @param brackt Tells whether a minimizer has been bracketed.
+ If the minimizer has not been bracketed, then on input this
+ variable must be set <code>false</code>. If the minimizer has
+ been bracketed, then on output this variable is
+ <code>true</code>.
+
+ @param stpmin Lower bound for the step.
+ @param stpmax Upper bound for the step.
+
+ If the return value is 1, 2, 3, or 4, then the step has
+ been computed successfully. A return value of 0 indicates
+ improper input parameters.
+
+ @author Jorge J. More, David J. Thuente: original Fortran version,
+ as part of Minpack project. Argonne Nat'l Laboratory, June 1983.
+ Robert Dodier: Java translation, August 1997.
+ */
+ static int mcstep(
+ FloatType& stx,
+ FloatType& fx,
+ FloatType& dx,
+ FloatType& sty,
+ FloatType& fy,
+ FloatType& dy,
+ FloatType& stp,
+ FloatType fp,
+ FloatType dp,
+ bool& brackt,
+ FloatType stpmin,
+ FloatType stpmax);
+
+ void serialize(std::ostream* out) const {
+ out->write((const char*)&infoc,sizeof(infoc));
+ out->write((const char*)&dginit,sizeof(dginit));
+ out->write((const char*)&brackt,sizeof(brackt));
+ out->write((const char*)&stage1,sizeof(stage1));
+ out->write((const char*)&finit,sizeof(finit));
+ out->write((const char*)&dgtest,sizeof(dgtest));
+ out->write((const char*)&width,sizeof(width));
+ out->write((const char*)&width1,sizeof(width1));
+ out->write((const char*)&stx,sizeof(stx));
+ out->write((const char*)&fx,sizeof(fx));
+ out->write((const char*)&dgx,sizeof(dgx));
+ out->write((const char*)&sty,sizeof(sty));
+ out->write((const char*)&fy,sizeof(fy));
+ out->write((const char*)&dgy,sizeof(dgy));
+ out->write((const char*)&stmin,sizeof(stmin));
+ out->write((const char*)&stmax,sizeof(stmax));
+ }
+
+ void deserialize(std::istream* in) const {
+ in->read((char*)&infoc, sizeof(infoc));
+ in->read((char*)&dginit, sizeof(dginit));
+ in->read((char*)&brackt, sizeof(brackt));
+ in->read((char*)&stage1, sizeof(stage1));
+ in->read((char*)&finit, sizeof(finit));
+ in->read((char*)&dgtest, sizeof(dgtest));
+ in->read((char*)&width, sizeof(width));
+ in->read((char*)&width1, sizeof(width1));
+ in->read((char*)&stx, sizeof(stx));
+ in->read((char*)&fx, sizeof(fx));
+ in->read((char*)&dgx, sizeof(dgx));
+ in->read((char*)&sty, sizeof(sty));
+ in->read((char*)&fy, sizeof(fy));
+ in->read((char*)&dgy, sizeof(dgy));
+ in->read((char*)&stmin, sizeof(stmin));
+ in->read((char*)&stmax, sizeof(stmax));
+ }
+ };
+
+ template <typename FloatType, typename SizeType>
+ void mcsrch<FloatType, SizeType>::run(
+ FloatType const& gtol,
+ FloatType const& stpmin,
+ FloatType const& stpmax,
+ SizeType n,
+ FloatType* x,
+ FloatType f,
+ const FloatType* g,
+ FloatType* s,
+ SizeType is0,
+ FloatType& stp,
+ FloatType ftol,
+ FloatType xtol,
+ SizeType maxfev,
+ int& info,
+ SizeType& nfev,
+ FloatType* wa)
+ {
+ if (info != -1) {
+ infoc = 1;
+ if ( n == 0
+ || maxfev == 0
+ || gtol < FloatType(0)
+ || xtol < FloatType(0)
+ || stpmin < FloatType(0)
+ || stpmax < stpmin) {
+ throw error_internal_error(__FILE__, __LINE__);
+ }
+ if (stp <= FloatType(0) || ftol < FloatType(0)) {
+ throw error_internal_error(__FILE__, __LINE__);
+ }
+ // Compute the initial gradient in the search direction
+ // and check that s is a descent direction.
+ dginit = FloatType(0);
+ for (SizeType j = 0; j < n; j++) {
+ dginit += g[j] * s[is0+j];
+ }
+ if (dginit >= FloatType(0)) {
+ throw error_search_direction_not_descent();
+ }
+ brackt = false;
+ stage1 = true;
+ nfev = 0;
+ finit = f;
+ dgtest = ftol*dginit;
+ width = stpmax - stpmin;
+ width1 = FloatType(2) * width;
+ std::copy(x, x+n, wa);
+ // The variables stx, fx, dgx contain the values of the step,
+ // function, and directional derivative at the best step.
+ // The variables sty, fy, dgy contain the value of the step,
+ // function, and derivative at the other endpoint of
+ // the interval of uncertainty.
+ // The variables stp, f, dg contain the values of the step,
+ // function, and derivative at the current step.
+ stx = FloatType(0);
+ fx = finit;
+ dgx = dginit;
+ sty = FloatType(0);
+ fy = finit;
+ dgy = dginit;
+ }
+ for (;;) {
+ if (info != -1) {
+ // Set the minimum and maximum steps to correspond
+ // to the present interval of uncertainty.
+ if (brackt) {
+ stmin = std::min(stx, sty);
+ stmax = std::max(stx, sty);
+ }
+ else {
+ stmin = stx;
+ stmax = stp + FloatType(4) * (stp - stx);
+ }
+ // Force the step to be within the bounds stpmax and stpmin.
+ stp = std::max(stp, stpmin);
+ stp = std::min(stp, stpmax);
+ // If an unusual termination is to occur then let
+ // stp be the lowest point obtained so far.
+ if ( (brackt && (stp <= stmin || stp >= stmax))
+ || nfev >= maxfev - 1 || infoc == 0
+ || (brackt && stmax - stmin <= xtol * stmax)) {
+ stp = stx;
+ }
+ // Evaluate the function and gradient at stp
+ // and compute the directional derivative.
+ // We return to main program to obtain F and G.
+ for (SizeType j = 0; j < n; j++) {
+ x[j] = wa[j] + stp * s[is0+j];
+ }
+ info=-1;
+ break;
+ }
+ info = 0;
+ nfev++;
+ FloatType dg(0);
+ for (SizeType j = 0; j < n; j++) {
+ dg += g[j] * s[is0+j];
+ }
+ FloatType ftest1 = finit + stp*dgtest;
+ // Test for convergence.
+ if ((brackt && (stp <= stmin || stp >= stmax)) || infoc == 0) {
+ throw error_line_search_failed_rounding_errors(
+ "Rounding errors prevent further progress."
+ " There may not be a step which satisfies the"
+ " sufficient decrease and curvature conditions."
+ " Tolerances may be too small.");
+ }
+ if (stp == stpmax && f <= ftest1 && dg <= dgtest) {
+ throw error_line_search_failed(
+ "The step is at the upper bound stpmax().");
+ }
+ if (stp == stpmin && (f > ftest1 || dg >= dgtest)) {
+ throw error_line_search_failed(
+ "The step is at the lower bound stpmin().");
+ }
+ if (nfev >= maxfev) {
+ throw error_line_search_failed(
+ "Number of function evaluations has reached maxfev().");
+ }
+ if (brackt && stmax - stmin <= xtol * stmax) {
+ throw error_line_search_failed(
+ "Relative width of the interval of uncertainty"
+ " is at most xtol().");
+ }
+ // Check for termination.
+ if (f <= ftest1 && abs(dg) <= gtol * (-dginit)) {
+ info = 1;
+ break;
+ }
+ // In the first stage we seek a step for which the modified
+ // function has a nonpositive value and nonnegative derivative.
+ if ( stage1 && f <= ftest1
+ && dg >= std::min(ftol, gtol) * dginit) {
+ stage1 = false;
+ }
+ // A modified function is used to predict the step only if
+ // we have not obtained a step for which the modified
+ // function has a nonpositive function value and nonnegative
+ // derivative, and if a lower function value has been
+ // obtained but the decrease is not sufficient.
+ if (stage1 && f <= fx && f > ftest1) {
+ // Define the modified function and derivative values.
+ FloatType fm = f - stp*dgtest;
+ FloatType fxm = fx - stx*dgtest;
+ FloatType fym = fy - sty*dgtest;
+ FloatType dgm = dg - dgtest;
+ FloatType dgxm = dgx - dgtest;
+ FloatType dgym = dgy - dgtest;
+ // Call cstep to update the interval of uncertainty
+ // and to compute the new step.
+ infoc = mcstep(stx, fxm, dgxm, sty, fym, dgym, stp, fm, dgm,
+ brackt, stmin, stmax);
+ // Reset the function and gradient values for f.
+ fx = fxm + stx*dgtest;
+ fy = fym + sty*dgtest;
+ dgx = dgxm + dgtest;
+ dgy = dgym + dgtest;
+ }
+ else {
+ // Call mcstep to update the interval of uncertainty
+ // and to compute the new step.
+ infoc = mcstep(stx, fx, dgx, sty, fy, dgy, stp, f, dg,
+ brackt, stmin, stmax);
+ }
+ // Force a sufficient decrease in the size of the
+ // interval of uncertainty.
+ if (brackt) {
+ if (abs(sty - stx) >= FloatType(0.66) * width1) {
+ stp = stx + FloatType(0.5) * (sty - stx);
+ }
+ width1 = width;
+ width = abs(sty - stx);
+ }
+ }
+ }
+
+ template <typename FloatType, typename SizeType>
+ int mcsrch<FloatType, SizeType>::mcstep(
+ FloatType& stx,
+ FloatType& fx,
+ FloatType& dx,
+ FloatType& sty,
+ FloatType& fy,
+ FloatType& dy,
+ FloatType& stp,
+ FloatType fp,
+ FloatType dp,
+ bool& brackt,
+ FloatType stpmin,
+ FloatType stpmax)
+ {
+ bool bound;
+ FloatType gamma, p, q, r, s, sgnd, stpc, stpf, stpq, theta;
+ int info = 0;
+ if ( ( brackt && (stp <= std::min(stx, sty)
+ || stp >= std::max(stx, sty)))
+ || dx * (stp - stx) >= FloatType(0) || stpmax < stpmin) {
+ return 0;
+ }
+ // Determine if the derivatives have opposite sign.
+ sgnd = dp * (dx / abs(dx));
+ if (fp > fx) {
+ // First case. A higher function value.
+ // The minimum is bracketed. If the cubic step is closer
+ // to stx than the quadratic step, the cubic step is taken,
+ // else the average of the cubic and quadratic steps is taken.
+ info = 1;
+ bound = true;
+ theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;
+ s = max3(abs(theta), abs(dx), abs(dp));
+ gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s));
+ if (stp < stx) gamma = - gamma;
+ p = (gamma - dx) + theta;
+ q = ((gamma - dx) + gamma) + dp;
+ r = p/q;
+ stpc = stx + r * (stp - stx);
+ stpq = stx
+ + ((dx / ((fx - fp) / (stp - stx) + dx)) / FloatType(2))
+ * (stp - stx);
+ if (abs(stpc - stx) < abs(stpq - stx)) {
+ stpf = stpc;
+ }
+ else {
+ stpf = stpc + (stpq - stpc) / FloatType(2);
+ }
+ brackt = true;
+ }
+ else if (sgnd < FloatType(0)) {
+ // Second case. A lower function value and derivatives of
+ // opposite sign. The minimum is bracketed. If the cubic
+ // step is closer to stx than the quadratic (secant) step,
+ // the cubic step is taken, else the quadratic step is taken.
+ info = 2;
+ bound = false;
+ theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;
+ s = max3(abs(theta), abs(dx), abs(dp));
+ gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s));
+ if (stp > stx) gamma = - gamma;
+ p = (gamma - dp) + theta;
+ q = ((gamma - dp) + gamma) + dx;
+ r = p/q;
+ stpc = stp + r * (stx - stp);
+ stpq = stp + (dp / (dp - dx)) * (stx - stp);
+ if (abs(stpc - stp) > abs(stpq - stp)) {
+ stpf = stpc;
+ }
+ else {
+ stpf = stpq;
+ }
+ brackt = true;
+ }
+ else if (abs(dp) < abs(dx)) {
+ // Third case. A lower function value, derivatives of the
+ // same sign, and the magnitude of the derivative decreases.
+ // The cubic step is only used if the cubic tends to infinity
+ // in the direction of the step or if the minimum of the cubic
+ // is beyond stp. Otherwise the cubic step is defined to be
+ // either stpmin or stpmax. The quadratic (secant) step is also
+ // computed and if the minimum is bracketed then the the step
+ // closest to stx is taken, else the step farthest away is taken.
+ info = 3;
+ bound = true;
+ theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;
+ s = max3(abs(theta), abs(dx), abs(dp));
+ gamma = s * std::sqrt(
+ std::max(FloatType(0), pow2(theta / s) - (dx / s) * (dp / s)));
+ if (stp > stx) gamma = -gamma;
+ p = (gamma - dp) + theta;
+ q = (gamma + (dx - dp)) + gamma;
+ r = p/q;
+ if (r < FloatType(0) && gamma != FloatType(0)) {
+ stpc = stp + r * (stx - stp);
+ }
+ else if (stp > stx) {
+ stpc = stpmax;
+ }
+ else {
+ stpc = stpmin;
+ }
+ stpq = stp + (dp / (dp - dx)) * (stx - stp);
+ if (brackt) {
+ if (abs(stp - stpc) < abs(stp - stpq)) {
+ stpf = stpc;
+ }
+ else {
+ stpf = stpq;
+ }
+ }
+ else {
+ if (abs(stp - stpc) > abs(stp - stpq)) {
+ stpf = stpc;
+ }
+ else {
+ stpf = stpq;
+ }
+ }
+ }
+ else {
+ // Fourth case. A lower function value, derivatives of the
+ // same sign, and the magnitude of the derivative does
+ // not decrease. If the minimum is not bracketed, the step
+ // is either stpmin or stpmax, else the cubic step is taken.
+ info = 4;
+ bound = false;
+ if (brackt) {
+ theta = FloatType(3) * (fp - fy) / (sty - stp) + dy + dp;
+ s = max3(abs(theta), abs(dy), abs(dp));
+ gamma = s * std::sqrt(pow2(theta / s) - (dy / s) * (dp / s));
+ if (stp > sty) gamma = -gamma;
+ p = (gamma - dp) + theta;
+ q = ((gamma - dp) + gamma) + dy;
+ r = p/q;
+ stpc = stp + r * (sty - stp);
+ stpf = stpc;
+ }
+ else if (stp > stx) {
+ stpf = stpmax;
+ }
+ else {
+ stpf = stpmin;
+ }
+ }
+ // Update the interval of uncertainty. This update does not
+ // depend on the new step or the case analysis above.
+ if (fp > fx) {
+ sty = stp;
+ fy = fp;
+ dy = dp;
+ }
+ else {
+ if (sgnd < FloatType(0)) {
+ sty = stx;
+ fy = fx;
+ dy = dx;
+ }
+ stx = stp;
+ fx = fp;
+ dx = dp;
+ }
+ // Compute the new step and safeguard it.
+ stpf = std::min(stpmax, stpf);
+ stpf = std::max(stpmin, stpf);
+ stp = stpf;
+ if (brackt && bound) {
+ if (sty > stx) {
+ stp = std::min(stx + FloatType(0.66) * (sty - stx), stp);
+ }
+ else {
+ stp = std::max(stx + FloatType(0.66) * (sty - stx), stp);
+ }
+ }
+ return info;
+ }
+
+ /* Compute the sum of a vector times a scalar plus another vector.
+ Adapted from the subroutine <code>daxpy</code> in
+ <code>lbfgs.f</code>.
+ */
+ template <typename FloatType, typename SizeType>
+ void daxpy(
+ SizeType n,
+ FloatType da,
+ const FloatType* dx,
+ SizeType ix0,
+ SizeType incx,
+ FloatType* dy,
+ SizeType iy0,
+ SizeType incy)
+ {
+ SizeType i, ix, iy, m;
+ if (n == 0) return;
+ if (da == FloatType(0)) return;
+ if (!(incx == 1 && incy == 1)) {
+ ix = 0;
+ iy = 0;
+ for (i = 0; i < n; i++) {
+ dy[iy0+iy] += da * dx[ix0+ix];
+ ix += incx;
+ iy += incy;
+ }
+ return;
+ }
+ m = n % 4;
+ for (i = 0; i < m; i++) {
+ dy[iy0+i] += da * dx[ix0+i];
+ }
+ for (; i < n;) {
+ dy[iy0+i] += da * dx[ix0+i]; i++;
+ dy[iy0+i] += da * dx[ix0+i]; i++;
+ dy[iy0+i] += da * dx[ix0+i]; i++;
+ dy[iy0+i] += da * dx[ix0+i]; i++;
+ }
+ }
+
+ template <typename FloatType, typename SizeType>
+ inline
+ void daxpy(
+ SizeType n,
+ FloatType da,
+ const FloatType* dx,
+ SizeType ix0,
+ FloatType* dy)
+ {
+ daxpy(n, da, dx, ix0, SizeType(1), dy, SizeType(0), SizeType(1));
+ }
+
+ /* Compute the dot product of two vectors.
+ Adapted from the subroutine <code>ddot</code>
+ in <code>lbfgs.f</code>.
+ */
+ template <typename FloatType, typename SizeType>
+ FloatType ddot(
+ SizeType n,
+ const FloatType* dx,
+ SizeType ix0,
+ SizeType incx,
+ const FloatType* dy,
+ SizeType iy0,
+ SizeType incy)
+ {
+ SizeType i, ix, iy, m;
+ FloatType dtemp(0);
+ if (n == 0) return FloatType(0);
+ if (!(incx == 1 && incy == 1)) {
+ ix = 0;
+ iy = 0;
+ for (i = 0; i < n; i++) {
+ dtemp += dx[ix0+ix] * dy[iy0+iy];
+ ix += incx;
+ iy += incy;
+ }
+ return dtemp;
+ }
+ m = n % 5;
+ for (i = 0; i < m; i++) {
+ dtemp += dx[ix0+i] * dy[iy0+i];
+ }
+ for (; i < n;) {
+ dtemp += dx[ix0+i] * dy[iy0+i]; i++;
+ dtemp += dx[ix0+i] * dy[iy0+i]; i++;
+ dtemp += dx[ix0+i] * dy[iy0+i]; i++;
+ dtemp += dx[ix0+i] * dy[iy0+i]; i++;
+ dtemp += dx[ix0+i] * dy[iy0+i]; i++;
+ }
+ return dtemp;
+ }
+
+ template <typename FloatType, typename SizeType>
+ inline
+ FloatType ddot(
+ SizeType n,
+ const FloatType* dx,
+ const FloatType* dy)
+ {
+ return ddot(
+ n, dx, SizeType(0), SizeType(1), dy, SizeType(0), SizeType(1));
+ }
+
+ } // namespace detail
+
+ //! Interface to the LBFGS %minimizer.
+ /*! This class solves the unconstrained minimization problem
+ <pre>
+ min f(x), x = (x1,x2,...,x_n),
+ </pre>
+ using the limited-memory BFGS method. The routine is
+ especially effective on problems involving a large number of
+ variables. In a typical iteration of this method an
+ approximation Hk to the inverse of the Hessian
+ is obtained by applying <code>m</code> BFGS updates to a
+ diagonal matrix Hk0, using information from the
+ previous <code>m</code> steps. The user specifies the number
+ <code>m</code>, which determines the amount of storage
+ required by the routine. The user may also provide the
+ diagonal matrices Hk0 (parameter <code>diag</code> in the run()
+ function) if not satisfied with the default choice. The
+ algorithm is described in "On the limited memory BFGS method for
+ large scale optimization", by D. Liu and J. Nocedal, Mathematical
+ Programming B 45 (1989) 503-528.
+
+ The user is required to calculate the function value
+ <code>f</code> and its gradient <code>g</code>. In order to
+ allow the user complete control over these computations,
+ reverse communication is used. The routine must be called
+ repeatedly under the control of the member functions
+ <code>requests_f_and_g()</code>,
+ <code>requests_diag()</code>.
+ If neither requests_f_and_g() nor requests_diag() is
+ <code>true</code> the user should check for convergence
+ (using class traditional_convergence_test or any
+ other custom test). If the convergence test is negative,
+ the minimizer may be called again for the next iteration.
+
+ The steplength (stp()) is determined at each iteration
+ by means of the line search routine <code>mcsrch</code>, which is
+ a slight modification of the routine <code>CSRCH</code> written
+ by More' and Thuente.
+
+ The only variables that are machine-dependent are
+ <code>xtol</code>,
+ <code>stpmin</code> and
+ <code>stpmax</code>.
+
+ Fatal errors cause <code>error</code> exceptions to be thrown.
+ The generic class <code>error</code> is sub-classed (e.g.
+ class <code>error_line_search_failed</code>) to facilitate
+ granular %error handling.
+
+ A note on performance: Using Compaq Fortran V5.4 and
+ Compaq C++ V6.5, the C++ implementation is about 15% slower
+ than the Fortran implementation.
+ */
+ template <typename FloatType, typename SizeType = std::size_t>
+ class minimizer
+ {
+ public:
+ //! Default constructor. Some members are not initialized!
+ minimizer()
+ : n_(0), m_(0), maxfev_(0),
+ gtol_(0), xtol_(0),
+ stpmin_(0), stpmax_(0),
+ ispt(0), iypt(0)
+ {}
+
+ //! Constructor.
+ /*! @param n The number of variables in the minimization problem.
+ Restriction: <code>n &gt; 0</code>.
+
+ @param m The number of corrections used in the BFGS update.
+ Values of <code>m</code> less than 3 are not recommended;
+ large values of <code>m</code> will result in excessive
+ computing time. <code>3 &lt;= m &lt;= 7</code> is
+ recommended.
+ Restriction: <code>m &gt; 0</code>.
+
+ @param maxfev Maximum number of function evaluations
+ <b>per line search</b>.
+ Termination occurs when the number of evaluations
+ of the objective function is at least <code>maxfev</code> by
+ the end of an iteration.
+
+ @param gtol Controls the accuracy of the line search.
+ If the function and gradient evaluations are inexpensive with
+ respect to the cost of the iteration (which is sometimes the
+ case when solving very large problems) it may be advantageous
+ to set <code>gtol</code> to a small value. A typical small
+ value is 0.1.
+ Restriction: <code>gtol</code> should be greater than 1e-4.
+
+ @param xtol An estimate of the machine precision (e.g. 10e-16
+ on a SUN station 3/60). The line search routine will
+ terminate if the relative width of the interval of
+ uncertainty is less than <code>xtol</code>.
+
+ @param stpmin Specifies the lower bound for the step
+ in the line search.
+ The default value is 1e-20. This value need not be modified
+ unless the exponent is too large for the machine being used,
+ or unless the problem is extremely badly scaled (in which
+ case the exponent should be increased).
+
+ @param stpmax specifies the upper bound for the step
+ in the line search.
+ The default value is 1e20. This value need not be modified
+ unless the exponent is too large for the machine being used,
+ or unless the problem is extremely badly scaled (in which
+ case the exponent should be increased).
+ */
+ explicit
+ minimizer(
+ SizeType n,
+ SizeType m = 5,
+ SizeType maxfev = 20,
+ FloatType gtol = FloatType(0.9),
+ FloatType xtol = FloatType(1.e-16),
+ FloatType stpmin = FloatType(1.e-20),
+ FloatType stpmax = FloatType(1.e20))
+ : n_(n), m_(m), maxfev_(maxfev),
+ gtol_(gtol), xtol_(xtol),
+ stpmin_(stpmin), stpmax_(stpmax),
+ iflag_(0), requests_f_and_g_(false), requests_diag_(false),
+ iter_(0), nfun_(0), stp_(0),
+ stp1(0), ftol(0.0001), ys(0), point(0), npt(0),
+ ispt(n+2*m), iypt((n+2*m)+n*m),
+ info(0), bound(0), nfev(0)
+ {
+ if (n_ == 0) {
+ throw error_improper_input_parameter("n = 0.");
+ }
+ if (m_ == 0) {
+ throw error_improper_input_parameter("m = 0.");
+ }
+ if (maxfev_ == 0) {
+ throw error_improper_input_parameter("maxfev = 0.");
+ }
+ if (gtol_ <= FloatType(1.e-4)) {
+ throw error_improper_input_parameter("gtol <= 1.e-4.");
+ }
+ if (xtol_ < FloatType(0)) {
+ throw error_improper_input_parameter("xtol < 0.");
+ }
+ if (stpmin_ < FloatType(0)) {
+ throw error_improper_input_parameter("stpmin < 0.");
+ }
+ if (stpmax_ < stpmin) {
+ throw error_improper_input_parameter("stpmax < stpmin");
+ }
+ w_.resize(n_*(2*m_+1)+2*m_);
+ scratch_array_.resize(n_);
+ }
+
+ //! Number of free parameters (as passed to the constructor).
+ SizeType n() const { return n_; }
+
+ //! Number of corrections kept (as passed to the constructor).
+ SizeType m() const { return m_; }
+
+ /*! \brief Maximum number of evaluations of the objective function
+ per line search (as passed to the constructor).
+ */
+ SizeType maxfev() const { return maxfev_; }
+
+ /*! \brief Control of the accuracy of the line search.
+ (as passed to the constructor).
+ */
+ FloatType gtol() const { return gtol_; }
+
+ //! Estimate of the machine precision (as passed to the constructor).
+ FloatType xtol() const { return xtol_; }
+
+ /*! \brief Lower bound for the step in the line search.
+ (as passed to the constructor).
+ */
+ FloatType stpmin() const { return stpmin_; }
+
+ /*! \brief Upper bound for the step in the line search.
+ (as passed to the constructor).
+ */
+ FloatType stpmax() const { return stpmax_; }
+
+ //! Status indicator for reverse communication.
+ /*! <code>true</code> if the run() function returns to request
+ evaluation of the objective function (<code>f</code>) and
+ gradients (<code>g</code>) for the current point
+ (<code>x</code>). To continue the minimization the
+ run() function is called again with the updated values for
+ <code>f</code> and <code>g</code>.
+ <p>
+ See also: requests_diag()
+ */
+ bool requests_f_and_g() const { return requests_f_and_g_; }
+
+ //! Status indicator for reverse communication.
+ /*! <code>true</code> if the run() function returns to request
+ evaluation of the diagonal matrix (<code>diag</code>)
+ for the current point (<code>x</code>).
+ To continue the minimization the run() function is called
+ again with the updated values for <code>diag</code>.
+ <p>
+ See also: requests_f_and_g()
+ */
+ bool requests_diag() const { return requests_diag_; }
+
+ //! Number of iterations so far.
+ /*! Note that one iteration may involve multiple evaluations
+ of the objective function.
+ <p>
+ See also: nfun()
+ */
+ SizeType iter() const { return iter_; }
+
+ //! Total number of evaluations of the objective function so far.
+ /*! The total number of function evaluations increases by the
+ number of evaluations required for the line search. The total
+ is only increased after a successful line search.
+ <p>
+ See also: iter()
+ */
+ SizeType nfun() const { return nfun_; }
+
+ //! Norm of gradient given gradient array of length n().
+ FloatType euclidean_norm(const FloatType* a) const {
+ return std::sqrt(detail::ddot(n_, a, a));
+ }
+
+ //! Current stepsize.
+ FloatType stp() const { return stp_; }
+
+ //! Execution of one step of the minimization.
+ /*! @param x On initial entry this must be set by the user to
+ the values of the initial estimate of the solution vector.
+
+ @param f Before initial entry or on re-entry under the
+ control of requests_f_and_g(), <code>f</code> must be set
+ by the user to contain the value of the objective function
+ at the current point <code>x</code>.
+
+ @param g Before initial entry or on re-entry under the
+ control of requests_f_and_g(), <code>g</code> must be set
+ by the user to contain the components of the gradient at
+ the current point <code>x</code>.
+
+ The return value is <code>true</code> if either
+ requests_f_and_g() or requests_diag() is <code>true</code>.
+ Otherwise the user should check for convergence
+ (e.g. using class traditional_convergence_test) and
+ call the run() function again to continue the minimization.
+ If the return value is <code>false</code> the user
+ should <b>not</b> update <code>f</code>, <code>g</code> or
+ <code>diag</code> (other overload) before calling
+ the run() function again.
+
+ Note that <code>x</code> is always modified by the run()
+ function. Depending on the situation it can therefore be
+ necessary to evaluate the objective function one more time
+ after the minimization is terminated.
+ */
+ bool run(
+ FloatType* x,
+ FloatType f,
+ const FloatType* g)
+ {
+ return generic_run(x, f, g, false, 0);
+ }
+
+ //! Execution of one step of the minimization.
+ /*! @param x See other overload.
+
+ @param f See other overload.
+
+ @param g See other overload.
+
+ @param diag On initial entry or on re-entry under the
+ control of requests_diag(), <code>diag</code> must be set by
+ the user to contain the values of the diagonal matrix Hk0.
+ The routine will return at each iteration of the algorithm
+ with requests_diag() set to <code>true</code>.
+ <p>
+ Restriction: all elements of <code>diag</code> must be
+ positive.
+ */
+ bool run(
+ FloatType* x,
+ FloatType f,
+ const FloatType* g,
+ const FloatType* diag)
+ {
+ return generic_run(x, f, g, true, diag);
+ }
+
+ void serialize(std::ostream* out) const {
+ out->write((const char*)&n_, sizeof(n_)); // sanity check
+ out->write((const char*)&m_, sizeof(m_)); // sanity check
+ SizeType fs = sizeof(FloatType);
+ out->write((const char*)&fs, sizeof(fs)); // sanity check
+
+ mcsrch_instance.serialize(out);
+ out->write((const char*)&iflag_, sizeof(iflag_));
+ out->write((const char*)&requests_f_and_g_, sizeof(requests_f_and_g_));
+ out->write((const char*)&requests_diag_, sizeof(requests_diag_));
+ out->write((const char*)&iter_, sizeof(iter_));
+ out->write((const char*)&nfun_, sizeof(nfun_));
+ out->write((const char*)&stp_, sizeof(stp_));
+ out->write((const char*)&stp1, sizeof(stp1));
+ out->write((const char*)&ftol, sizeof(ftol));
+ out->write((const char*)&ys, sizeof(ys));
+ out->write((const char*)&point, sizeof(point));
+ out->write((const char*)&npt, sizeof(npt));
+ out->write((const char*)&info, sizeof(info));
+ out->write((const char*)&bound, sizeof(bound));
+ out->write((const char*)&nfev, sizeof(nfev));
+ out->write((const char*)&w_[0], sizeof(FloatType) * w_.size());
+ out->write((const char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size());
+ }
+
+ void deserialize(std::istream* in) {
+ SizeType n, m, fs;
+ in->read((char*)&n, sizeof(n));
+ in->read((char*)&m, sizeof(m));
+ in->read((char*)&fs, sizeof(fs));
+ assert(n == n_);
+ assert(m == m_);
+ assert(fs == sizeof(FloatType));
+
+ mcsrch_instance.deserialize(in);
+ in->read((char*)&iflag_, sizeof(iflag_));
+ in->read((char*)&requests_f_and_g_, sizeof(requests_f_and_g_));
+ in->read((char*)&requests_diag_, sizeof(requests_diag_));
+ in->read((char*)&iter_, sizeof(iter_));
+ in->read((char*)&nfun_, sizeof(nfun_));
+ in->read((char*)&stp_, sizeof(stp_));
+ in->read((char*)&stp1, sizeof(stp1));
+ in->read((char*)&ftol, sizeof(ftol));
+ in->read((char*)&ys, sizeof(ys));
+ in->read((char*)&point, sizeof(point));
+ in->read((char*)&npt, sizeof(npt));
+ in->read((char*)&info, sizeof(info));
+ in->read((char*)&bound, sizeof(bound));
+ in->read((char*)&nfev, sizeof(nfev));
+ in->read((char*)&w_[0], sizeof(FloatType) * w_.size());
+ in->read((char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size());
+ }
+
+ protected:
+ static void throw_diagonal_element_not_positive(SizeType i) {
+ throw error_improper_input_data(
+ "The " + error::itoa(i) + ". diagonal element of the"
+ " inverse Hessian approximation is not positive.");
+ }
+
+ bool generic_run(
+ FloatType* x,
+ FloatType f,
+ const FloatType* g,
+ bool diagco,
+ const FloatType* diag);
+
+ detail::mcsrch<FloatType, SizeType> mcsrch_instance;
+ const SizeType n_;
+ const SizeType m_;
+ const SizeType maxfev_;
+ const FloatType gtol_;
+ const FloatType xtol_;
+ const FloatType stpmin_;
+ const FloatType stpmax_;
+ int iflag_;
+ bool requests_f_and_g_;
+ bool requests_diag_;
+ SizeType iter_;
+ SizeType nfun_;
+ FloatType stp_;
+ FloatType stp1;
+ FloatType ftol;
+ FloatType ys;
+ SizeType point;
+ SizeType npt;
+ const SizeType ispt;
+ const SizeType iypt;
+ int info;
+ SizeType bound;
+ SizeType nfev;
+ std::vector<FloatType> w_;
+ std::vector<FloatType> scratch_array_;
+ };
+
+ template <typename FloatType, typename SizeType>
+ bool minimizer<FloatType, SizeType>::generic_run(
+ FloatType* x,
+ FloatType f,
+ const FloatType* g,
+ bool diagco,
+ const FloatType* diag)
+ {
+ bool execute_entire_while_loop = false;
+ if (!(requests_f_and_g_ || requests_diag_)) {
+ execute_entire_while_loop = true;
+ }
+ requests_f_and_g_ = false;
+ requests_diag_ = false;
+ FloatType* w = &(*(w_.begin()));
+ if (iflag_ == 0) { // Initialize.
+ nfun_ = 1;
+ if (diagco) {
+ for (SizeType i = 0; i < n_; i++) {
+ if (diag[i] <= FloatType(0)) {
+ throw_diagonal_element_not_positive(i);
+ }
+ }
+ }
+ else {
+ std::fill_n(scratch_array_.begin(), n_, FloatType(1));
+ diag = &(*(scratch_array_.begin()));
+ }
+ for (SizeType i = 0; i < n_; i++) {
+ w[ispt + i] = -g[i] * diag[i];
+ }
+ FloatType gnorm = std::sqrt(detail::ddot(n_, g, g));
+ if (gnorm == FloatType(0)) return false;
+ stp1 = FloatType(1) / gnorm;
+ execute_entire_while_loop = true;
+ }
+ if (execute_entire_while_loop) {
+ bound = iter_;
+ iter_++;
+ info = 0;
+ if (iter_ != 1) {
+ if (iter_ > m_) bound = m_;
+ ys = detail::ddot(
+ n_, w, iypt + npt, SizeType(1), w, ispt + npt, SizeType(1));
+ if (!diagco) {
+ FloatType yy = detail::ddot(
+ n_, w, iypt + npt, SizeType(1), w, iypt + npt, SizeType(1));
+ std::fill_n(scratch_array_.begin(), n_, ys / yy);
+ diag = &(*(scratch_array_.begin()));
+ }
+ else {
+ iflag_ = 2;
+ requests_diag_ = true;
+ return true;
+ }
+ }
+ }
+ if (execute_entire_while_loop || iflag_ == 2) {
+ if (iter_ != 1) {
+ if (diag == 0) {
+ throw error_internal_error(__FILE__, __LINE__);
+ }
+ if (diagco) {
+ for (SizeType i = 0; i < n_; i++) {
+ if (diag[i] <= FloatType(0)) {
+ throw_diagonal_element_not_positive(i);
+ }
+ }
+ }
+ SizeType cp = point;
+ if (point == 0) cp = m_;
+ w[n_ + cp -1] = 1 / ys;
+ SizeType i;
+ for (i = 0; i < n_; i++) {
+ w[i] = -g[i];
+ }
+ cp = point;
+ for (i = 0; i < bound; i++) {
+ if (cp == 0) cp = m_;
+ cp--;
+ FloatType sq = detail::ddot(
+ n_, w, ispt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1));
+ SizeType inmc=n_+m_+cp;
+ SizeType iycn=iypt+cp*n_;
+ w[inmc] = w[n_ + cp] * sq;
+ detail::daxpy(n_, -w[inmc], w, iycn, w);
+ }
+ for (i = 0; i < n_; i++) {
+ w[i] *= diag[i];
+ }
+ for (i = 0; i < bound; i++) {
+ FloatType yr = detail::ddot(
+ n_, w, iypt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1));
+ FloatType beta = w[n_ + cp] * yr;
+ SizeType inmc=n_+m_+cp;
+ beta = w[inmc] - beta;
+ SizeType iscn=ispt+cp*n_;
+ detail::daxpy(n_, beta, w, iscn, w);
+ cp++;
+ if (cp == m_) cp = 0;
+ }
+ std::copy(w, w+n_, w+(ispt + point * n_));
+ }
+ stp_ = FloatType(1);
+ if (iter_ == 1) stp_ = stp1;
+ std::copy(g, g+n_, w);
+ }
+ mcsrch_instance.run(
+ gtol_, stpmin_, stpmax_, n_, x, f, g, w, ispt + point * n_,
+ stp_, ftol, xtol_, maxfev_, info, nfev, &(*(scratch_array_.begin())));
+ if (info == -1) {
+ iflag_ = 1;
+ requests_f_and_g_ = true;
+ return true;
+ }
+ if (info != 1) {
+ throw error_internal_error(__FILE__, __LINE__);
+ }
+ nfun_ += nfev;
+ npt = point*n_;
+ for (SizeType i = 0; i < n_; i++) {
+ w[ispt + npt + i] = stp_ * w[ispt + npt + i];
+ w[iypt + npt + i] = g[i] - w[i];
+ }
+ point++;
+ if (point == m_) point = 0;
+ return false;
+ }
+
+ //! Traditional LBFGS convergence test.
+ /*! This convergence test is equivalent to the test embedded
+ in the <code>lbfgs.f</code> Fortran code. The test assumes that
+ there is a meaningful relation between the Euclidean norm of the
+ parameter vector <code>x</code> and the norm of the gradient
+ vector <code>g</code>. Therefore this test should not be used if
+ this assumption is not correct for a given problem.
+ */
+ template <typename FloatType, typename SizeType = std::size_t>
+ class traditional_convergence_test
+ {
+ public:
+ //! Default constructor.
+ traditional_convergence_test()
+ : n_(0), eps_(0)
+ {}
+
+ //! Constructor.
+ /*! @param n The number of variables in the minimization problem.
+ Restriction: <code>n &gt; 0</code>.
+
+ @param eps Determines the accuracy with which the solution
+ is to be found.
+ */
+ explicit
+ traditional_convergence_test(
+ SizeType n,
+ FloatType eps = FloatType(1.e-5))
+ : n_(n), eps_(eps)
+ {
+ if (n_ == 0) {
+ throw error_improper_input_parameter("n = 0.");
+ }
+ if (eps_ < FloatType(0)) {
+ throw error_improper_input_parameter("eps < 0.");
+ }
+ }
+
+ //! Number of free parameters (as passed to the constructor).
+ SizeType n() const { return n_; }
+
+ /*! \brief Accuracy with which the solution is to be found
+ (as passed to the constructor).
+ */
+ FloatType eps() const { return eps_; }
+
+ //! Execution of the convergence test for the given parameters.
+ /*! Returns <code>true</code> if
+ <pre>
+ ||g|| &lt; eps * max(1,||x||),
+ </pre>
+ where <code>||.||</code> denotes the Euclidean norm.
+
+ @param x Current solution vector.
+
+ @param g Components of the gradient at the current
+ point <code>x</code>.
+ */
+ bool
+ operator()(const FloatType* x, const FloatType* g) const
+ {
+ FloatType xnorm = std::sqrt(detail::ddot(n_, x, x));
+ FloatType gnorm = std::sqrt(detail::ddot(n_, g, g));
+ if (gnorm <= eps_ * std::max(FloatType(1), xnorm)) return true;
+ return false;
+ }
+ protected:
+ const SizeType n_;
+ const FloatType eps_;
+ };
+
+}} // namespace scitbx::lbfgs
+
+template <typename T>
+std::ostream& operator<<(std::ostream& os, const scitbx::lbfgs::minimizer<T>& min) {
+ return os << "ITER=" << min.iter() << "\tNFUN=" << min.nfun() << "\tSTP=" << min.stp() << "\tDIAG=" << min.requests_diag() << "\tF&G=" << min.requests_f_and_g();
+}
+
+
+#endif // SCITBX_LBFGS_H
diff --git a/training/lbfgs_test.cc b/training/lbfgs_test.cc
new file mode 100644
index 00000000..4171c118
--- /dev/null
+++ b/training/lbfgs_test.cc
@@ -0,0 +1,112 @@
+#include <cassert>
+#include <iostream>
+#include <sstream>
+#include "lbfgs.h"
+#include "sparse_vector.h"
+#include "fdict.h"
+
+using namespace std;
+
+double TestOptimizer() {
+ cerr << "TESTING NON-PERSISTENT OPTIMIZER\n";
+
+ // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5
+ // df/dx1 = 8*x1 + x2
+ // df/dx2 = 2*x2 + x1
+ // df/dx3 = 2*x3 + 6
+ double x[3];
+ double g[3];
+ scitbx::lbfgs::minimizer<double> opt(3);
+ scitbx::lbfgs::traditional_convergence_test<double> converged(3);
+ x[0] = 8;
+ x[1] = 8;
+ x[2] = 8;
+ double obj = 0;
+ do {
+ g[0] = 8 * x[0] + x[1];
+ g[1] = 2 * x[1] + x[0];
+ g[2] = 2 * x[2] + 6;
+ obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5;
+ opt.run(x, obj, g);
+
+ cerr << x[0] << " " << x[1] << " " << x[2] << endl;
+ cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl;
+ cerr << opt << endl;
+ } while (!converged(x, g));
+ return obj;
+}
+
+double TestPersistentOptimizer() {
+ cerr << "\nTESTING PERSISTENT OPTIMIZER\n";
+ // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5
+ // df/dx1 = 8*x1 + x2
+ // df/dx2 = 2*x2 + x1
+ // df/dx3 = 2*x3 + 6
+ double x[3];
+ double g[3];
+ scitbx::lbfgs::traditional_convergence_test<double> converged(3);
+ x[0] = 8;
+ x[1] = 8;
+ x[2] = 8;
+ double obj = 0;
+ string state;
+ do {
+ g[0] = 8 * x[0] + x[1];
+ g[1] = 2 * x[1] + x[0];
+ g[2] = 2 * x[2] + 6;
+ obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5;
+
+ {
+ scitbx::lbfgs::minimizer<double> opt(3);
+ if (state.size() > 0) {
+ istringstream is(state, ios::binary);
+ opt.deserialize(&is);
+ }
+ opt.run(x, obj, g);
+ ostringstream os(ios::binary); opt.serialize(&os); state = os.str();
+ }
+
+ cerr << x[0] << " " << x[1] << " " << x[2] << endl;
+ cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl;
+ } while (!converged(x, g));
+ return obj;
+}
+
+void TestSparseVector() {
+ cerr << "Testing SparseVector<double> serialization.\n";
+ int f1 = FD::Convert("Feature_1");
+ int f2 = FD::Convert("Feature_2");
+ FD::Convert("LanguageModel");
+ int f4 = FD::Convert("SomeFeature");
+ int f5 = FD::Convert("SomeOtherFeature");
+ SparseVector<double> g;
+ g.set_value(f2, log(0.5));
+ g.set_value(f4, log(0.125));
+ g.set_value(f1, 0);
+ g.set_value(f5, 23.777);
+ ostringstream os;
+ double iobj = 1.5;
+ B64::Encode(iobj, g, &os);
+ cerr << iobj << "\t" << g << endl;
+ string data = os.str();
+ cout << data << endl;
+ SparseVector<double> v;
+ double obj;
+ assert(B64::Decode(&obj, &v, &data[0], data.size()));
+ cerr << obj << "\t" << v << endl;
+ assert(obj == iobj);
+ assert(g.num_active() == v.num_active());
+}
+
+int main() {
+ double o1 = TestOptimizer();
+ double o2 = TestPersistentOptimizer();
+ if (o1 != o2) {
+ cerr << "OPTIMIZERS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl;
+ return 1;
+ }
+ TestSparseVector();
+ cerr << "SUCCESS\n";
+ return 0;
+}
+
diff --git a/training/make-lexcrf-grammar.pl b/training/make-lexcrf-grammar.pl
new file mode 100755
index 00000000..0e290492
--- /dev/null
+++ b/training/make-lexcrf-grammar.pl
@@ -0,0 +1,236 @@
+#!/usr/bin/perl -w
+use utf8;
+use strict;
+my ($effile, $model1) = @ARGV;
+die "Usage: $0 corpus.fr-en corpus.model1\n" unless $effile && -f $effile && $model1 && -f $model1;
+
+open EF, "<$effile" or die;
+open M1, "<$model1" or die;
+binmode(EF,":utf8");
+binmode(M1,":utf8");
+binmode(STDOUT,":utf8");
+my %model1;
+while(<M1>) {
+ chomp;
+ my ($f, $e, $lp) = split /\s+/;
+ $model1{$f}->{$e} = $lp;
+}
+
+my $ADD_MODEL1 = 0; # found that model1 hurts performance
+my $IS_FRENCH_F = 0; # indicates that the f language is french
+my $IS_ARABIC_F = 1; # indicates that the f language is arabic
+my $ADD_PREFIX_ID = 0;
+my $ADD_LEN = 1;
+my $ADD_LD = 0;
+my $ADD_DICE = 1;
+my $ADD_111 = 1;
+my $ADD_ID = 1;
+my $ADD_PUNC = 1;
+my $ADD_NUM_MM = 1;
+my $ADD_NULL = 1;
+my $BEAM_RATIO = 50;
+
+my %fdict;
+my %fcounts;
+my %ecounts;
+
+while(<EF>) {
+ chomp;
+ my ($f, $e) = split /\s*\|\|\|\s*/;
+ my @es = split /\s+/, $e;
+ my @fs = split /\s+/, $f;
+ for my $ew (@es){ $ecounts{$ew}++; }
+ push @fs, '<eps>' if $ADD_NULL;
+ for my $fw (@fs){ $fcounts{$fw}++; }
+ for my $fw (@fs){
+ for my $ew (@es){
+ $fdict{$fw}->{$ew}++;
+ }
+ }
+}
+
+print STDERR "Dice 0\n" if $ADD_DICE;
+print STDERR "OneOneOne 0\nId_OneOneOne 0\n" if $ADD_111;
+print STDERR "Identical 0\n" if $ADD_ID;
+print STDERR "PuncMiss 0\n" if $ADD_PUNC;
+print STDERR "IsNull 0\n" if $ADD_NULL;
+print STDERR "Model1 0\n" if $ADD_MODEL1;
+print STDERR "DLen 0\n" if $ADD_LEN;
+print STDERR "NumMM 0\n" if $ADD_NUM_MM;
+print STDERR "Level 0\n" if $ADD_LD;
+print STDERR "PfxIdentical 0\n" if ($ADD_PREFIX_ID);
+my $fc = 1000000;
+for my $f (sort keys %fdict) {
+ my $re = $fdict{$f};
+ my $max;
+ for my $e (sort {$re->{$b} <=> $re->{$a}} keys %$re) {
+ my $efcount = $re->{$e};
+ unless (defined $max) { $max = $efcount; }
+ my $m1 = $model1{$f}->{$e};
+ unless (defined $m1) { next; }
+ $fc++;
+ my $dice = 2 * $efcount / ($ecounts{$e} + $fcounts{$f});
+ my $feats = "F$fc=1";
+ my $oe = $e;
+ my $len_e = length($oe);
+ my $of = $f; # normalized form
+ if ($IS_FRENCH_F) {
+ # see http://en.wikipedia.org/wiki/Use_of_the_circumflex_in_French
+ $of =~ s/â/as/g;
+ $of =~ s/ê/es/g;
+ $of =~ s/î/is/g;
+ $of =~ s/ô/os/g;
+ $of =~ s/û/us/g;
+ } elsif ($IS_ARABIC_F) {
+ if (length($of) > 1 && !($of =~ /\d/)) {
+ $of =~ s/\$/sh/g;
+ }
+ }
+ my $len_f = length($of);
+ $feats .= " Model1=$m1" if ($ADD_MODEL1);
+ $feats .= " Dice=$dice" if $ADD_DICE;
+ my $is_null = undef;
+ if ($ADD_NULL && $f eq '<eps>') {
+ $feats .= " IsNull=1";
+ $is_null = 1;
+ }
+ if ($ADD_LEN) {
+ if (!$is_null) {
+ my $dlen = abs($len_e - $len_f);
+ $feats .= " DLen=$dlen";
+ }
+ }
+ my $f_num = ($of =~ /^-?\d[0-9\.\,]+%?$/); # this matches *two digit* and more numbers
+ my $e_num = ($oe =~ /^-?\d[0-9\.\,]+%?$/);
+ my $both_non_numeric = (!$e_num && !$f_num);
+ if ($ADD_NUM_MM && (($f_num && !$e_num) || ($e_num && !$f_num))) {
+ $feats .= " NumMM=1";
+ }
+ if ($ADD_PREFIX_ID) {
+ if ($len_e > 3 && $len_f > 3 && $both_non_numeric) {
+ my $pe = substr $oe, 0, 3;
+ my $pf = substr $of, 0, 3;
+ if ($pe eq $pf) { $feats .= " PfxIdentical=1"; }
+ }
+ }
+ if ($ADD_LD) {
+ my $ld = 0;
+ if ($is_null) { $ld = length($e); } else {
+ $ld = levenshtein($e, $f);
+ }
+ $feats .= " Leven=$ld";
+ }
+ my $ident = ($e eq $f);
+ if ($ident && $ADD_ID) { $feats .= " Identical=1"; }
+ if ($ADD_111 && ($efcount == 1 && $ecounts{$e} == 1 && $fcounts{$f} == 1)) {
+ if ($ident && $ADD_ID) {
+ $feats .= " Id_OneOneOne=1";
+ }
+ $feats .= " OneOneOne=1";
+ }
+ if ($ADD_PUNC) {
+ if (($f =~ /^[0-9!\$%,\-\/"':;=+?.()«»]+$/ && $e =~ /[a-z]+/) ||
+ ($e =~ /^[0-9!\$%,\-\/"':;=+?.()«»]+$/ && $f =~ /[a-z]+/)) {
+ $feats .= " PuncMiss=1";
+ }
+ }
+ my $r = (0.5 - rand)/5;
+ print STDERR "F$fc $r\n";
+ print "$f ||| $e ||| $feats\n";
+ }
+}
+
+sub levenshtein
+{
+ # $s1 and $s2 are the two strings
+ # $len1 and $len2 are their respective lengths
+ #
+ my ($s1, $s2) = @_;
+ my ($len1, $len2) = (length $s1, length $s2);
+
+ # If one of the strings is empty, the distance is the length
+ # of the other string
+ #
+ return $len2 if ($len1 == 0);
+ return $len1 if ($len2 == 0);
+
+ my %mat;
+
+ # Init the distance matrix
+ #
+ # The first row to 0..$len1
+ # The first column to 0..$len2
+ # The rest to 0
+ #
+ # The first row and column are initialized so to denote distance
+ # from the empty string
+ #
+ for (my $i = 0; $i <= $len1; ++$i)
+ {
+ for (my $j = 0; $j <= $len2; ++$j)
+ {
+ $mat{$i}{$j} = 0;
+ $mat{0}{$j} = $j;
+ }
+
+ $mat{$i}{0} = $i;
+ }
+
+ # Some char-by-char processing is ahead, so prepare
+ # array of chars from the strings
+ #
+ my @ar1 = split(//, $s1);
+ my @ar2 = split(//, $s2);
+
+ for (my $i = 1; $i <= $len1; ++$i)
+ {
+ for (my $j = 1; $j <= $len2; ++$j)
+ {
+ # Set the cost to 1 iff the ith char of $s1
+ # equals the jth of $s2
+ #
+ # Denotes a substitution cost. When the char are equal
+ # there is no need to substitute, so the cost is 0
+ #
+ my $cost = ($ar1[$i-1] eq $ar2[$j-1]) ? 0 : 1;
+
+ # Cell $mat{$i}{$j} equals the minimum of:
+ #
+ # - The cell immediately above plus 1
+ # - The cell immediately to the left plus 1
+ # - The cell diagonally above and to the left plus the cost
+ #
+ # We can either insert a new char, delete a char or
+ # substitute an existing char (with an associated cost)
+ #
+ $mat{$i}{$j} = min([$mat{$i-1}{$j} + 1,
+ $mat{$i}{$j-1} + 1,
+ $mat{$i-1}{$j-1} + $cost]);
+ }
+ }
+
+ # Finally, the Levenshtein distance equals the rightmost bottom cell
+ # of the matrix
+ #
+ # Note that $mat{$x}{$y} denotes the distance between the substrings
+ # 1..$x and 1..$y
+ #
+ return $mat{$len1}{$len2};
+}
+
+
+# minimal element of a list
+#
+sub min
+{
+ my @list = @{$_[0]};
+ my $min = $list[0];
+
+ foreach my $i (@list)
+ {
+ $min = $i if ($i < $min);
+ }
+
+ return $min;
+}
+
diff --git a/training/model1.cc b/training/model1.cc
new file mode 100644
index 00000000..f571700f
--- /dev/null
+++ b/training/model1.cc
@@ -0,0 +1,103 @@
+#include <iostream>
+
+#include "lattice.h"
+#include "stringlib.h"
+#include "filelib.h"
+#include "ttables.h"
+#include "tdict.h"
+
+using namespace std;
+
+int main(int argc, char** argv) {
+ if (argc != 2) {
+ cerr << "Usage: " << argv[0] << " corpus.fr-en\n";
+ return 1;
+ }
+ const int ITERATIONS = 5;
+ const prob_t BEAM_THRESHOLD(0.0001);
+ TTable tt;
+ const WordID kNULL = TD::Convert("<eps>");
+ bool use_null = true;
+ TTable::Word2Word2Double was_viterbi;
+ for (int iter = 0; iter < ITERATIONS; ++iter) {
+ const bool final_iteration = (iter == (ITERATIONS - 1));
+ cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl;
+ ReadFile rf(argv[1]);
+ istream& in = *rf.stream();
+ prob_t likelihood = prob_t::One();
+ double denom = 0.0;
+ int lc = 0;
+ bool flag = false;
+ while(true) {
+ string line;
+ getline(in, line);
+ if (!in) break;
+ ++lc;
+ if (lc % 1000 == 0) { cerr << '.'; flag = true; }
+ if (lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; }
+ string ssrc, strg;
+ ParseTranslatorInput(line, &ssrc, &strg);
+ Lattice src, trg;
+ LatticeTools::ConvertTextToLattice(ssrc, &src);
+ LatticeTools::ConvertTextToLattice(strg, &trg);
+ assert(src.size() > 0);
+ assert(trg.size() > 0);
+ denom += 1.0;
+ vector<prob_t> probs(src.size() + 1);
+ for (int j = 0; j < trg.size(); ++j) {
+ const WordID& f_j = trg[j][0].label;
+ prob_t sum = prob_t::Zero();
+ if (use_null) {
+ probs[0] = tt.prob(kNULL, f_j);
+ sum += probs[0];
+ }
+ for (int i = 1; i <= src.size(); ++i) {
+ probs[i] = tt.prob(src[i-1][0].label, f_j);
+ sum += probs[i];
+ }
+ if (final_iteration) {
+ WordID max_i = 0;
+ prob_t max_p = prob_t::Zero();
+ if (use_null) {
+ max_i = kNULL;
+ max_p = probs[0];
+ }
+ for (int i = 1; i <= src.size(); ++i) {
+ if (probs[i] > max_p) {
+ max_p = probs[i];
+ max_i = src[i-1][0].label;
+ }
+ }
+ was_viterbi[max_i][f_j] = 1.0;
+ } else {
+ if (use_null)
+ tt.Increment(kNULL, f_j, probs[0] / sum);
+ for (int i = 1; i <= src.size(); ++i)
+ tt.Increment(src[i-1][0].label, f_j, probs[i] / sum);
+ }
+ likelihood *= sum;
+ }
+ }
+ if (flag) { cerr << endl; }
+ cerr << " log likelihood: " << log(likelihood) << endl;
+ cerr << " cross entopy: " << (-log(likelihood) / denom) << endl;
+ cerr << " perplexity: " << pow(2.0, -log(likelihood) / denom) << endl;
+ if (!final_iteration) tt.Normalize();
+ }
+ for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) {
+ const TTable::Word2Double& cpd = ei->second;
+ const TTable::Word2Double& vit = was_viterbi[ei->first];
+ const string& esym = TD::Convert(ei->first);
+ prob_t max_p = prob_t::Zero();
+ for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi)
+ if (fi->second > max_p) max_p = prob_t(fi->second);
+ const prob_t threshold = max_p * BEAM_THRESHOLD;
+ for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) {
+ if (fi->second > threshold || (vit.count(fi->first) > 0)) {
+ cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl;
+ }
+ }
+ }
+ return 0;
+}
+
diff --git a/training/mr_em_train.cc b/training/mr_em_train.cc
new file mode 100644
index 00000000..a15fbe4c
--- /dev/null
+++ b/training/mr_em_train.cc
@@ -0,0 +1,270 @@
+#include <iostream>
+#include <vector>
+#include <cassert>
+#include <cmath>
+
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "config.h"
+#ifdef HAVE_BOOST_DIGAMMA
+#include <boost/math/special_functions/digamma.hpp>
+using boost::math::digamma;
+#endif
+
+#include "tdict.h"
+#include "filelib.h"
+#include "trule.h"
+#include "fdict.h"
+#include "weights.h"
+#include "sparse_vector.h"
+
+using namespace std;
+using boost::shared_ptr;
+namespace po = boost::program_options;
+
+#ifndef HAVE_BOOST_DIGAMMA
+#warning Using Mark Johnson's digamma()
+double digamma(double x) {
+ double result = 0, xx, xx2, xx4;
+ assert(x > 0);
+ for ( ; x < 7; ++x)
+ result -= 1/x;
+ x -= 1.0/2.0;
+ xx = 1.0/x;
+ xx2 = xx*xx;
+ xx4 = xx2*xx2;
+ result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4;
+ return result;
+}
+#endif
+
+void SanityCheck(const vector<double>& w) {
+ for (int i = 0; i < w.size(); ++i) {
+ assert(!isnan(w[i]));
+ }
+}
+
+struct FComp {
+ const vector<double>& w_;
+ FComp(const vector<double>& w) : w_(w) {}
+ bool operator()(int a, int b) const {
+ return w_[a] > w_[b];
+ }
+};
+
+void ShowLargestFeatures(const vector<double>& w) {
+ vector<int> fnums(w.size() - 1);
+ for (int i = 1; i < w.size(); ++i)
+ fnums[i-1] = i;
+ vector<int>::iterator mid = fnums.begin();
+ mid += (w.size() > 10 ? 10 : w.size()) - 1;
+ partial_sort(fnums.begin(), mid, fnums.end(), FComp(w));
+ cerr << "MOST PROBABLE:";
+ for (vector<int>::iterator i = fnums.begin(); i != mid; ++i) {
+ cerr << ' ' << FD::Convert(*i) << '=' << w[*i];
+ }
+ cerr << endl;
+}
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("output,o",po::value<string>()->default_value("-"),"Output log probs file")
+ ("grammar,g",po::value<vector<string> >()->composing(),"SCFG grammar file(s)")
+ ("optimization_method,m", po::value<string>()->default_value("em"), "Optimization method (em, vb)")
+ ("input_format,f",po::value<string>()->default_value("b64"),"Encoding of the input (b64 or text)");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ ifstream config((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help") || !conf->count("grammar")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+// describes a multinomial or multinomial with a prior
+// does not contain the parameters- just the list of events
+// and any hyperparameters
+struct MultinomialInfo {
+ MultinomialInfo() : alpha(1.0) {}
+ vector<int> events; // the events that this multinomial generates
+ double alpha; // hyperparameter for (optional) Dirichlet prior
+};
+
+typedef map<WordID, MultinomialInfo> ModelDefinition;
+
+void LoadModelEvents(const po::variables_map& conf, ModelDefinition* pm) {
+ ModelDefinition& m = *pm;
+ m.clear();
+ vector<string> gfiles = conf["grammar"].as<vector<string> >();
+ for (int i = 0; i < gfiles.size(); ++i) {
+ ReadFile rf(gfiles[i]);
+ istream& in = *rf.stream();
+ int lc = 0;
+ while(in) {
+ string line;
+ getline(in, line);
+ if (line.empty()) continue;
+ ++lc;
+ TRule r(line, true);
+ const SparseVector<double>& f = r.GetFeatureValues();
+ if (f.num_active() == 0) {
+ cerr << "[WARNING] no feature found in " << gfiles[i] << ':' << lc << endl;
+ continue;
+ }
+ if (f.num_active() > 1) {
+ cerr << "[ERROR] more than one feature found in " << gfiles[i] << ':' << lc << endl;
+ exit(1);
+ }
+ SparseVector<double>::const_iterator it = f.begin();
+ if (it->second != 1.0) {
+ cerr << "[ERROR] feature with value != 1 found in " << gfiles[i] << ':' << lc << endl;
+ exit(1);
+ }
+ m[r.GetLHS()].events.push_back(it->first);
+ }
+ }
+ for (ModelDefinition::iterator it = m.begin(); it != m.end(); ++it) {
+ const vector<int>& v = it->second.events;
+ cerr << "Multinomial [" << TD::Convert(it->first*-1) << "]\n";
+ if (v.size() < 1000) {
+ cerr << " generates:";
+ for (int i = 0; i < v.size(); ++i) {
+ cerr << " " << FD::Convert(v[i]);
+ }
+ cerr << endl;
+ }
+ }
+}
+
+void Maximize(const ModelDefinition& m, const bool use_vb, vector<double>* counts) {
+ for (ModelDefinition::const_iterator it = m.begin(); it != m.end(); ++it) {
+ const MultinomialInfo& mult_info = it->second;
+ const vector<int>& events = mult_info.events;
+ cerr << "Multinomial [" << TD::Convert(it->first*-1) << "]";
+ double tot = 0;
+ for (int i = 0; i < events.size(); ++i)
+ tot += (*counts)[events[i]];
+ cerr << " = " << tot << endl;
+ assert(tot > 0.0);
+ double ltot = log(tot);
+ if (use_vb)
+ ltot = digamma(tot + events.size() * mult_info.alpha);
+ for (int i = 0; i < events.size(); ++i) {
+ if (use_vb) {
+ (*counts)[events[i]] = digamma((*counts)[events[i]] + mult_info.alpha) - ltot;
+ } else {
+ (*counts)[events[i]] = log((*counts)[events[i]]) - ltot;
+ }
+ }
+ if (events.size() < 50) {
+ for (int i = 0; i < events.size(); ++i) {
+ cerr << " p(" << FD::Convert(events[i]) << ")=" << exp((*counts)[events[i]]);
+ }
+ cerr << endl;
+ }
+ }
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+
+ const bool use_b64 = conf["input_format"].as<string>() == "b64";
+ const bool use_vb = conf["optimization_method"].as<string>() == "vb";
+ if (use_vb)
+ cerr << "Using variational Bayes, make sure alphas are set\n";
+
+ ModelDefinition model_def;
+ LoadModelEvents(conf, &model_def);
+
+ const string s_obj = "**OBJ**";
+ int num_feats = FD::NumFeats();
+ cerr << "Number of features: " << num_feats << endl;
+
+ vector<double> counts(num_feats, 0);
+ double logprob = 0;
+ // 0<TAB>**OBJ**=12.2;Feat1=2.3;Feat2=-0.2;
+ // 0<TAB>**OBJ**=1.1;Feat1=1.0;
+
+ // E-step
+ while(cin) {
+ string line;
+ getline(cin, line);
+ if (line.empty()) continue;
+ int feat;
+ double val;
+ size_t i = line.find("\t");
+ assert(i != string::npos);
+ ++i;
+ if (use_b64) {
+ SparseVector<double> g;
+ double obj;
+ if (!B64::Decode(&obj, &g, &line[i], line.size() - i)) {
+ cerr << "B64 decoder returned error, skipping!\n";
+ continue;
+ }
+ logprob += obj;
+ const SparseVector<double>& cg = g;
+ for (SparseVector<double>::const_iterator it = cg.begin(); it != cg.end(); ++it) {
+ if (it->first >= num_feats) {
+ cerr << "Unexpected feature: " << FD::Convert(it->first) << endl;
+ abort();
+ }
+ counts[it->first] += it->second;
+ }
+ } else { // text encoding - your counts will not be accurate!
+ while (i < line.size()) {
+ size_t start = i;
+ while (line[i] != '=' && i < line.size()) ++i;
+ if (i == line.size()) { cerr << "FORMAT ERROR\n"; break; }
+ string fname = line.substr(start, i - start);
+ if (fname == s_obj) {
+ feat = -1;
+ } else {
+ feat = FD::Convert(line.substr(start, i - start));
+ if (feat >= num_feats) {
+ cerr << "Unexpected feature: " << line.substr(start, i - start) << endl;
+ abort();
+ }
+ }
+ ++i;
+ start = i;
+ while (line[i] != ';' && i < line.size()) ++i;
+ if (i - start == 0) continue;
+ val = atof(line.substr(start, i - start).c_str());
+ ++i;
+ if (feat == -1) {
+ logprob += val;
+ } else {
+ counts[feat] += val;
+ }
+ }
+ }
+ }
+
+ cerr << "LOGPROB: " << logprob << endl;
+ // M-step
+ Maximize(model_def, use_vb, &counts);
+
+ SanityCheck(counts);
+ ShowLargestFeatures(counts);
+ Weights weights;
+ weights.InitFromVector(counts);
+ weights.WriteToFile(conf["output"].as<string>(), false);
+
+ return 0;
+}
diff --git a/training/mr_optimize_reduce.cc b/training/mr_optimize_reduce.cc
new file mode 100644
index 00000000..56b73c30
--- /dev/null
+++ b/training/mr_optimize_reduce.cc
@@ -0,0 +1,243 @@
+#include <sstream>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <cassert>
+#include <cmath>
+
+#include <boost/shared_ptr.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "optimize.h"
+#include "fdict.h"
+#include "weights.h"
+#include "sparse_vector.h"
+
+using namespace std;
+using boost::shared_ptr;
+namespace po = boost::program_options;
+
+void SanityCheck(const vector<double>& w) {
+ for (int i = 0; i < w.size(); ++i) {
+ assert(!isnan(w[i]));
+ assert(!isinf(w[i]));
+ }
+}
+
+struct FComp {
+ const vector<double>& w_;
+ FComp(const vector<double>& w) : w_(w) {}
+ bool operator()(int a, int b) const {
+ return fabs(w_[a]) > fabs(w_[b]);
+ }
+};
+
+void ShowLargestFeatures(const vector<double>& w) {
+ vector<int> fnums(w.size());
+ for (int i = 0; i < w.size(); ++i)
+ fnums[i] = i;
+ vector<int>::iterator mid = fnums.begin();
+ mid += (w.size() > 10 ? 10 : w.size());
+ partial_sort(fnums.begin(), mid, fnums.end(), FComp(w));
+ cerr << "TOP FEATURES:";
+ for (vector<int>::iterator i = fnums.begin(); i != mid; ++i) {
+ cerr << ' ' << FD::Convert(*i) << '=' << w[*i];
+ }
+ cerr << endl;
+}
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("input_weights,i",po::value<string>(),"Input feature weights file")
+ ("output_weights,o",po::value<string>()->default_value("-"),"Output feature weights file")
+ ("optimization_method,m", po::value<string>()->default_value("lbfgs"), "Optimization method (sgd, lbfgs, rprop)")
+ ("state,s",po::value<string>(),"Read (and write if output_state is not set) optimizer state from this state file. In the first iteration, the file should not exist.")
+ ("input_format,f",po::value<string>()->default_value("b64"),"Encoding of the input (b64 or text)")
+ ("output_state,S", po::value<string>(), "Output state file (optional override)")
+ ("correction_buffers,M", po::value<int>()->default_value(10), "Number of gradients for LBFGS to maintain in memory")
+ ("eta,e", po::value<double>()->default_value(0.1), "Learning rate for SGD (eta)")
+ ("gaussian_prior,p","Use a Gaussian prior on the weights")
+ ("means,u", po::value<string>(), "File containing the means for Gaussian prior")
+ ("sigma_squared", po::value<double>()->default_value(1.0), "Sigma squared term for spherical Gaussian prior");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ ifstream config((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help") || !conf->count("input_weights") || !conf->count("state")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+
+ const bool use_b64 = conf["input_format"].as<string>() == "b64";
+
+ Weights weights;
+ weights.InitFromFile(conf["input_weights"].as<string>());
+ const string s_obj = "**OBJ**";
+ int num_feats = FD::NumFeats();
+ cerr << "Number of features: " << num_feats << endl;
+ const bool gaussian_prior = conf.count("gaussian_prior");
+ vector<double> means(num_feats, 0);
+ if (conf.count("means")) {
+ if (!gaussian_prior) {
+ cerr << "Don't use --means without --gaussian_prior!\n";
+ exit(1);
+ }
+ Weights wm;
+ wm.InitFromFile(conf["means"].as<string>());
+ if (num_feats != FD::NumFeats()) {
+ cerr << "[ERROR] Means file had unexpected features!\n";
+ exit(1);
+ }
+ wm.InitVector(&means);
+ }
+ shared_ptr<Optimizer> o;
+ const string omethod = conf["optimization_method"].as<string>();
+ if (omethod == "sgd")
+ o.reset(new SGDOptimizer(conf["eta"].as<double>()));
+ else if (omethod == "rprop")
+ o.reset(new RPropOptimizer(num_feats)); // TODO add configuration
+ else
+ o.reset(new LBFGSOptimizer(num_feats, conf["correction_buffers"].as<int>()));
+ cerr << "Optimizer: " << o->Name() << endl;
+ string state_file = conf["state"].as<string>();
+ {
+ ifstream in(state_file.c_str(), ios::binary);
+ if (in)
+ o->Load(&in);
+ else
+ cerr << "No state file found, assuming ITERATION 1\n";
+ }
+
+ vector<double> lambdas(num_feats, 0);
+ weights.InitVector(&lambdas);
+ double objective = 0;
+ vector<double> gradient(num_feats, 0);
+ // 0<TAB>**OBJ**=12.2;Feat1=2.3;Feat2=-0.2;
+ // 0<TAB>**OBJ**=1.1;Feat1=1.0;
+ int total_lines = 0; // TODO - this should be a count of the
+ // training instances!!
+ while(cin) {
+ string line;
+ getline(cin, line);
+ if (line.empty()) continue;
+ ++total_lines;
+ int feat;
+ double val;
+ size_t i = line.find("\t");
+ assert(i != string::npos);
+ ++i;
+ if (use_b64) {
+ SparseVector<double> g;
+ double obj;
+ if (!B64::Decode(&obj, &g, &line[i], line.size() - i)) {
+ cerr << "B64 decoder returned error, skipping gradient!\n";
+ cerr << " START: " << line.substr(0,min(200ul, line.size())) << endl;
+ if (line.size() > 200)
+ cerr << " END: " << line.substr(line.size() - 200, 200) << endl;
+ cout << "-1\tRESTART\n";
+ exit(99);
+ }
+ objective += obj;
+ const SparseVector<double>& cg = g;
+ for (SparseVector<double>::const_iterator it = cg.begin(); it != cg.end(); ++it) {
+ if (it->first >= num_feats) {
+ cerr << "Unexpected feature in gradient: " << FD::Convert(it->first) << endl;
+ abort();
+ }
+ gradient[it->first] -= it->second;
+ }
+ } else { // text encoding - your gradients will not be accurate!
+ while (i < line.size()) {
+ size_t start = i;
+ while (line[i] != '=' && i < line.size()) ++i;
+ if (i == line.size()) { cerr << "FORMAT ERROR\n"; break; }
+ string fname = line.substr(start, i - start);
+ if (fname == s_obj) {
+ feat = -1;
+ } else {
+ feat = FD::Convert(line.substr(start, i - start));
+ if (feat >= num_feats) {
+ cerr << "Unexpected feature in gradient: " << line.substr(start, i - start) << endl;
+ abort();
+ }
+ }
+ ++i;
+ start = i;
+ while (line[i] != ';' && i < line.size()) ++i;
+ if (i - start == 0) continue;
+ val = atof(line.substr(start, i - start).c_str());
+ ++i;
+ if (feat == -1) {
+ objective += val;
+ } else {
+ gradient[feat] -= val;
+ }
+ }
+ }
+ }
+
+ if (gaussian_prior) {
+ const double sigsq = conf["sigma_squared"].as<double>();
+ double norm = 0;
+ for (int k = 1; k < lambdas.size(); ++k) {
+ const double& lambda_k = lambdas[k];
+ if (lambda_k) {
+ const double param = (lambda_k - means[k]);
+ norm += param * param;
+ gradient[k] += param / sigsq;
+ }
+ }
+ const double reg = norm / (2.0 * sigsq);
+ cerr << "REGULARIZATION TERM: " << reg << endl;
+ objective += reg;
+ }
+ cerr << "EVALUATION #" << o->EvaluationCount() << " OBJECTIVE: " << objective << endl;
+ double gnorm = 0;
+ for (int i = 0; i < gradient.size(); ++i)
+ gnorm += gradient[i] * gradient[i];
+ cerr << " GNORM=" << sqrt(gnorm) << endl;
+ vector<double> old = lambdas;
+ int c = 0;
+ while (old == lambdas) {
+ ++c;
+ if (c > 1) { cerr << "Same lambdas, repeating optimization\n"; }
+ o->Optimize(objective, gradient, &lambdas);
+ assert(c < 5);
+ }
+ old.clear();
+ SanityCheck(lambdas);
+ ShowLargestFeatures(lambdas);
+ weights.InitFromVector(lambdas);
+ weights.WriteToFile(conf["output_weights"].as<string>(), false);
+
+ const bool conv = o->HasConverged();
+ if (conv) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; }
+
+ if (conf.count("output_state"))
+ state_file = conf["output_state"].as<string>();
+ ofstream out(state_file.c_str(), ios::binary);
+ cerr << "Writing state to: " << state_file << endl;
+ o->Save(&out);
+ out.close();
+
+ cout << o->EvaluationCount() << "\t" << conv << endl;
+ return 0;
+}
diff --git a/training/optimize.cc b/training/optimize.cc
new file mode 100644
index 00000000..5194752e
--- /dev/null
+++ b/training/optimize.cc
@@ -0,0 +1,114 @@
+#include "optimize.h"
+
+#include <iostream>
+#include <cassert>
+
+#include "lbfgs.h"
+
+using namespace std;
+
+Optimizer::~Optimizer() {}
+
+void Optimizer::Save(ostream* out) const {
+ out->write((const char*)&eval_, sizeof(eval_));
+ out->write((const char*)&has_converged_, sizeof(has_converged_));
+ SaveImpl(out);
+ unsigned int magic = 0xABCDDCBA; // should be uint32_t
+ out->write((const char*)&magic, sizeof(magic));
+}
+
+void Optimizer::Load(istream* in) {
+ in->read((char*)&eval_, sizeof(eval_));
+ ++eval_;
+ in->read((char*)&has_converged_, sizeof(has_converged_));
+ LoadImpl(in);
+ unsigned int magic = 0; // should be uint32_t
+ in->read((char*)&magic, sizeof(magic));
+ assert(magic == 0xABCDDCBA);
+ cerr << Name() << " EVALUATION #" << eval_ << endl;
+}
+
+void Optimizer::SaveImpl(ostream* out) const {
+ (void)out;
+}
+
+void Optimizer::LoadImpl(istream* in) {
+ (void)in;
+}
+
+string RPropOptimizer::Name() const {
+ return "RPropOptimizer";
+}
+
+void RPropOptimizer::OptimizeImpl(const double& obj,
+ const vector<double>& g,
+ vector<double>* x) {
+ for (int i = 0; i < g.size(); ++i) {
+ const double g_i = g[i];
+ const double sign_i = (signbit(g_i) ? -1.0 : 1.0);
+ const double prod = g_i * prev_g_[i];
+ if (prod > 0.0) {
+ const double dij = min(delta_ij_[i] * eta_plus_, delta_max_);
+ (*x)[i] -= dij * sign_i;
+ delta_ij_[i] = dij;
+ prev_g_[i] = g_i;
+ } else if (prod < 0.0) {
+ delta_ij_[i] = max(delta_ij_[i] * eta_minus_, delta_min_);
+ prev_g_[i] = 0.0;
+ } else {
+ (*x)[i] -= delta_ij_[i] * sign_i;
+ prev_g_[i] = g_i;
+ }
+ }
+}
+
+void RPropOptimizer::SaveImpl(ostream* out) const {
+ const size_t n = prev_g_.size();
+ out->write((const char*)&n, sizeof(n));
+ out->write((const char*)&prev_g_[0], sizeof(double) * n);
+ out->write((const char*)&delta_ij_[0], sizeof(double) * n);
+}
+
+void RPropOptimizer::LoadImpl(istream* in) {
+ size_t n;
+ in->read((char*)&n, sizeof(n));
+ assert(n == prev_g_.size());
+ assert(n == delta_ij_.size());
+ in->read((char*)&prev_g_[0], sizeof(double) * n);
+ in->read((char*)&delta_ij_[0], sizeof(double) * n);
+}
+
+string SGDOptimizer::Name() const {
+ return "SGDOptimizer";
+}
+
+void SGDOptimizer::OptimizeImpl(const double& obj,
+ const vector<double>& g,
+ vector<double>* x) {
+ (void)obj;
+ for (int i = 0; i < g.size(); ++i)
+ (*x)[i] -= g[i] * eta_;
+}
+
+string LBFGSOptimizer::Name() const {
+ return "LBFGSOptimizer";
+}
+
+LBFGSOptimizer::LBFGSOptimizer(int num_feats, int memory_buffers) :
+ opt_(num_feats, memory_buffers) {}
+
+void LBFGSOptimizer::SaveImpl(ostream* out) const {
+ opt_.serialize(out);
+}
+
+void LBFGSOptimizer::LoadImpl(istream* in) {
+ opt_.deserialize(in);
+}
+
+void LBFGSOptimizer::OptimizeImpl(const double& obj,
+ const vector<double>& g,
+ vector<double>* x) {
+ opt_.run(&(*x)[0], obj, &g[0]);
+ cerr << opt_ << endl;
+}
+
diff --git a/training/optimize.h b/training/optimize.h
new file mode 100644
index 00000000..eddceaad
--- /dev/null
+++ b/training/optimize.h
@@ -0,0 +1,104 @@
+#ifndef _OPTIMIZE_H_
+#define _OPTIMIZE_H_
+
+#include <iostream>
+#include <vector>
+#include <string>
+#include <cassert>
+
+#include "lbfgs.h"
+
+// abstract base class for first order optimizers
+// order of invocation: new, Load(), Optimize(), Save(), delete
+class Optimizer {
+ public:
+ Optimizer() : eval_(1), has_converged_(false) {}
+ virtual ~Optimizer();
+ virtual std::string Name() const = 0;
+ int EvaluationCount() const { return eval_; }
+ bool HasConverged() const { return has_converged_; }
+
+ void Optimize(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x) {
+ assert(g.size() == x->size());
+ OptimizeImpl(obj, g, x);
+ scitbx::lbfgs::traditional_convergence_test<double> converged(g.size());
+ has_converged_ = converged(&(*x)[0], &g[0]);
+ }
+
+ void Save(std::ostream* out) const;
+ void Load(std::istream* in);
+ protected:
+ virtual void SaveImpl(std::ostream* out) const;
+ virtual void LoadImpl(std::istream* in);
+ virtual void OptimizeImpl(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x) = 0;
+
+ int eval_;
+ private:
+ bool has_converged_;
+};
+
+class RPropOptimizer : public Optimizer {
+ public:
+ explicit RPropOptimizer(int num_vars,
+ double eta_plus = 1.2,
+ double eta_minus = 0.5,
+ double delta_0 = 0.1,
+ double delta_max = 50.0,
+ double delta_min = 1e-6) :
+ prev_g_(num_vars, 0.0),
+ delta_ij_(num_vars, delta_0),
+ eta_plus_(eta_plus),
+ eta_minus_(eta_minus),
+ delta_max_(delta_max),
+ delta_min_(delta_min) {
+ assert(eta_plus > 1.0);
+ assert(eta_minus > 0.0 && eta_minus < 1.0);
+ assert(delta_max > 0.0);
+ assert(delta_min > 0.0);
+ }
+ std::string Name() const;
+ void OptimizeImpl(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x);
+ void SaveImpl(std::ostream* out) const;
+ void LoadImpl(std::istream* in);
+ private:
+ std::vector<double> prev_g_;
+ std::vector<double> delta_ij_;
+ const double eta_plus_;
+ const double eta_minus_;
+ const double delta_max_;
+ const double delta_min_;
+};
+
+class SGDOptimizer : public Optimizer {
+ public:
+ explicit SGDOptimizer(int num_vars, double eta = 0.1) : eta_(eta) {
+ (void) num_vars;
+ }
+ std::string Name() const;
+ void OptimizeImpl(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x);
+ private:
+ const double eta_;
+};
+
+class LBFGSOptimizer : public Optimizer {
+ public:
+ explicit LBFGSOptimizer(int num_vars, int memory_buffers = 10);
+ std::string Name() const;
+ void SaveImpl(std::ostream* out) const;
+ void LoadImpl(std::istream* in);
+ void OptimizeImpl(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x);
+ private:
+ scitbx::lbfgs::minimizer<double> opt_;
+};
+
+#endif
diff --git a/training/optimize_test.cc b/training/optimize_test.cc
new file mode 100644
index 00000000..0ada7cbb
--- /dev/null
+++ b/training/optimize_test.cc
@@ -0,0 +1,105 @@
+#include <cassert>
+#include <iostream>
+#include <sstream>
+#include <boost/program_options/variables_map.hpp>
+#include "optimize.h"
+#include "sparse_vector.h"
+#include "fdict.h"
+
+using namespace std;
+
+double TestOptimizer(Optimizer* opt) {
+ cerr << "TESTING NON-PERSISTENT OPTIMIZER\n";
+
+ // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5
+ // df/dx1 = 8*x1 + x2
+ // df/dx2 = 2*x2 + x1
+ // df/dx3 = 2*x3 + 6
+ vector<double> x(3);
+ vector<double> g(3);
+ x[0] = 8;
+ x[1] = 8;
+ x[2] = 8;
+ double obj = 0;
+ do {
+ g[0] = 8 * x[0] + x[1];
+ g[1] = 2 * x[1] + x[0];
+ g[2] = 2 * x[2] + 6;
+ obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5;
+ opt->Optimize(obj, g, &x);
+
+ cerr << x[0] << " " << x[1] << " " << x[2] << endl;
+ cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl;
+ } while (!opt->HasConverged());
+ return obj;
+}
+
+double TestPersistentOptimizer(Optimizer* opt) {
+ cerr << "\nTESTING PERSISTENT OPTIMIZER\n";
+ // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5
+ // df/dx1 = 8*x1 + x2
+ // df/dx2 = 2*x2 + x1
+ // df/dx3 = 2*x3 + 6
+ vector<double> x(3);
+ vector<double> g(3);
+ x[0] = 8;
+ x[1] = 8;
+ x[2] = 8;
+ double obj = 0;
+ string state;
+ bool converged = false;
+ while (!converged) {
+ g[0] = 8 * x[0] + x[1];
+ g[1] = 2 * x[1] + x[0];
+ g[2] = 2 * x[2] + 6;
+ obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5;
+
+ {
+ if (state.size() > 0) {
+ istringstream is(state, ios::binary);
+ opt->Load(&is);
+ }
+ opt->Optimize(obj, g, &x);
+ ostringstream os(ios::binary); opt->Save(&os); state = os.str();
+
+ }
+
+ cerr << x[0] << " " << x[1] << " " << x[2] << endl;
+ cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl;
+ converged = opt->HasConverged();
+ if (!converged) {
+ // now screw up the state (should be undone by Load)
+ obj += 2.0;
+ g[1] = -g[2];
+ vector<double> x2 = x;
+ try {
+ opt->Optimize(obj, g, &x2);
+ } catch (...) { }
+ }
+ }
+ return obj;
+}
+
+template <class O>
+void TestOptimizerVariants(int num_vars) {
+ O oa(num_vars);
+ cerr << "-------------------------------------------------------------------------\n";
+ cerr << "TESTING: " << oa.Name() << endl;
+ double o1 = TestOptimizer(&oa);
+ O ob(num_vars);
+ double o2 = TestPersistentOptimizer(&ob);
+ if (o1 != o2) {
+ cerr << oa.Name() << " VARIANTS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl;
+ exit(1);
+ }
+ cerr << oa.Name() << " SUCCESS\n";
+}
+
+int main() {
+ int n = 3;
+ TestOptimizerVariants<SGDOptimizer>(n);
+ TestOptimizerVariants<LBFGSOptimizer>(n);
+ TestOptimizerVariants<RPropOptimizer>(n);
+ return 0;
+}
+
diff --git a/training/plftools.cc b/training/plftools.cc
new file mode 100644
index 00000000..903ec54f
--- /dev/null
+++ b/training/plftools.cc
@@ -0,0 +1,93 @@
+#include <iostream>
+#include <fstream>
+#include <vector>
+
+#include <boost/lexical_cast.hpp>
+#include <boost/program_options.hpp>
+
+#include "filelib.h"
+#include "tdict.h"
+#include "prob.h"
+#include "hg.h"
+#include "hg_io.h"
+#include "viterbi.h"
+#include "kbest.h"
+
+namespace po = boost::program_options;
+using namespace std;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("input,i", po::value<string>(), "REQ. Lattice input file (PLF), - for STDIN")
+ ("prior_scale,p", po::value<double>()->default_value(1.0), "Scale path probabilities by this amount < 1 flattens, > 1 sharpens")
+ ("weight,w", po::value<vector<double> >(), "Weight(s) for arc features")
+ ("output,o", po::value<string>()->default_value("plf"), "Output format (text, plf)")
+ ("command,c", po::value<string>()->default_value("push"), "Operation to perform: push, graphviz, 1best, 2best ...")
+ ("help,h", "Print this help message and exit");
+ po::options_description clo("Command line options");
+ po::options_description dcmdline_options;
+ dcmdline_options.add(opts);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ po::notify(*conf);
+
+ if (conf->count("help") || conf->count("input") == 0) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int main(int argc, char **argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ string infile = conf["input"].as<string>();
+ ReadFile rf(infile);
+ istream* in = rf.stream();
+ assert(*in);
+ SparseVector<double> wts;
+ vector<double> wv;
+ if (conf.count("weight") > 0) wv = conf["weight"].as<vector<double> >();
+ if (wv.empty()) wv.push_back(1.0);
+ for (int i = 0; i < wv.size(); ++i) {
+ const string fname = "Feature_" + boost::lexical_cast<string>(i);
+ cerr << "[INFO] Arc weight " << (i+1) << " = " << wv[i] << endl;
+ wts.set_value(FD::Convert(fname), wv[i]);
+ }
+ const string cmd = conf["command"].as<string>();
+ const bool push_weights = cmd == "push";
+ const bool output_plf = cmd == "plf";
+ const bool graphviz = cmd == "graphviz";
+ const bool kbest = cmd.rfind("best") == (cmd.size() - 4) && cmd.size() > 4;
+ int k = 1;
+ if (kbest) {
+ k = boost::lexical_cast<int>(cmd.substr(0, cmd.size() - 4));
+ cerr << "KBEST = " << k << endl;
+ }
+ const double scale = conf["prior_scale"].as<double>();
+ int lc = 0;
+ while(*in) {
+ ++lc;
+ string plf;
+ getline(*in, plf);
+ if (plf.empty()) continue;
+ Hypergraph hg;
+ HypergraphIO::ReadFromPLF(plf, &hg);
+ hg.Reweight(wts);
+ if (graphviz) hg.PrintGraphviz();
+ if (push_weights) hg.PushWeightsToSource(scale);
+ if (output_plf) {
+ cout << HypergraphIO::AsPLF(hg) << endl;
+ } else {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, k);
+ for (int i = 0; i < k; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cout << lc << " ||| " << TD::GetString(d->yield) << " ||| " << d->score << endl;
+ }
+ }
+ }
+ return 0;
+}
+
diff --git a/vest/Makefile.am b/vest/Makefile.am
new file mode 100644
index 00000000..87c2383a
--- /dev/null
+++ b/vest/Makefile.am
@@ -0,0 +1,32 @@
+bin_PROGRAMS = \
+ mr_vest_map \
+ mr_vest_reduce \
+ scorer_test \
+ lo_test \
+ mr_vest_generate_mapper_input \
+ fast_score \
+ union_forests
+
+union_forests_SOURCES = union_forests.cc
+union_forests_LDADD = $(top_srcdir)/src/libhg.a -lz
+
+fast_score_SOURCES = fast_score.cc ter.cc comb_scorer.cc scorer.cc viterbi_envelope.cc
+fast_score_LDADD = $(top_srcdir)/src/libhg.a -lz
+
+mr_vest_generate_mapper_input_SOURCES = mr_vest_generate_mapper_input.cc line_optimizer.cc
+mr_vest_generate_mapper_input_LDADD = $(top_srcdir)/src/libhg.a -lz
+
+mr_vest_map_SOURCES = viterbi_envelope.cc error_surface.cc mr_vest_map.cc scorer.cc ter.cc comb_scorer.cc line_optimizer.cc
+mr_vest_map_LDADD = $(top_srcdir)/src/libhg.a -lz
+
+mr_vest_reduce_SOURCES = error_surface.cc mr_vest_reduce.cc scorer.cc ter.cc comb_scorer.cc line_optimizer.cc viterbi_envelope.cc
+mr_vest_reduce_LDADD = $(top_srcdir)/src/libhg.a -lz
+
+scorer_test_SOURCES = scorer_test.cc scorer.cc ter.cc comb_scorer.cc viterbi_envelope.cc
+scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(top_srcdir)/src/libhg.a -lz
+
+lo_test_SOURCES = lo_test.cc scorer.cc ter.cc comb_scorer.cc viterbi_envelope.cc error_surface.cc line_optimizer.cc
+lo_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(top_srcdir)/src/libhg.a -lz
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(BOOST_CPPFLAGS) $(GTEST_CPPFLAGS) -I$(top_srcdir)/src
+AM_LDFLAGS = $(BOOST_LDFLAGS) $(BOOST_PROGRAM_OPTIONS_LIB)
diff --git a/vest/comb_scorer.cc b/vest/comb_scorer.cc
new file mode 100644
index 00000000..7b2187f4
--- /dev/null
+++ b/vest/comb_scorer.cc
@@ -0,0 +1,81 @@
+#include "comb_scorer.h"
+
+#include <cstdio>
+
+using namespace std;
+
+class BLEUTERCombinationScore : public Score {
+ friend class BLEUTERCombinationScorer;
+ public:
+ ~BLEUTERCombinationScore();
+ float ComputeScore() const {
+ return (bleu->ComputeScore() - ter->ComputeScore()) / 2.0f;
+ }
+ void ScoreDetails(string* details) const {
+ char buf[160];
+ sprintf(buf, "Combi = %.2f, BLEU = %.2f, TER = %.2f",
+ ComputeScore()*100.0f, bleu->ComputeScore()*100.0f, ter->ComputeScore()*100.0f);
+ *details = buf;
+ }
+ void PlusEquals(const Score& delta) {
+ bleu->PlusEquals(*static_cast<const BLEUTERCombinationScore&>(delta).bleu);
+ ter->PlusEquals(*static_cast<const BLEUTERCombinationScore&>(delta).ter);
+ }
+ Score* GetZero() const {
+ BLEUTERCombinationScore* res = new BLEUTERCombinationScore;
+ res->bleu = bleu->GetZero();
+ res->ter = ter->GetZero();
+ return res;
+ }
+ void Subtract(const Score& rhs, Score* res) const {
+ bleu->Subtract(*static_cast<const BLEUTERCombinationScore&>(rhs).bleu,
+ static_cast<BLEUTERCombinationScore*>(res)->bleu);
+ ter->Subtract(*static_cast<const BLEUTERCombinationScore&>(rhs).ter,
+ static_cast<BLEUTERCombinationScore*>(res)->ter);
+ }
+ void Encode(std::string* out) const {
+ string bs, ts;
+ bleu->Encode(&bs);
+ ter->Encode(&ts);
+ out->clear();
+ (*out) += static_cast<char>(bs.size());
+ (*out) += bs;
+ (*out) += ts;
+ }
+ bool IsAdditiveIdentity() const {
+ return bleu->IsAdditiveIdentity() && ter->IsAdditiveIdentity();
+ }
+ private:
+ Score* bleu;
+ Score* ter;
+};
+
+BLEUTERCombinationScore::~BLEUTERCombinationScore() {
+ delete bleu;
+ delete ter;
+}
+
+BLEUTERCombinationScorer::BLEUTERCombinationScorer(const vector<vector<WordID> >& refs) {
+ bleu_ = SentenceScorer::CreateSentenceScorer(IBM_BLEU, refs);
+ ter_ = SentenceScorer::CreateSentenceScorer(TER, refs);
+}
+
+BLEUTERCombinationScorer::~BLEUTERCombinationScorer() {
+ delete bleu_;
+ delete ter_;
+}
+
+Score* BLEUTERCombinationScorer::ScoreCandidate(const std::vector<WordID>& hyp) const {
+ BLEUTERCombinationScore* res = new BLEUTERCombinationScore;
+ res->bleu = bleu_->ScoreCandidate(hyp);
+ res->ter = ter_->ScoreCandidate(hyp);
+ return res;
+}
+
+Score* BLEUTERCombinationScorer::ScoreFromString(const std::string& in) {
+ int bss = in[0];
+ BLEUTERCombinationScore* r = new BLEUTERCombinationScore;
+ r->bleu = SentenceScorer::CreateScoreFromString(IBM_BLEU, in.substr(1, bss));
+ r->ter = SentenceScorer::CreateScoreFromString(TER, in.substr(1 + bss));
+ return r;
+}
diff --git a/vest/dist-vest.pl b/vest/dist-vest.pl
new file mode 100755
index 00000000..5528838c
--- /dev/null
+++ b/vest/dist-vest.pl
@@ -0,0 +1,642 @@
+#!/usr/bin/env perl
+
+use Getopt::Long;
+use IPC::Open2;
+use strict;
+use POSIX ":sys_wait_h";
+
+my $mydir = `dirname $0`;
+chomp $mydir;
+# Default settings
+my $srcFile = "/fs/cliplab/mteval/Evaluation/Chinese-English/mt03.src.txt";
+my $refFiles = "/fs/cliplab/mteval/Evaluation/Chinese-English/mt03.ref.txt.*";
+my $bin_dir = "/fs/clip-software/cdec/bin";
+$bin_dir = "/Users/redpony/cdyer-svn-root/cdec/vest/bin_dir";
+die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir;
+my $FAST_SCORE="$bin_dir/fast_score";
+die "Can't find $FAST_SCORE" unless -x $FAST_SCORE;
+my $MAPINPUT = "$bin_dir/mr_vest_generate_mapper_input";
+my $MAPPER = "$bin_dir/mr_vest_map";
+my $REDUCER = "$bin_dir/mr_vest_reduce";
+my $SCORER = $FAST_SCORE;
+die "Can't find $MAPPER" unless -x $MAPPER;
+my $forestUnion = "$bin_dir/union_forests";
+die "Can't find $forestUnion" unless -x $forestUnion;
+my $cdec = "$bin_dir/cdec";
+die "Can't find decoder in $cdec" unless -x $cdec;
+my $decoder = $cdec;
+my $lines_per_mapper = 440;
+my $rand_directions = 10;
+my $iteration = 1;
+my $run_local = 0;
+my $best_weights;
+my $max_iterations = 40;
+my $optimization_iters = 6;
+my $num_rand_points = 20;
+my $mert_nodes = join(" ", grep(/^c\d\d$/, split(/\n/, `pbsnodes -a`))); # "1 1 1 1 1" fails due to file staging conflicts
+my $decode_nodes = "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"; # start 15 jobs
+my $pmem = "3g";
+my $disable_clean = 0;
+my %seen_weights;
+my $normalize;
+my $help = 0;
+my $epsilon = 0.0001;
+my $interval = 5;
+my $dryrun = 0;
+my $ranges;
+my $restart = 0;
+my $metric = "ibm_bleu";
+my $dir;
+my $iniFile;
+my $weights;
+my $initialWeights;
+my $decoderOpt;
+
+# Process command-line options
+Getopt::Long::Configure("no_auto_abbrev");
+if (GetOptions(
+ "decoder=s" => \$decoderOpt,
+ "decode-nodes=s" => \$decode_nodes,
+ "dont-clean" => \$disable_clean,
+ "dry-run" => \$dryrun,
+ "epsilon" => \$epsilon,
+ "help" => \$help,
+ "interval" => \$interval,
+ "iteration=i" => \$iteration,
+ "local" => \$run_local,
+ "max-iterations=i" => \$max_iterations,
+ "mert-nodes=s" => \$mert_nodes,
+ "normalize=s" => \$normalize,
+ "pmem=s" => \$pmem,
+ "ranges=s" => \$ranges,
+ "rand-directions=i" => \$rand_directions,
+ "ref-files=s" => \$refFiles,
+ "metric=s" => \$metric,
+ "restart" => \$restart,
+ "source-file=s" => \$srcFile,
+ "weights=s" => \$initialWeights,
+ "workdir=s" => \$dir
+) == 0 || @ARGV!=1 || $help) {
+ print_help();
+ exit;
+}
+
+if ($metric =~ /^(combi|ter)$/i) {
+ $lines_per_mapper = 40;
+}
+
+($iniFile) = @ARGV;
+
+sub write_config;
+sub enseg;
+sub print_help;
+
+my $nodelist;
+my $host =`hostname`; chomp $host;
+my $bleu;
+my $best_bleu = 0.0;
+my $interval_count = 0;
+my $epsilon_bleu = $best_bleu + $epsilon;
+my $logfile;
+my $projected_score;
+
+my $refs_comma_sep = get_comma_sep_refs($refFiles);
+
+unless ($dir){
+ $dir = "vest";
+}
+unless ($dir =~ /^\//){ # convert relative path to absolute path
+ my $basedir = `pwd`;
+ chomp $basedir;
+ $dir = "$basedir/$dir";
+}
+if ($restart){
+ $iniFile = `ls $dir/*.ini`; chomp $iniFile;
+ unless (-e $iniFile){
+ die "ERROR: Could not find ini file in $dir to restart\n";
+ }
+ $logfile = "$dir/mert.log";
+ open(LOGFILE, ">>$logfile");
+ print LOGFILE "RESTARTING STOPPED OPTIMIZATION\n\n";
+
+ # figure out best weights so far and iteration number
+ open(LOG, "$dir/mert.log");
+ my $wi = 0;
+ while (my $line = <LOG>){
+ chomp $line;
+ if ($line =~ /ITERATION (\d+)/) {
+ $iteration = $1;
+ }
+ }
+
+ $iteration = $wi + 1;
+}
+
+if ($decoderOpt){ $decoder = $decoderOpt; }
+
+
+# Initializations and helper functions
+srand;
+
+my @childpids = ();
+my @cleanupcmds = ();
+
+sub cleanup {
+ print STDERR "Cleanup...\n";
+ for my $pid (@childpids){ `kill $pid`; }
+ for my $cmd (@cleanupcmds){`$cmd`; }
+ exit 1;
+};
+$SIG{INT} = "cleanup";
+$SIG{TERM} = "cleanup";
+$SIG{HUP} = "cleanup";
+
+my $decoderBase = `basename $decoder`; chomp $decoderBase;
+my $newIniFile = "$dir/$decoderBase.ini";
+my $parallelize = "$mydir/parallelize.pl";
+my $inputFileName = "$dir/input";
+my $user = $ENV{"USER"};
+
+# process ini file
+-e $iniFile || die "Error: could not open $iniFile for reading\n";
+open(INI, $iniFile);
+
+if ($dryrun){
+ write_config(*STDERR);
+ exit 0;
+} else {
+ if (-e $dir){
+ unless($restart){
+ die "ERROR: working dir $dir already exists\n\n";
+ }
+ } else {
+ mkdir $dir;
+ mkdir "$dir/hgs";
+ mkdir "$dir/hgs-current";
+ unless (-e $initialWeights) {
+ print STDERR "Please specify an initial weights file with --initial-weights\n";
+ print_help();
+ exit;
+ }
+ `cp $initialWeights $dir/weights.0`;
+ die "Can't find weights.0" unless (-e "$dir/weights.0");
+ }
+ unless($restart){
+ $logfile = "$dir/mert.log";
+ open(LOGFILE, ">$logfile");
+ }
+ write_config(*LOGFILE);
+}
+
+
+# Generate initial files and values
+unless ($restart){ `cp $iniFile $newIniFile`; }
+$iniFile = $newIniFile;
+
+my $newsrc = "$dir/dev.input";
+unless($restart){ enseg($srcFile, $newsrc); }
+$srcFile = $newsrc;
+my $devSize = 0;
+open F, "<$srcFile" or die "Can't read $srcFile: $!";
+while(<F>) { $devSize++; }
+close F;
+
+unless($best_weights){ $best_weights = $weights; }
+unless($projected_score){ $projected_score = 0.0; }
+$seen_weights{$weights} = 1;
+
+my $random_seed = int(time / 1000);
+my $lastWeightsFile;
+my $lastPScore = 0;
+# main optimization loop
+while (1){
+ print LOGFILE "\n\nITERATION $iteration\n==========\n";
+
+ # iteration-specific files
+ my $runFile="$dir/run.raw.$iteration";
+ my $onebestFile="$dir/1best.$iteration";
+ my $logdir="$dir/logs.$iteration";
+ my $decoderLog="$logdir/decoder.sentserver.log.$iteration";
+ my $scorerLog="$logdir/scorer.log.$iteration";
+ `mkdir -p $logdir`;
+
+ #decode
+ print LOGFILE "DECODE\n";
+ print LOGFILE `date`;
+ my $im1 = $iteration - 1;
+ my $weightsFile="$dir/weights.$im1";
+ my $decoder_cmd = "$decoder -c $iniFile -w $weightsFile -O $dir/hgs-current";
+ my $pcmd = "cat $srcFile | $parallelize -p $pmem -e $logdir -n \"$decode_nodes\" -- ";
+ if ($run_local) { $pcmd = "cat $srcFile |"; }
+ my $cmd = $pcmd . "$decoder_cmd 2> $decoderLog 1> $runFile";
+ print LOGFILE "COMMAND:\n$cmd\n";
+ my $result = 0;
+ $result = system($cmd);
+ unless ($result == 0){
+ cleanup();
+ print LOGFILE "ERROR: Parallel decoder returned non-zero exit code $result\n";
+ die;
+ }
+ my $dec_score = `cat $runFile | $SCORER $refs_comma_sep -l $metric`;
+ chomp $dec_score;
+ print LOGFILE "DECODER SCORE: $dec_score\n";
+
+ # save space
+ `gzip $runFile`;
+ `gzip $decoderLog`;
+
+ if ($iteration > $max_iterations){
+ print LOGFILE "\nREACHED STOPPING CRITERION: Maximum iterations\n";
+ last;
+ }
+
+ # run optimizer
+ print LOGFILE "\nUNION FORESTS\n";
+ print LOGFILE `date`;
+ my $mergeLog="$logdir/prune-merge.log.$iteration";
+ $cmd = "$forestUnion -r $dir/hgs -n $dir/hgs-current -s $devSize";
+ print LOGFILE "COMMAND:\n$cmd\n";
+ $result = system($cmd);
+ unless ($result == 0){
+ cleanup();
+ print LOGFILE "ERROR: merge command returned non-zero exit code $result\n";
+ die;
+ }
+ `rm -f $dir/hgs-current/*.json.gz`; # clean up old HGs, they've been moved to the repository
+
+ my $score = 0;
+ my $icc = 0;
+ my $inweights="$dir/weights.$im1";
+ for (my $opt_iter=1; $opt_iter<$optimization_iters; $opt_iter++) {
+ print LOGFILE "\nGENERATE OPTIMIZATION STRATEGY (OPT-ITERATION $opt_iter/$optimization_iters)\n";
+ print LOGFILE `date`;
+ $icc++;
+ $cmd="$MAPINPUT -w $inweights -r $dir/hgs -s $devSize -d $rand_directions > $dir/agenda.$im1-$opt_iter";
+ print LOGFILE "COMMAND:\n$cmd\n";
+ $result = system($cmd);
+ unless ($result == 0){
+ cleanup();
+ print LOGFILE "ERROR: mapinput command returned non-zero exit code $result\n";
+ die;
+ }
+
+ `mkdir $dir/splag.$im1`;
+ $cmd="split -a 3 -l $lines_per_mapper $dir/agenda.$im1-$opt_iter $dir/splag.$im1/mapinput.";
+ print LOGFILE "COMMAND:\n$cmd\n";
+ $result = system($cmd);
+ unless ($result == 0){
+ cleanup();
+ print LOGFILE "ERROR: split command returned non-zero exit code $result\n";
+ die;
+ }
+ opendir(DIR, "$dir/splag.$im1") or die "Can't open directory: $!";
+ my @shards = grep { /^mapinput\./ } readdir(DIR);
+ closedir DIR;
+ die "No shards!" unless scalar @shards > 0;
+ my $joblist = "";
+ my $nmappers = 0;
+ my @mapoutputs = ();
+ @cleanupcmds = ();
+ my %o2i = ();
+ my $first_shard = 1;
+ for my $shard (@shards) {
+ my $mapoutput = $shard;
+ my $client_name = $shard;
+ $client_name =~ s/mapinput.//;
+ $client_name = "fmert.$client_name";
+ $mapoutput =~ s/mapinput/mapoutput/;
+ push @mapoutputs, "$dir/splag.$im1/$mapoutput";
+ $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard";
+ my $script = "$MAPPER -l $metric $refs_comma_sep < $dir/splag.$im1/$shard | sort -k1 > $dir/splag.$im1/$mapoutput";
+ if ($run_local) {
+ print LOGFILE "COMMAND:\n$script\n";
+ $result = system($script);
+ unless ($result == 0){
+ cleanup();
+ print LOGFILE "ERROR: mapper returned non-zero exit code $result\n";
+ die;
+ }
+ } else {
+ my $todo = "qsub -q batch -l pmem=3000mb,walltime=5:00:00 -N $client_name -o /dev/null -e $logdir/$client_name.ER";
+ local(*QOUT, *QIN);
+ open2(\*QOUT, \*QIN, $todo);
+ print QIN $script;
+ if ($first_shard) { print LOGFILE "$script\n"; $first_shard=0; }
+ close QIN;
+ $nmappers++;
+ while (my $jobid=<QOUT>){
+ chomp $jobid;
+ push(@cleanupcmds, "`qdel $jobid 2> /dev/null`");
+ $jobid =~ s/^(\d+)(.*?)$/\1/g;
+ print STDERR "short job id $jobid\n";
+ if ($joblist == "") { $joblist = $jobid; }
+ else {$joblist = $joblist . "\|" . $jobid; }
+ }
+ close QOUT;
+ }
+ }
+ if ($run_local) {
+ } else {
+ print LOGFILE "Launched $nmappers mappers.\n";
+ print LOGFILE "Waiting for mappers to complete...\n";
+ while ($nmappers > 0) {
+ sleep 5;
+ my @livejobs = grep(/$joblist/, split(/\n/, `qstat`));
+ $nmappers = scalar @livejobs;
+ }
+ print LOGFILE "All mappers complete.\n";
+ }
+ my $tol = 0;
+ my $til = 0;
+ for my $mo (@mapoutputs) {
+ my $olines = get_lines($mo);
+ my $ilines = get_lines($o2i{$mo});
+ $tol += $olines;
+ $til += $ilines;
+ die "$mo: output lines ($olines) doesn't match input lines ($ilines)" unless $olines==$ilines;
+ }
+ print LOGFILE "Results for $tol/$til lines\n";
+ print LOGFILE "\nSORTING AND RUNNING FMERT REDUCER\n";
+ print LOGFILE `date`;
+ $cmd="sort -k1 @mapoutputs | $REDUCER > $dir/redoutput.$im1";
+ print LOGFILE "COMMAND:\n$cmd\n";
+ $result = system($cmd);
+ unless ($result == 0){
+ cleanup();
+ print LOGFILE "ERROR: reducer command returned non-zero exit code $result\n";
+ die;
+ }
+ $cmd="sort -rnk3 '-t|' $dir/redoutput.$im1 | head -1";
+ my $best=`$cmd`; chomp $best;
+ print LOGFILE "$best\n";
+ my ($oa, $x, $xscore) = split /\|/, $best;
+ $score = $xscore;
+ print LOGFILE "PROJECTED SCORE: $score\n";
+ if (abs($x) < $epsilon) {
+ print LOGFILE "\nOPTIMIZER: no score improvement: abs($x) < $epsilon\n";
+ last;
+ }
+ my ($origin, $axis) = split /\s+/, $oa;
+
+ my %ori = convert($origin);
+ my %axi = convert($axis);
+
+ my $finalFile="$dir/weights.$im1-$opt_iter";
+ open W, ">$finalFile" or die "Can't write: $finalFile: $!";
+ for my $k (sort keys %ori) {
+ my $v = $ori{$k} + $axi{$k} * $x;
+ print W "$k $v\n";
+ }
+
+ `rm -rf $dir/splag.$im1`;
+ $inweights = $finalFile;
+ }
+ $lastWeightsFile = "$dir/weights.$iteration";
+ `cp $inweights $lastWeightsFile`;
+ if ($icc < 2) {
+ print LOGFILE "\nREACHED STOPPING CRITERION: score change too little\n";
+ last;
+ }
+ $lastPScore = $score;
+ $iteration++;
+ print LOGFILE "\n==========\n";
+}
+
+print LOGFILE "\nFINAL WEIGHTS: $dir/$lastWeightsFile\n(Use -w <this file> with hiero)\n\n";
+
+sub normalize_weights {
+ my ($rfn, $rpts, $feat) = @_;
+ my @feat_names = @$rfn;
+ my @pts = @$rpts;
+ my $z = 1.0;
+ for (my $i=0; $i < scalar @feat_names; $i++) {
+ if ($feat_names[$i] eq $feat) {
+ $z = $pts[$i];
+ last;
+ }
+ }
+ for (my $i=0; $i < scalar @feat_names; $i++) {
+ $pts[$i] /= $z;
+ }
+ print LOGFILE " NORM WEIGHTS: @pts\n";
+ return @pts;
+}
+
+sub get_lines {
+ my $fn = shift @_;
+ open FL, "<$fn" or die "Couldn't read $fn: $!";
+ my $lc = 0;
+ while(<FL>) { $lc++; }
+ return $lc;
+}
+
+sub get_comma_sep_refs {
+ my ($p) = @_;
+ my $o = `echo $p`;
+ chomp $o;
+ my @files = split /\s+/, $o;
+ return "-r " . join(' -r ', @files);
+}
+
+sub read_weights_file {
+ my ($file) = @_;
+ open F, "<$file" or die "Couldn't read $file: $!";
+ my @r = ();
+ my $pm = -1;
+ while(<F>) {
+ next if /^#/;
+ next if /^\s*$/;
+ chomp;
+ if (/^(.+)\s+(.+)$/) {
+ my $m = $1;
+ my $w = $2;
+ die "Weights out of order: $m <= $pm" unless $m > $pm;
+ push @r, $w;
+ } else {
+ warn "Unexpected feature name in weight file: $_";
+ }
+ }
+ close F;
+ return join ' ', @r;
+}
+
+# subs
+sub write_config {
+ my $fh = shift;
+ my $cleanup = "yes";
+ if ($disable_clean) {$cleanup = "no";}
+
+ print $fh "\n";
+ print $fh "DECODER: $decoder\n";
+ print $fh "INI FILE: $iniFile\n";
+ print $fh "WORKING DIR: $dir\n";
+ print $fh "SOURCE (DEV): $srcFile\n";
+ print $fh "REFS (DEV): $refFiles\n";
+ print $fh "EVAL METRIC: $metric\n";
+ print $fh "START ITERATION: $iteration\n";
+ print $fh "MAX ITERATIONS: $max_iterations\n";
+ print $fh "MERT NODES: $mert_nodes\n";
+ print $fh "DECODE NODES: $decode_nodes\n";
+ print $fh "HEAD NODE: $host\n";
+ print $fh "PMEM (DECODING): $pmem\n";
+ print $fh "CLEANUP: $cleanup\n";
+ print $fh "INITIAL WEIGHTS: $initialWeights\n";
+
+ if ($restart){
+ print $fh "PROJECTED BLEU: $projected_score\n";
+ print $fh "BEST WEIGHTS: $best_weights\n";
+ print $fh "BEST BLEU: $best_bleu\n";
+ }
+}
+
+sub update_weights_file {
+ my ($neww, $rfn, $rpts) = @_;
+ my @feats = @$rfn;
+ my @pts = @$rpts;
+ my $num_feats = scalar @feats;
+ my $num_pts = scalar @pts;
+ die "$num_feats (num_feats) != $num_pts (num_pts)" unless $num_feats == $num_pts;
+ open G, ">$neww" or die;
+ for (my $i = 0; $i < $num_feats; $i++) {
+ my $f = $feats[$i];
+ my $lambda = $pts[$i];
+ print G "$f $lambda\n";
+ }
+ close G;
+}
+
+sub enseg {
+ my $src = shift;
+ my $newsrc = shift;
+ open(SRC, $src);
+ open(NEWSRC, ">$newsrc");
+ my $i=0;
+ while (my $line=<SRC>){
+ chomp $line;
+ print NEWSRC "<seg id=\"$i\">$line</seg>\n";
+ $i++;
+ }
+ close SRC;
+ close NEWSRC;
+}
+
+sub print_help {
+
+ my $executable = `basename $0`; chomp $executable;
+ print << "Help";
+
+Usage: $executable [options] <ini file>
+ $executable --restart <work dir>
+
+ $executable [options] <ini file>
+ Runs a complete MERT optimization and test set decoding, using
+ the decoder configuration in ini file. Note that many of the
+ options have default values that are inferred automatically
+ based on certain conventions. For details, refer to descriptions
+ of the options --decoder, --weights, and --workdir.
+
+ $executable --restart <work dir>
+ Continues an optimization run that was stopped for some reason,
+ using configuration information found in the working directory
+ left behind by the stopped run.
+
+Options:
+
+ --local
+ Run the decoder and optimizer locally.
+
+ --decoder <decoder path>
+ Decoder binary to use.
+
+ --decode-nodes <nodelist>
+ A list of nodes used for parallel decoding. If specific nodes
+ are not desired, use "1" for each node requested. Defaults to
+ "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1", which indicates a request for
+ 15 nodes.
+
+ --dont-clean
+ If present, this flag prevents intermediate files, including
+ run files and cumulative files, from being automatically removed
+ after a successful optimization run (these files are left if the
+ run fails for any reason). If used, a makefile containing
+ cleanup commands is written to the directory. To clean up
+ the intermediate files, invoke make without any arguments.
+
+ --dry-run
+ Prints out the settings and exits without doing anything.
+
+ --epsilon <epsilon>
+ Require that the dev set BLEU score improve by at least <epsilon>
+ within <interval> iterations (controlled by parameter --interval).
+ If not specified, defaults to .002.
+
+ --help
+ Print this message and exit.
+
+ --interval <i>
+ Require that the dev set BLEU score improve by at least <epsilon>
+ (controlled by parameter --epsilon) within <interval> iterations.
+ If not specified, defaults to 5.
+
+ --iteration <I>
+ Starting iteration number. If not specified, defaults to 1.
+
+ --max-iterations <M>
+ Maximum number of iterations to run. If not specified, defaults
+ to 10.
+
+ --pmem <N>
+ Amount of physical memory requested for parallel decoding jobs,
+ in the format expected by qsub. If not specified, defaults to
+ 2g.
+
+ --ref-files <files>
+ Dev set ref files. This option takes only a single string argument.
+ To use multiple files (including file globbing), this argument should
+ be quoted. If not specified, defaults to
+ /fs/cliplab/mteval/Evaluation/Chinese-English/mt03.ref.txt.*
+
+ --metric <method>
+ Metric to optimize. See fmert's --metric option for values.
+ Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi
+
+ --normalize <feature-name>
+ After each iteration, rescale all feature weights such that feature-
+ name has a weight of 1.0.
+
+ --rand-directions <num>
+ MERT will attempt to optimize along all of the principle directions,
+ set this parameter to explore other directions. Defaults to 5.
+
+ --source-file <file>
+ Dev set source file. If not specified, defaults to
+ /fs/cliplab/mteval/Evaluation/Chinese-English/mt03.src.txt
+
+ --weights <file>
+ A file specifying initial feature weights. The format is
+ FeatureName_1 value1
+ FeatureName_2 value2
+
+ --workdir <dir>
+ Directory for intermediate and output files. If not specified, the
+ name is derived from the ini filename. Assuming that the ini
+ filename begins with the decoder name and ends with ini, the default
+ name of the working directory is inferred from the middle part of
+ the filename. E.g. an ini file named decoder.foo.ini would have
+ a default working directory name foo.
+
+Help
+}
+
+sub convert {
+ my ($str) = @_;
+ my @ps = split /;/, $str;
+ my %dict = ();
+ for my $p (@ps) {
+ my ($k, $v) = split /=/, $p;
+ $dict{$k} = $v;
+ }
+ return %dict;
+}
+
+
diff --git a/vest/error_surface.cc b/vest/error_surface.cc
new file mode 100644
index 00000000..4e0af35c
--- /dev/null
+++ b/vest/error_surface.cc
@@ -0,0 +1,46 @@
+#include "error_surface.h"
+
+#include <cassert>
+#include <sstream>
+
+using namespace std;
+
+ErrorSurface::~ErrorSurface() {
+ for (ErrorSurface::iterator i = begin(); i != end(); ++i)
+ //delete i->delta;
+ ;
+}
+
+void ErrorSurface::Serialize(std::string* out) const {
+ const int segments = this->size();
+ ostringstream os(ios::binary);
+ os.write((const char*)&segments,sizeof(segments));
+ for (int i = 0; i < segments; ++i) {
+ const ErrorSegment& cur = (*this)[i];
+ string senc;
+ cur.delta->Encode(&senc);
+ assert(senc.size() < 256);
+ unsigned char len = senc.size();
+ os.write((const char*)&cur.x, sizeof(cur.x));
+ os.write((const char*)&len, sizeof(len));
+ os.write((const char*)&senc[0], len);
+ }
+ *out = os.str();
+}
+
+void ErrorSurface::Deserialize(ScoreType type, const std::string& in) {
+ istringstream is(in, ios::binary);
+ int segments;
+ is.read((char*)&segments, sizeof(segments));
+ this->resize(segments);
+ for (int i = 0; i < segments; ++i) {
+ ErrorSegment& cur = (*this)[i];
+ unsigned char len;
+ is.read((char*)&cur.x, sizeof(cur.x));
+ is.read((char*)&len, sizeof(len));
+ string senc(len, '\0'); assert(senc.size() == len);
+ is.read((char*)&senc[0], len);
+ cur.delta = SentenceScorer::CreateScoreFromString(type, senc);
+ }
+}
+
diff --git a/vest/fast_score.cc b/vest/fast_score.cc
new file mode 100644
index 00000000..45b60d78
--- /dev/null
+++ b/vest/fast_score.cc
@@ -0,0 +1,74 @@
+#include <iostream>
+#include <vector>
+
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "filelib.h"
+#include "tdict.h"
+#include "scorer.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("reference,r",po::value<vector<string> >(), "[REQD] Reference translation(s) (tokenized text file)")
+ ("loss_function,l",po::value<string>()->default_value("ibm_bleu"), "Scoring metric (ibm_bleu, nist_bleu, koehn_bleu, ter, combi)")
+ ("in_file,i", po::value<string>()->default_value("-"), "Input file")
+ ("help,h", "Help");
+ po::options_description dcmdline_options;
+ dcmdline_options.add(opts);
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ bool flag = false;
+ if (!conf->count("reference")) {
+ cerr << "Please specify one or more references using -r <REF1.TXT> -r <REF2.TXT> ...\n";
+ flag = true;
+ }
+ if (flag || conf->count("help")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ const string loss_function = conf["loss_function"].as<string>();
+ ScoreType type = ScoreTypeFromString(loss_function);
+ DocScorer ds(type, conf["reference"].as<vector<string> >());
+ cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function << endl;
+
+ ReadFile rf(conf["in_file"].as<string>());
+ Score* acc = NULL;
+ istream& in = *rf.stream();
+ int lc = 0;
+ while(in) {
+ string line;
+ getline(in, line);
+ if (line.empty()) continue;
+ vector<WordID> sent;
+ TD::ConvertSentence(line, &sent);
+ Score* sentscore = ds[lc]->ScoreCandidate(sent);
+ if (!acc) { acc = sentscore->GetZero(); }
+ acc->PlusEquals(*sentscore);
+ delete sentscore;
+ ++lc;
+ }
+ assert(lc > 0);
+ if (lc > ds.size()) {
+ cerr << "Too many (" << lc << ") translations in input, expected " << ds.size() << endl;
+ return 1;
+ }
+ if (lc != ds.size())
+ cerr << "Fewer sentences in hyp (" << lc << ") than refs ("
+ << ds.size() << "): scoring partial set!\n";
+ float score = acc->ComputeScore();
+ string details;
+ acc->ScoreDetails(&details);
+ delete acc;
+ cerr << details << endl;
+ cout << score << endl;
+ return 0;
+}
diff --git a/vest/line_optimizer.cc b/vest/line_optimizer.cc
new file mode 100644
index 00000000..98dcec34
--- /dev/null
+++ b/vest/line_optimizer.cc
@@ -0,0 +1,101 @@
+#include "line_optimizer.h"
+
+#include <limits>
+#include <algorithm>
+
+#include "sparse_vector.h"
+#include "scorer.h"
+
+using namespace std;
+
+typedef ErrorSurface::const_iterator ErrorIter;
+
+// sort by increasing x-ints
+struct IntervalComp {
+ bool operator() (const ErrorIter& a, const ErrorIter& b) const {
+ return a->x < b->x;
+ }
+};
+
+double LineOptimizer::LineOptimize(
+ const vector<ErrorSurface>& surfaces,
+ const LineOptimizer::ScoreType type,
+ float* best_score,
+ const double epsilon) {
+ vector<ErrorIter> all_ints;
+ for (vector<ErrorSurface>::const_iterator i = surfaces.begin();
+ i != surfaces.end(); ++i) {
+ const ErrorSurface& surface = *i;
+ for (ErrorIter j = surface.begin(); j != surface.end(); ++j)
+ all_ints.push_back(j);
+ }
+ sort(all_ints.begin(), all_ints.end(), IntervalComp());
+ double last_boundary = all_ints.front()->x;
+ Score* acc = all_ints.front()->delta->GetZero();
+ float& cur_best_score = *best_score;
+ cur_best_score = (type == MAXIMIZE_SCORE ?
+ -numeric_limits<float>::max() : numeric_limits<float>::max());
+ bool left_edge = true;
+ double pos = numeric_limits<double>::quiet_NaN();
+ for (vector<ErrorIter>::iterator i = all_ints.begin();
+ i != all_ints.end(); ++i) {
+ const ErrorSegment& seg = **i;
+ assert(seg.delta);
+ if (seg.x - last_boundary > epsilon) {
+ float sco = acc->ComputeScore();
+ if ((type == MAXIMIZE_SCORE && sco > cur_best_score) ||
+ (type == MINIMIZE_SCORE && sco < cur_best_score) ) {
+ cur_best_score = sco;
+ if (left_edge) {
+ pos = seg.x - 0.1;
+ left_edge = false;
+ } else {
+ pos = last_boundary + (seg.x - last_boundary) / 2;
+ }
+ // cerr << "NEW BEST: " << pos << " (score=" << cur_best_score << ")\n";
+ }
+ // cerr << "---- s=" << sco << "\n";
+ last_boundary = seg.x;
+ }
+ // cerr << "x-boundary=" << seg.x << "\n";
+ acc->PlusEquals(*seg.delta);
+ }
+ float sco = acc->ComputeScore();
+ if ((type == MAXIMIZE_SCORE && sco > cur_best_score) ||
+ (type == MINIMIZE_SCORE && sco < cur_best_score) ) {
+ cur_best_score = sco;
+ if (left_edge) {
+ pos = 0;
+ } else {
+ pos = last_boundary + 1000.0;
+ }
+ }
+ delete acc;
+ return pos;
+}
+
+void LineOptimizer::RandomUnitVector(const vector<int>& features_to_optimize,
+ SparseVector<double>* axis,
+ RandomNumberGenerator<boost::mt19937>* rng) {
+ axis->clear();
+ for (int i = 0; i < features_to_optimize.size(); ++i)
+ axis->set_value(features_to_optimize[i], rng->next() - 0.5);
+ (*axis) /= axis->l2norm();
+}
+
+void LineOptimizer::CreateOptimizationDirections(
+ const vector<int>& features_to_optimize,
+ int additional_random_directions,
+ RandomNumberGenerator<boost::mt19937>* rng,
+ vector<SparseVector<double> >* dirs) {
+ const int num_directions = features_to_optimize.size() + additional_random_directions;
+ dirs->resize(num_directions);
+ for (int i = 0; i < num_directions; ++i) {
+ SparseVector<double>& axis = (*dirs)[i];
+ if (i < features_to_optimize.size())
+ axis.set_value(features_to_optimize[i], 1.0);
+ else
+ RandomUnitVector(features_to_optimize, &axis, rng);
+ }
+ cerr << "Generated " << num_directions << " total axes to optimize along.\n";
+}
diff --git a/vest/lo_test.cc b/vest/lo_test.cc
new file mode 100644
index 00000000..0acae5e0
--- /dev/null
+++ b/vest/lo_test.cc
@@ -0,0 +1,201 @@
+#include <cmath>
+#include <iostream>
+#include <fstream>
+
+#include <boost/shared_ptr.hpp>
+#include <gtest/gtest.h>
+
+#include "fdict.h"
+#include "hg.h"
+#include "kbest.h"
+#include "hg_io.h"
+#include "filelib.h"
+#include "inside_outside.h"
+#include "viterbi.h"
+#include "viterbi_envelope.h"
+#include "line_optimizer.h"
+#include "scorer.h"
+
+using namespace std;
+using boost::shared_ptr;
+
+class OptTest : public testing::Test {
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+};
+
+const char* ref11 = "australia reopens embassy in manila";
+const char* ref12 = "( afp , manila , january 2 ) australia reopened its embassy in the philippines today , which was shut down about seven weeks ago due to what was described as a specific threat of a terrorist attack .";
+const char* ref21 = "australia reopened manila embassy";
+const char* ref22 = "( agence france-presse , manila , 2nd ) - australia reopened its embassy in the philippines today . the embassy was closed seven weeks ago after what was described as a specific threat of a terrorist attack .";
+const char* ref31 = "australia to reopen embassy in manila";
+const char* ref32 = "( afp report from manila , january 2 ) australia reopened its embassy in the philippines today . seven weeks ago , the embassy was shut down due to so - called confirmed terrorist attack threats .";
+const char* ref41 = "australia to re - open its embassy to manila";
+const char* ref42 = "( afp , manila , thursday ) australia reopens its embassy to manila , which was closed for the so - called \" clear \" threat of terrorist attack 7 weeks ago .";
+
+TEST_F(OptTest, TestCheckNaN) {
+ double x = 0;
+ double y = 0;
+ double z = x / y;
+ EXPECT_EQ(true, isnan(z));
+}
+
+TEST_F(OptTest,TestViterbiEnvelope) {
+ shared_ptr<Segment> a1(new Segment(-1, 0));
+ shared_ptr<Segment> b1(new Segment(1, 0));
+ shared_ptr<Segment> a2(new Segment(-1, 1));
+ shared_ptr<Segment> b2(new Segment(1, -1));
+ vector<shared_ptr<Segment> > sa; sa.push_back(a1); sa.push_back(b1);
+ vector<shared_ptr<Segment> > sb; sb.push_back(a2); sb.push_back(b2);
+ ViterbiEnvelope a(sa);
+ cerr << a << endl;
+ ViterbiEnvelope b(sb);
+ ViterbiEnvelope c = a;
+ c *= b;
+ cerr << a << " (*) " << b << " = " << c << endl;
+ EXPECT_EQ(3, c.size());
+}
+
+TEST_F(OptTest,TestViterbiEnvelopeInside) {
+ const string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}";
+ Hypergraph hg;
+ istringstream instr(json);
+ HypergraphIO::ReadFromJSON(&instr, &hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 1.0);
+ hg.Reweight(wts);
+ vector<pair<vector<WordID>, prob_t> > list;
+ std::vector<SparseVector<double> > features;
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, 10);
+ for (int i = 0; i < 10; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl;
+ }
+ SparseVector<double> dir; dir.set_value(FD::Convert("f1"), 1.0);
+ ViterbiEnvelopeWeightFunction wf(wts, dir);
+ ViterbiEnvelope env = Inside<ViterbiEnvelope, ViterbiEnvelopeWeightFunction>(hg, NULL, wf);
+ cerr << env << endl;
+ const vector<boost::shared_ptr<Segment> >& segs = env.GetSortedSegs();
+ dir *= segs[1]->x;
+ wts += dir;
+ hg.Reweight(wts);
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest2(hg, 10);
+ for (int i = 0; i < 10; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest2.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl;
+ }
+ for (int i = 0; i < segs.size(); ++i) {
+ cerr << "seg=" << i << endl;
+ vector<WordID> trans;
+ segs[i]->ConstructTranslation(&trans);
+ cerr << TD::GetString(trans) << endl;
+ }
+}
+
+TEST_F(OptTest, TestS1) {
+ int fPhraseModel_0 = FD::Convert("PhraseModel_0");
+ int fPhraseModel_1 = FD::Convert("PhraseModel_1");
+ int fPhraseModel_2 = FD::Convert("PhraseModel_2");
+ int fLanguageModel = FD::Convert("LanguageModel");
+ int fWordPenalty = FD::Convert("WordPenalty");
+ int fPassThrough = FD::Convert("PassThrough");
+ SparseVector<double> wts;
+ wts.set_value(fWordPenalty, 4.25);
+ wts.set_value(fLanguageModel, -1.1165);
+ wts.set_value(fPhraseModel_0, -0.96);
+ wts.set_value(fPhraseModel_1, -0.65);
+ wts.set_value(fPhraseModel_2, -0.77);
+ wts.set_value(fPassThrough, -10.0);
+
+ vector<int> to_optimize;
+ to_optimize.push_back(fWordPenalty);
+ to_optimize.push_back(fLanguageModel);
+ to_optimize.push_back(fPhraseModel_0);
+ to_optimize.push_back(fPhraseModel_1);
+ to_optimize.push_back(fPhraseModel_2);
+
+ Hypergraph hg;
+ ReadFile rf("./test_data/0.json.gz");
+ HypergraphIO::ReadFromJSON(rf.stream(), &hg);
+ hg.Reweight(wts);
+
+ Hypergraph hg2;
+ ReadFile rf2("./test_data/1.json.gz");
+ HypergraphIO::ReadFromJSON(rf2.stream(), &hg2);
+ hg2.Reweight(wts);
+
+ vector<vector<WordID> > refs1(4);
+ TD::ConvertSentence(ref11, &refs1[0]);
+ TD::ConvertSentence(ref21, &refs1[1]);
+ TD::ConvertSentence(ref31, &refs1[2]);
+ TD::ConvertSentence(ref41, &refs1[3]);
+ vector<vector<WordID> > refs2(4);
+ TD::ConvertSentence(ref12, &refs2[0]);
+ TD::ConvertSentence(ref22, &refs2[1]);
+ TD::ConvertSentence(ref32, &refs2[2]);
+ TD::ConvertSentence(ref42, &refs2[3]);
+ ScoreType type = ScoreTypeFromString("ibm_bleu");
+ SentenceScorer* scorer1 = SentenceScorer::CreateSentenceScorer(type, refs1);
+ SentenceScorer* scorer2 = SentenceScorer::CreateSentenceScorer(type, refs2);
+ vector<ViterbiEnvelope> envs(2);
+
+ RandomNumberGenerator<boost::mt19937> rng;
+
+ vector<SparseVector<double> > axes;
+ LineOptimizer::CreateOptimizationDirections(
+ to_optimize,
+ 10,
+ &rng,
+ &axes);
+ assert(axes.size() == 10 + to_optimize.size());
+ for (int i = 0; i < axes.size(); ++i)
+ cerr << axes[i] << endl;
+ const SparseVector<double>& axis = axes[0];
+
+ cerr << "Computing Viterbi envelope using inside algorithm...\n";
+ cerr << "axis: " << axis << endl;
+ clock_t t_start=clock();
+ ViterbiEnvelopeWeightFunction wf(wts, axis);
+ envs[0] = Inside<ViterbiEnvelope, ViterbiEnvelopeWeightFunction>(hg, NULL, wf);
+ envs[1] = Inside<ViterbiEnvelope, ViterbiEnvelopeWeightFunction>(hg2, NULL, wf);
+
+ vector<ErrorSurface> es(2);
+ scorer1->ComputeErrorSurface(envs[0], &es[0]);
+ scorer2->ComputeErrorSurface(envs[1], &es[1]);
+ cerr << envs[0].size() << " " << envs[1].size() << endl;
+ cerr << es[0].size() << " " << es[1].size() << endl;
+ envs.clear();
+ clock_t t_env=clock();
+ float score;
+ double m = LineOptimizer::LineOptimize(es, LineOptimizer::MAXIMIZE_SCORE, &score);
+ clock_t t_opt=clock();
+ cerr << "line optimizer returned: " << m << " (SCORE=" << score << ")\n";
+ EXPECT_FLOAT_EQ(0.48719698, score);
+ SparseVector<double> res = axis;
+ res *= m;
+ res += wts;
+ cerr << "res: " << res << endl;
+ cerr << "ENVELOPE PROCESSING=" << (static_cast<double>(t_env - t_start) / 1000.0) << endl;
+ cerr << " LINE OPTIMIZATION=" << (static_cast<double>(t_opt - t_env) / 1000.0) << endl;
+ hg.Reweight(res);
+ hg2.Reweight(res);
+ vector<WordID> t1,t2;
+ ViterbiESentence(hg, &t1);
+ ViterbiESentence(hg2, &t2);
+ cerr << TD::GetString(t1) << endl;
+ cerr << TD::GetString(t2) << endl;
+ delete scorer1;
+ delete scorer2;
+}
+
+int main(int argc, char **argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc
new file mode 100644
index 00000000..c96a61e4
--- /dev/null
+++ b/vest/mr_vest_generate_mapper_input.cc
@@ -0,0 +1,72 @@
+#include <iostream>
+#include <vector>
+
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "filelib.h"
+#include "weights.h"
+#include "line_optimizer.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("dev_set_size,s",po::value<unsigned int>(),"[REQD] Development set size (# of parallel sentences)")
+ ("forest_repository,r",po::value<string>(),"[REQD] Path to forest repository")
+ ("weights,w",po::value<string>(),"[REQD] Current feature weights file")
+ ("optimize_feature,o",po::value<vector<string> >(), "Feature to optimize (if none specified, all weights listed in the weights file will be optimized)")
+ ("random_directions,d",po::value<unsigned int>()->default_value(20),"Number of random directions to run the line optimizer in")
+ ("help,h", "Help");
+ po::options_description dcmdline_options;
+ dcmdline_options.add(opts);
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ bool flag = false;
+ if (conf->count("dev_set_size") == 0) {
+ cerr << "Please specify the size of the development set using -d N\n";
+ flag = true;
+ }
+ if (conf->count("weights") == 0) {
+ cerr << "Please specify the starting-point weights using -w <weightfile.txt>\n";
+ flag = true;
+ }
+ if (conf->count("forest_repository") == 0) {
+ cerr << "Please specify the forest repository location using -r <DIR>\n";
+ flag = true;
+ }
+ if (flag || conf->count("help")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int main(int argc, char** argv) {
+ RandomNumberGenerator<boost::mt19937> rng;
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ Weights weights;
+ vector<string> features;
+ weights.InitFromFile(conf["weights"].as<string>(), &features);
+ const string forest_repository = conf["forest_repository"].as<string>();
+ assert(DirectoryExists(forest_repository));
+ SparseVector<double> origin;
+ weights.InitSparseVector(&origin);
+ if (conf.count("optimize_feature") > 0)
+ features=conf["optimize_feature"].as<vector<string> >();
+ vector<SparseVector<double> > axes;
+ vector<int> fids(features.size());
+ for (int i = 0; i < features.size(); ++i)
+ fids[i] = FD::Convert(features[i]);
+ LineOptimizer::CreateOptimizationDirections(
+ fids,
+ conf["random_directions"].as<unsigned int>(),
+ &rng,
+ &axes);
+ int dev_set_size = conf["dev_set_size"].as<unsigned int>();
+ for (int i = 0; i < dev_set_size; ++i)
+ for (int j = 0; j < axes.size(); ++j)
+ cout << forest_repository << '/' << i << ".json.gz " << i << ' ' << origin << ' ' << axes[j] << endl;
+ return 0;
+}
diff --git a/vest/mr_vest_map.cc b/vest/mr_vest_map.cc
new file mode 100644
index 00000000..80e84218
--- /dev/null
+++ b/vest/mr_vest_map.cc
@@ -0,0 +1,98 @@
+#include <sstream>
+#include <iostream>
+#include <fstream>
+#include <vector>
+
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "sparse_vector.h"
+#include "scorer.h"
+#include "viterbi_envelope.h"
+#include "inside_outside.h"
+#include "error_surface.h"
+#include "hg.h"
+#include "hg_io.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("reference,r",po::value<vector<string> >(), "[REQD] Reference translation (tokenized text)")
+ ("loss_function,l",po::value<string>()->default_value("ibm_bleu"), "Loss function being optimized")
+ ("help,h", "Help");
+ po::options_description dcmdline_options;
+ dcmdline_options.add(opts);
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ bool flag = false;
+ if (!conf->count("reference")) {
+ cerr << "Please specify one or more references using -r <REF.TXT>\n";
+ flag = true;
+ }
+ if (flag || conf->count("help")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+bool ReadSparseVectorString(const string& s, SparseVector<double>* v) {
+ vector<string> fields;
+ Tokenize(s, ';', &fields);
+ if (fields.empty()) return false;
+ for (int i = 0; i < fields.size(); ++i) {
+ vector<string> pair(2);
+ Tokenize(fields[i], '=', &pair);
+ if (pair.size() != 2) {
+ cerr << "Error parsing vector string: " << fields[i] << endl;
+ return false;
+ }
+ v->set_value(FD::Convert(pair[0]), atof(pair[1].c_str()));
+ }
+ return true;
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ const string loss_function = conf["loss_function"].as<string>();
+ ScoreType type = ScoreTypeFromString(loss_function);
+ DocScorer ds(type, conf["reference"].as<vector<string> >());
+ cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function << endl;
+ Hypergraph hg;
+ string last_file;
+ while(cin) {
+ string line;
+ getline(cin, line);
+ if (line.empty()) continue;
+ istringstream is(line);
+ int sent_id;
+ string file, s_origin, s_axis;
+ is >> file >> sent_id >> s_origin >> s_axis;
+ SparseVector<double> origin;
+ assert(ReadSparseVectorString(s_origin, &origin));
+ SparseVector<double> axis;
+ assert(ReadSparseVectorString(s_axis, &axis));
+ // cerr << "File: " << file << "\nAxis: " << axis << "\n X: " << origin << endl;
+ if (last_file != file) {
+ last_file = file;
+ ReadFile rf(file);
+ HypergraphIO::ReadFromJSON(rf.stream(), &hg);
+ }
+ ViterbiEnvelopeWeightFunction wf(origin, axis);
+ ViterbiEnvelope ve = Inside<ViterbiEnvelope, ViterbiEnvelopeWeightFunction>(hg, NULL, wf);
+ ErrorSurface es;
+ ds[sent_id]->ComputeErrorSurface(ve, &es);
+ //cerr << "Viterbi envelope has " << ve.size() << " segments\n";
+ cerr << "Error surface has " << es.size() << " segments\n";
+ string val;
+ es.Serialize(&val);
+ cout << 'M' << ' ' << s_origin << ' ' << s_axis << '\t';
+ B64::b64encode(val.c_str(), val.size(), &cout);
+ cout << endl;
+ }
+ return 0;
+}
diff --git a/vest/mr_vest_reduce.cc b/vest/mr_vest_reduce.cc
new file mode 100644
index 00000000..c1347065
--- /dev/null
+++ b/vest/mr_vest_reduce.cc
@@ -0,0 +1,80 @@
+#include <sstream>
+#include <iostream>
+#include <fstream>
+#include <vector>
+
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "sparse_vector.h"
+#include "error_surface.h"
+#include "line_optimizer.h"
+#include "hg_io.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("loss_function,l",po::value<string>()->default_value("ibm_bleu"), "Loss function being optimized")
+ ("help,h", "Help");
+ po::options_description dcmdline_options;
+ dcmdline_options.add(opts);
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ bool flag = false;
+ if (flag || conf->count("help")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ const string loss_function = conf["loss_function"].as<string>();
+ ScoreType type = ScoreTypeFromString(loss_function);
+ LineOptimizer::ScoreType opt_type = LineOptimizer::MAXIMIZE_SCORE;
+ if (type == TER)
+ opt_type = LineOptimizer::MINIMIZE_SCORE;
+ string last_key;
+ vector<ErrorSurface> esv;
+ while(cin) {
+ string line;
+ getline(cin, line);
+ if (line.empty()) continue;
+ size_t ks = line.find("\t");
+ assert(string::npos != ks);
+ assert(ks > 2);
+ string key = line.substr(2, ks - 2);
+ string val = line.substr(ks + 1);
+ if (key != last_key) {
+ if (!last_key.empty()) {
+ float score;
+ double x = LineOptimizer::LineOptimize(esv, opt_type, &score);
+ cout << last_key << "|" << x << "|" << score << endl;
+ }
+ last_key = key;
+ esv.clear();
+ }
+ if (val.size() % 4 != 0) {
+ cerr << "B64 encoding error 1! Skipping.\n";
+ continue;
+ }
+ string encoded(val.size() / 4 * 3, '\0');
+ if (!B64::b64decode(reinterpret_cast<const unsigned char*>(&val[0]), val.size(), &encoded[0], encoded.size())) {
+ cerr << "B64 encoding error 2! Skipping.\n";
+ continue;
+ }
+ esv.push_back(ErrorSurface());
+ esv.back().Deserialize(type, encoded);
+ }
+ if (!esv.empty()) {
+ cerr << "ESV=" << esv.size() << endl;
+ for (int i = 0; i < esv.size(); ++i) { cerr << esv[i].size() << endl; }
+ float score;
+ double x = LineOptimizer::LineOptimize(esv, opt_type, &score);
+ cout << last_key << "|" << x << "|" << score << endl;
+ }
+ return 0;
+}
diff --git a/vest/scorer.cc b/vest/scorer.cc
new file mode 100644
index 00000000..e242bb46
--- /dev/null
+++ b/vest/scorer.cc
@@ -0,0 +1,485 @@
+#include "scorer.h"
+
+#include <map>
+#include <sstream>
+#include <iostream>
+#include <fstream>
+#include <cstdio>
+#include <valarray>
+
+#include <boost/shared_ptr.hpp>
+
+#include "viterbi_envelope.h"
+#include "error_surface.h"
+#include "ter.h"
+#include "comb_scorer.h"
+#include "tdict.h"
+#include "stringlib.h"
+
+using boost::shared_ptr;
+using namespace std;
+
+const bool minimize_segments = true; // if adjacent segments have equal scores, merge them
+
+ScoreType ScoreTypeFromString(const std::string& st) {
+ const string sl = LowercaseString(st);
+ if (sl == "ser")
+ return SER;
+ if (sl == "ter")
+ return TER;
+ if (sl == "bleu" || sl == "ibm_bleu")
+ return IBM_BLEU;
+ if (sl == "nist_bleu")
+ return NIST_BLEU;
+ if (sl == "koehn_bleu")
+ return Koehn_BLEU;
+ if (sl == "combi")
+ return BLEU_minus_TER_over_2;
+ cerr << "Don't understand score type '" << sl << "', defaulting to ibm_bleu.\n";
+ return IBM_BLEU;
+}
+
+Score::~Score() {}
+SentenceScorer::~SentenceScorer() {}
+
+class SERScore : public Score {
+ friend class SERScorer;
+ public:
+ SERScore() : correct(0), total(0) {}
+ float ComputeScore() const {
+ return static_cast<float>(correct) / static_cast<float>(total);
+ }
+ void ScoreDetails(string* details) const {
+ ostringstream os;
+ os << "SER= " << ComputeScore() << " (" << correct << '/' << total << ')';
+ *details = os.str();
+ }
+ void PlusEquals(const Score& delta) {
+ correct += static_cast<const SERScore&>(delta).correct;
+ total += static_cast<const SERScore&>(delta).total;
+ }
+ Score* GetZero() const { return new SERScore; }
+ void Subtract(const Score& rhs, Score* res) const {
+ SERScore* r = static_cast<SERScore*>(res);
+ r->correct = correct - static_cast<const SERScore&>(rhs).correct;
+ r->total = total - static_cast<const SERScore&>(rhs).total;
+ }
+ void Encode(std::string* out) const {
+ assert(!"not implemented");
+ }
+ bool IsAdditiveIdentity() const {
+ return (total == 0 && correct == 0); // correct is always 0 <= n <= total
+ }
+ private:
+ int correct, total;
+};
+
+class SERScorer : public SentenceScorer {
+ public:
+ SERScorer(const vector<vector<WordID> >& references) : refs_(references) {}
+ Score* ScoreCandidate(const std::vector<WordID>& hyp) const {
+ SERScore* res = new SERScore;
+ res->total = 1;
+ for (int i = 0; i < refs_.size(); ++i)
+ if (refs_[i] == hyp) res->correct = 1;
+ return res;
+ }
+ static Score* ScoreFromString(const std::string& data) {
+ assert(!"Not implemented");
+ }
+ private:
+ vector<vector<WordID> > refs_;
+};
+
+class BLEUScore : public Score {
+ friend class BLEUScorerBase;
+ public:
+ BLEUScore(int n) : correct_ngram_hit_counts(0,n), hyp_ngram_counts(0,n) {
+ ref_len = 0;
+ hyp_len = 0; }
+ float ComputeScore() const;
+ void ScoreDetails(string* details) const;
+ void PlusEquals(const Score& delta);
+ Score* GetZero() const;
+ void Subtract(const Score& rhs, Score* res) const;
+ void Encode(std::string* out) const;
+ bool IsAdditiveIdentity() const {
+ if (fabs(ref_len) > 0.1f || hyp_len != 0) return false;
+ for (int i = 0; i < correct_ngram_hit_counts.size(); ++i)
+ if (hyp_ngram_counts[i] != 0 ||
+ correct_ngram_hit_counts[i] != 0) return false;
+ return true;
+ }
+ private:
+ float ComputeScore(vector<float>* precs, float* bp) const;
+ valarray<int> correct_ngram_hit_counts;
+ valarray<int> hyp_ngram_counts;
+ float ref_len;
+ int hyp_len;
+};
+
+class BLEUScorerBase : public SentenceScorer {
+ public:
+ BLEUScorerBase(const std::vector<std::vector<WordID> >& references,
+ int n
+ );
+ Score* ScoreCandidate(const std::vector<WordID>& hyp) const;
+ static Score* ScoreFromString(const std::string& in);
+
+ protected:
+ virtual float ComputeRefLength(const vector<WordID>& hyp) const = 0;
+ private:
+ struct NGramCompare {
+ int operator() (const vector<WordID>& a, const vector<WordID>& b) {
+ size_t as = a.size();
+ size_t bs = b.size();
+ const size_t s = (as < bs ? as : bs);
+ for (size_t i = 0; i < s; ++i) {
+ int d = a[i] - b[i];
+ if (d < 0) return true;
+ if (d > 0) return false;
+ }
+ return as < bs;
+ }
+ };
+ typedef map<vector<WordID>, pair<int,int>, NGramCompare> NGramCountMap;
+ void CountRef(const vector<WordID>& ref) {
+ NGramCountMap tc;
+ vector<WordID> ngram(n_);
+ int s = ref.size();
+ for (int j=0; j<s; ++j) {
+ int remaining = s-j;
+ int k = (n_ < remaining ? n_ : remaining);
+ ngram.clear();
+ for (int i=1; i<=k; ++i) {
+ ngram.push_back(ref[j + i - 1]);
+ tc[ngram].first++;
+ }
+ }
+ for (NGramCountMap::iterator i = tc.begin(); i != tc.end(); ++i) {
+ pair<int,int>& p = ngrams_[i->first];
+ if (p.first < i->second.first)
+ p = i->second;
+ }
+ }
+
+ void ComputeNgramStats(const vector<WordID>& sent,
+ valarray<int>* correct,
+ valarray<int>* hyp) const {
+ assert(correct->size() == n_);
+ assert(hyp->size() == n_);
+ vector<WordID> ngram(n_);
+ (*correct) *= 0;
+ (*hyp) *= 0;
+ int s = sent.size();
+ for (int j=0; j<s; ++j) {
+ int remaining = s-j;
+ int k = (n_ < remaining ? n_ : remaining);
+ ngram.clear();
+ for (int i=1; i<=k; ++i) {
+ ngram.push_back(sent[j + i - 1]);
+ pair<int,int>& p = ngrams_[ngram];
+ if (p.second < p.first) {
+ ++p.second;
+ (*correct)[i-1]++;
+ }
+ // if the 1 gram isn't found, don't try to match don't need to match any 2- 3- .. grams:
+ if (!p.first) {
+ for (; i<=k; ++i)
+ (*hyp)[i-1]++;
+ } else {
+ (*hyp)[i-1]++;
+ }
+ }
+ }
+ }
+
+ mutable NGramCountMap ngrams_;
+ int n_;
+ vector<int> lengths_;
+};
+
+Score* BLEUScorerBase::ScoreFromString(const std::string& in) {
+ istringstream is(in);
+ int n;
+ is >> n;
+ BLEUScore* r = new BLEUScore(n);
+ is >> r->ref_len >> r->hyp_len;
+
+ for (int i = 0; i < n; ++i) {
+ is >> r->correct_ngram_hit_counts[i];
+ is >> r->hyp_ngram_counts[i];
+ }
+ return r;
+}
+
+class IBM_BLEUScorer : public BLEUScorerBase {
+ public:
+ IBM_BLEUScorer(const std::vector<std::vector<WordID> >& references,
+ int n=4) : BLEUScorerBase(references, n), lengths_(references.size()) {
+ for (int i=0; i < references.size(); ++i)
+ lengths_[i] = references[i].size();
+ }
+ protected:
+ float ComputeRefLength(const vector<WordID>& hyp) const {
+ if (lengths_.size() == 1) return lengths_[0];
+ int bestd = 2000000;
+ int hl = hyp.size();
+ int bl = -1;
+ for (vector<int>::const_iterator ci = lengths_.begin(); ci != lengths_.end(); ++ci) {
+ int cl = *ci;
+ if (abs(cl - hl) < bestd) {
+ bestd = abs(cl - hl);
+ bl = cl;
+ }
+ }
+ return bl;
+ }
+ private:
+ vector<int> lengths_;
+};
+
+class NIST_BLEUScorer : public BLEUScorerBase {
+ public:
+ NIST_BLEUScorer(const std::vector<std::vector<WordID> >& references,
+ int n=4) : BLEUScorerBase(references, n),
+ shortest_(references[0].size()) {
+ for (int i=1; i < references.size(); ++i)
+ if (references[i].size() < shortest_)
+ shortest_ = references[i].size();
+ }
+ protected:
+ float ComputeRefLength(const vector<WordID>& hyp) const {
+ return shortest_;
+ }
+ private:
+ float shortest_;
+};
+
+class Koehn_BLEUScorer : public BLEUScorerBase {
+ public:
+ Koehn_BLEUScorer(const std::vector<std::vector<WordID> >& references,
+ int n=4) : BLEUScorerBase(references, n),
+ avg_(0) {
+ for (int i=0; i < references.size(); ++i)
+ avg_ += references[i].size();
+ avg_ /= references.size();
+ }
+ protected:
+ float ComputeRefLength(const vector<WordID>& hyp) const {
+ return avg_;
+ }
+ private:
+ float avg_;
+};
+
+SentenceScorer* SentenceScorer::CreateSentenceScorer(const ScoreType type,
+ const std::vector<std::vector<WordID> >& refs) {
+ switch (type) {
+ case IBM_BLEU: return new IBM_BLEUScorer(refs, 4);
+ case NIST_BLEU: return new NIST_BLEUScorer(refs, 4);
+ case Koehn_BLEU: return new Koehn_BLEUScorer(refs, 4);
+ case TER: return new TERScorer(refs);
+ case SER: return new SERScorer(refs);
+ case BLEU_minus_TER_over_2: return new BLEUTERCombinationScorer(refs);
+ default:
+ assert(!"Not implemented!");
+ }
+}
+
+Score* SentenceScorer::CreateScoreFromString(const ScoreType type, const std::string& in) {
+ switch (type) {
+ case IBM_BLEU:
+ case NIST_BLEU:
+ case Koehn_BLEU:
+ return BLEUScorerBase::ScoreFromString(in);
+ case TER:
+ return TERScorer::ScoreFromString(in);
+ case SER:
+ return SERScorer::ScoreFromString(in);
+ case BLEU_minus_TER_over_2:
+ return BLEUTERCombinationScorer::ScoreFromString(in);
+ default:
+ assert(!"Not implemented!");
+ }
+}
+
+void SentenceScorer::ComputeErrorSurface(const ViterbiEnvelope& ve, ErrorSurface* env) const {
+ vector<WordID> prev_trans;
+ const vector<shared_ptr<Segment> >& ienv = ve.GetSortedSegs();
+ env->resize(ienv.size());
+ Score* prev_score = NULL;
+ int j = 0;
+ for (int i = 0; i < ienv.size(); ++i) {
+ const Segment& seg = *ienv[i];
+ vector<WordID> trans;
+ seg.ConstructTranslation(&trans);
+ // cerr << "Scoring: " << TD::GetString(trans) << endl;
+ if (trans == prev_trans) {
+ if (!minimize_segments) {
+ assert(prev_score); // if this fails, it means
+ // the decoder can generate null translations
+ ErrorSegment& out = (*env)[j];
+ out.delta = prev_score->GetZero();
+ out.x = seg.x;
+ ++j;
+ }
+ // cerr << "Identical translation, skipping scoring\n";
+ } else {
+ Score* score = ScoreCandidate(trans);
+ // cerr << "score= " << score->ComputeScore() << "\n";
+ Score* cur_delta = score->GetZero();
+ // just record the score diffs
+ if (!prev_score)
+ prev_score = score->GetZero();
+
+ score->Subtract(*prev_score, cur_delta);
+ delete prev_score;
+ prev_trans.swap(trans);
+ prev_score = score;
+ if ((!minimize_segments) || (!cur_delta->IsAdditiveIdentity())) {
+ ErrorSegment& out = (*env)[j];
+ out.delta = cur_delta;
+ out.x = seg.x;
+ ++j;
+ }
+ }
+ }
+ delete prev_score;
+ // cerr << " In segments: " << ienv.size() << endl;
+ // cerr << "Out segments: " << j << endl;
+ assert(j > 0);
+ env->resize(j);
+}
+
+void BLEUScore::ScoreDetails(std::string* details) const {
+ char buf[2000];
+ vector<float> precs(4);
+ float bp;
+ float bleu = ComputeScore(&precs, &bp);
+ sprintf(buf, "BLEU = %.2f, %.1f|%.1f|%.1f|%.1f (brev=%.3f)",
+ bleu*100.0,
+ precs[0]*100.0,
+ precs[1]*100.0,
+ precs[2]*100.0,
+ precs[3]*100.0,
+ bp);
+ *details = buf;
+}
+
+float BLEUScore::ComputeScore(vector<float>* precs, float* bp) const {
+ float log_bleu = 0;
+ if (precs) precs->clear();
+ int count = 0;
+ for (int i = 0; i < hyp_ngram_counts.size(); ++i) {
+ if (hyp_ngram_counts[i] > 0) {
+ float lprec = log(correct_ngram_hit_counts[i]) - log(hyp_ngram_counts[i]);
+ if (precs) precs->push_back(exp(lprec));
+ log_bleu += lprec;
+ ++count;
+ }
+ }
+ log_bleu /= static_cast<float>(count);
+ float lbp = 0.0;
+ if (hyp_len < ref_len)
+ lbp = (hyp_len - ref_len) / hyp_len;
+ log_bleu += lbp;
+ if (bp) *bp = exp(lbp);
+ return exp(log_bleu);
+}
+
+float BLEUScore::ComputeScore() const {
+ return ComputeScore(NULL, NULL);
+}
+
+void BLEUScore::Subtract(const Score& rhs, Score* res) const {
+ const BLEUScore& d = static_cast<const BLEUScore&>(rhs);
+ BLEUScore* o = static_cast<BLEUScore*>(res);
+ o->ref_len = ref_len - d.ref_len;
+ o->hyp_len = hyp_len - d.hyp_len;
+ o->correct_ngram_hit_counts = correct_ngram_hit_counts - d.correct_ngram_hit_counts;
+ o->hyp_ngram_counts = hyp_ngram_counts - d.hyp_ngram_counts;
+}
+
+void BLEUScore::PlusEquals(const Score& delta) {
+ const BLEUScore& d = static_cast<const BLEUScore&>(delta);
+ correct_ngram_hit_counts += d.correct_ngram_hit_counts;
+ hyp_ngram_counts += d.hyp_ngram_counts;
+ ref_len += d.ref_len;
+ hyp_len += d.hyp_len;
+}
+
+Score* BLEUScore::GetZero() const {
+ return new BLEUScore(hyp_ngram_counts.size());
+}
+
+void BLEUScore::Encode(std::string* out) const {
+ ostringstream os;
+ const int n = correct_ngram_hit_counts.size();
+ os << n << ' ' << ref_len << ' ' << hyp_len;
+ for (int i = 0; i < n; ++i)
+ os << ' ' << correct_ngram_hit_counts[i] << ' ' << hyp_ngram_counts[i];
+ *out = os.str();
+}
+
+BLEUScorerBase::BLEUScorerBase(const std::vector<std::vector<WordID> >& references,
+ int n) : n_(n) {
+ for (vector<vector<WordID> >::const_iterator ci = references.begin();
+ ci != references.end(); ++ci) {
+ lengths_.push_back(ci->size());
+ CountRef(*ci);
+ }
+}
+
+Score* BLEUScorerBase::ScoreCandidate(const vector<WordID>& hyp) const {
+ BLEUScore* bs = new BLEUScore(n_);
+ for (NGramCountMap::iterator i=ngrams_.begin(); i != ngrams_.end(); ++i)
+ i->second.second = 0;
+ ComputeNgramStats(hyp, &bs->correct_ngram_hit_counts, &bs->hyp_ngram_counts);
+ bs->ref_len = ComputeRefLength(hyp);
+ bs->hyp_len = hyp.size();
+ return bs;
+}
+
+DocScorer::~DocScorer() {
+ for (int i=0; i < scorers_.size(); ++i)
+ delete scorers_[i];
+}
+
+DocScorer::DocScorer(
+ const ScoreType type,
+ const std::vector<std::string>& ref_files) {
+ // TODO stop using valarray, start using ReadFile
+ cerr << "Loading references (" << ref_files.size() << " files)\n";
+ valarray<ifstream> ifs(ref_files.size());
+ for (int i=0; i < ref_files.size(); ++i) {
+ ifs[i].open(ref_files[i].c_str());
+ assert(ifs[i].good());
+ }
+ char buf[64000];
+ bool expect_eof = false;
+ while (!ifs[0].eof()) {
+ vector<vector<WordID> > refs(ref_files.size());
+ for (int i=0; i < ref_files.size(); ++i) {
+ if (ifs[i].eof()) break;
+ ifs[i].getline(buf, 64000);
+ refs[i].clear();
+ if (strlen(buf) == 0) {
+ if (ifs[i].eof()) {
+ if (!expect_eof) {
+ assert(i == 0);
+ expect_eof = true;
+ }
+ break;
+ }
+ } else {
+ TD::ConvertSentence(buf, &refs[i]);
+ assert(!refs[i].empty());
+ }
+ assert(!expect_eof);
+ }
+ if (!expect_eof) scorers_.push_back(SentenceScorer::CreateSentenceScorer(type, refs));
+ }
+ cerr << "Loaded reference translations for " << scorers_.size() << " sentences.\n";
+}
+
diff --git a/vest/scorer_test.cc b/vest/scorer_test.cc
new file mode 100644
index 00000000..fca219cc
--- /dev/null
+++ b/vest/scorer_test.cc
@@ -0,0 +1,178 @@
+#include <iostream>
+#include <fstream>
+#include <valarray>
+#include <gtest/gtest.h>
+
+#include "tdict.h"
+#include "scorer.h"
+
+using namespace std;
+
+class ScorerTest : public testing::Test {
+ protected:
+ virtual void SetUp() {
+ refs0.resize(4);
+ refs1.resize(4);
+ TD::ConvertSentence("export of high-tech products in guangdong in first two months this year reached 3.76 billion us dollars", &refs0[0]);
+ TD::ConvertSentence("guangdong's export of new high technology products amounts to us $ 3.76 billion in first two months of this year", &refs0[1]);
+ TD::ConvertSentence("guangdong exports us $ 3.76 billion worth of high technology products in the first two months of this year", &refs0[2]);
+ TD::ConvertSentence("in the first 2 months this year , the export volume of new hi-tech products in guangdong province reached 3.76 billion us dollars .", &refs0[3]);
+ TD::ConvertSentence("xinhua news agency , guangzhou , march 16 ( reporter chen ji ) the latest statistics show that from january through february this year , the export of high-tech products in guangdong province reached 3.76 billion us dollars , up 34.8 \% over the same period last year and accounted for 25.5 \% of the total export in the province .", &refs1[0]);
+ TD::ConvertSentence("xinhua news agency , guangzhou , march 16 ( reporter : chen ji ) -- latest statistic indicates that guangdong's export of new high technology products amounts to us $ 3.76 billion , up 34.8 \% over corresponding period and accounts for 25.5 \% of the total exports of the province .", &refs1[1]);
+ TD::ConvertSentence("xinhua news agency report of march 16 from guangzhou ( by staff reporter chen ji ) - latest statistics indicate guangdong province exported us $ 3.76 billion worth of high technology products , up 34.8 percent from the same period last year , which account for 25.5 percent of the total exports of the province .", &refs1[2]);
+ TD::ConvertSentence("guangdong , march 16 , ( xinhua ) -- ( chen ji reports ) as the newest statistics shows , in january and feberuary this year , the export volume of new hi-tech products in guangdong province reached 3.76 billion us dollars , up 34.8 \% than last year , making up 25.5 \% of the province's total .", &refs1[3]);
+ TD::ConvertSentence("one guangdong province will next export us $ 3.76 high-tech product two months first this year 3.76 billion us dollars", &hyp1);
+ TD::ConvertSentence("xinhua news agency , guangzhou , 16th of march ( reporter chen ) -- latest statistics suggest that guangdong exports new advanced technology product totals $ 3.76 million , 34.8 percent last corresponding period and accounts for 25.5 percent of the total export province .", &hyp2);
+ }
+
+ virtual void TearDown() { }
+
+ vector<vector<WordID> > refs0;
+ vector<vector<WordID> > refs1;
+ vector<WordID> hyp1;
+ vector<WordID> hyp2;
+};
+
+TEST_F(ScorerTest, TestCreateFromFiles) {
+ vector<string> files;
+ files.push_back("test_data/re.txt.0");
+ files.push_back("test_data/re.txt.1");
+ files.push_back("test_data/re.txt.2");
+ files.push_back("test_data/re.txt.3");
+ DocScorer ds(IBM_BLEU, files);
+}
+
+TEST_F(ScorerTest, TestBLEUScorer) {
+ SentenceScorer* s1 = SentenceScorer::CreateSentenceScorer(IBM_BLEU, refs0);
+ SentenceScorer* s2 = SentenceScorer::CreateSentenceScorer(IBM_BLEU, refs1);
+ Score* b1 = s1->ScoreCandidate(hyp1);
+ EXPECT_FLOAT_EQ(0.23185077, b1->ComputeScore());
+ Score* b2 = s2->ScoreCandidate(hyp2);
+ EXPECT_FLOAT_EQ(0.38101241, b2->ComputeScore());
+ b1->PlusEquals(*b2);
+ EXPECT_FLOAT_EQ(0.348854, b1->ComputeScore());
+ EXPECT_FALSE(b1->IsAdditiveIdentity());
+ string details;
+ b1->ScoreDetails(&details);
+ EXPECT_EQ("BLEU = 34.89, 81.5|50.8|29.5|18.6 (brev=0.898)", details);
+ cerr << details << endl;
+ string enc;
+ b1->Encode(&enc);
+ Score* b3 = SentenceScorer::CreateScoreFromString(IBM_BLEU, enc);
+ details.clear();
+ cerr << "Encoded BLEU score size: " << enc.size() << endl;
+ b3->ScoreDetails(&details);
+ cerr << details << endl;
+ EXPECT_FALSE(b3->IsAdditiveIdentity());
+ EXPECT_EQ("BLEU = 34.89, 81.5|50.8|29.5|18.6 (brev=0.898)", details);
+ Score* bz = b3->GetZero();
+ EXPECT_TRUE(bz->IsAdditiveIdentity());
+ delete bz;
+ delete b1;
+ delete s1;
+ delete b2;
+ delete s2;
+}
+
+TEST_F(ScorerTest, TestTERScorer) {
+ SentenceScorer* s1 = SentenceScorer::CreateSentenceScorer(TER, refs0);
+ SentenceScorer* s2 = SentenceScorer::CreateSentenceScorer(TER, refs1);
+ string details;
+ Score* t1 = s1->ScoreCandidate(hyp1);
+ t1->ScoreDetails(&details);
+ cerr << "DETAILS: " << details << endl;
+ cerr << t1->ComputeScore() << endl;
+ Score* t2 = s2->ScoreCandidate(hyp2);
+ t2->ScoreDetails(&details);
+ cerr << "DETAILS: " << details << endl;
+ cerr << t2->ComputeScore() << endl;
+ t1->PlusEquals(*t2);
+ cerr << t1->ComputeScore() << endl;
+ t1->ScoreDetails(&details);
+ cerr << "DETAILS: " << details << endl;
+ EXPECT_EQ("TER = 44.16, 4| 8| 16| 6 (len=77)", details);
+ string enc;
+ t1->Encode(&enc);
+ Score* t3 = SentenceScorer::CreateScoreFromString(TER, enc);
+ details.clear();
+ t3->ScoreDetails(&details);
+ EXPECT_EQ("TER = 44.16, 4| 8| 16| 6 (len=77)", details);
+ EXPECT_FALSE(t3->IsAdditiveIdentity());
+ Score* tz = t3->GetZero();
+ EXPECT_TRUE(tz->IsAdditiveIdentity());
+ delete tz;
+ delete t3;
+ delete t1;
+ delete s1;
+ delete t2;
+ delete s2;
+}
+
+TEST_F(ScorerTest, TestTERScorerSimple) {
+ vector<vector<WordID> > ref(1);
+ TD::ConvertSentence("1 2 3 A B", &ref[0]);
+ vector<WordID> hyp;
+ TD::ConvertSentence("A B 1 2 3", &hyp);
+ SentenceScorer* s1 = SentenceScorer::CreateSentenceScorer(TER, ref);
+ string details;
+ Score* t1 = s1->ScoreCandidate(hyp);
+ t1->ScoreDetails(&details);
+ cerr << "DETAILS: " << details << endl;
+ delete t1;
+ delete s1;
+}
+
+TEST_F(ScorerTest, TestSERScorerSimple) {
+ vector<vector<WordID> > ref(1);
+ TD::ConvertSentence("A B C D", &ref[0]);
+ vector<WordID> hyp1;
+ TD::ConvertSentence("A B C", &hyp1);
+ vector<WordID> hyp2;
+ TD::ConvertSentence("A B C D", &hyp2);
+ SentenceScorer* s1 = SentenceScorer::CreateSentenceScorer(SER, ref);
+ string details;
+ Score* t1 = s1->ScoreCandidate(hyp1);
+ t1->ScoreDetails(&details);
+ cerr << "DETAILS: " << details << endl;
+ Score* t2 = s1->ScoreCandidate(hyp2);
+ t2->ScoreDetails(&details);
+ cerr << "DETAILS: " << details << endl;
+ t2->PlusEquals(*t1);
+ t2->ScoreDetails(&details);
+ cerr << "DETAILS: " << details << endl;
+ delete t1;
+ delete t2;
+ delete s1;
+}
+
+TEST_F(ScorerTest, TestCombiScorer) {
+ SentenceScorer* s1 = SentenceScorer::CreateSentenceScorer(BLEU_minus_TER_over_2, refs0);
+ string details;
+ Score* t1 = s1->ScoreCandidate(hyp1);
+ t1->ScoreDetails(&details);
+ cerr << "DETAILS: " << details << endl;
+ cerr << t1->ComputeScore() << endl;
+ string enc;
+ t1->Encode(&enc);
+ Score* t2 = SentenceScorer::CreateScoreFromString(BLEU_minus_TER_over_2, enc);
+ details.clear();
+ t2->ScoreDetails(&details);
+ cerr << "DETAILS: " << details << endl;
+ Score* cz = t2->GetZero();
+ EXPECT_FALSE(t2->IsAdditiveIdentity());
+ EXPECT_TRUE(cz->IsAdditiveIdentity());
+ cz->PlusEquals(*t2);
+ EXPECT_FALSE(cz->IsAdditiveIdentity());
+ string d2;
+ cz->ScoreDetails(&d2);
+ EXPECT_EQ(d2, details);
+ delete cz;
+ delete t2;
+ delete t1;
+}
+
+int main(int argc, char **argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
diff --git a/vest/ter.cc b/vest/ter.cc
new file mode 100644
index 00000000..ef66f3b7
--- /dev/null
+++ b/vest/ter.cc
@@ -0,0 +1,518 @@
+#include "ter.h"
+
+#include <cstdio>
+#include <cassert>
+#include <iostream>
+#include <limits>
+#include <sstream>
+#include <tr1/unordered_map>
+#include <set>
+#include <valarray>
+#include <boost/functional/hash.hpp>
+
+#include "tdict.h"
+
+const bool ter_use_average_ref_len = true;
+const int ter_short_circuit_long_sentences = -1;
+
+using namespace std;
+using namespace std::tr1;
+
+struct COSTS {
+ static const float substitution;
+ static const float deletion;
+ static const float insertion;
+ static const float shift;
+};
+const float COSTS::substitution = 1.0f;
+const float COSTS::deletion = 1.0f;
+const float COSTS::insertion = 1.0f;
+const float COSTS::shift = 1.0f;
+
+static const int MAX_SHIFT_SIZE = 10;
+static const int MAX_SHIFT_DIST = 50;
+
+struct Shift {
+ unsigned int d_;
+ Shift() : d_() {}
+ Shift(int b, int e, int m) : d_() {
+ begin(b);
+ end(e);
+ moveto(m);
+ }
+ inline int begin() const {
+ return d_ & 0x3ff;
+ }
+ inline int end() const {
+ return (d_ >> 10) & 0x3ff;
+ }
+ inline int moveto() const {
+ int m = (d_ >> 20) & 0x7ff;
+ if (m > 1024) { m -= 1024; m *= -1; }
+ return m;
+ }
+ inline void begin(int b) {
+ d_ &= 0xfffffc00u;
+ d_ |= (b & 0x3ff);
+ }
+ inline void end(int e) {
+ d_ &= 0xfff003ffu;
+ d_ |= (e & 0x3ff) << 10;
+ }
+ inline void moveto(int m) {
+ bool neg = (m < 0);
+ if (neg) { m *= -1; m += 1024; }
+ d_ &= 0xfffff;
+ d_ |= (m & 0x7ff) << 20;
+ }
+};
+
+class TERScorerImpl {
+
+ public:
+ enum TransType { MATCH, SUBSTITUTION, INSERTION, DELETION };
+
+ explicit TERScorerImpl(const vector<WordID>& ref) : ref_(ref) {
+ for (int i = 0; i < ref.size(); ++i)
+ rwexists_.insert(ref[i]);
+ }
+
+ float Calculate(const vector<WordID>& hyp, int* subs, int* ins, int* dels, int* shifts) const {
+ return CalculateAllShifts(hyp, subs, ins, dels, shifts);
+ }
+
+ inline int GetRefLength() const {
+ return ref_.size();
+ }
+
+ private:
+ vector<WordID> ref_;
+ set<WordID> rwexists_;
+
+ typedef unordered_map<vector<WordID>, set<int>, boost::hash<vector<WordID> > > NgramToIntsMap;
+ mutable NgramToIntsMap nmap_;
+
+ static float MinimumEditDistance(
+ const vector<WordID>& hyp,
+ const vector<WordID>& ref,
+ vector<TransType>* path) {
+ vector<vector<TransType> > bmat(hyp.size() + 1, vector<TransType>(ref.size() + 1, MATCH));
+ vector<vector<float> > cmat(hyp.size() + 1, vector<float>(ref.size() + 1, 0));
+ for (int i = 0; i <= hyp.size(); ++i)
+ cmat[i][0] = i;
+ for (int j = 0; j <= ref.size(); ++j)
+ cmat[0][j] = j;
+ for (int i = 1; i <= hyp.size(); ++i) {
+ const WordID& hw = hyp[i-1];
+ for (int j = 1; j <= ref.size(); ++j) {
+ const WordID& rw = ref[j-1];
+ float& cur_c = cmat[i][j];
+ TransType& cur_b = bmat[i][j];
+
+ if (rw == hw) {
+ cur_c = cmat[i-1][j-1];
+ cur_b = MATCH;
+ } else {
+ cur_c = cmat[i-1][j-1] + COSTS::substitution;
+ cur_b = SUBSTITUTION;
+ }
+ float cwoi = cmat[i-1][j];
+ if (cur_c > cwoi + COSTS::insertion) {
+ cur_c = cwoi + COSTS::insertion;
+ cur_b = INSERTION;
+ }
+ float cwod = cmat[i][j-1];
+ if (cur_c > cwod + COSTS::deletion) {
+ cur_c = cwod + COSTS::deletion;
+ cur_b = DELETION;
+ }
+ }
+ }
+
+ // trace back along the best path and record the transition types
+ path->clear();
+ int i = hyp.size();
+ int j = ref.size();
+ while (i > 0 || j > 0) {
+ if (j == 0) {
+ --i;
+ path->push_back(INSERTION);
+ } else if (i == 0) {
+ --j;
+ path->push_back(DELETION);
+ } else {
+ TransType t = bmat[i][j];
+ path->push_back(t);
+ switch (t) {
+ case SUBSTITUTION:
+ case MATCH:
+ --i; --j; break;
+ case INSERTION:
+ --i; break;
+ case DELETION:
+ --j; break;
+ }
+ }
+ }
+ reverse(path->begin(), path->end());
+ return cmat[hyp.size()][ref.size()];
+ }
+
+ void BuildWordMatches(const vector<WordID>& hyp, NgramToIntsMap* nmap) const {
+ nmap->clear();
+ set<WordID> exists_both;
+ for (int i = 0; i < hyp.size(); ++i)
+ if (rwexists_.find(hyp[i]) != rwexists_.end())
+ exists_both.insert(hyp[i]);
+ for (int start=0; start<ref_.size(); ++start) {
+ if (exists_both.find(ref_[start]) == exists_both.end()) continue;
+ vector<WordID> cp;
+ int mlen = min(MAX_SHIFT_SIZE, static_cast<int>(ref_.size() - start));
+ for (int len=0; len<mlen; ++len) {
+ if (len && exists_both.find(ref_[start + len]) == exists_both.end()) break;
+ cp.push_back(ref_[start + len]);
+ (*nmap)[cp].insert(start);
+ }
+ }
+ }
+
+ static void PerformShift(const vector<WordID>& in,
+ int start, int end, int moveto, vector<WordID>* out) {
+ // cerr << "ps: " << start << " " << end << " " << moveto << endl;
+ out->clear();
+ if (moveto == -1) {
+ for (int i = start; i <= end; ++i)
+ out->push_back(in[i]);
+ for (int i = 0; i < start; ++i)
+ out->push_back(in[i]);
+ for (int i = end+1; i < in.size(); ++i)
+ out->push_back(in[i]);
+ } else if (moveto < start) {
+ for (int i = 0; i <= moveto; ++i)
+ out->push_back(in[i]);
+ for (int i = start; i <= end; ++i)
+ out->push_back(in[i]);
+ for (int i = moveto+1; i < start; ++i)
+ out->push_back(in[i]);
+ for (int i = end+1; i < in.size(); ++i)
+ out->push_back(in[i]);
+ } else if (moveto > end) {
+ for (int i = 0; i < start; ++i)
+ out->push_back(in[i]);
+ for (int i = end+1; i <= moveto; ++i)
+ out->push_back(in[i]);
+ for (int i = start; i <= end; ++i)
+ out->push_back(in[i]);
+ for (int i = moveto+1; i < in.size(); ++i)
+ out->push_back(in[i]);
+ } else {
+ for (int i = 0; i < start; ++i)
+ out->push_back(in[i]);
+ for (int i = end+1; (i < in.size()) && (i <= end + (moveto - start)); ++i)
+ out->push_back(in[i]);
+ for (int i = start; i <= end; ++i)
+ out->push_back(in[i]);
+ for (int i = (end + (moveto - start))+1; i < in.size(); ++i)
+ out->push_back(in[i]);
+ }
+ if (out->size() != in.size()) {
+ cerr << "ps: " << start << " " << end << " " << moveto << endl;
+ cerr << "in=" << TD::GetString(in) << endl;
+ cerr << "out=" << TD::GetString(*out) << endl;
+ }
+ assert(out->size() == in.size());
+ // cerr << "ps: " << TD::GetString(*out) << endl;
+ }
+
+ void GetAllPossibleShifts(const vector<WordID>& hyp,
+ const vector<int>& ralign,
+ const vector<bool>& herr,
+ const vector<bool>& rerr,
+ const int min_size,
+ vector<vector<Shift> >* shifts) const {
+ for (int start = 0; start < hyp.size(); ++start) {
+ vector<WordID> cp(1, hyp[start]);
+ NgramToIntsMap::iterator niter = nmap_.find(cp);
+ if (niter == nmap_.end()) continue;
+ bool ok = false;
+ int moveto;
+ for (set<int>::iterator i = niter->second.begin(); i != niter->second.end(); ++i) {
+ moveto = *i;
+ int rm = ralign[moveto];
+ ok = (start != rm &&
+ (rm - start) < MAX_SHIFT_DIST &&
+ (start - rm - 1) < MAX_SHIFT_DIST);
+ if (ok) break;
+ }
+ if (!ok) continue;
+ cp.clear();
+ for (int end = start + min_size - 1;
+ ok && end < hyp.size() && end < (start + MAX_SHIFT_SIZE); ++end) {
+ cp.push_back(hyp[end]);
+ vector<Shift>& sshifts = (*shifts)[end - start];
+ ok = false;
+ NgramToIntsMap::iterator niter = nmap_.find(cp);
+ if (niter == nmap_.end()) break;
+ bool any_herr = false;
+ for (int i = start; i <= end && !any_herr; ++i)
+ any_herr = herr[i];
+ if (!any_herr) {
+ ok = true;
+ continue;
+ }
+ for (set<int>::iterator mi = niter->second.begin();
+ mi != niter->second.end(); ++mi) {
+ int moveto = *mi;
+ int rm = ralign[moveto];
+ if (! ((rm != start) &&
+ ((rm < start) || (rm > end)) &&
+ (rm - start <= MAX_SHIFT_DIST) &&
+ ((start - rm - 1) <= MAX_SHIFT_DIST))) continue;
+ ok = true;
+ bool any_rerr = false;
+ for (int i = 0; (i <= end - start) && (!any_rerr); ++i)
+ any_rerr = rerr[moveto+i];
+ if (!any_rerr) continue;
+ for (int roff = 0; roff <= (end - start); ++roff) {
+ int rmr = ralign[moveto+roff];
+ if ((start != rmr) && ((roff == 0) || (rmr != ralign[moveto])))
+ sshifts.push_back(Shift(start, end, moveto + roff));
+ }
+ }
+ }
+ }
+ }
+
+ bool CalculateBestShift(const vector<WordID>& cur,
+ const vector<WordID>& hyp,
+ float curerr,
+ const vector<TransType>& path,
+ vector<WordID>* new_hyp,
+ float* newerr,
+ vector<TransType>* new_path) const {
+ vector<bool> herr, rerr;
+ vector<int> ralign;
+ int hpos = -1;
+ for (int i = 0; i < path.size(); ++i) {
+ switch (path[i]) {
+ case MATCH:
+ ++hpos;
+ herr.push_back(false);
+ rerr.push_back(false);
+ ralign.push_back(hpos);
+ break;
+ case SUBSTITUTION:
+ ++hpos;
+ herr.push_back(true);
+ rerr.push_back(true);
+ ralign.push_back(hpos);
+ break;
+ case INSERTION:
+ ++hpos;
+ herr.push_back(true);
+ break;
+ case DELETION:
+ rerr.push_back(true);
+ ralign.push_back(hpos);
+ break;
+ }
+ }
+#if 0
+ cerr << "RALIGN: ";
+ for (int i = 0; i < rerr.size(); ++i)
+ cerr << ralign[i] << " ";
+ cerr << endl;
+ cerr << "RERR: ";
+ for (int i = 0; i < rerr.size(); ++i)
+ cerr << (bool)rerr[i] << " ";
+ cerr << endl;
+ cerr << "HERR: ";
+ for (int i = 0; i < herr.size(); ++i)
+ cerr << (bool)herr[i] << " ";
+ cerr << endl;
+#endif
+
+ vector<vector<Shift> > shifts(MAX_SHIFT_SIZE + 1);
+ GetAllPossibleShifts(cur, ralign, herr, rerr, 1, &shifts);
+ float cur_best_shift_cost = 0;
+ *newerr = curerr;
+ vector<TransType> cur_best_path;
+ vector<WordID> cur_best_hyp;
+
+ bool res = false;
+ for (int i = shifts.size() - 1; i >=0; --i) {
+ float curfix = curerr - (cur_best_shift_cost + *newerr);
+ float maxfix = 2.0f * (1 + i) - COSTS::shift;
+ if ((curfix > maxfix) || ((cur_best_shift_cost == 0) && (curfix == maxfix))) break;
+ for (int j = 0; j < shifts[i].size(); ++j) {
+ const Shift& s = shifts[i][j];
+ curfix = curerr - (cur_best_shift_cost + *newerr);
+ maxfix = 2.0f * (1 + i) - COSTS::shift; // TODO remove?
+ if ((curfix > maxfix) || ((cur_best_shift_cost == 0) && (curfix == maxfix))) continue;
+ vector<WordID> shifted(cur.size());
+ PerformShift(cur, s.begin(), s.end(), ralign[s.moveto()], &shifted);
+ vector<TransType> try_path;
+ float try_cost = MinimumEditDistance(shifted, ref_, &try_path);
+ float gain = (*newerr + cur_best_shift_cost) - (try_cost + COSTS::shift);
+ if (gain > 0.0f || ((cur_best_shift_cost == 0.0f) && (gain == 0.0f))) {
+ *newerr = try_cost;
+ cur_best_shift_cost = COSTS::shift;
+ new_path->swap(try_path);
+ new_hyp->swap(shifted);
+ res = true;
+ // cerr << "Found better shift " << s.begin() << "..." << s.end() << " moveto " << s.moveto() << endl;
+ }
+ }
+ }
+
+ return res;
+ }
+
+ static void GetPathStats(const vector<TransType>& path, int* subs, int* ins, int* dels) {
+ *subs = *ins = *dels = 0;
+ for (int i = 0; i < path.size(); ++i) {
+ switch (path[i]) {
+ case SUBSTITUTION:
+ ++(*subs);
+ case MATCH:
+ break;
+ case INSERTION:
+ ++(*ins); break;
+ case DELETION:
+ ++(*dels); break;
+ }
+ }
+ }
+
+ float CalculateAllShifts(const vector<WordID>& hyp,
+ int* subs, int* ins, int* dels, int* shifts) const {
+ BuildWordMatches(hyp, &nmap_);
+ vector<TransType> path;
+ float med_cost = MinimumEditDistance(hyp, ref_, &path);
+ float edits = 0;
+ vector<WordID> cur = hyp;
+ *shifts = 0;
+ if (ter_short_circuit_long_sentences < 0 ||
+ ref_.size() < ter_short_circuit_long_sentences) {
+ while (true) {
+ vector<WordID> new_hyp;
+ vector<TransType> new_path;
+ float new_med_cost;
+ if (!CalculateBestShift(cur, hyp, med_cost, path, &new_hyp, &new_med_cost, &new_path))
+ break;
+ edits += COSTS::shift;
+ ++(*shifts);
+ med_cost = new_med_cost;
+ path.swap(new_path);
+ cur.swap(new_hyp);
+ }
+ }
+ GetPathStats(path, subs, ins, dels);
+ return med_cost + edits;
+ }
+};
+
+class TERScore : public Score {
+ friend class TERScorer;
+
+ public:
+ static const unsigned kINSERTIONS = 0;
+ static const unsigned kDELETIONS = 1;
+ static const unsigned kSUBSTITUTIONS = 2;
+ static const unsigned kSHIFTS = 3;
+ static const unsigned kREF_WORDCOUNT = 4;
+ static const unsigned kDUMMY_LAST_ENTRY = 5;
+
+ TERScore() : stats(0,kDUMMY_LAST_ENTRY) {}
+ float ComputeScore() const {
+ float edits = static_cast<float>(stats[kINSERTIONS] + stats[kDELETIONS] + stats[kSUBSTITUTIONS] + stats[kSHIFTS]);
+ return edits / static_cast<float>(stats[kREF_WORDCOUNT]);
+ }
+ void ScoreDetails(string* details) const;
+ void PlusEquals(const Score& delta) {
+ stats += static_cast<const TERScore&>(delta).stats;
+ }
+ Score* GetZero() const {
+ return new TERScore;
+ }
+ void Subtract(const Score& rhs, Score* res) const {
+ static_cast<TERScore*>(res)->stats = stats - static_cast<const TERScore&>(rhs).stats;
+ }
+ void Encode(std::string* out) const {
+ ostringstream os;
+ os << stats[kINSERTIONS] << ' '
+ << stats[kDELETIONS] << ' '
+ << stats[kSUBSTITUTIONS] << ' '
+ << stats[kSHIFTS] << ' '
+ << stats[kREF_WORDCOUNT];
+ *out = os.str();
+ }
+ bool IsAdditiveIdentity() const {
+ for (int i = 0; i < kDUMMY_LAST_ENTRY; ++i)
+ if (stats[i] != 0) return false;
+ return true;
+ }
+ private:
+ valarray<int> stats;
+};
+
+Score* TERScorer::ScoreFromString(const std::string& data) {
+ istringstream is(data);
+ TERScore* r = new TERScore;
+ is >> r->stats[TERScore::kINSERTIONS]
+ >> r->stats[TERScore::kDELETIONS]
+ >> r->stats[TERScore::kSUBSTITUTIONS]
+ >> r->stats[TERScore::kSHIFTS]
+ >> r->stats[TERScore::kREF_WORDCOUNT];
+ return r;
+}
+
+void TERScore::ScoreDetails(std::string* details) const {
+ char buf[200];
+ sprintf(buf, "TER = %.2f, %3d|%3d|%3d|%3d (len=%d)",
+ ComputeScore() * 100.0f,
+ stats[kINSERTIONS],
+ stats[kDELETIONS],
+ stats[kSUBSTITUTIONS],
+ stats[kSHIFTS],
+ stats[kREF_WORDCOUNT]);
+ *details = buf;
+}
+
+TERScorer::~TERScorer() {
+ for (vector<TERScorerImpl*>::iterator i = impl_.begin(); i != impl_.end(); ++i)
+ delete *i;
+}
+
+TERScorer::TERScorer(const vector<vector<WordID> >& refs) : impl_(refs.size()) {
+ for (int i = 0; i < refs.size(); ++i)
+ impl_[i] = new TERScorerImpl(refs[i]);
+}
+
+Score* TERScorer::ScoreCandidate(const std::vector<WordID>& hyp) const {
+ float best_score = numeric_limits<float>::max();
+ TERScore* res = new TERScore;
+ int avg_len = 0;
+ for (int i = 0; i < impl_.size(); ++i)
+ avg_len += impl_[i]->GetRefLength();
+ avg_len /= impl_.size();
+ for (int i = 0; i < impl_.size(); ++i) {
+ int subs, ins, dels, shifts;
+ float score = impl_[i]->Calculate(hyp, &subs, &ins, &dels, &shifts);
+ // cerr << "Component TER cost: " << score << endl;
+ if (score < best_score) {
+ res->stats[TERScore::kINSERTIONS] = ins;
+ res->stats[TERScore::kDELETIONS] = dels;
+ res->stats[TERScore::kSUBSTITUTIONS] = subs;
+ res->stats[TERScore::kSHIFTS] = shifts;
+ if (ter_use_average_ref_len) {
+ res->stats[TERScore::kREF_WORDCOUNT] = avg_len;
+ } else {
+ res->stats[TERScore::kREF_WORDCOUNT] = impl_[i]->GetRefLength();
+ }
+
+ best_score = score;
+ }
+ }
+ return res;
+}
diff --git a/vest/test_data/0.json.gz b/vest/test_data/0.json.gz
new file mode 100644
index 00000000..30f8dd77
--- /dev/null
+++ b/vest/test_data/0.json.gz
Binary files differ
diff --git a/vest/test_data/1.json.gz b/vest/test_data/1.json.gz
new file mode 100644
index 00000000..c82cc179
--- /dev/null
+++ b/vest/test_data/1.json.gz
Binary files differ
diff --git a/vest/test_data/c2e.txt.0 b/vest/test_data/c2e.txt.0
new file mode 100644
index 00000000..12c4abe9
--- /dev/null
+++ b/vest/test_data/c2e.txt.0
@@ -0,0 +1,2 @@
+australia reopens embassy in manila
+( afp , manila , january 2 ) australia reopened its embassy in the philippines today , which was shut down about seven weeks ago due to what was described as a specific threat of a terrorist attack .
diff --git a/vest/test_data/c2e.txt.1 b/vest/test_data/c2e.txt.1
new file mode 100644
index 00000000..4ac12df1
--- /dev/null
+++ b/vest/test_data/c2e.txt.1
@@ -0,0 +1,2 @@
+australia reopened manila embassy
+( agence france-presse , manila , 2nd ) - australia reopened its embassy in the philippines today . the embassy was closed seven weeks ago after what was described as a specific threat of a terrorist attack .
diff --git a/vest/test_data/c2e.txt.2 b/vest/test_data/c2e.txt.2
new file mode 100644
index 00000000..2f67b72f
--- /dev/null
+++ b/vest/test_data/c2e.txt.2
@@ -0,0 +1,2 @@
+australia to reopen embassy in manila
+( afp report from manila , january 2 ) australia reopened its embassy in the philippines today . seven weeks ago , the embassy was shut down due to so-called confirmed terrorist attack threats .
diff --git a/vest/test_data/c2e.txt.3 b/vest/test_data/c2e.txt.3
new file mode 100644
index 00000000..5483cef6
--- /dev/null
+++ b/vest/test_data/c2e.txt.3
@@ -0,0 +1,2 @@
+australia to re - open its embassy to manila
+( afp , manila , thursday ) australia reopens its embassy to manila , which was closed for the so-called " clear " threat of terrorist attack 7 weeks ago .
diff --git a/vest/test_data/re.txt.0 b/vest/test_data/re.txt.0
new file mode 100644
index 00000000..86eff087
--- /dev/null
+++ b/vest/test_data/re.txt.0
@@ -0,0 +1,5 @@
+erdogan states turkey to reject any pressures to urge it to recognize cyprus
+ankara 12 - 1 ( afp ) - turkish prime minister recep tayyip erdogan announced today , wednesday , that ankara will reject any pressure by the european union to urge it to recognize cyprus . this comes two weeks before the summit of european union state and government heads who will decide whether or nor membership negotiations with ankara should be opened .
+erdogan told " ntv " television station that " the european union cannot address us by imposing new conditions on us with regard to cyprus .
+we will discuss this dossier in the course of membership negotiations . "
+he added " let me be clear , i cannot sidestep turkey , this is something we cannot accept . "
diff --git a/vest/test_data/re.txt.1 b/vest/test_data/re.txt.1
new file mode 100644
index 00000000..2140f198
--- /dev/null
+++ b/vest/test_data/re.txt.1
@@ -0,0 +1,5 @@
+erdogan confirms turkey will resist any pressure to recognize cyprus
+ankara 12 - 1 ( afp ) - the turkish head of government , recep tayyip erdogan , announced today ( wednesday ) that ankara would resist any pressure the european union might exercise in order to force it into recognizing cyprus . this comes two weeks before a summit of european union heads of state and government , who will decide whether or not to open membership negotiations with ankara .
+erdogan said to the ntv television channel : " the european union cannot engage with us through imposing new conditions on us with regard to cyprus .
+we shall discuss this issue in the course of the membership negotiations . "
+he added : " let me be clear - i cannot confine turkey . this is something we do not accept . "
diff --git a/vest/test_data/re.txt.2 b/vest/test_data/re.txt.2
new file mode 100644
index 00000000..94e46286
--- /dev/null
+++ b/vest/test_data/re.txt.2
@@ -0,0 +1,5 @@
+erdogan confirms that turkey will reject any pressures to encourage it to recognize cyprus
+ankara , 12 / 1 ( afp ) - the turkish prime minister recep tayyip erdogan declared today , wednesday , that ankara will reject any pressures that the european union may apply on it to encourage to recognize cyprus . this comes two weeks before a summit of the heads of countries and governments of the european union , who will decide on whether or not to start negotiations on joining with ankara .
+erdogan told the ntv television station that " it is not possible for the european union to talk to us by imposing new conditions on us regarding cyprus .
+we shall discuss this dossier during the negotiations on joining . "
+and he added , " let me be clear . turkey's arm should not be twisted ; this is something we cannot accept . "
diff --git a/vest/test_data/re.txt.3 b/vest/test_data/re.txt.3
new file mode 100644
index 00000000..f87c3308
--- /dev/null
+++ b/vest/test_data/re.txt.3
@@ -0,0 +1,5 @@
+erdogan stresses that turkey will reject all pressures to force it to recognize cyprus
+ankara 12 - 1 ( afp ) - turkish prime minister recep tayyip erdogan announced today , wednesday , that ankara would refuse all pressures applied on it by the european union to force it to recognize cyprus . that came two weeks before the summit of the presidents and prime ministers of the european union , who would decide on whether to open negotiations on joining with ankara or not .
+erdogan said to " ntv " tv station that the " european union can not communicate with us by imposing on us new conditions related to cyprus .
+we will discuss this file during the negotiations on joining . "
+he added , " let me be clear . turkey's arm should not be twisted . this is unacceptable to us . "
diff --git a/vest/union_forests.cc b/vest/union_forests.cc
new file mode 100644
index 00000000..207ecb5c
--- /dev/null
+++ b/vest/union_forests.cc
@@ -0,0 +1,73 @@
+#include <iostream>
+#include <string>
+#include <sstream>
+
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "hg.h"
+#include "hg_io.h"
+#include "filelib.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("dev_set_size,s",po::value<unsigned int>(),"[REQD] Development set size (# of parallel sentences)")
+ ("forest_repository,r",po::value<string>(),"[REQD] Path to forest repository")
+ ("new_forest_repository,n",po::value<string>(),"[REQD] Path to new forest repository")
+ ("help,h", "Help");
+ po::options_description dcmdline_options;
+ dcmdline_options.add(opts);
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ bool flag = false;
+ if (conf->count("dev_set_size") == 0) {
+ cerr << "Please specify the size of the development set using -d N\n";
+ flag = true;
+ }
+ if (conf->count("new_forest_repository") == 0) {
+ cerr << "Please specify the starting-point weights using -n PATH\n";
+ flag = true;
+ }
+ if (conf->count("forest_repository") == 0) {
+ cerr << "Please specify the forest repository location using -r PATH\n";
+ flag = true;
+ }
+ if (flag || conf->count("help")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ const int size = conf["dev_set_size"].as<unsigned int>();
+ const string repo = conf["forest_repository"].as<string>();
+ const string new_repo = conf["new_forest_repository"].as<string>();
+ for (int i = 0; i < size; ++i) {
+ ostringstream sfin, sfout;
+ sfin << new_repo << '/' << i << ".json.gz";
+ sfout << repo << '/' << i << ".json.gz";
+ const string fin = sfin.str();
+ const string fout = sfout.str();
+ Hypergraph existing_hg;
+ cerr << "Processing " << fin << endl;
+ assert(FileExists(fin));
+ if (FileExists(fout)) {
+ ReadFile rf(fout);
+ assert(HypergraphIO::ReadFromJSON(rf.stream(), &existing_hg));
+ }
+ Hypergraph new_hg;
+ if (true) {
+ ReadFile rf(fin);
+ assert(HypergraphIO::ReadFromJSON(rf.stream(), &new_hg));
+ }
+ existing_hg.Union(new_hg);
+ WriteFile wf(fout);
+ assert(HypergraphIO::WriteToJSON(existing_hg, false, wf.stream()));
+ }
+ return 0;
+}
diff --git a/vest/viterbi_envelope.cc b/vest/viterbi_envelope.cc
new file mode 100644
index 00000000..1122030a
--- /dev/null
+++ b/vest/viterbi_envelope.cc
@@ -0,0 +1,167 @@
+#include "viterbi_envelope.h"
+
+#include <cassert>
+#include <limits>
+
+using namespace std;
+using boost::shared_ptr;
+
+ostream& operator<<(ostream& os, const ViterbiEnvelope& env) {
+ os << '<';
+ const vector<shared_ptr<Segment> >& segs = env.GetSortedSegs();
+ for (int i = 0; i < segs.size(); ++i)
+ os << (i==0 ? "" : "|") << "x=" << segs[i]->x << ",b=" << segs[i]->b << ",m=" << segs[i]->m << ",p1=" << segs[i]->p1 << ",p2=" << segs[i]->p2;
+ return os << '>';
+}
+
+ViterbiEnvelope::ViterbiEnvelope(int i) {
+ if (i == 0) {
+ // do nothing - <>
+ } else if (i == 1) {
+ segs.push_back(shared_ptr<Segment>(new Segment(0, 0, 0, shared_ptr<Segment>(), shared_ptr<Segment>())));
+ assert(this->IsMultiplicativeIdentity());
+ } else {
+ cerr << "Only can create ViterbiEnvelope semiring 0 and 1 with this constructor!\n";
+ abort();
+ }
+}
+
+struct SlopeCompare {
+ bool operator() (const shared_ptr<Segment>& a, const shared_ptr<Segment>& b) const {
+ return a->m < b->m;
+ }
+};
+
+const ViterbiEnvelope& ViterbiEnvelope::operator+=(const ViterbiEnvelope& other) {
+ if (!other.is_sorted) other.Sort();
+ if (segs.empty()) {
+ segs = other.segs;
+ return *this;
+ }
+ is_sorted = false;
+ int j = segs.size();
+ segs.resize(segs.size() + other.segs.size());
+ for (int i = 0; i < other.segs.size(); ++i)
+ segs[j++] = other.segs[i];
+ assert(j == segs.size());
+ return *this;
+}
+
+void ViterbiEnvelope::Sort() const {
+ sort(segs.begin(), segs.end(), SlopeCompare());
+ const int k = segs.size();
+ int j = 0;
+ for (int i = 0; i < k; ++i) {
+ Segment l = *segs[i];
+ l.x = kMinusInfinity;
+ // cerr << "m=" << l.m << endl;
+ if (0 < j) {
+ if (segs[j-1]->m == l.m) { // lines are parallel
+ if (l.b <= segs[j-1]->b) continue;
+ --j;
+ }
+ while(0 < j) {
+ l.x = (l.b - segs[j-1]->b) / (segs[j-1]->m - l.m);
+ if (segs[j-1]->x < l.x) break;
+ --j;
+ }
+ if (0 == j) l.x = kMinusInfinity;
+ }
+ *segs[j++] = l;
+ }
+ segs.resize(j);
+ is_sorted = true;
+}
+
+const ViterbiEnvelope& ViterbiEnvelope::operator*=(const ViterbiEnvelope& other) {
+ if (other.IsMultiplicativeIdentity()) { return *this; }
+ if (this->IsMultiplicativeIdentity()) { (*this) = other; return *this; }
+
+ if (!is_sorted) Sort();
+ if (!other.is_sorted) other.Sort();
+
+ if (this->IsEdgeEnvelope()) {
+// if (other.size() > 1)
+// cerr << *this << " (TIMES) " << other << endl;
+ shared_ptr<Segment> edge_parent = segs[0];
+ const double& edge_b = edge_parent->b;
+ const double& edge_m = edge_parent->m;
+ segs.clear();
+ for (int i = 0; i < other.segs.size(); ++i) {
+ const Segment& seg = *other.segs[i];
+ const double m = seg.m + edge_m;
+ const double b = seg.b + edge_b;
+ const double& x = seg.x; // x's don't change with *
+ segs.push_back(shared_ptr<Segment>(new Segment(x, m, b, edge_parent, other.segs[i])));
+ assert(segs.back()->p1->rule);
+ }
+// if (other.size() > 1)
+// cerr << " = " << *this << endl;
+ } else {
+ vector<shared_ptr<Segment> > new_segs;
+ int this_i = 0;
+ int other_i = 0;
+ const int this_size = segs.size();
+ const int other_size = other.segs.size();
+ double cur_x = kMinusInfinity; // moves from left to right across the
+ // real numbers, stopping for all inter-
+ // sections
+ double this_next_val = (1 < this_size ? segs[1]->x : kPlusInfinity);
+ double other_next_val = (1 < other_size ? other.segs[1]->x : kPlusInfinity);
+ while (this_i < this_size && other_i < other_size) {
+ const Segment& this_seg = *segs[this_i];
+ const Segment& other_seg= *other.segs[other_i];
+ const double m = this_seg.m + other_seg.m;
+ const double b = this_seg.b + other_seg.b;
+
+ new_segs.push_back(shared_ptr<Segment>(new Segment(cur_x, m, b, segs[this_i], other.segs[other_i])));
+ int comp = 0;
+ if (this_next_val < other_next_val) comp = -1; else
+ if (this_next_val > other_next_val) comp = 1;
+ if (0 == comp) { // the next values are equal, advance both indices
+ ++this_i;
+ ++other_i;
+ cur_x = this_next_val; // could be other_next_val (they're equal!)
+ this_next_val = (this_i+1 < this_size ? segs[this_i+1]->x : kPlusInfinity);
+ other_next_val = (other_i+1 < other_size ? other.segs[other_i+1]->x : kPlusInfinity);
+ } else { // advance the i with the lower x, update cur_x
+ if (-1 == comp) {
+ ++this_i;
+ cur_x = this_next_val;
+ this_next_val = (this_i+1 < this_size ? segs[this_i+1]->x : kPlusInfinity);
+ } else {
+ ++other_i;
+ cur_x = other_next_val;
+ other_next_val = (other_i+1 < other_size ? other.segs[other_i+1]->x : kPlusInfinity);
+ }
+ }
+ }
+ segs.swap(new_segs);
+ }
+ //cerr << "Multiply: result=" << (*this) << endl;
+ return *this;
+}
+
+// recursively construct translation
+void Segment::ConstructTranslation(vector<WordID>* trans) const {
+ const Segment* cur = this;
+ vector<vector<WordID> > ant_trans;
+ while(!cur->rule) {
+ ant_trans.resize(ant_trans.size() + 1);
+ cur->p2->ConstructTranslation(&ant_trans.back());
+ cur = cur->p1.get();
+ }
+ size_t ant_size = ant_trans.size();
+ vector<const vector<WordID>*> pants(ant_size);
+ --ant_size;
+ for (int i = 0; i < pants.size(); ++i) pants[ant_size - i] = &ant_trans[i];
+ cur->rule->ESubstitute(pants, trans);
+}
+
+ViterbiEnvelope ViterbiEnvelopeWeightFunction::operator()(const Hypergraph::Edge& e) const {
+ const double m = direction.dot(e.feature_values_);
+ const double b = origin.dot(e.feature_values_);
+ Segment* seg = new Segment(m, b, e.rule_);
+ return ViterbiEnvelope(1, seg);
+}
+